Split up compiler.cpp (#15355)
authorZachary DeVito <zdevito@fb.com>
Wed, 19 Dec 2018 03:41:00 +0000 (19:41 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 19 Dec 2018 03:43:35 +0000 (19:43 -0800)
Summary:
This separates the different parts of compiler.cpp to make their relationship more clear. In particular it adds:

* sugared_value.{h,cpp} - all the public SugaredValues that the compiler defines and a few that were inside compiler.cpp
* type_parser.{h, cpp} - Turns TreeRef's defining types into TypePtr
* schema_matching.{h, cpp} - infrastructure for matching arguments against overloaded schema and emitting builtin operators with a particular schema.
Retains:
* compiler.{h, cpp} - now responsible simply for the `defineMethodsInModule` infra structure.

Some utility functions like inlineCallTo have moved to ir.h.

Only thing that is not a move is some changes in module.h/cpp that remove multiple returns from `Method::emit_call_to`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15355

Reviewed By: suo, wanchaol

Differential Revision: D13507524

Pulled By: zdevito

fbshipit-source-id: 69ec936a9ff1a383c12a883616346b219c72e393

18 files changed:
tools/build_pytorch_libs.sh
tools/build_variables.py
torch/CMakeLists.txt
torch/csrc/jit/graph_executor.cpp
torch/csrc/jit/ir.cpp
torch/csrc/jit/ir.h
torch/csrc/jit/passes/to_batch.cpp
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/compiler.h
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/module.cpp
torch/csrc/jit/script/module.h
torch/csrc/jit/script/schema_matching.cpp [new file with mode: 0644]
torch/csrc/jit/script/schema_matching.h [new file with mode: 0644]
torch/csrc/jit/script/sugared_value.cpp [new file with mode: 0644]
torch/csrc/jit/script/sugared_value.h [new file with mode: 0644]
torch/csrc/jit/script/type_parser.cpp [new file with mode: 0644]
torch/csrc/jit/script/type_parser.h [new file with mode: 0644]

index 4b69ac6..5b593f9 100755 (executable)
@@ -8,8 +8,9 @@
 #
 # TODO: Replace this with the root-level CMakeLists.txt
 
+set -e
 if [[ $VERBOSE_SCRIPT == '1' ]]; then
-  set -ex
+  set -x
   report() {
     echo "$@"
   }
index f9b71fe..3324221 100644 (file)
@@ -85,6 +85,9 @@ torch_sources_no_python_default = [
     "torch/csrc/jit/register_special_ops.cpp",
     "torch/csrc/jit/scope.cpp",
     "torch/csrc/jit/script/compiler.cpp",
+    "torch/csrc/jit/script/type_parser.cpp",
+    "torch/csrc/jit/script/sugared_value.cpp",
+    "torch/csrc/jit/script/schema_matching.cpp",
     "torch/csrc/jit/script/parser.cpp",
     "torch/csrc/jit/import_method.cpp",
     "torch/csrc/jit/hooks_for_testing.cpp",
index 5b3e832..8d44c20 100644 (file)
@@ -189,6 +189,9 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
   ${TORCH_SRC_DIR}/csrc/jit/scope.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/compiler.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/script/schema_matching.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/script/type_parser.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/script/sugared_value.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/parser.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/builtin_functions.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp
index 9d40546..ab0ab40 100644 (file)
@@ -530,7 +530,7 @@ private:
     auto local_graph = this->graph->copy();
     setInputTypes(*local_graph, spec);
     PropagateInputShapes(local_graph);
-    auto output_values = script::inlineCallTo(*state->graph, *local_graph, input_values);
+    auto output_values = inlineCallTo(*state->graph, *local_graph, input_values);
 
     auto outputs = last(stack, num_outputs);
     for (size_t i = 0; i < outputs.size(); ++i) {
index 32bd8b3..650d5ef 100644 (file)
@@ -5,7 +5,7 @@
 #include <torch/csrc/autograd/function.h>
 #include <torch/csrc/jit/constants.h>
 #include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/schema_matching.h>
 #include <torch/csrc/jit/passes/python_print.h>
 #include <torch/csrc/jit/passes/alias_analysis.h>
 
@@ -1475,6 +1475,37 @@ void Graph::freeBlock(Block * b) {
   all_blocks.erase(it);
 }
 
+at::ArrayRef<Value*> createTupleUnpack(Value* v) {
+  // small peephole optimization to ensure IntList attributes can still turn
+  // into constants e.g. in x.expand([3, 4])
+  if(v->node()->kind() == prim::TupleConstruct)
+    return v->node()->inputs();
+  auto & g = *v->owningGraph();
+  return g.insertNode(g.createTupleUnpack(v))->outputs();
+}
+
+std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
+  std::unordered_map<Value*, Value*> value_map;
+  auto value_map_func = [&](Value* v) { return value_map.at(v); };
+  JIT_ASSERT(callee.inputs().size() == inputs.size());
+  for (size_t i = 0; i < inputs.size(); ++i) {
+    value_map[callee.inputs()[i]] = inputs[i];
+  }
+  for (auto* node : callee.nodes()) {
+    auto* new_node =
+        g.insertNode(g.createClone(node, value_map_func));
+    for (size_t i = 0; i < node->outputs().size(); ++i) {
+      value_map[node->outputs()[i]] = new_node->outputs()[i];
+    }
+  }
+
+  std::vector<Value*> outputs;
+  for (auto* output : callee.outputs()) {
+    outputs.push_back(value_map_func(output));
+  }
+  return outputs;
+}
+
 PythonOp* defaultAllocPythonOp(Graph*g) {
   throw std::runtime_error("Trying to allocate a Python object without python bindings loaded");
 }
@@ -1488,4 +1519,5 @@ void setAllocPythonOp(PythonOp* (*v)(Graph* g)) {
   alloc_python_op.store(v);
 }
 
+
 }} // namespace torch::jit
index 4823067..2a6c9cf 100644 (file)
@@ -1103,4 +1103,8 @@ inline Node* Graph::createPythonOp(
 
 TORCH_API void LintGraph(std::shared_ptr<Graph>& graph);
 
+TORCH_API at::ArrayRef<Value*> createTupleUnpack(Value* v);
+TORCH_API std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs);
+
+
 }} // namespace torch::jit
index 37992b0..0411569 100644 (file)
@@ -21,9 +21,9 @@ std::shared_ptr<Graph> ToBatch::getBatchOperator(const std::string& name, int64_
 }
 
 std::vector<Value*> inlineUnpackedCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
-  auto outputs = script::inlineCallTo(g, callee, inputs);
+  auto outputs = inlineCallTo(g, callee, inputs);
   if (callee.outputs().size() == 1 && callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
-    auto tc = script::createTupleUnpack(outputs.at(0));
+    auto tc = createTupleUnpack(outputs.at(0));
     outputs = std::vector<Value*>(tc.begin(), tc.end());
   }
   return outputs;
@@ -527,7 +527,7 @@ std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph) {
   // lower the tuple before the pass
   if (graph->outputs().at(0)->type()->kind() == TupleType::Kind) {
     graph = graph->copy();
-    auto outs = script::createTupleUnpack(graph->outputs().at(0));
+    auto outs = createTupleUnpack(graph->outputs().at(0));
     graph->eraseOutput(0);
     for(auto o : outs)
       graph->registerOutput(o);
index ad6c186..58817ba 100644 (file)
@@ -1,4 +1,5 @@
 #include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/schema_matching.h>
 #include <torch/csrc/jit/passes/lower_tuples.h>
 #include <torch/csrc/jit/passes/constant_pooling.h>
 #include <torch/csrc/jit/operator.h>
@@ -8,7 +9,6 @@
 #include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/utils/object_ptr.h>
 #include <torch/csrc/jit/operator.h>
-#include <torch/csrc/jit/script/builtin_functions.h>
 #include <torch/csrc/jit/hooks_for_testing.h>
 
 #include <torch/csrc/jit/constants.h>
@@ -28,69 +28,6 @@ using ValueTable = std::unordered_map<std::string, SugaredValuePtr>;
 using AttributeMap = std::unordered_map<std::string, Const>;
 using ListAttributeMap = std::unordered_map<std::string, std::vector<Const>>;
 
-struct NoneValue : SugaredValue {
-  NoneValue() = default;
-  std::string kind() const override {
-    return "None";
-  }
-};
-
-struct PrintValue : public SugaredValue {
-  std::string kind() const override {
-    return "print";
-  }
-  std::shared_ptr<SugaredValue> call(
-    const SourceRange& loc,
-    Method & m,
-    at::ArrayRef<NamedValue> inputs,
-    at::ArrayRef<NamedValue> attributes,
-    size_t n_binders) override {
-      auto& g = *m.graph();
-      if (!attributes.empty())
-        throw ErrorReport(loc) << "print doesn't accept any keyword arguments";
-
-      //temporary hack to allow print statements to work in python 2, where
-      //print(a, b) is treated as a (a, b) tuple input.
-
-      std::vector<Value*> lowered_inputs = toValues(*m.graph(), inputs);
-      if(lowered_inputs.size() == 1 && lowered_inputs.at(0)->node()->kind() == prim::TupleConstruct) {
-        auto input = lowered_inputs[0];
-        for(size_t j = 0; j < input->node()->inputs().size(); ++j) {
-          lowered_inputs.insert(lowered_inputs.begin() + 1 + j, input->node()->inputs().at(j));
-        }
-        lowered_inputs.erase(lowered_inputs.begin());
-      }
-      g.insertNode(g.create(prim::Print, lowered_inputs, 0)
-                       ->setSourceLocation(std::make_shared<SourceRange>(loc)));
-      return std::make_shared<NoneValue>();
-  }
-};
-
-// expressions like int(x)
-// these are the same as call prim::Int or equivalent except it
-// is a noop when the input is a subtype of 'type'
-struct CastValue : public BuiltinFunction {
-  CastValue(TypePtr type, c10::Symbol method)
-  : BuiltinFunction(method, c10::nullopt)
-  , type_(std::move(type)) {}
-  std::shared_ptr<SugaredValue> call(
-    const SourceRange& loc,
-    Method & m,
-    at::ArrayRef<NamedValue> inputs,
-    at::ArrayRef<NamedValue> attributes,
-    size_t n_binders) override {
-      if(inputs.size() == 1 && attributes.size() == 0) {
-        auto v = inputs[0].value(*m.graph());
-        if (v->type()->isSubtypeOf(type_)) {
-          return std::make_shared<SimpleValue>(v);
-        }
-      }
-      return BuiltinFunction::call(loc, m , inputs, attributes, n_binders);
-  }
-private:
-  TypePtr type_;
-};
-
 static Value* asSimple(const SugaredValuePtr& value) {
   if(SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
     return sv->getValue();
@@ -386,194 +323,6 @@ private:
   ValueTable value_table;
 };
 
-Value* packOutputs(Graph& g, at::ArrayRef<Value*> values) {
-  if(values.size() == 1) {
-    return values[0];
-  }
-  return g.insertNode(g.createTuple(values))->output();
-}
-
-at::ArrayRef<Value*> createTupleUnpack(Value* v) {
-  // small peephole optimization to ensure IntList attributes can still turn
-  // into constants e.g. in x.expand([3, 4])
-  if(v->node()->kind() == prim::TupleConstruct)
-    return v->node()->inputs();
-  auto & g = *v->owningGraph();
-  return g.insertNode(g.createTupleUnpack(v))->outputs();
-}
-
-inline TypePtr unwrapOptional(TypePtr opt_type) {
-  if (auto unwrap_list_type = opt_type->cast<OptionalType>()) {
-    return unwrap_list_type->getElementType();
-  }
-  return opt_type;
-}
-
-static inline bool isIntOrFloatUsedAsList(
-    const Value* value,
-    const Argument& arg) {
-  // Look for int[N] or float[N]
-  const auto& v_type = value->type();
-  if (v_type != FloatType::get() && v_type != IntType::get())
-    return false;
-  auto arg_type = unwrapOptional(arg.type());
-  auto list_type = arg_type->cast<ListType>();
-  return list_type && list_type->getElementType() == v_type && arg.N();
-}
-
-inline bool convertibleToList(const TypePtr& type, const TypePtr& list_type_) {
-  auto list_type = list_type_->cast<ListType>();
-  if(!list_type) {
-    return false;
-  }
-  if(type->isSubtypeOf(list_type_)) {
-    return true;
-  }
-  if(auto tuple = type->cast<TupleType>()) {
-    return std::all_of(
-        tuple->elements().begin(),
-        tuple->elements().end(),
-        [&](const TypePtr& t) {
-          return t->isSubtypeOf(list_type->getElementType());
-        });
-  }
-  return false;
-}
-
-// applies implict conversion from value trying to turn it into type concrete_type
-// it succeeds if the return_value->isSubclassOf(concrete_type)
-Value* tryConvertToType(
-    const SourceRange& loc,
-    Graph& graph,
-    const TypePtr& concrete_type,
-    Value* value,
-    bool allow_conversions) {
-
-  if (auto value_tuple = value->type()->cast<TupleType>()) {
-    // Allow homogeneous tuples to be casted implicitly to lists of appropriate
-    // types
-    if (convertibleToList(value->type(), unwrapOptional(concrete_type))) {
-      auto unpacked = createTupleUnpack(value);
-      auto elem_type = unwrapOptional(concrete_type)->expect<ListType>()->getElementType();
-      value = graph.insertNode(graph.createList(elem_type, unpacked))->output();
-    }
-    // inductively apply implicit conversions to tuples
-    if (auto concrete_tuple = concrete_type->cast<TupleType>()) {
-      if (!value_tuple->isSubtypeOf(concrete_tuple) &&
-          concrete_tuple->elements().size() == value_tuple->elements().size()) {
-        auto unpacked = createTupleUnpack(value);
-        std::vector<Value*> converted;
-        for (size_t i = 0; i < concrete_tuple->elements().size(); ++i) {
-          converted.emplace_back(tryConvertToType(
-              loc,
-              graph,
-              concrete_tuple->elements().at(i),
-              unpacked.at(i),
-              allow_conversions));
-        }
-        value = graph.insertNode(graph.createTuple(converted))->output();
-      }
-    }
-  }
-
-  if (value->type()->isSubtypeOf(NoneType::get()) && !concrete_type->isSubtypeOf(NoneType::get())){
-    if (concrete_type->isSubtypeOf(GeneratorType::get())) {
-      value = graph.insertNode(graph.createNoneGenerator())->output();
-    } else if (concrete_type->isSubtypeOf(OptionalType::ofTensor())) {
-      // create undefined tensor when None pass to a optional[tensor] formal arg
-      value = graph.insertNode(graph.createUndefined())->output();
-    } else if (auto optional_type = concrete_type->cast<OptionalType>()) {
-      value = graph.insertNode(graph.createNone(optional_type->getElementType()))->output();
-    }
-  }
-
-  //implicit conversions
-  if(allow_conversions) {
-     if(concrete_type->isSubtypeOf(NumberType::get())
-      && value->type()->isSubtypeOf(DynamicType::get())) {
-      auto n = graph.createImplicitTensorToNum(concrete_type, value);
-      value = graph.insertNode(n)
-        ->setSourceLocation(std::make_shared<SourceRange>(loc))
-        ->output();
-    }
-    if (value->type()->isSubtypeOf(StringType::get()) &&
-        DeviceObjType::get()->isSubtypeOf(concrete_type))  {
-      return graph.insert(aten::device, { value }, {}, loc);
-    }
-  }
-
-  return value;
-}
-
-Value* tryMatchArgument(
-    const Argument& arg,
-    Graph& graph,
-    const SourceRange& loc,
-    const NamedValue& named_value,
-    const std::function<std::ostream&()>& err,
-    bool allow_conversions,
-    TypeEnv & type_env) {
-  Value* value = named_value.value(graph);
-
-  // some functions that take lists of integers or floats for fixed size arrays
-  // also allow single ints/floats to be passed in their place.
-  // the single int/float is then repeated to the length of the list
-  if (isIntOrFloatUsedAsList(value, arg)) {
-    std::vector<Value*> repeated(*arg.N(), value);
-    value = graph.insertNode(graph.createList(value->type(), repeated))->output();
-  }
-
-  const MatchTypeReturn matched_type =
-      matchTypeVariables(arg.type(), value->type(), type_env);
-  if (!matched_type.type) {
-    err() << "could not match type " << value->type()->str() << " to "
-          << arg.type()->str() << " in argument '" << arg.name()
-          << "': " << matched_type.errMsg << "\n"
-          << named_value.locOr(loc);
-    return nullptr;
-  }
-  const auto concrete_type = *matched_type.type;
-
-  value = tryConvertToType(loc, graph, concrete_type, value, allow_conversions);
-
-  if(!value->type()->isSubtypeOf(concrete_type)) {
-    err() << "expected a value of type " << concrete_type->str() << " for argument '" << arg.name() << "' but found "
-          << value->type()->str() << "\n"
-          << named_value.locOr(loc);
-    return nullptr;
-  }
-  return value;
-}
-
-c10::optional<size_t> findInputWithName(
-    const std::string& name,
-    at::ArrayRef<NamedValue> kwargs) {
-  for(size_t i = 0; i < kwargs.size(); ++i) {
-    if(kwargs[i].name() == name)
-      return i;
-  }
-  return c10::nullopt;
-}
-
-Value* tryCreateList(
-    const TypePtr& elem_type,
-    Graph& graph,
-    const SourceRange& loc,
-    at::ArrayRef<NamedValue> varargs,
-    const std::function<std::ostream&()>& err,
-    bool convert_tensor_to_num,
-    TypeEnv & type_env) {
-  Argument elem_arg("<varargs>", elem_type);
-  std::vector<Value*> list_ctor;
-  for(const auto& a : varargs) {
-    Value* av = tryMatchArgument(elem_arg, graph, loc, a, err, convert_tensor_to_num, type_env);
-    if(!av)
-      return nullptr;
-    list_ctor.push_back(av);
-  }
-  return graph.insertNode(graph.createList(elem_type, list_ctor))->output();
-}
-
 template<class T>
 static Value* materializeConstant(T val, Graph& graph,
     const SourceRange& r, std::unordered_map<T, Value*>& map) {
@@ -589,222 +338,6 @@ static Value* materializeConstant(T val, Graph& graph,
   return new_constant;
 }
 
-c10::optional<MatchedSchema> tryMatchSchema(
-    const FunctionSchema& schema,
-    const SourceRange& loc,
-    Graph& graph,
-    c10::optional<NamedValue> self,
-    at::ArrayRef<NamedValue> args,
-    at::ArrayRef<NamedValue> kwargs,
-    std::ostream& failure_messages,
-    bool allow_conversions) {
-  auto err = [&]() -> std::ostream& {
-    failure_messages << "\nfor operator " << schema << ":\n";
-    return failure_messages;
-  };
-
-  TypeEnv type_env;
-  std::vector<Value*> positional_inputs;
-  std::vector<bool> used_kwarg(kwargs.size(), false);
-
-  // if we finish the loop will we have consumed all arguments?
-  size_t used_args = 0;
-  for (size_t schema_i = 0; schema_i < schema.arguments().size(); ++schema_i) {
-    const auto& arg = schema.arguments()[schema_i];
-    c10::optional<NamedValue> v;
-    if (arg.name() == "self" && self) {
-      v = self;
-      self = c10::nullopt;
-    } else if (!arg.kwarg_only() && used_args < args.size()) {
-      // allow zeros(IntList sizes) to work with zeros(1, 2) or zeros(1)
-      if (arg.type()->kind() == TypeKind::ListType && // the formal must be a list
-          !arg.N() && // it must not be a broadcasting list like int[3], otherwise
-                    // a single int is a valid input
-          (schema_i + 1 == schema.arguments().size() ||
-           schema.arguments()[schema_i + 1]
-               .kwarg_only())) { // must be the last position argument
-        auto actual_type = args[used_args].value(graph)->type();
-        if (actual_type->kind() != TypeKind::ListType &&
-            !convertibleToList(
-                actual_type,
-                unwrapOptional(arg.type()))) { // and the actual should not be a list already
-          auto elem_type = unwrapOptional(arg.type())->expect<ListType>()->getElementType();
-          Value* list = tryCreateList(
-              elem_type,
-              graph,
-              loc,
-              at::ArrayRef<NamedValue>(args).slice(used_args),
-              err,
-              allow_conversions,
-              type_env);
-          if (!list)
-            return c10::nullopt;
-          used_args = args.size();
-          positional_inputs.push_back(list);
-          continue;
-        }
-      }
-
-      v = args[used_args];
-      used_args++;
-    } else if (auto idx = findInputWithName(arg.name(), kwargs)) {
-      const NamedValue& nv = kwargs[*idx];
-      if (used_kwarg[*idx]) {
-        err() << "argument " << nv.name()
-              << " specified twice in schema, submit a bug report!\n"
-              << nv.locOr(loc);
-        return c10::nullopt;
-      }
-      used_kwarg[*idx] = true;
-      v = nv;
-    } else if (arg.default_value()) {
-      v = NamedValue(*arg.default_value());
-    } else {
-      err() << "argument " << schema.arguments()[schema_i].name()
-            << " not provided.\n"
-            << loc;
-      return c10::nullopt;
-    }
-    Value* positional = tryMatchArgument(
-        arg, graph, loc, *v, err, allow_conversions, type_env);
-    if (!positional)
-      return c10::nullopt;
-    positional_inputs.push_back(positional);
-  }
-  // check for unused self argument
-  if(self != c10::nullopt) {
-    err() << "provided self argument not used in schema\n";
-  }
-
-  if (schema.is_vararg()) {
-    for(;used_args < args.size(); ++used_args) {
-      positional_inputs.push_back(args[used_args].value(graph));
-    }
-  }
-
-  // check for unused positional arguments
-  if (used_args < args.size()) {
-    err() << "expected at most " << used_args << " arguments "
-          << "but found " << args.size() << " positional arguments.\n"
-          << loc << "\n";
-    return c10::nullopt;
-  }
-  // check for unused kwargs
-  for (size_t i = 0; i < kwargs.size(); ++i) {
-    const auto& nv = kwargs[i];
-    if (!used_kwarg[i]) {
-      if (!schema.argumentIndexWithName(nv.name())) {
-        err() << "keyword argument " << nv.name() << " unknown\n";
-      } else {
-        err() << "keyword argument " << nv.name() << " specified twice\n";
-      }
-      return c10::nullopt;
-    }
-  }
-  auto return_types = fmap(schema.returns(), [&](const Argument& r) {
-    return evalTypeVariables(r.type(), type_env);
-  });
-  return MatchedSchema{std::move(positional_inputs), std::move(return_types)};
-}
-
-static std::string prefixLine(const std::string& str, const std::string& prefix) {
-  std::stringstream ss;
-  bool was_newline = true;
-  for(auto c : str) {
-    if(was_newline)
-      ss << prefix;
-    ss.put(c);
-    was_newline = c == '\n';
-  }
-  return ss.str();
-}
-
-// Given a successful match between operator schema and symbol, emit a node
-// with the appropriate inputs and outputs.
-static Value* emitBuiltinNode(
-    const MatchedSchema& matched_schema,
-    const SourceRange& loc,
-    Graph& graph,
-    Symbol name) {
-  auto n = graph.insertNode(graph.create(name, matched_schema.inputs, 0))
-                ->setSourceLocation(std::make_shared<SourceRange>(loc));
-
-  for(auto & ret : matched_schema.return_types) {
-    n->addOutput()->setType(ret);
-  }
-
-  // assert that we did indeed create an op that has implementation
-  // otherwise schema and dispatch are not in sync
-  getOperation(n);
-
-  return packOutputs(graph, n->outputs());
-}
-
-// Search for operators matching the provided symbol name and input types.
-// If one is found, emit a node to the graph for that operator.
-Value* emitBuiltinCall(
-  const SourceRange& loc,
-  Graph& graph,
-  Symbol name,
-  const c10::optional<NamedValue>& self,
-  at::ArrayRef<NamedValue> inputs,
-  at::ArrayRef<NamedValue> attributes,
-  // if true, emitBuiltinCall will throw an exception if this builtin does not exist,
-  // otherwise it will return nullptr if the builtin is not found.
-  bool required) {
-
-
-  const auto& variants = getAllOperatorsFor(name);
-  const auto& builtin_functions = getAllBuiltinFunctionsFor(name);
-
-  std::stringstream failure_messages;
-  //first we try to match the schema without any conversion
-  //if no schema matches then insert ImplicitTensorToNum
-  for (bool allow_conversions : {false, true}) {
-    // clear previous error messages
-    failure_messages.str("");
-    for (const std::shared_ptr<Operator>& op : variants) {
-      const auto matched_schema = tryMatchSchema(
-          op->schema(),
-          loc,
-          graph,
-          self,
-          inputs,
-          attributes,
-          failure_messages,
-          allow_conversions);
-      if (matched_schema) {
-        return emitBuiltinNode(*matched_schema, loc, graph, name);
-      }
-    }
-    for (Method* method : builtin_functions) {
-      if (auto result = try_emit_call_to(
-              graph,
-              loc,
-              *method,
-              self,
-              inputs,
-              attributes,
-              failure_messages,
-              nullptr,
-              allow_conversions)) {
-        return packOutputs(graph, *result);
-      }
-    }
-  }
-
-  // none of the options worked
-  if (!required) {
-    return nullptr;
-  }
-  if(variants.size() == 0) {
-    throw ErrorReport(loc) << "unknown builtin op";
-  }
-  throw ErrorReport(loc) << "arguments for call are not valid:\n"
-                         << prefixLine(failure_messages.str(), "  ")
-                         << "for call at";
-}
-
 static Value* ensureInt(const SourceRange& range, Value* v) {
   if(!v->type()->isSubtypeOf(IntType::get())) {
     throw ErrorReport(range) << "expected a int but found a "
@@ -2097,8 +1630,8 @@ private:
     }
   }
 
-  Value* emitExpr(const Expr& tree, TypePtr type_hint = nullptr) {
-    return emitSugaredExpr(tree, 1, std::move(type_hint))->asValue(tree.range(), method);
+  Value* emitExpr(const Expr& tree, const TypePtr& type_hint = nullptr) {
+    return emitSugaredExpr(tree, 1, type_hint)->asValue(tree.range(), method);
   }
 
   NodeKind reverseComparision(NodeKind kind) {
@@ -2121,7 +1654,7 @@ private:
   // or a = torch.jit.annotate(List[int], [])
   // the caller is responsible for checking that the result matches type_hint
   // emitSugaredExpr is free to ignore it.
-  std::shared_ptr<SugaredValue> emitSugaredExpr(const Expr& tree, size_t n_binders, TypePtr type_hint=nullptr) {
+  std::shared_ptr<SugaredValue> emitSugaredExpr(const Expr& tree, size_t n_binders, const TypePtr& type_hint=nullptr) {
     switch(tree.kind()) {
       case TK_VAR:
         return environment_stack->getSugaredVar(Var(tree).name());
@@ -2135,7 +1668,7 @@ private:
         return emitApplyExpr(apply, n_binders);
       } break;
       default:
-        return std::make_shared<SimpleValue>(emitSimpleExpr(tree, std::move(type_hint)));
+        return std::make_shared<SimpleValue>(emitSimpleExpr(tree, type_hint));
     }
   }
 
@@ -2587,72 +2120,6 @@ private:
   }
 };
 
-static const std::unordered_map<std::string, std::string> &builtin_cast_methods() {
-  static std::unordered_map<std::string, std::string> builtin_cast_methods = {
-    {"byte", "_cast_Byte"},
-    {"char", "_cast_Char"},
-    {"double", "_cast_Double"},
-    {"float", "_cast_Float"},
-    {"int", "_cast_Int"},
-    {"long", "_cast_Long"},
-    {"short", "_cast_Short"},
-    {"half", "_cast_Half"}
-  };
-  return builtin_cast_methods;
-}
-
-// support syntax sugar for x.foo(y, z) by allowing x.foo to return a
-// callable value that will resolve to foo(x, y, z) when called.
-std::shared_ptr<SugaredValue> SimpleValue::attr(const SourceRange& loc, Method & m, const std::string& field) {
-  // Allow method-style casts on Tensor types. e.g. x.int()
-  if (value->type()->isSubtypeOf(DynamicType::get())) {
-    if (builtin_cast_methods().count(field)) {
-      return std::make_shared<BuiltinFunction>(
-          Symbol::aten(builtin_cast_methods().at(field)),
-          NamedValue(loc, "self", value));
-    }
-    // functions that are just direct property lookups on tensor
-    // must be registered as prim::<name>(Tensor t) -> <return_type>
-    static const std::unordered_set<std::string> fields = {
-      "dtype",
-      "device",
-      "shape",
-      "is_cuda",
-      "requires_grad",
-    };
-    if (fields.count(field)) {
-      auto r = m.graph()->insert(Symbol::fromQualString("prim::"+field), {value});
-      return std::make_shared<SimpleValue>(r);
-    }
-  }
-  if (getValue()->type()->isSubtypeOf(NumberType::get())) {
-    throw ErrorReport(loc) << "Cannot call methods on numbers";
-  }
-  return std::make_shared<BuiltinFunction>(
-      Symbol::aten(field), NamedValue(loc, "self", value));
-}
-
-std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
-  std::unordered_map<Value*, Value*> value_map;
-  auto value_map_func = [&](Value* v) { return value_map.at(v); };
-  JIT_ASSERT(callee.inputs().size() == inputs.size());
-  for (size_t i = 0; i < inputs.size(); ++i) {
-    value_map[callee.inputs()[i]] = inputs[i];
-  }
-  for (auto* node : callee.nodes()) {
-    auto* new_node =
-        g.insertNode(g.createClone(node, value_map_func));
-    for (size_t i = 0; i < node->outputs().size(); ++i) {
-      value_map[node->outputs()[i]] = new_node->outputs()[i];
-    }
-  }
-
-  std::vector<Value*> outputs;
-  for (auto* output : callee.outputs()) {
-    outputs.push_back(value_map_func(output));
-  }
-  return outputs;
-}
 
 void defineMethodsInModule(const std::shared_ptr<Module>& m, const std::vector<Def>& definitions, const std::vector<Resolver>& resolvers, const SugaredValuePtr& self) {
   JIT_ASSERT(definitions.size() == resolvers.size());
@@ -2692,153 +2159,7 @@ void defineMethodsInModule(const std::shared_ptr<Module>& m, const std::vector<D
   didFinishEmitModule(m);
 }
 
-const std::unordered_map<std::string, TypePtr> &ident_to_type_lut() {
-  static std::unordered_map<std::string, TypePtr> map = {
-    {"Tensor", DynamicType::get()},
-    {"int", IntType::get()},
-    {"float", FloatType::get()},
-    {"bool", BoolType::get()},
-    {"str", StringType::get()},
-    {"Device", DeviceObjType::get()},
-    // technically this is not a python type but we need it when
-    // parsing serialized methods that use implicit converions to Scalar
-    {"number", NumberType::get()},
-    {"None", NoneType::get()},
-  };
-  return map;
-}
-
-const std::unordered_map<std::string, std::function<TypePtr(Subscript)>> &subscript_to_type_fns() {
-  static std::unordered_map<std::string, std::function<TypePtr(Subscript)>> map = {
-    {"Tuple", [](Subscript subscript) -> TypePtr {
-      std::vector<TypePtr> subscript_expr_types;
-      for (auto expr : subscript.subscript_exprs()) {
-        subscript_expr_types.push_back(parseTypeFromExpr(expr));
-      }
-      return TupleType::create(subscript_expr_types);
-    }},
-    {"List", [](Subscript subscript) -> TypePtr {
-      if (subscript.subscript_exprs().size() != 1) {
-        throw ErrorReport(subscript) << " expected exactly one element type but found " << subscript.subscript_exprs().size();
-      }
-      auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
-      return ListType::create(elem_type);
-    }},
-    {"Optional", [](Subscript subscript) -> TypePtr {
-      if (subscript.subscript_exprs().size() != 1) {
-        throw ErrorReport(subscript) << " expected exactly one element type but found " << subscript.subscript_exprs().size();
-      }
-      auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
-      return OptionalType::create(elem_type);
-    }},
-    {"Future", [](Subscript subscript) -> TypePtr {
-      if (subscript.subscript_exprs().size() != 1) {
-        throw ErrorReport(subscript) << " expected exactly one element type but found " << subscript.subscript_exprs().size();
-      }
-      auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
-      return FutureType::create(elem_type);
-    }},
-  };
-  return map;
-}
-
-bool isTorch(const Expr& expr) {
-  return expr.kind() == TK_VAR && Var(expr).name().name() == "torch";
-}
-
-// gets the base type name given namespaces where the types live
-// turns torch.Tensor -> Tensor, X -> X
-c10::optional<std::string> parseBaseTypeName(const Expr& expr) {
-  switch (expr.kind()) {
-    case TK_VAR: {
-      return Var(expr).name().name();
-    }
-    case TK_NONE: {
-      return "None";
-    }
-    case '.': {
-      auto select = Select(expr);
-      const std::string& name = select.selector().name();
-      if (isTorch(select.value()) && name == "Tensor")
-        return "Tensor";
-    } break;
-  }
-  return at::nullopt;
-}
-
-TypePtr parseTypeFromExpr(const Expr& expr) {
-  if (expr.kind() == TK_SUBSCRIPT) {
-    auto subscript = Subscript(expr);
-    auto value_name = parseBaseTypeName(subscript.value());
-    if (!value_name) {
-      throw ErrorReport(subscript.value().range()) << "Subscripted type must be a type identifier";
-    }
-    if (!subscript_to_type_fns().count(*value_name)) {
-      throw ErrorReport(subscript.range()) << "Unknown type constructor " << *value_name;
-    }
-    return subscript_to_type_fns().at(*value_name)(subscript);
-  } else if (auto name = parseBaseTypeName(expr)) {
-    auto itr = ident_to_type_lut().find(*name);
-    if (itr != ident_to_type_lut().end()) {
-      return itr->second;
-    }
-    throw ErrorReport(expr) << "Unknown type name " << *name;
-  }
-  throw ErrorReport(expr.range()) << "Expression of type " << kindToString(expr.kind())
-                                  << " cannot be used in a type expression";
-}
-
-c10::optional<std::pair<TypePtr, int32_t>> handleBroadcastList(const Expr& expr) {
-  if (expr.kind() != TK_SUBSCRIPT)
-    return c10::nullopt;
-  auto subscript = Subscript(expr);
-  if (subscript.value().kind() != TK_VAR)
-    return c10::nullopt;
-  auto var = Var(subscript.value());
-  auto subscript_exprs = subscript.subscript_exprs();
-
-  // handle the case where the BroadcastingList is wrapped in a Optional type
-  if(var.name().name() == "Optional") {
-    auto broadcast_list = handleBroadcastList(subscript_exprs[0]);
-    if (broadcast_list) {
-      TypePtr opt_type = OptionalType::create(broadcast_list->first);
-      return std::pair<TypePtr, int32_t>(opt_type, broadcast_list->second);
-    } else {
-      return c10::nullopt;
-    }
-  } else if (var.name().name().find("BroadcastingList") != 0) {
-    return c10::nullopt;
-  }
-
-  if (subscript_exprs.size() != 1)
-    throw ErrorReport(subscript.subscript_exprs().range())
-      << "BroadcastingList/Optional[BroadcastingList] must be subscripted with a type";
-
-  auto typ = subscript_exprs[0];
-  auto len = var.name().name().substr(strlen("BroadcastingList"));
-
-  if (typ.kind() != TK_VAR)
-    throw ErrorReport(subscript.value().range()) << "Subscripted type must be a type identifier";
-
-  auto value_name = Var(typ).name().name();
-  if (value_name != "float" && value_name != "int")
-    throw ErrorReport(subscript.value().range()) << "Broadcastable lists only supported for int or float";
-
-  auto elem_ptr = ident_to_type_lut().find(value_name);
-  JIT_ASSERT(elem_ptr != ident_to_type_lut().end());
-  TypePtr list_ptr = ListType::create(elem_ptr->second);
-
-  const char* len_c = len.c_str();
-  char* end;
-  size_t len_v = strtoull(len_c, &end, 10);
-  if (end != len_c + len.size()) {
-    throw ErrorReport(subscript.subscript_exprs().range())
-        << "subscript of Broadcastable list must be a positive integer";
-  }
-  return std::pair<TypePtr, int32_t>(list_ptr, len_v);
-}
-
-void defineMethodsInModule(std::shared_ptr<Module> m, const std::string& source, const Resolver& resolver, const SugaredValuePtr& self) {
+void defineMethodsInModule(const std::shared_ptr<Module>& m, const std::string& source, const Resolver& resolver, const SugaredValuePtr& self) {
   Parser p(source);
   std::vector<Def> definitions;
   std::vector<Resolver> resolvers;
@@ -2847,29 +2168,9 @@ void defineMethodsInModule(std::shared_ptr<Module> m, const std::string& source,
     definitions.push_back(def);
     resolvers.push_back(resolver);
   }
-  defineMethodsInModule(std::move(m), definitions, resolvers, self);
+  defineMethodsInModule(m, definitions, resolvers, self);
 }
 
-std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
-    const SourceRange& loc,
-    Method& m,
-    const c10::optional<size_t>& size_hint) {
-  static const auto make_simple_value = [](Value* v) -> std::shared_ptr<SugaredValue> {
-    return std::make_shared<SimpleValue>(v);
-  };
-  if(value->type()->kind() == TypeKind::TupleType) {
-    auto outputs = createTupleUnpack(value);
-    return fmap(outputs, make_simple_value);
-  } else if (value->type()->kind() == TypeKind::ListType) {
-    if (!size_hint) {
-      throw ErrorReport(loc) << "cannot statically infer the expected size of a list in this context";
-    }
-    auto graph = value->owningGraph();
-    Node *unpack = graph->insertNode(graph->createListUnpack(value, *size_hint));
-    return fmap(unpack->outputs(), make_simple_value);
-  }
-  throw ErrorReport(loc) << value->type()->str() << " cannot be used as a tuple";
-}
 
 } // namespace script
 } // namespace jit
index 1730a88..4f4c42f 100644 (file)
 #include <torch/csrc/jit/script/error_report.h>
 #include <torch/csrc/jit/script/tree_views.h>
 #include <torch/csrc/jit/script/module.h>
+#include <torch/csrc/jit/script/sugared_value.h>
 
 namespace torch {
 namespace jit {
 namespace script {
 
-static inline std::vector<Value*> toValues(Graph& g, at::ArrayRef<NamedValue> nvs) {
-  return fmap(nvs, [&](const NamedValue& v) {
-    return v.value(g);
-  });
-}
-
-// The AST can contain nodes like `self`, `self.b` or `python_fn` that
-// are not first-class values in the graph representation, but instead
-// will be desugared based on how they are used in the AST.
-
-// SugaredValue is used to temporarily represent these values in a way
-// that separates their behavior from the AST -> IR converter itself.
-// This allows us to keep dependencies on python minimal.
-
-enum NoneStatus {
- ALWAYS,
- MAYBE,
- NEVER
-};
-
-struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
-  // what is this node? for error reporting (e.g. Module, python function)
-  virtual std::string kind() const = 0;
-
-  // what can we do with this thing?
-  // use it as a value e.g.  `this + 4`
-  virtual Value * asValue(const SourceRange& loc, Method & m) {
-    throw ErrorReport(loc) << kind() << " cannot be used as a value";
-  }
-
-  // select an attribute on it, e.g. `this.field`
-  virtual std::shared_ptr<SugaredValue> attr(const SourceRange& loc, Method & m, const std::string& field) {
-    throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
-  }
-  virtual NoneStatus isNone() {
-    return NEVER;
-  }
-
-  // use it as a vector of values, e.g. a tuple of values as return value from
-  // a method invocation
-  virtual std::vector<std::shared_ptr<SugaredValue>> asTuple(
-      const SourceRange& loc,
-      Method& m,
-      const c10::optional<size_t>& size_hint = {}) {
-    throw ErrorReport(loc) << kind() << " cannot be used as a tuple";
-  }
-
-  // call it like a function, e.g. `outputs = this(inputs)`
-  virtual std::shared_ptr<SugaredValue> call(
-    const SourceRange& loc,
-    Method & m,
-    // note: names for args will be 'argument 0', 'argument 1', etc..
-    at::ArrayRef<NamedValue> inputs_,
-    at::ArrayRef<NamedValue> attributes,
-    size_t n_binders) {
-// n_binders is always set to the number of variables an expression is
-// syntactically bound to:
-//     a = foo() # 1 binder (note in this case the single binder might be a tuple)
-//     a, * b = foo() # 1 binder
-//     a, b = foo() # 2 binders
-//     foo() # 0 binders
-//
-// In subexpressions, like bar() in foo(bar()), n_binders is always set to
-// 1. n_binders is used as a hint to subexpressions to determine how many
-// values they should return when that number is ambiguous statically. In
-// particular it is currently used to decide how many tensors a call to a
-// python function will return. It is only a hint, functions do not have to
-// check that n_binders match the number of things they are returning, the
-// assignment logic will do that anyway.
-
-    throw ErrorReport(loc) << "cannot call a " << kind();
-  }
-
-  virtual ~SugaredValue() = default;
-};
-
-// most things in the environment are just simple value types
-// and not special python syntax sugar types
-struct TORCH_API SimpleValue : public SugaredValue {
-  SimpleValue(Value * value)
-  : value(value) {}
-  std::string kind() const override {
-    return "value";
-  }
-  Value * asValue(const SourceRange& range, Method & m) override {
-    return value;
-  }
-  NoneStatus isNone() override {
-    if (value->mustBeNone())
-      return ALWAYS;
-    else if (value->type()->cast<OptionalType>())
-      return MAYBE;
-    else
-      return NEVER;
-  }
-  std::vector<std::shared_ptr<SugaredValue>> asTuple(
-      const SourceRange& loc,
-      Method& m,
-      const c10::optional<size_t>& size_hint = {}) override;
-  std::shared_ptr<SugaredValue> attr(const SourceRange& loc, Method & m, const std::string& field) override;
-  Value* getValue() const {
-    return value;
-  }
-private:
-  Value* value;
-};
-
-struct TORCH_API BuiltinFunction : public SugaredValue {
-  BuiltinFunction(Symbol symbol, c10::optional<NamedValue> self)
-      : symbol(symbol), self(std::move(self)) {}
-
-  // The symbol of the function (e.g. `aten::relu`).
-  Symbol symbol;
-
-  // if this is method, then this is the self argument.
-  c10::optional<NamedValue> self;
-
-  std::string kind() const override {
-    return "builtin";
-  }
-  std::shared_ptr<SugaredValue> call(
-      const SourceRange& loc,
-      Method& m,
-      at::ArrayRef<NamedValue> attributes,
-      at::ArrayRef<NamedValue> inputs,
-      size_t n_binders) override;
-};
-
-struct TORCH_API BuiltinModule : public SugaredValue {
-  BuiltinModule(std::string name,
-                c10::optional<int64_t> version = at::nullopt)
-    : name(std::move(name))
-    , version(std::move(version)) {}
-
-  std::string kind() const override {
-    return "builtin module";
-  }
-  std::shared_ptr<SugaredValue> attr(const SourceRange& loc, Method & m, const std::string& field) override {
-    return std::make_shared<BuiltinFunction>(Symbol::fromQualString(name+"::"+field), c10::nullopt);
-  }
-
-private:
-  std::string name;
-  // when we add operator versioning, emit this op as it exising at 'version'
-  // if not set, use the latest version
-  c10::optional<int64_t> version;
-};
-
-// These SugaredValues have special handling in the compiler because they
-// change the normal evalution order of the expression they participate in.
-// They are exposed here so that the python frontend can inject them
-// when it sees the equivalent thing in python
-struct TORCH_API ForkValue : public SugaredValue {
-  ForkValue() = default;
-  std::string kind() const override {
-    return "fork";
-  }
-};
-struct TORCH_API AnnotateValue : public SugaredValue {
-  AnnotateValue() = default;
-  std::string kind() const override {
-    return "annotate";
-  }
-};
-
-// matched against for special handling of getattr expressions
-struct TORCH_API GetAttrValue : SugaredValue {
-  GetAttrValue() = default;
-  std::string kind() const override {
-    return "getattr";
-  }
-};
-
-// matched against for special handling of isinstance expressions
-struct TORCH_API IsInstanceValue : SugaredValue {
-  IsInstanceValue() = default;
-  std::string kind() const override {
-    return "isinstance";
-  }
-};
-
 using Resolver = std::function<std::shared_ptr<SugaredValue>(const std::string& name, Method& m, const SourceRange& loc)>;
 
 inline std::shared_ptr<SugaredValue> nativeResolver(const std::string& name, Method& m, const SourceRange& loc){
@@ -210,74 +30,7 @@ TORCH_API void defineMethodsInModule(
 );
 
 // same as above but parse the definitions from source
-TORCH_API void defineMethodsInModule(std::shared_ptr<Module> m, const std::string& source, const Resolver& resolver, const std::shared_ptr<SugaredValue>& self);
-
-// pack outputs of a function following python rules. If there is a single value return
-// a SimpleValue, otherwise pack all the values into a Tuple.
-TORCH_API Value* packOutputs(Graph& g, at::ArrayRef<Value*> values);
-TORCH_API std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs);
-
-// defines how a method obtained from a module behaves in script
-struct MethodValue : public SugaredValue {
-  MethodValue(std::shared_ptr<Module> module, Method& method)
-  : module(std::move(module)) //insurance that method stays alive
-  , method(method) {}
-  std::string kind() const override {
-    return "method";
-  }
-  std::shared_ptr<SugaredValue> call(
-      const SourceRange& loc,
-      Method& caller,
-      at::ArrayRef<NamedValue> inputs,
-      at::ArrayRef<NamedValue> attributes,
-      size_t n_binders) override {
-    return std::make_shared<SimpleValue>(packOutputs(
-        *caller.graph(), caller.emit_call_to(loc, method, inputs, attributes)));
-  }
-
- private:
-  std::shared_ptr<Module> module;
-  Method& method;
-
-};
-
-// try to match a list if inputs and keyword 'attributes' to this schema,
-// if it works return the flat list of positional inputs to the call
-// if it returns nullopt, then failure_messages contains a good error report
-// set convert_tensor_to_num to true if ImplicitTensorToNums should be inserted to
-// match the schema
-
-struct MatchedSchema {
-  std::vector<Value*> inputs;
-  std::vector<TypePtr> return_types;
-};
-
-TORCH_API c10::optional<MatchedSchema> tryMatchSchema(
-  const FunctionSchema& schema,
-  const SourceRange& loc,
-  Graph& graph,
-  c10::optional<NamedValue> self,
-  at::ArrayRef<NamedValue> inputs,
-  at::ArrayRef<NamedValue> attributes,
-  std::ostream& failure_messages,
-  bool allow_conversions);
-
-TORCH_API Value* emitBuiltinCall(
-  const SourceRange& loc,
-  Graph& graph,
-  Symbol name,
-  const c10::optional<NamedValue>& self,
-  at::ArrayRef<NamedValue> inputs,
-  at::ArrayRef<NamedValue> attributes,
-  // if true, emitBuiltinCall will throw an exception if this builtin does not exist,
-  // otherwise it will return nullptr if the builtin is not found.
-  bool required);
-
-TORCH_API c10::optional<size_t> findInputWithName(
-  const std::string& name,
-  at::ArrayRef<NamedValue> kwargs);
-
-TORCH_API at::ArrayRef<Value*> createTupleUnpack(Value* v);
+TORCH_API void defineMethodsInModule(const std::shared_ptr<Module>& m, const std::string& source, const Resolver& resolver, const std::shared_ptr<SugaredValue>& self);
 
 } // namespace script
 } // namespace jit
index e0d51bf..6ccd194 100644 (file)
@@ -5,6 +5,7 @@
 #include <torch/csrc/Layout.h>
 #include <torch/csrc/jit/import.h>
 #include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/schema_matching.h>
 
 #include <torch/csrc/jit/python_tracer.h>
 #include <torch/csrc/jit/pybind_utils.h>
@@ -124,11 +125,9 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
     for(auto &i : matched_schema->inputs)
       new_node->addInput(i);
 
-    std::vector<Value*> outputs;
-    for(auto & ret_arg : matched_schema->return_types) {
-      outputs.push_back(new_node->addOutput()->setType(ret_arg));
-    }
-    return std::make_shared<SimpleValue>(packOutputs(*m.graph(), outputs));
+    JIT_ASSERT(matched_schema->return_types.size() == 1);
+    Value* output = new_node->addOutput()->setType(matched_schema->return_types.at(0));
+    return std::make_shared<SimpleValue>(output);
   }
 
   std::string kind() const override {
index 5adf873..a4275f5 100644 (file)
@@ -1,6 +1,7 @@
 #include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/script/module.h>
 #include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/schema_matching.h>
 #include <torch/csrc/jit/script/error_report.h>
 #include <torch/csrc/jit/export.h>
 #include <torch/csrc/jit/operator.h>
@@ -13,7 +14,7 @@ void placeholderCreator(Method&) {
   throw RecursiveMethodCallError();
 }
 
-c10::optional<std::vector<Value*>> try_emit_call_to(
+Value* try_emit_call_to(
     Graph& graph,
     const SourceRange& loc,
     Method& callee,
@@ -35,7 +36,7 @@ c10::optional<std::vector<Value*>> try_emit_call_to(
     callee.getSchema(),
     loc, graph, std::move(self), args, kwargs, failure_messages, conv_tensors_to_nums);
   if(!matched_schema)
-    return c10::nullopt;
+    return nullptr;
 
   // parameters to callee method (which become parameters to _this_ method
   // if they were not already)
@@ -45,10 +46,11 @@ c10::optional<std::vector<Value*>> try_emit_call_to(
     }
     matched_schema->inputs.push_back(caller->get_or_add_parameter(member));
   }
-  return inlineCallTo(graph, *callee.graph(), matched_schema->inputs);
+  callee.check_single_output();
+  return inlineCallTo(graph, *callee.graph(), matched_schema->inputs).at(0);
 }
 
-std::vector<Value*> Method::emit_call_to(const SourceRange& loc, Method & callee, ArrayRef<NamedValue> args, ArrayRef<NamedValue> kwargs) {
+Value* Method::emit_call_to(const SourceRange& loc, Method & callee, ArrayRef<NamedValue> args, ArrayRef<NamedValue> kwargs) {
   JIT_ASSERT(!executor);
   std::stringstream failure_messages;
   if (auto result = try_emit_call_to(
@@ -61,7 +63,7 @@ std::vector<Value*> Method::emit_call_to(const SourceRange& loc, Method & callee
           failure_messages,
           this,
           /*conv_tensors_to_nums=*/true)) {
-    return *result;
+    return result;
   }
   throw ErrorReport(loc) << failure_messages.str();
 }
index aeed2d8..48684b6 100644 (file)
@@ -89,7 +89,7 @@ struct Method {
   // adding any extra parameters necessary to do this call
 
   // defined here to keep details of member_input handling confined to this class
-  std::vector<Value*> emit_call_to(const SourceRange& loc, Method & callee, ArrayRef<NamedValue> args, ArrayRef<NamedValue> kwargs);
+  Value* emit_call_to(const SourceRange& loc, Method & callee, ArrayRef<NamedValue> args, ArrayRef<NamedValue> kwargs);
 
   // if this isn't yet defined, run its method_creator function
   TORCH_API void ensure_defined();
@@ -197,6 +197,11 @@ struct Method {
     return *owner_;
   }
 
+  void check_single_output() {
+    AT_CHECK(
+        graph()->outputs().size() == 1,
+        "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
+  }
 private:
 
   static FunctionSchema defaultSchemaFor(const Method& method) {
@@ -217,9 +222,7 @@ private:
 
   GraphExecutor& get_executor() {
     std::call_once(executor_init, [&] {
-      AT_CHECK(
-          graph()->outputs().size() == 1,
-          "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
+      check_single_output();
       executor = GraphExecutor(graph(), optimize);
     });
     return executor;
@@ -506,9 +509,9 @@ struct Module {
   bool optimize;
 };
 
-// returns c10::nullopt and fills in failure_messages if the callee does not
+// returns nullptr and fills in failure_messages if the callee does not
 // match the functions schema
-c10::optional<std::vector<Value*>> try_emit_call_to(
+Value* try_emit_call_to(
     Graph& graph,
     const SourceRange& loc,
     Method& callee,
diff --git a/torch/csrc/jit/script/schema_matching.cpp b/torch/csrc/jit/script/schema_matching.cpp
new file mode 100644 (file)
index 0000000..e27f5ea
--- /dev/null
@@ -0,0 +1,412 @@
+#include <torch/csrc/jit/script/schema_matching.h>
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/operator.h>
+#include <torch/csrc/jit/script/builtin_functions.h>
+#include <torch/csrc/jit/script/error_report.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+inline TypePtr unwrapOptional(TypePtr opt_type) {
+  if (auto unwrap_list_type = opt_type->cast<OptionalType>()) {
+    return unwrap_list_type->getElementType();
+  }
+  return opt_type;
+}
+
+static inline bool isIntOrFloatUsedAsList(
+    const Value* value,
+    const Argument& arg) {
+  // Look for int[N] or float[N]
+  const auto& v_type = value->type();
+  if (v_type != FloatType::get() && v_type != IntType::get())
+    return false;
+  auto arg_type = unwrapOptional(arg.type());
+  auto list_type = arg_type->cast<ListType>();
+  return list_type && list_type->getElementType() == v_type && arg.N();
+}
+
+inline bool convertibleToList(const TypePtr& type, const TypePtr& list_type_) {
+  auto list_type = list_type_->cast<ListType>();
+  if(!list_type) {
+    return false;
+  }
+  if(type->isSubtypeOf(list_type_)) {
+    return true;
+  }
+  if(auto tuple = type->cast<TupleType>()) {
+    return std::all_of(
+        tuple->elements().begin(),
+        tuple->elements().end(),
+        [&](const TypePtr& t) {
+          return t->isSubtypeOf(list_type->getElementType());
+        });
+  }
+  return false;
+}
+
+// applies implict conversion from value trying to turn it into type concrete_type
+// it succeeds if the return_value->isSubclassOf(concrete_type)
+Value* tryConvertToType(
+    const SourceRange& loc,
+    Graph& graph,
+    const TypePtr& concrete_type,
+    Value* value,
+    bool allow_conversions) {
+
+  if (auto value_tuple = value->type()->cast<TupleType>()) {
+    // Allow homogeneous tuples to be casted implicitly to lists of appropriate
+    // types
+    if (convertibleToList(value->type(), unwrapOptional(concrete_type))) {
+      auto unpacked = createTupleUnpack(value);
+      auto elem_type = unwrapOptional(concrete_type)->expect<ListType>()->getElementType();
+      value = graph.insertNode(graph.createList(elem_type, unpacked))->output();
+    }
+    // inductively apply implicit conversions to tuples
+    if (auto concrete_tuple = concrete_type->cast<TupleType>()) {
+      if (!value_tuple->isSubtypeOf(concrete_tuple) &&
+          concrete_tuple->elements().size() == value_tuple->elements().size()) {
+        auto unpacked = createTupleUnpack(value);
+        std::vector<Value*> converted;
+        for (size_t i = 0; i < concrete_tuple->elements().size(); ++i) {
+          converted.emplace_back(tryConvertToType(
+              loc,
+              graph,
+              concrete_tuple->elements().at(i),
+              unpacked.at(i),
+              allow_conversions));
+        }
+        value = graph.insertNode(graph.createTuple(converted))->output();
+      }
+    }
+  }
+
+  if (value->type()->isSubtypeOf(NoneType::get()) && !concrete_type->isSubtypeOf(NoneType::get())){
+    if (concrete_type->isSubtypeOf(GeneratorType::get())) {
+      value = graph.insertNode(graph.createNoneGenerator())->output();
+    } else if (concrete_type->isSubtypeOf(OptionalType::ofTensor())) {
+      // create undefined tensor when None pass to a optional[tensor] formal arg
+      value = graph.insertNode(graph.createUndefined())->output();
+    } else if (auto optional_type = concrete_type->cast<OptionalType>()) {
+      value = graph.insertNode(graph.createNone(optional_type->getElementType()))->output();
+    }
+  }
+
+  //implicit conversions
+  if(allow_conversions) {
+     if(concrete_type->isSubtypeOf(NumberType::get())
+      && value->type()->isSubtypeOf(DynamicType::get())) {
+      auto n = graph.createImplicitTensorToNum(concrete_type, value);
+      value = graph.insertNode(n)
+        ->setSourceLocation(std::make_shared<SourceRange>(loc))
+        ->output();
+    }
+    if (value->type()->isSubtypeOf(StringType::get()) &&
+        DeviceObjType::get()->isSubtypeOf(concrete_type))  {
+      return graph.insert(aten::device, { value }, {}, loc);
+    }
+  }
+
+  return value;
+}
+
+Value* tryMatchArgument(
+    const Argument& arg,
+    Graph& graph,
+    const SourceRange& loc,
+    const NamedValue& named_value,
+    const std::function<std::ostream&()>& err,
+    bool allow_conversions,
+    TypeEnv & type_env) {
+  Value* value = named_value.value(graph);
+
+  // some functions that take lists of integers or floats for fixed size arrays
+  // also allow single ints/floats to be passed in their place.
+  // the single int/float is then repeated to the length of the list
+  if (isIntOrFloatUsedAsList(value, arg)) {
+    std::vector<Value*> repeated(*arg.N(), value);
+    value = graph.insertNode(graph.createList(value->type(), repeated))->output();
+  }
+
+  const MatchTypeReturn matched_type =
+      matchTypeVariables(arg.type(), value->type(), type_env);
+  if (!matched_type.type) {
+    err() << "could not match type " << value->type()->str() << " to "
+          << arg.type()->str() << " in argument '" << arg.name()
+          << "': " << matched_type.errMsg << "\n"
+          << named_value.locOr(loc);
+    return nullptr;
+  }
+  const auto concrete_type = *matched_type.type;
+
+  value = tryConvertToType(loc, graph, concrete_type, value, allow_conversions);
+
+  if(!value->type()->isSubtypeOf(concrete_type)) {
+    err() << "expected a value of type " << concrete_type->str() << " for argument '" << arg.name() << "' but found "
+          << value->type()->str() << "\n"
+          << named_value.locOr(loc);
+    return nullptr;
+  }
+  return value;
+}
+
+c10::optional<size_t> findInputWithName(
+    const std::string& name,
+    at::ArrayRef<NamedValue> kwargs) {
+  for(size_t i = 0; i < kwargs.size(); ++i) {
+    if(kwargs[i].name() == name)
+      return i;
+  }
+  return c10::nullopt;
+}
+
+Value* tryCreateList(
+    const TypePtr& elem_type,
+    Graph& graph,
+    const SourceRange& loc,
+    at::ArrayRef<NamedValue> varargs,
+    const std::function<std::ostream&()>& err,
+    bool convert_tensor_to_num,
+    TypeEnv & type_env) {
+  Argument elem_arg("<varargs>", elem_type);
+  std::vector<Value*> list_ctor;
+  for(const auto& a : varargs) {
+    Value* av = tryMatchArgument(elem_arg, graph, loc, a, err, convert_tensor_to_num, type_env);
+    if(!av)
+      return nullptr;
+    list_ctor.push_back(av);
+  }
+  return graph.insertNode(graph.createList(elem_type, list_ctor))->output();
+}
+
+c10::optional<MatchedSchema> tryMatchSchema(
+    const FunctionSchema& schema,
+    const SourceRange& loc,
+    Graph& graph,
+    c10::optional<NamedValue> self,
+    at::ArrayRef<NamedValue> args,
+    at::ArrayRef<NamedValue> kwargs,
+    std::ostream& failure_messages,
+    bool allow_conversions) {
+  auto err = [&]() -> std::ostream& {
+    failure_messages << "\nfor operator " << schema << ":\n";
+    return failure_messages;
+  };
+
+  TypeEnv type_env;
+  std::vector<Value*> positional_inputs;
+  std::vector<bool> used_kwarg(kwargs.size(), false);
+
+  // if we finish the loop will we have consumed all arguments?
+  size_t used_args = 0;
+  for (size_t schema_i = 0; schema_i < schema.arguments().size(); ++schema_i) {
+    const auto& arg = schema.arguments()[schema_i];
+    c10::optional<NamedValue> v;
+    if (arg.name() == "self" && self) {
+      v = self;
+      self = c10::nullopt;
+    } else if (!arg.kwarg_only() && used_args < args.size()) {
+      // allow zeros(IntList sizes) to work with zeros(1, 2) or zeros(1)
+      if (arg.type()->kind() == TypeKind::ListType && // the formal must be a list
+          !arg.N() && // it must not be a broadcasting list like int[3], otherwise
+                    // a single int is a valid input
+          (schema_i + 1 == schema.arguments().size() ||
+           schema.arguments()[schema_i + 1]
+               .kwarg_only())) { // must be the last position argument
+        auto actual_type = args[used_args].value(graph)->type();
+        if (actual_type->kind() != TypeKind::ListType &&
+            !convertibleToList(
+                actual_type,
+                unwrapOptional(arg.type()))) { // and the actual should not be a list already
+          auto elem_type = unwrapOptional(arg.type())->expect<ListType>()->getElementType();
+          Value* list = tryCreateList(
+              elem_type,
+              graph,
+              loc,
+              at::ArrayRef<NamedValue>(args).slice(used_args),
+              err,
+              allow_conversions,
+              type_env);
+          if (!list)
+            return c10::nullopt;
+          used_args = args.size();
+          positional_inputs.push_back(list);
+          continue;
+        }
+      }
+
+      v = args[used_args];
+      used_args++;
+    } else if (auto idx = findInputWithName(arg.name(), kwargs)) {
+      const NamedValue& nv = kwargs[*idx];
+      if (used_kwarg[*idx]) {
+        err() << "argument " << nv.name()
+              << " specified twice in schema, submit a bug report!\n"
+              << nv.locOr(loc);
+        return c10::nullopt;
+      }
+      used_kwarg[*idx] = true;
+      v = nv;
+    } else if (arg.default_value()) {
+      v = NamedValue(*arg.default_value());
+    } else {
+      err() << "argument " << schema.arguments()[schema_i].name()
+            << " not provided.\n"
+            << loc;
+      return c10::nullopt;
+    }
+    Value* positional = tryMatchArgument(
+        arg, graph, loc, *v, err, allow_conversions, type_env);
+    if (!positional)
+      return c10::nullopt;
+    positional_inputs.push_back(positional);
+  }
+  // check for unused self argument
+  if(self != c10::nullopt) {
+    err() << "provided self argument not used in schema\n";
+  }
+
+  if (schema.is_vararg()) {
+    for(;used_args < args.size(); ++used_args) {
+      positional_inputs.push_back(args[used_args].value(graph));
+    }
+  }
+
+  // check for unused positional arguments
+  if (used_args < args.size()) {
+    err() << "expected at most " << used_args << " arguments "
+          << "but found " << args.size() << " positional arguments.\n"
+          << loc << "\n";
+    return c10::nullopt;
+  }
+  // check for unused kwargs
+  for (size_t i = 0; i < kwargs.size(); ++i) {
+    const auto& nv = kwargs[i];
+    if (!used_kwarg[i]) {
+      if (!schema.argumentIndexWithName(nv.name())) {
+        err() << "keyword argument " << nv.name() << " unknown\n";
+      } else {
+        err() << "keyword argument " << nv.name() << " specified twice\n";
+      }
+      return c10::nullopt;
+    }
+  }
+  auto return_types = fmap(schema.returns(), [&](const Argument& r) {
+    return evalTypeVariables(r.type(), type_env);
+  });
+  return MatchedSchema{std::move(positional_inputs), std::move(return_types)};
+}
+
+
+// pack outputs of a function following python rules. If there is a single value return
+// a SimpleValue, otherwise pack all the values into a Tuple.
+Value* packOutputs(Graph& g, at::ArrayRef<Value*> values) {
+  if(values.size() == 1) {
+    return values[0];
+  }
+  return g.insertNode(g.createTuple(values))->output();
+}
+
+// Given a successful match between operator schema and symbol, emit a node
+// with the appropriate inputs and outputs.
+static Value* emitBuiltinNode(
+    const MatchedSchema& matched_schema,
+    const SourceRange& loc,
+    Graph& graph,
+    Symbol name) {
+  auto n = graph.insertNode(graph.create(name, matched_schema.inputs, 0))
+                ->setSourceLocation(std::make_shared<SourceRange>(loc));
+
+  for(auto & ret : matched_schema.return_types) {
+    n->addOutput()->setType(ret);
+  }
+
+  // assert that we did indeed create an op that has implementation
+  // otherwise schema and dispatch are not in sync
+  getOperation(n);
+
+  return packOutputs(graph, n->outputs());
+}
+
+static std::string prefixLine(const std::string& str, const std::string& prefix) {
+  std::stringstream ss;
+  bool was_newline = true;
+  for(auto c : str) {
+    if(was_newline)
+      ss << prefix;
+    ss.put(c);
+    was_newline = c == '\n';
+  }
+  return ss.str();
+}
+
+// Search for operators matching the provided symbol name and input types.
+// If one is found, emit a node to the graph for that operator.
+Value* emitBuiltinCall(
+  const SourceRange& loc,
+  Graph& graph,
+  Symbol name,
+  const c10::optional<NamedValue>& self,
+  at::ArrayRef<NamedValue> inputs,
+  at::ArrayRef<NamedValue> attributes,
+  // if true, emitBuiltinCall will throw an exception if this builtin does not exist,
+  // otherwise it will return nullptr if the builtin is not found.
+  bool required) {
+
+
+  const auto& variants = getAllOperatorsFor(name);
+  const auto& builtin_functions = getAllBuiltinFunctionsFor(name);
+
+  std::stringstream failure_messages;
+  //first we try to match the schema without any conversion
+  //if no schema matches then insert ImplicitTensorToNum
+  for (bool allow_conversions : {false, true}) {
+    // clear previous error messages
+    failure_messages.str("");
+    for (const std::shared_ptr<Operator>& op : variants) {
+      const auto matched_schema = tryMatchSchema(
+          op->schema(),
+          loc,
+          graph,
+          self,
+          inputs,
+          attributes,
+          failure_messages,
+          allow_conversions);
+      if (matched_schema) {
+        return emitBuiltinNode(*matched_schema, loc, graph, name);
+      }
+    }
+    for (Method* method : builtin_functions) {
+      if (auto result = try_emit_call_to(
+              graph,
+              loc,
+              *method,
+              self,
+              inputs,
+              attributes,
+              failure_messages,
+              nullptr,
+              allow_conversions)) {
+        return result;
+      }
+    }
+  }
+
+  // none of the options worked
+  if (!required) {
+    return nullptr;
+  }
+  if(variants.size() == 0) {
+    throw ErrorReport(loc) << "unknown builtin op";
+  }
+  throw ErrorReport(loc) << "arguments for call are not valid:\n"
+                         << prefixLine(failure_messages.str(), "  ")
+                         << "for call at";
+}
+
+
+} // namespace script
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/script/schema_matching.h b/torch/csrc/jit/script/schema_matching.h
new file mode 100644 (file)
index 0000000..506a474
--- /dev/null
@@ -0,0 +1,58 @@
+#pragma once
+#include <torch/csrc/jit/type.h>
+#include <torch/csrc/jit/named_value.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/function_schema.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+  // try to match a list if inputs and keyword 'attributes' to this schema,
+  // if it works return the flat list of positional inputs to the call
+  // if it returns nullopt, then failure_messages contains a good error report
+  // set convert_tensor_to_num to true if ImplicitTensorToNums should be inserted to
+  // match the schema
+
+struct MatchedSchema {
+  std::vector<Value*> inputs;
+  std::vector<TypePtr> return_types;
+};
+
+TORCH_API c10::optional<MatchedSchema> tryMatchSchema(
+  const FunctionSchema& schema,
+  const SourceRange& loc,
+  Graph& graph,
+  c10::optional<NamedValue> self,
+  at::ArrayRef<NamedValue> inputs,
+  at::ArrayRef<NamedValue> attributes,
+  std::ostream& failure_messages,
+  bool allow_conversions);
+
+TORCH_API Value* emitBuiltinCall(
+  const SourceRange& loc,
+  Graph& graph,
+  Symbol name,
+  const c10::optional<NamedValue>& self,
+  at::ArrayRef<NamedValue> inputs,
+  at::ArrayRef<NamedValue> attributes,
+  // if true, emitBuiltinCall will throw an exception if this builtin does not exist,
+  // otherwise it will return nullptr if the builtin is not found.
+  bool required);
+
+TORCH_API c10::optional<size_t> findInputWithName(
+  const std::string& name,
+  at::ArrayRef<NamedValue> kwargs);
+
+// applies implict conversion from value trying to turn it into type concrete_type
+// it succeeds if the return_value->isSubclassOf(concrete_type)
+TORCH_API Value* tryConvertToType(
+    const SourceRange& loc,
+    Graph& graph,
+    const TypePtr& concrete_type,
+    Value* value,
+    bool allow_conversions);
+
+}
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/script/sugared_value.cpp b/torch/csrc/jit/script/sugared_value.cpp
new file mode 100644 (file)
index 0000000..418685b
--- /dev/null
@@ -0,0 +1,111 @@
+#include <torch/csrc/jit/script/type_parser.h>
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/script/tree_views.h>
+#include <torch/csrc/jit/script/sugared_value.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+struct NoneValue : SugaredValue {
+  NoneValue() = default;
+  std::string kind() const override {
+    return "None";
+  }
+};
+
+std::shared_ptr<SugaredValue> PrintValue::call(
+  const SourceRange& loc,
+  Method & m,
+  at::ArrayRef<NamedValue> inputs,
+  at::ArrayRef<NamedValue> attributes,
+  size_t n_binders) {
+    auto& g = *m.graph();
+    if (!attributes.empty())
+      throw ErrorReport(loc) << "print doesn't accept any keyword arguments";
+
+    //temporary hack to allow print statements to work in python 2, where
+    //print(a, b) is treated as a (a, b) tuple input.
+
+    std::vector<Value*> lowered_inputs = toValues(*m.graph(), inputs);
+    if(lowered_inputs.size() == 1 && lowered_inputs.at(0)->node()->kind() == prim::TupleConstruct) {
+      auto input = lowered_inputs[0];
+      for(size_t j = 0; j < input->node()->inputs().size(); ++j) {
+        lowered_inputs.insert(lowered_inputs.begin() + 1 + j, input->node()->inputs().at(j));
+      }
+      lowered_inputs.erase(lowered_inputs.begin());
+    }
+    g.insertNode(g.create(prim::Print, lowered_inputs, 0)
+                     ->setSourceLocation(std::make_shared<SourceRange>(loc)));
+    return std::make_shared<NoneValue>();
+}
+
+static const std::unordered_map<std::string, std::string> &builtin_cast_methods() {
+  static std::unordered_map<std::string, std::string> builtin_cast_methods = {
+    {"byte", "_cast_Byte"},
+    {"char", "_cast_Char"},
+    {"double", "_cast_Double"},
+    {"float", "_cast_Float"},
+    {"int", "_cast_Int"},
+    {"long", "_cast_Long"},
+    {"short", "_cast_Short"},
+    {"half", "_cast_Half"}
+  };
+  return builtin_cast_methods;
+}
+
+// support syntax sugar for x.foo(y, z) by allowing x.foo to return a
+// callable value that will resolve to foo(x, y, z) when called.
+std::shared_ptr<SugaredValue> SimpleValue::attr(const SourceRange& loc, Method & m, const std::string& field) {
+  // Allow method-style casts on Tensor types. e.g. x.int()
+  if (value->type()->isSubtypeOf(DynamicType::get())) {
+    if (builtin_cast_methods().count(field)) {
+      return std::make_shared<BuiltinFunction>(
+          Symbol::aten(builtin_cast_methods().at(field)),
+          NamedValue(loc, "self", value));
+    }
+    // functions that are just direct property lookups on tensor
+    // must be registered as prim::<name>(Tensor t) -> <return_type>
+    static const std::unordered_set<std::string> fields = {
+      "dtype",
+      "device",
+      "shape",
+      "is_cuda",
+      "requires_grad",
+    };
+    if (fields.count(field)) {
+      auto r = m.graph()->insert(Symbol::fromQualString("prim::"+field), {value});
+      return std::make_shared<SimpleValue>(r);
+    }
+  }
+  if (getValue()->type()->isSubtypeOf(NumberType::get())) {
+    throw ErrorReport(loc) << "Cannot call methods on numbers";
+  }
+  return std::make_shared<BuiltinFunction>(
+      Symbol::aten(field), NamedValue(loc, "self", value));
+}
+
+std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
+    const SourceRange& loc,
+    Method& m,
+    const c10::optional<size_t>& size_hint) {
+  static const auto make_simple_value = [](Value* v) -> std::shared_ptr<SugaredValue> {
+    return std::make_shared<SimpleValue>(v);
+  };
+  if(value->type()->kind() == TypeKind::TupleType) {
+    auto outputs = createTupleUnpack(value);
+    return fmap(outputs, make_simple_value);
+  } else if (value->type()->kind() == TypeKind::ListType) {
+    if (!size_hint) {
+      throw ErrorReport(loc) << "cannot statically infer the expected size of a list in this context";
+    }
+    auto graph = value->owningGraph();
+    Node *unpack = graph->insertNode(graph->createListUnpack(value, *size_hint));
+    return fmap(unpack->outputs(), make_simple_value);
+  }
+  throw ErrorReport(loc) << value->type()->str() << " cannot be used as a tuple";
+}
+
+} // namespace script
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/script/sugared_value.h b/torch/csrc/jit/script/sugared_value.h
new file mode 100644 (file)
index 0000000..916792e
--- /dev/null
@@ -0,0 +1,260 @@
+#pragma once
+#include <functional>
+#include <memory>
+#include <string>
+
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/script/error_report.h>
+#include <torch/csrc/jit/script/tree_views.h>
+#include <torch/csrc/jit/script/module.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+// The AST can contain nodes like `self`, `self.b` or `python_fn` that
+// are not first-class values in the graph representation, but instead
+// will be desugared based on how they are used in the AST.
+
+// SugaredValue is used to temporarily represent these values in a way
+// that separates their behavior from the AST -> IR converter itself.
+// This allows us to keep dependencies on python minimal.
+
+enum NoneStatus {
+ ALWAYS,
+ MAYBE,
+ NEVER
+};
+
+struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
+  // what is this node? for error reporting (e.g. Module, python function)
+  virtual std::string kind() const = 0;
+
+  // what can we do with this thing?
+  // use it as a value e.g.  `this + 4`
+  virtual Value * asValue(const SourceRange& loc, Method & m) {
+    throw ErrorReport(loc) << kind() << " cannot be used as a value";
+  }
+
+  // select an attribute on it, e.g. `this.field`
+  virtual std::shared_ptr<SugaredValue> attr(const SourceRange& loc, Method & m, const std::string& field) {
+    throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
+  }
+  virtual NoneStatus isNone() {
+    return NEVER;
+  }
+
+  // use it as a vector of values, e.g. a tuple of values as return value from
+  // a method invocation
+  virtual std::vector<std::shared_ptr<SugaredValue>> asTuple(
+      const SourceRange& loc,
+      Method& m,
+      const c10::optional<size_t>& size_hint = {}) {
+    throw ErrorReport(loc) << kind() << " cannot be used as a tuple";
+  }
+
+  // call it like a function, e.g. `outputs = this(inputs)`
+  virtual std::shared_ptr<SugaredValue> call(
+    const SourceRange& loc,
+    Method & m,
+    // note: names for args will be 'argument 0', 'argument 1', etc..
+    at::ArrayRef<NamedValue> inputs_,
+    at::ArrayRef<NamedValue> attributes,
+    size_t n_binders) {
+// n_binders is always set to the number of variables an expression is
+// syntactically bound to:
+//     a = foo() # 1 binder (note in this case the single binder might be a tuple)
+//     a, * b = foo() # 1 binder
+//     a, b = foo() # 2 binders
+//     foo() # 0 binders
+//
+// In subexpressions, like bar() in foo(bar()), n_binders is always set to
+// 1. n_binders is used as a hint to subexpressions to determine how many
+// values they should return when that number is ambiguous statically. In
+// particular it is currently used to decide how many tensors a call to a
+// python function will return. It is only a hint, functions do not have to
+// check that n_binders match the number of things they are returning, the
+// assignment logic will do that anyway.
+
+    throw ErrorReport(loc) << "cannot call a " << kind();
+  }
+
+  virtual ~SugaredValue() = default;
+};
+
+// most things in the environment are just simple value types
+// and not special python syntax sugar types
+struct TORCH_API SimpleValue : public SugaredValue {
+  SimpleValue(Value * value)
+  : value(value) {}
+  std::string kind() const override {
+    return "value";
+  }
+  Value * asValue(const SourceRange& range, Method & m) override {
+    return value;
+  }
+  NoneStatus isNone() override {
+    if (value->mustBeNone())
+      return ALWAYS;
+    else if (value->type()->cast<OptionalType>())
+      return MAYBE;
+    else
+      return NEVER;
+  }
+  std::vector<std::shared_ptr<SugaredValue>> asTuple(
+      const SourceRange& loc,
+      Method& m,
+      const c10::optional<size_t>& size_hint = {}) override;
+  std::shared_ptr<SugaredValue> attr(const SourceRange& loc, Method & m, const std::string& field) override;
+  Value* getValue() const {
+    return value;
+  }
+private:
+  Value* value;
+};
+
+struct TORCH_API BuiltinFunction : public SugaredValue {
+  BuiltinFunction(Symbol symbol, c10::optional<NamedValue> self)
+      : symbol(symbol), self(std::move(self)) {}
+
+  // The symbol of the function (e.g. `aten::relu`).
+  Symbol symbol;
+
+  // if this is method, then this is the self argument.
+  c10::optional<NamedValue> self;
+
+  std::string kind() const override {
+    return "builtin";
+  }
+  std::shared_ptr<SugaredValue> call(
+      const SourceRange& loc,
+      Method& m,
+      at::ArrayRef<NamedValue> attributes,
+      at::ArrayRef<NamedValue> inputs,
+      size_t n_binders) override;
+};
+
+struct TORCH_API BuiltinModule : public SugaredValue {
+  BuiltinModule(std::string name,
+                c10::optional<int64_t> version = at::nullopt)
+    : name(std::move(name))
+    , version(std::move(version)) {}
+
+  std::string kind() const override {
+    return "builtin module";
+  }
+  std::shared_ptr<SugaredValue> attr(const SourceRange& loc, Method & m, const std::string& field) override {
+    return std::make_shared<BuiltinFunction>(Symbol::fromQualString(name+"::"+field), c10::nullopt);
+  }
+
+private:
+  std::string name;
+  // when we add operator versioning, emit this op as it exising at 'version'
+  // if not set, use the latest version
+  c10::optional<int64_t> version;
+};
+
+// defines how a method obtained from a module behaves in script
+struct MethodValue : public SugaredValue {
+  MethodValue(std::shared_ptr<Module> module, Method& method)
+  : module(std::move(module)) //insurance that method stays alive
+  , method(method) {}
+  std::string kind() const override {
+    return "method";
+  }
+  std::shared_ptr<SugaredValue> call(
+      const SourceRange& loc,
+      Method& caller,
+      at::ArrayRef<NamedValue> inputs,
+      at::ArrayRef<NamedValue> attributes,
+      size_t n_binders) override {
+    return std::make_shared<SimpleValue>(caller.emit_call_to(loc, method, inputs, attributes));
+  }
+
+ private:
+  std::shared_ptr<Module> module;
+  Method& method;
+
+};
+
+struct TORCH_API PrintValue : public SugaredValue {
+  std::string kind() const override {
+    return "print";
+  }
+  std::shared_ptr<SugaredValue> call(
+    const SourceRange& loc,
+    Method & m,
+    at::ArrayRef<NamedValue> inputs,
+    at::ArrayRef<NamedValue> attributes,
+    size_t n_binders) override;
+};
+
+// expressions like int(x)
+// these are the same as call prim::Int or equivalent except it
+// is a noop when the input is a subtype of 'type'
+struct TORCH_API CastValue : public BuiltinFunction {
+  CastValue(TypePtr type, c10::Symbol method)
+  : BuiltinFunction(method, c10::nullopt)
+  , type_(std::move(type)) {}
+  std::shared_ptr<SugaredValue> call(
+    const SourceRange& loc,
+    Method & m,
+    at::ArrayRef<NamedValue> inputs,
+    at::ArrayRef<NamedValue> attributes,
+    size_t n_binders) override {
+      if(inputs.size() == 1 && attributes.size() == 0) {
+        auto v = inputs[0].value(*m.graph());
+        if (v->type()->isSubtypeOf(type_)) {
+          return std::make_shared<SimpleValue>(v);
+        }
+      }
+      return BuiltinFunction::call(loc, m , inputs, attributes, n_binders);
+  }
+private:
+  TypePtr type_;
+};
+
+
+// These SugaredValues have special handling in the compiler because they
+// change the normal evalution order of the expression they participate in.
+// They are exposed here so that the python frontend can inject them
+// when it sees the equivalent thing in python
+
+struct TORCH_API ForkValue : public SugaredValue {
+  ForkValue() = default;
+  std::string kind() const override {
+    return "fork";
+  }
+};
+struct TORCH_API AnnotateValue : public SugaredValue {
+  AnnotateValue() = default;
+  std::string kind() const override {
+    return "annotate";
+  }
+};
+
+// matched against for special handling of getattr expressions
+struct TORCH_API GetAttrValue : SugaredValue {
+  GetAttrValue() = default;
+  std::string kind() const override {
+    return "getattr";
+  }
+};
+
+// matched against for special handling of isinstance expressions
+struct TORCH_API IsInstanceValue : SugaredValue {
+  IsInstanceValue() = default;
+  std::string kind() const override {
+    return "isinstance";
+  }
+};
+
+static inline std::vector<Value*> toValues(Graph& g, at::ArrayRef<NamedValue> nvs) {
+  return fmap(nvs, [&](const NamedValue& v) {
+    return v.value(g);
+  });
+}
+
+}
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/script/type_parser.cpp b/torch/csrc/jit/script/type_parser.cpp
new file mode 100644 (file)
index 0000000..5b4ce12
--- /dev/null
@@ -0,0 +1,158 @@
+#include <torch/csrc/jit/script/type_parser.h>
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/script/tree_views.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+const std::unordered_map<std::string, TypePtr> &ident_to_type_lut() {
+  static std::unordered_map<std::string, TypePtr> map = {
+    {"Tensor", DynamicType::get()},
+    {"int", IntType::get()},
+    {"float", FloatType::get()},
+    {"bool", BoolType::get()},
+    {"str", StringType::get()},
+    {"Device", DeviceObjType::get()},
+    // technically this is not a python type but we need it when
+    // parsing serialized methods that use implicit converions to Scalar
+    {"number", NumberType::get()},
+    {"None", NoneType::get()},
+  };
+  return map;
+}
+
+const std::unordered_map<std::string, std::function<TypePtr(Subscript)>> &subscript_to_type_fns() {
+  static std::unordered_map<std::string, std::function<TypePtr(Subscript)>> map = {
+    {"Tuple", [](Subscript subscript) -> TypePtr {
+      std::vector<TypePtr> subscript_expr_types;
+      for (auto expr : subscript.subscript_exprs()) {
+        subscript_expr_types.push_back(parseTypeFromExpr(expr));
+      }
+      return TupleType::create(subscript_expr_types);
+    }},
+    {"List", [](Subscript subscript) -> TypePtr {
+      if (subscript.subscript_exprs().size() != 1) {
+        throw ErrorReport(subscript) << " expected exactly one element type but found " << subscript.subscript_exprs().size();
+      }
+      auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
+      return ListType::create(elem_type);
+    }},
+    {"Optional", [](Subscript subscript) -> TypePtr {
+      if (subscript.subscript_exprs().size() != 1) {
+        throw ErrorReport(subscript) << " expected exactly one element type but found " << subscript.subscript_exprs().size();
+      }
+      auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
+      return OptionalType::create(elem_type);
+    }},
+    {"Future", [](Subscript subscript) -> TypePtr {
+      if (subscript.subscript_exprs().size() != 1) {
+        throw ErrorReport(subscript) << " expected exactly one element type but found " << subscript.subscript_exprs().size();
+      }
+      auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
+      return FutureType::create(elem_type);
+    }},
+  };
+  return map;
+}
+
+bool isTorch(const Expr& expr) {
+  return expr.kind() == TK_VAR && Var(expr).name().name() == "torch";
+}
+
+
+
+c10::optional<std::pair<TypePtr, int32_t>> handleBroadcastList(const Expr& expr) {
+  if (expr.kind() != TK_SUBSCRIPT)
+    return c10::nullopt;
+  auto subscript = Subscript(expr);
+  if (subscript.value().kind() != TK_VAR)
+    return c10::nullopt;
+  auto var = Var(subscript.value());
+  auto subscript_exprs = subscript.subscript_exprs();
+
+  // handle the case where the BroadcastingList is wrapped in a Optional type
+  if(var.name().name() == "Optional") {
+    auto broadcast_list = handleBroadcastList(subscript_exprs[0]);
+    if (broadcast_list) {
+      TypePtr opt_type = OptionalType::create(broadcast_list->first);
+      return std::pair<TypePtr, int32_t>(opt_type, broadcast_list->second);
+    } else {
+      return c10::nullopt;
+    }
+  } else if (var.name().name().find("BroadcastingList") != 0) {
+    return c10::nullopt;
+  }
+
+  if (subscript_exprs.size() != 1)
+    throw ErrorReport(subscript.subscript_exprs().range())
+      << "BroadcastingList/Optional[BroadcastingList] must be subscripted with a type";
+
+  auto typ = subscript_exprs[0];
+  auto len = var.name().name().substr(strlen("BroadcastingList"));
+
+  if (typ.kind() != TK_VAR)
+    throw ErrorReport(subscript.value().range()) << "Subscripted type must be a type identifier";
+
+  auto value_name = Var(typ).name().name();
+  if (value_name != "float" && value_name != "int")
+    throw ErrorReport(subscript.value().range()) << "Broadcastable lists only supported for int or float";
+
+  auto elem_ptr = ident_to_type_lut().find(value_name);
+  JIT_ASSERT(elem_ptr != ident_to_type_lut().end());
+  TypePtr list_ptr = ListType::create(elem_ptr->second);
+
+  const char* len_c = len.c_str();
+  char* end;
+  size_t len_v = strtoull(len_c, &end, 10);
+  if (end != len_c + len.size()) {
+    throw ErrorReport(subscript.subscript_exprs().range())
+        << "subscript of Broadcastable list must be a positive integer";
+  }
+  return std::pair<TypePtr, int32_t>(list_ptr, len_v);
+}
+
+// gets the base type name given namespaces where the types live
+// turns torch.Tensor -> Tensor, X -> X
+c10::optional<std::string> parseBaseTypeName(const Expr& expr) {
+  switch (expr.kind()) {
+    case TK_VAR: {
+      return Var(expr).name().name();
+    }
+    case TK_NONE: {
+      return "None";
+    }
+    case '.': {
+      auto select = Select(expr);
+      const std::string& name = select.selector().name();
+      if (isTorch(select.value()) && name == "Tensor")
+        return "Tensor";
+    } break;
+  }
+  return at::nullopt;
+}
+
+TypePtr parseTypeFromExpr(const Expr& expr) {
+  if (expr.kind() == TK_SUBSCRIPT) {
+    auto subscript = Subscript(expr);
+    auto value_name = parseBaseTypeName(subscript.value());
+    if (!value_name) {
+      throw ErrorReport(subscript.value().range()) << "Subscripted type must be a type identifier";
+    }
+    if (!subscript_to_type_fns().count(*value_name)) {
+      throw ErrorReport(subscript.range()) << "Unknown type constructor " << *value_name;
+    }
+    return subscript_to_type_fns().at(*value_name)(subscript);
+  } else if (auto name = parseBaseTypeName(expr)) {
+    auto itr = ident_to_type_lut().find(*name);
+    if (itr != ident_to_type_lut().end()) {
+      return itr->second;
+    }
+    throw ErrorReport(expr) << "Unknown type name " << *name;
+  }
+  throw ErrorReport(expr.range()) << "Expression of type " << kindToString(expr.kind())
+                                  << " cannot be used in a type expression";
+}
+} // namespace script
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/script/type_parser.h b/torch/csrc/jit/script/type_parser.h
new file mode 100644 (file)
index 0000000..b583405
--- /dev/null
@@ -0,0 +1,12 @@
+#pragma once
+#include <ATen/core/jit_type.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
+namespace torch {
+namespace jit {
+namespace script {
+struct Expr;
+TORCH_API c10::optional<std::string> parseBaseTypeName(const Expr& expr);
+TORCH_API c10::TypePtr parseTypeFromExpr(const Expr& expr);
+}
+} // namespace jit
+} // namespace torch