Revert D14842057: Compiler uses first-class modules**
authorZachary DeVito <zdevito@fb.com>
Thu, 11 Apr 2019 13:14:21 +0000 (06:14 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 11 Apr 2019 13:17:01 +0000 (06:17 -0700)
Differential Revision:
D14842057

Original commit changeset: ca6e7b5a4380

fbshipit-source-id: e8f1862a59bf20d5f78648b2fdc53a8b3750ead3

32 files changed:
aten/src/ATen/core/function_schema.h
aten/src/ATen/core/jit_type.h
aten/src/ATen/core/type.cpp
test/cpp/jit/test.cpp
test/cpp/jit/test_misc.h
test/expect/TestScript.test_onnx_export_script_inline_params.expect
test/expect/TestScript.test_onnx_export_speculate-f2.expect
test/test_jit.py
tools/build_variables.py
torch/CMakeLists.txt
torch/csrc/api/include/torch/jit.h
torch/csrc/api/src/jit.cpp
torch/csrc/jit/import_source.cpp
torch/csrc/jit/ir.h
torch/csrc/jit/passes/graph_fuser.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/python_ir.cpp
torch/csrc/jit/script/builtin_functions.cpp
torch/csrc/jit/script/builtin_functions.h
torch/csrc/jit/script/class_type.cpp
torch/csrc/jit/script/compilation_unit.h [deleted file]
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/compiler.h
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/module.cpp
torch/csrc/jit/script/module.h
torch/csrc/jit/script/schema_matching.cpp
torch/csrc/jit/script/slot.h
torch/csrc/jit/script/sugared_value.cpp
torch/csrc/jit/script/sugared_value.h
torch/csrc/jit/symbolic_script.cpp
torch/jit/__init__.py

index 888cfcb..c5185c1 100644 (file)
@@ -164,18 +164,6 @@ public:
     }
     return c10::nullopt;
   }
-  FunctionSchema cloneWithArguments(std::vector<Argument> new_arguments) const {
-    return FunctionSchema(
-        name(),
-        overload_name(),
-        std::move(new_arguments),
-        returns(),
-        is_vararg(),
-        is_varret());
-  }
-  // Check that inputs have the correct types and appends any missing default
-  // values.
-  void checkAndNormalizeInputs(std::vector<IValue>& inputs) const;
 };
 
 inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) {
@@ -239,46 +227,4 @@ inline std::string toString(const FunctionSchema& schema) {
   return str.str();
 }
 
-inline void FunctionSchema::checkAndNormalizeInputs(std::vector<IValue>& inputs) const {
-  // Do we have more inputs than the schema accepts?
-  AT_CHECK(
-      inputs.size() <= arguments().size(),
-      "Expected at most ",
-      arguments().size(),
-      " argument(s) for operator '",
-      name(),
-      "', but received ",
-      inputs.size(),
-      " argument(s). Declaration: ",
-      *this);
-
-  for (size_t pos = 0; pos < arguments().size(); ++pos) {
-    const auto& argument = arguments()[pos];
-    if (pos < inputs.size()) {
-      if (!isSubvalueOf(inputs[pos], argument.type())) {
-        AT_ERROR(
-            "Expected value of type ",
-            *argument.type(),
-            " for argument '",
-            argument.name(),
-            "' in position ",
-            pos,
-            ", but instead got value of type ",
-            attemptToRecoverType(inputs[pos])->str(),
-            ". Declaration: ",
-            *this);
-      }
-    } else if (argument.default_value()) {
-      inputs.push_back(*argument.default_value());
-    } else {
-      AT_ERROR(
-          name(),
-          "() is missing value for argument '",
-          argument.name(),
-          "'. Declaration: ",
-          *this);
-    }
-  }
-}
-
 } // namespace c10
index 2399dcc..3656251 100644 (file)
@@ -17,8 +17,8 @@
 namespace torch {
 namespace jit {
 namespace script {
-struct CompilationUnit;
-struct Function;
+struct Module;
+struct Method;
 }
 } // namespace jit
 } // namespace torch
@@ -1100,19 +1100,19 @@ CAFFE2_API TypePtr evalTypeVariables(TypePtr type, TypeEnv & type_env);
 
 struct ClassType;
 using ClassTypePtr = std::shared_ptr<ClassType>;
-using ::torch::jit::script::CompilationUnit;
-using ::torch::jit::script::Function;
+using ::torch::jit::script::Module;
+using ::torch::jit::script::Method;
 
 // This represents a class in TorchScript.
 struct CAFFE2_API ClassType : public Type {
   // Create a user type and register it globally.
   static ClassTypePtr create(
       const std::string& name,
-      std::shared_ptr<CompilationUnit> module);
+      std::shared_ptr<Module> module);
 
   // Create a type representing a Module,
   // These do not have methods, and are not globally registered
-  static ClassTypePtr createModuleType(std::shared_ptr<CompilationUnit> module);
+  static ClassTypePtr createModuleType();
 
   // returns nullptr if there is no type with that name
   static ClassTypePtr get(const std::string& name);
@@ -1168,11 +1168,8 @@ struct CAFFE2_API ClassType : public Type {
     return attributeNames_[slot];
   }
 
-  Function* getMethod(const std::string& name) const;
-  CompilationUnit& compilation_unit();
-  const CompilationUnit& compilation_unit() const;
-  std::vector<Function*> methods() const;
-
+  Method* getMethod(const std::string& name) const;
+  std::vector<Method*> methods() const;
 
   const std::string& name() const {
     return typename_;
@@ -1229,10 +1226,10 @@ struct CAFFE2_API ClassType : public Type {
   static const TypeKind Kind = TypeKind::ClassType;
 
  private:
-  ClassType(std::string name, std::shared_ptr<CompilationUnit> cu)
+  ClassType(std::string name, std::shared_ptr<Module> module)
       : Type(TypeKind::ClassType),
         typename_(std::move(name)),
-        compilation_unit_(std::move(cu)) {}
+        module_(std::move(module)) {}
 
   // Name of type (note that this has to be globally unique).
   std::string typename_;
@@ -1246,7 +1243,7 @@ struct CAFFE2_API ClassType : public Type {
   std::vector<std::string> attributeNames_;
   std::vector<TypePtr> attributeTypes_;
   // Holds method attributes
-  std::shared_ptr<CompilationUnit> compilation_unit_;
+  std::shared_ptr<Module> module_;
 
 };
 } // namespace c10
index 534c9f3..e6c3628 100644 (file)
@@ -472,18 +472,18 @@ ClassTypeRegistry& getRegistry() {
 
 ClassTypePtr ClassType::create(
     const std::string& name,
-    std::shared_ptr<CompilationUnit> cu) {
-  auto ptr = ClassTypePtr(new ClassType(name, std::move(cu)));
+    std::shared_ptr<Module> module) {
+  auto ptr = ClassTypePtr(new ClassType(name, std::move(module)));
   getRegistry().registerType(name, ptr);
   return ptr;
 }
 
-ClassTypePtr ClassType::createModuleType(std::shared_ptr<CompilationUnit> cu) {
-  return ClassTypePtr(new ClassType("Module", std::move(cu)));
+ClassTypePtr ClassType::createModuleType() {
+  return ClassTypePtr(new ClassType("Module", nullptr));
 }
 
 ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {
-  auto ptr = ClassTypePtr(new ClassType(typename_, compilation_unit_));
+  auto ptr = ClassTypePtr(new ClassType(typename_, module_));
   AT_ASSERT(numAttributes() == refined_slots.size());
   for(size_t i = 0; i < attributeNames_.size(); ++i) {
     AT_ASSERT(refined_slots[i]->isSubtypeOf(attributeTypes_[i]));
index 90b38f3..7145c1a 100644 (file)
@@ -65,8 +65,7 @@ namespace jit {
   _(NoneSchemaMatch)               \
   _(ClassParser)                   \
   _(PeepholeOptimize)              \
-  _(RecordFunction)                \
-  _(ModuleDefine)
+  _(RecordFunction)
 
 #define TH_FORALL_TESTS_CUDA(_) \
   _(ArgumentSpec)               \
index ea20266..8b93c21 100644 (file)
@@ -40,7 +40,6 @@
 #include "ATen/core/ivalue.h"
 #include "torch/csrc/jit/script/compiler.h"
 #include "torch/csrc/jit/script/module.h"
-#include "torch/jit.h"
 
 #include "onnx/onnx_pb.h"
 
@@ -370,10 +369,11 @@ static const auto cf_examples = R"JIT(
     return a
 )JIT";
 void testControlFlow() {
-  auto cu = compile(cf_examples);
-
+  auto cu = std::make_shared<script::Module>();
+  script::defineMethodsInModule(
+      cu, cf_examples, script::nativeResolver, c10::nullopt);
   auto run = [&](const std::string& name, std::vector<IValue> stack) {
-    auto graph = cu->get_function(name).graph();
+    auto graph = cu->get_method(name).graph();
     Code code(graph);
     InterpreterState interp(code);
     interp.run(stack);
@@ -576,11 +576,12 @@ void testTopologicalIndex() {
 }
 
 void invokeTestRecordFunction(at::Tensor& t) {
-  autograd::profiler::GetPackedInputsCallback inputs_cb = [t]() {
-    Stack st;
-    pack(st, t);
-    return st;
-  };
+  autograd::profiler::GetPackedInputsCallback inputs_cb =
+    [t]() {
+      Stack st;
+      pack(st, t);
+      return st;
+    };
   autograd::profiler::RecordFunction guard("test", inputs_cb);
   t.add_(torch::ones_like(t));
 }
@@ -604,15 +605,15 @@ void invokeTestRecordFunctionNested() {
 
 void testRecordFunction() {
   std::vector<std::vector<int64_t>> input_sizes;
-  autograd::profiler::pushCallback(
-      [&input_sizes](const autograd::profiler::RecordFunction& fn) {
-        for (const auto& input : fn.inputs()) {
-          if (input.isTensor()) {
-            std::vector<int64_t> t = input.toTensor().sizes().vec();
-            input_sizes.push_back(t);
-          }
-        }
-      });
+  autograd::profiler::pushCallback([&input_sizes](
+      const autograd::profiler::RecordFunction& fn) {
+    for (const auto& input : fn.inputs()) {
+      if (input.isTensor()) {
+        std::vector<int64_t> t = input.toTensor().sizes().vec();
+        input_sizes.push_back(t);
+      }
+    }
+  });
 
   auto t = torch::randn({1, 2, 3}, at::kCPU);
   invokeTestRecordFunction(t);
@@ -624,15 +625,14 @@ void testRecordFunction() {
 
   // test nested RecordFunctions
   std::vector<std::string> nested_names;
-  autograd::profiler::pushCallback(
-      [&nested_names](const autograd::profiler::RecordFunction& fn) {
-        nested_names.push_back(getFullName(&fn));
-      });
+  autograd::profiler::pushCallback([&nested_names](
+      const autograd::profiler::RecordFunction& fn) {
+    nested_names.push_back(getFullName(&fn));
+  });
 
   {
     autograd::profiler::RecordFunction guard("outer");
-    invokeTestRecordFunctionNested();
-    ;
+    invokeTestRecordFunctionNested();;
   }
 
   autograd::profiler::popCallback();
@@ -709,18 +709,6 @@ void testNoneSchemaMatch() {
   // checking that constant propagation ran wo/failure
   AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1);
 }
-
-void testModuleDefine() {
-  auto m = std::make_shared<script::Module>();
-  m->register_parameter("foo", torch::ones({}), false);
-  m->define(R"(
-    def add_it(self, x, b : int = 4):
-      return self.foo + x + b
-  )");
-  auto result = m->run_method("add_it", torch::ones({}));
-  AT_ASSERT(result.toTensor().item<float>() == 6)
-}
-
 } // namespace test
 } // namespace jit
 } // namespace torch
index 1cb1092..ffa284a 100644 (file)
@@ -5,14 +5,14 @@ ModelProto {
   graph:
     GraphProto {
       name: "torch-jit-export"
-      inputs: [{name: "x", type:Tensor dims: 2 3},{name: "1", type:Tensor dims: 3 4},{name: "2", type:Tensor dims: 3 3}]
+      inputs: [{name: "x", type:Tensor dims: 2 3},{name: "1", type:Tensor dims: 3 3},{name: "2", type:Tensor dims: 3 4}]
       outputs: [{name: "6", type:Tensor dims: 2 4}]
-      initializers: [TensorProto shape: [3 4],TensorProto shape: [3 3]]
+      initializers: [TensorProto shape: [3 3],TensorProto shape: [3 4]]
       nodes: [
         Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [1]}]},
-        Node {type: "Gemm", inputs: [x,2,3], outputs: [4], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0}]},
+        Node {type: "Gemm", inputs: [x,1,3], outputs: [4], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0}]},
         Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [1]}]},
-        Node {type: "Gemm", inputs: [4,1,5], outputs: [6], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0}]}
+        Node {type: "Gemm", inputs: [4,2,5], outputs: [6], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0}]}
       ]
     }
   opset_import: [OperatorSetIdProto { domain: }],
index 29ce206..3126f1d 100644 (file)
@@ -5,9 +5,9 @@ ModelProto {
   graph:
     GraphProto {
       name: "torch-jit-export"
-      inputs: [{name: "x.1", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20},{name: "2", type:Tensor dims: 20 10}]
+      inputs: [{name: "x.1", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20 10},{name: "2", type:Tensor dims: 20}]
       outputs: [{name: "8", type:Tensor dims: 1 20}]
-      initializers: [TensorProto shape: [20],TensorProto shape: [20 10]]
+      initializers: [TensorProto shape: [20 10],TensorProto shape: [20]]
       nodes: [
         Node {type: "Add", inputs: [x.1,x.1], outputs: [3], attributes: []},
         Node {type: "ReduceSum", inputs: [3], outputs: [4], attributes: [{ name: 'keepdims', type: int, value: 0}]},
@@ -28,7 +28,7 @@ ModelProto {
                       outputs: [{name: "10", type:Tensor dims: 1 20}]
                       initializers: []
                       nodes: [
-                        Node {type: "Gemm", inputs: [3,2,1], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
+                        Node {type: "Gemm", inputs: [3,1,2], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
                       ]
                     }
 
@@ -39,7 +39,7 @@ ModelProto {
                       outputs: [{name: "11", type:Tensor dims: 1 20}]
                       initializers: []
                       nodes: [
-                        Node {type: "Gemm", inputs: [3,2,1], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
+                        Node {type: "Gemm", inputs: [3,1,2], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
                       ]
                     }
 
@@ -54,7 +54,7 @@ ModelProto {
               outputs: [{name: "12", type:Tensor dims: 1 20}]
               initializers: []
               nodes: [
-                Node {type: "Gemm", inputs: [3,2,1], outputs: [12], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
+                Node {type: "Gemm", inputs: [3,1,2], outputs: [12], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
               ]
             }
 
index 18ca724..787dee4 100644 (file)
@@ -7478,7 +7478,7 @@ a")
             def foo(self, input):
                 self.call_foo(input)
 
-        with self.assertRaisesRegex(RuntimeError, 'called recursively'):
+        with self.assertRaisesRegex(RuntimeError, 'called recursively involving'):
             M()
 
     def test_script_kwargs_fn_call(self):
index ff27ce3..89a5ed8 100644 (file)
@@ -95,7 +95,6 @@ libtorch_sources = [
     "torch/csrc/jit/register_quantized_ops.cpp",
     "torch/csrc/jit/scope.cpp",
     "torch/csrc/jit/script/compiler.cpp",
-    "torch/csrc/api/src/jit.cpp",
     "torch/csrc/jit/script/edit_distance.cpp",
     "torch/csrc/jit/script/logging.cpp",
     "torch/csrc/jit/script/final_returns.cpp",
index 4b2281b..60f883a 100644 (file)
@@ -175,7 +175,6 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/register_quantized_ops.cpp
   ${TORCH_SRC_DIR}/csrc/jit/scope.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/compiler.cpp
-  ${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
   ${TORCH_SRC_DIR}/csrc/jit/testing/file_check.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/final_returns.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/schema_matching.cpp
@@ -237,6 +236,7 @@ if (NOT NO_API)
     ${TORCH_SRC_DIR}/csrc/api/src/data/samplers/random.cpp
     ${TORCH_SRC_DIR}/csrc/api/src/data/samplers/sequential.cpp
     ${TORCH_SRC_DIR}/csrc/api/src/data/samplers/stream.cpp
+    ${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
     ${TORCH_SRC_DIR}/csrc/api/src/nn/init.cpp
     ${TORCH_SRC_DIR}/csrc/api/src/nn/module.cpp
     ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/batchnorm.cpp
@@ -528,6 +528,7 @@ if (BUILD_PYTHON)
     ${TORCH_SRC_DIR}/csrc/jit/python_tracer.cpp
     ${TORCH_SRC_DIR}/csrc/jit/script/init.cpp
     ${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp
+    ${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
     ${TORCH_SRC_DIR}/csrc/jit/script/python_tree_views.cpp
     ${TORCH_SRC_DIR}/csrc/multiprocessing/init.cpp
     ${TORCH_SRC_DIR}/csrc/nn/THNN.cpp
index 7e2e4c9..9814ead 100644 (file)
@@ -32,7 +32,7 @@ namespace jit {
 ///   )JIT");
 ///   IValue output = module->run_method("relu_script", a, b);
 /// \endrst
-TORCH_API std::shared_ptr<script::CompilationUnit> compile(const std::string& source);
+TORCH_API std::shared_ptr<script::Module> compile(const std::string& source);
 
 } // namespace jit
 } // namespace torch
index a66e947..29ea39f 100644 (file)
@@ -9,9 +9,10 @@
 namespace torch {
 namespace jit {
 
-std::shared_ptr<script::CompilationUnit> compile(const std::string& source) {
-  auto module = std::make_shared<script::CompilationUnit>();
-  module->define(source, script::nativeResolver, nullptr);
+std::shared_ptr<script::Module> compile(const std::string& source) {
+  auto module = std::make_shared<script::Module>();
+  defineMethodsInModule(
+      module, source, script::nativeResolver, /*self=*/c10::nullopt);
   return module;
 }
 
index e27988d..6a74810 100644 (file)
@@ -6,6 +6,39 @@ namespace torch {
 namespace jit {
 namespace script {
 
+// this is a much simpler accessor that only handles modules, parameters, and
+// and methods. It does not depend on python to work.
+struct ModuleAccessorValue : public SugaredValue {
+  ModuleAccessorValue(std::shared_ptr<Module> module)
+      : module(std::move(module)) {}
+  std::string kind() const override {
+    return "module";
+  }
+  // select an attribute on it, e.g. `this.field`
+  std::shared_ptr<SugaredValue> attr(
+      const SourceRange& loc,
+      Method& m,
+      const std::string& field) override {
+    if (std::shared_ptr<Module> v = module->find_module(field)) {
+      return std::make_shared<ModuleAccessorValue>(std::move(v));
+    } else if (script::Slot* v = module->find_parameter(field)) {
+      return std::make_shared<SimpleValue>(m.get_or_add_parameter(*v));
+    } else if (script::Slot* v = module->find_buffer(field)) {
+      return std::make_shared<SimpleValue>(m.get_or_add_parameter(*v));
+    } else if (script::Slot* v = module->find_attribute(field)) {
+      return std::make_shared<script::SimpleValue>(
+          m.get_or_add_attribute(*v));
+    } else if (Method* m = module->find_method(field)) {
+      return std::make_shared<MethodValue>(shared_from_this(), *m);
+    } else {
+      throw ErrorReport(loc) << "unknown attr: " << field;
+    }
+  }
+
+ private:
+  std::shared_ptr<Module> module;
+};
+
 struct OpsValue : public SugaredValue {
   OpsValue(size_t version) : version_(version) {}
   std::string kind() const override {
@@ -13,7 +46,7 @@ struct OpsValue : public SugaredValue {
   }
   std::shared_ptr<SugaredValue> attr(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const std::string& field) override {
     return std::make_shared<BuiltinModule>(field, version_);
   }
@@ -26,7 +59,7 @@ struct ConstantValue : public SugaredValue {
   std::string kind() const override {
     return "constant";
   }
-  Value* asValue(const SourceRange& loc, Function& m) override {
+  Value* asValue(const SourceRange& loc, Method& m) override {
     return m.graph()->insertConstant(value_);
   }
 };
@@ -42,7 +75,7 @@ struct ConstantTableValue : public SugaredValue {
   // select an attribute on it, e.g. `this.field`
   std::shared_ptr<SugaredValue> attr(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const std::string& field) override {
     const char* field_s = field.c_str();
     char* end;
@@ -84,7 +117,7 @@ struct SourceImporter {
     };
 
     resolver_ = [&](const std::string& name,
-                    Function& m,
+                    Method& m,
                     const SourceRange& loc) -> std::shared_ptr<SugaredValue> {
       auto it = env_.find(name);
       if (it == env_.end()) {
@@ -100,7 +133,7 @@ struct SourceImporter {
   const std::vector<at::Tensor>& constant_table_;
   std::unordered_map<std::string, std::shared_ptr<SugaredValue>> env_;
   std::function<std::shared_ptr<
-      SugaredValue>(const std::string& name, Function& m, const SourceRange& loc)>
+      SugaredValue>(const std::string& name, Method& m, const SourceRange& loc)>
       resolver_;
 
   size_t parseVersionNumber() {
@@ -134,11 +167,8 @@ void import_methods(
     definitions.emplace_back(def);
     resolvers.emplace_back(importer.resolver_);
   }
-  auto self = [&](Value* v) {
-    v->setType(mod->module_object()->type());
-    return std::make_shared<SimpleValue>(v);
-  };
-  mod->module_object()->type()->compilation_unit().define(definitions, resolvers, self);
+  auto self = std::make_shared<ModuleAccessorValue>(mod);
+  defineMethodsInModule(mod, definitions, resolvers, Self(self));
 }
 
 void import_libs(
@@ -156,13 +186,9 @@ void import_libs(
       resolvers.emplace_back(importer.resolver_);
     }
 
-    auto cu = std::make_shared<CompilationUnit>();
-    auto class_type = ClassType::create(class_def.name().name(), cu);
-    auto self = [&](Value* v) {
-      v->setType(class_type);
-      return std::make_shared<SimpleValue>(v);
-    };
-    cu->define(definitions, resolvers, self);
+    auto mod = std::make_shared<Module>();
+    Self self(ClassType::create(class_def.name().name(), mod));
+    defineMethodsInModule(mod, definitions, resolvers, self);
   }
 }
 
index 42e7267..ce43b5b 100644 (file)
@@ -1074,10 +1074,6 @@ struct Graph {
       const std::string& field,
       Value* newValue);
   TORCH_API Node* createGetAttr(Value* obj, const std::string& field);
-  TORCH_API Value* insertGetAttr(Value* obj, const std::string& field) {
-    return insertNode(createGetAttr(obj, field))->output();
-  }
-
   // Note: defined in python_ir.cpp and can be used only in python extension
   Node* createPythonOp(
       THPObjectPtr&& pyobj,
index e7d5f94..a8ae6d3 100644 (file)
@@ -267,9 +267,10 @@ struct GraphFuser {
             norm_invstd = 1 / (eps + torch.sqrt(norm_var))
             return ((input - norm_mean) * norm_invstd)
       )SCRIPT";
-          script::CompilationUnit cu;
-          cu.define(source, script::nativeResolver, nullptr);
-          *graph_ptr = cu.get_function("batch_norm").graph();
+          auto module = std::make_shared<script::Module>();
+          defineMethodsInModule(
+              module, source, script::nativeResolver, /*self=*/c10::nullopt);
+          *graph_ptr = module->get_method("batch_norm").graph();
         },
         &bn_graph);
 
index f7bfd99..b157cc6 100644 (file)
@@ -1133,16 +1133,6 @@ struct PythonPrintPass {
         [](const Argument& arg) { return arg.default_value(); });
     printFunction(graph, name, is_class, defaults, ivalue_names);
   }
-  void printFunction(
-      script::Function& method,
-      bool is_class) {
-    const std::string& name = method.name();
-    Graph& graph = *method.graph();
-    auto defaults = fmap(
-        method.getSchema().arguments(),
-        [](const Argument& arg) { return arg.default_value(); });
-    printFunction(graph, name, is_class, defaults, {});
-  }
   void printModule(script::Module& module) {
     std::unordered_map<script::Slot, QualifiedNamePtr> extra_ivalue_names;
     createTensorToParameterNameMap(
@@ -1163,8 +1153,9 @@ struct PythonPrintPass {
     out << "class " << classType->name() << ":\n";
     {
       const auto guard = WithIndented();
+      std::unordered_map<script::Slot, QualifiedNamePtr> extra_ivalue_names;
       for (auto& method : classType->methods()) {
-        printFunction(*method, /*is_class=*/true);
+        printMethod(*method, /*is_class=*/true, extra_ivalue_names);
       }
     }
   }
index 1a40aa4..7652dc0 100644 (file)
@@ -137,7 +137,6 @@ void ConcretePythonOp::cloneFrom(Node* other_) {
   this->cconv = other->cconv;
   Py_INCREF(other->pyobj.get());
   this->pyobj = THPObjectPtr(other->pyobj.get());
-  this->ignore_on_export = other->ignore_on_export;
   for (auto& sa : other->scalar_args) {
     Py_INCREF(sa.get());
     this->scalar_args.emplace_back(sa.get());
index 2ee7730..a1ed46c 100644 (file)
@@ -37,8 +37,8 @@ def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]:
 )SCRIPT");
 
 struct BuiltinFunctionRegistry {
-  const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
-    const static std::vector<Function*> empty;
+  const std::vector<Method*>& getAllBuiltinFunctionsFor(Symbol name) {
+    const static std::vector<Method*> empty;
     // when initializing the builtin function library, we will re-enter
     // getAllBuiltinFunctionsFor since it is called in the compiler to
     // lookup builtins and initializing the builtin functions calls the
@@ -62,10 +62,11 @@ struct BuiltinFunctionRegistry {
 
  private:
   void loadSource(const std::string& source) {
-    std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
-    modules.emplace_back(cu);
-    cu->define(source, script::nativeResolver, /*self=*/nullptr);
-    for (auto& method : cu->get_functions()) {
+    auto module = std::make_shared<script::Module>();
+    defineMethodsInModule(
+        module, source, script::nativeResolver, /*self=*/c10::nullopt);
+    modules.push_back(module);
+    for (auto& method : module->get_methods()) {
       builtins_by_name[Symbol::fromQualString("aten::" + method->name())]
           .push_back(method.get());
     }
@@ -96,11 +97,11 @@ struct BuiltinFunctionRegistry {
   }
   enum { UNINITIALIZED, INTIIALIZING, INITIALIZED } state = UNINITIALIZED;
   std::recursive_mutex mutex;
-  std::vector<std::shared_ptr<CompilationUnit>> modules;
-  std::unordered_map<Symbol, std::vector<Function*>> builtins_by_name;
+  std::vector<std::shared_ptr<Module>> modules;
+  std::unordered_map<Symbol, std::vector<Method*>> builtins_by_name;
 };
 
-TORCH_API const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
+TORCH_API const std::vector<Method*>& getAllBuiltinFunctionsFor(Symbol name) {
   static BuiltinFunctionRegistry registry;
   return registry.getAllBuiltinFunctionsFor(name);
 }
index f1a5f22..42e15e7 100644 (file)
@@ -7,7 +7,7 @@ namespace torch {
 namespace jit {
 namespace script {
 
-TORCH_API const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name);
+TORCH_API const std::vector<Method*>& getAllBuiltinFunctionsFor(Symbol name);
 
 }
 } // namespace jit
index 0841e80..f669e03 100644 (file)
@@ -5,20 +5,13 @@ namespace c10 {
 
 // This file exists because we need to reference module.h, which we can't from
 // c10. Sigh...
-Function* ClassType::getMethod(const std::string& name) const {
-  return compilation_unit_->find_function(name).get();
+Method* ClassType::getMethod(const std::string& name) const {
+  return module_? module_->find_method(name) : nullptr;
 }
 
-CompilationUnit& ClassType::compilation_unit() {
-  return *compilation_unit_;
-}
-const CompilationUnit& ClassType::compilation_unit() const {
-  return *compilation_unit_;
-}
-
-std::vector<Function*> ClassType::methods() const {
-  std::vector<Function*> ret;
-  for (const auto& pr : compilation_unit().get_functions()) {
+std::vector<Method*> ClassType::methods() const {
+  std::vector<Method*> ret;
+  for (const auto& pr : module_->get_methods()) {
     ret.push_back(pr.get());
   }
   return ret;
diff --git a/torch/csrc/jit/script/compilation_unit.h b/torch/csrc/jit/script/compilation_unit.h
deleted file mode 100644 (file)
index 790d061..0000000
+++ /dev/null
@@ -1,285 +0,0 @@
-#pragma once
-#include <c10/util/Exception.h>
-#include <torch/csrc/jit/graph_executor.h>
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/source_range.h>
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/utils/memory.h>
-
-#include <ATen/core/function_schema.h>
-#include <c10/util/ArrayRef.h>
-#include <c10/util/Optional.h>
-
-#include <functional>
-#include <memory>
-#include <mutex>
-#include <ostream>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-namespace torch {
-namespace jit {
-
-namespace script {
-
-struct Def;
-struct SugaredValue;
-struct Function;
-
-using Resolver = std::function<std::shared_ptr<SugaredValue>(
-    const std::string& name,
-    Function& f,
-    const SourceRange& loc)>;
-using Self = std::function<std::shared_ptr<SugaredValue>(Value*)>;
-
-// A Function is a pure Graph with no implicit `self` object bound.
-// It contains schema information, and the executor that manages the
-// execution of the function. script::Method is a wrapper around a
-// underlying Function that also provides a `self` object.
-struct TORCH_API Function {
-  Function(
-      std::string name,
-      bool optimize,
-      std::shared_ptr<Graph> graph,
-      std::function<void(Function&)> function_creator)
-      : name_(std::move(name)),
-        graph_(std::move(graph)),
-        optimize_(optimize),
-        function_creator_(std::move(function_creator)) {}
-
-  void run(Stack& stack) {
-    get_executor().run(stack);
-  }
-
-  void run(Stack&& stack) {
-    run(stack);
-  }
-
-  IValue operator()(std::vector<IValue> stack) {
-    getSchema().checkAndNormalizeInputs(stack);
-    run(stack);
-    return stack.front();
-  }
-
-  std::shared_ptr<Graph> graph_for(Stack inputs) {
-    return get_executor().graphFor(inputs);
-  }
-
-  std::shared_ptr<Graph> graph() const {
-    return graph_;
-  }
-
-  const std::string& name() const {
-    return name_;
-  }
-
-  // if this isn't yet defined, run its method_creator function
-  void ensure_defined();
-
-  size_t num_inputs() const {
-    return graph()->inputs().size();
-  }
-
-  Function& setSchema(FunctionSchema schema) {
-    schema_ = make_unique<FunctionSchema>(std::move(schema));
-    return *this;
-  }
-
-  const FunctionSchema& getSchema() const {
-    if (schema_ == nullptr) {
-      schema_ = make_unique<FunctionSchema>(defaultSchemaFor(*this));
-    }
-    return *schema_;
-  }
-
-  std::string pretty_print_schema() const {
-    AT_ASSERT(schema_);
-    std::stringstream ss;
-    ss << *schema_;
-    return ss.str();
-  }
-
-  GraphExecutorState getDebugState() {
-    return get_executor().getDebugState();
-  }
-
-  void debugDisableAutodiffSubgraphInlining() {
-    return get_executor().debugDisableAutodiffSubgraphInlining();
-  }
-
-  bool is_optimized() const {
-    return optimize_;
-  }
-
-  void check_single_output() {
-    AT_CHECK(
-        graph()->outputs().size() == 1,
-        "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
-  }
-
-  GraphExecutor& get_executor() {
-    std::call_once(executor_init_, [&] {
-      check_single_output();
-      executor_ = GraphExecutor(graph(), optimize_);
-    });
-    return executor_;
-  }
-
-  // returns nullptr and fills in failure_messages if the callee does not
-  // match the functions schema
-
-  // TODO: defined in module.cpp, move to compilation_unit.cpp
-  Value* try_emit_call(
-      Graph& graph,
-      const SourceRange& loc,
-      c10::optional<NamedValue> self,
-      ArrayRef<NamedValue> args,
-      ArrayRef<NamedValue> kwargs,
-      std::stringstream& failure_messages,
-      bool conv_tensors_to_nums);
-
-  Value* emit_call(
-      Graph& graph,
-      const SourceRange& loc,
-      ArrayRef<NamedValue> args,
-      ArrayRef<NamedValue> kwargs);
-
- private:
-  static FunctionSchema defaultSchemaFor(const Function& function) {
-    std::vector<Argument> args;
-    std::vector<Argument> returns;
-    Graph& g = *function.graph();
-    size_t num_inputs = function.num_inputs();
-    for (size_t i = 0; i < num_inputs; ++i) {
-      const Value* v = g.inputs().at(i);
-      std::string name = v->hasUniqueName() ? v->uniqueNameBase()
-                                            : ("argument_" + std::to_string(i));
-      args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type()));
-    }
-    for (size_t i = 0; i < g.outputs().size(); ++i) {
-      returns.emplace_back("", unshapedType(g.outputs()[i]->type()));
-    }
-    return {function.name(), "", std::move(args), std::move(returns)};
-  }
-
-  std::string name_;
-  std::shared_ptr<Graph> graph_; // for debugging and for inlining
-  bool optimize_;
-
-  GraphExecutor executor_; // for execution
-
-  std::once_flag executor_init_;
-
-  // an optional function that actually creates the method when
-  // emit_call_to(this,...) is first called. this is used by the compiler so
-  // that it can construct methods out of order
-  std::function<void(Function&)> function_creator_;
-
-  // if absent, then we generate a default schema based on the graph
-  // mutable because getSchema caches the default schema if one is requested
-  // before a call to setSchema
-  mutable std::unique_ptr<FunctionSchema> schema_;
-};
-
-
-// A CompilationUnit is a list of named script::Functions
-// with helper methods to iterate the list, or invoke the function.
-// Classes have a CompilationUnit holding the class methods
-// and Modules also have a CompilationUnit holding the Functions that
-// are used to implement their Methods
-
-struct TORCH_API CompilationUnit {
-  std::shared_ptr<Function> find_function(const std::string& name) const {
-    auto it = dict_.find(name);
-    if (it == dict_.end())
-      return nullptr;
-    return functions_[it->second];
-  }
-
-  Function& get_function(const std::string& name) const {
-    if (auto r = find_function(name))
-      return *r;
-    AT_ERROR("attempted to get undefined function ", name);
-  }
-
-  void set_optimized(bool o) {
-    optimized_ = o;
-  }
-
-  bool is_optimized() const {
-    return optimized_;
-  }
-
-  // for historic reasons, these are defined in compiler.cpp
-  void define(
-      const std::vector<Def>& definitions,
-      const std::vector<Resolver>& resolvers, /* determines how we handle free
-                                                 variables in each definition*/
-      // if non-null, the first argument to each def, is bound to this value
-      const Self& self);
-
-  // same as above but parse the definitions from source
-  void define(
-      const std::string& source,
-      const Resolver& resolver,
-      const Self& self);
-
-  void clone_function(const Function& remote) {
-    create_function(remote.name(), remote.graph()->copy());
-  }
-
-  Function& create_function(std::string name, std::shared_ptr<Graph> graph) {
-    auto fn = std::make_shared<Function>(
-        std::move(name), is_optimized(), std::move(graph), nullptr);
-    return register_function(std::move(fn));
-  }
-
-  const std::vector<std::shared_ptr<Function>>& get_functions() const {
-    return functions_;
-  }
-
-  /// Run a method from this compilation.
-  ///
-  /// For example:
-  /// @code
-  ///   IValue output = module->run("relu_script", a, b);
-  /// @endcode
-  ///
-  /// To get a compile a module from a source string, see torch::jit::compile
-  ///
-  /// @param method_name The name of the method to run
-  /// @param args Arguments to be passed to the method
-  /// @return An IValue containing the return value (or values if it is a tuple)
-  /// from the method
-  template <typename... Types>
-  IValue run_method(const std::string& method_name, Types&&... args) {
-    return get_function(method_name)({IValue(std::forward<Types>(args))...});
-  }
-
-  void drop_all_functions() {
-    dict_.clear();
-    functions_.clear();
-  }
-
- private:
-  Function& register_function(std::shared_ptr<Function> fn) {
-    AT_CHECK(
-        0 == dict_.count(fn->name()),
-        "method '",
-        fn->name(),
-        "' already defined.");
-    functions_.emplace_back(std::move(fn));
-    dict_[functions_.back()->name()] = functions_.size() - 1;
-    return *functions_.back();
-  }
-  std::vector<std::shared_ptr<Function>> functions_;
-  // for fast lookup
-  std::unordered_map<std::string, size_t> dict_;
-  bool optimized_ = true;
-};
-
-} // namespace script
-} // namespace jit
-} // namespace torch
index 6048ed6..11358e2 100644 (file)
@@ -25,7 +25,7 @@ namespace jit {
 namespace script {
 
 using SugaredValuePtr = std::shared_ptr<SugaredValue>;
-using FunctionTable = std::unordered_map<std::string, Function&>;
+using FunctionTable = std::unordered_map<std::string, Method&>;
 using ValueTable = std::unordered_map<std::string, SugaredValuePtr>;
 using AttributeMap = std::unordered_map<std::string, Const>;
 using ListAttributeMap = std::unordered_map<std::string, std::vector<Const>>;
@@ -190,7 +190,7 @@ static bool meaningfulName(const std::string& name) {
 //      delete unnecessary ones later with replaceAllusesWith().
 struct Environment {
   Environment(
-      Function& method,
+      Method& method,
       Resolver resolver,
       Block* b,
       std::shared_ptr<Environment> next = nullptr)
@@ -199,7 +199,7 @@ struct Environment {
         b(b),
         next(std::move(next)) {}
 
-  Function& method;
+  Method& method;
   Resolver resolver;
   std::vector<std::string> captured_inputs;
   std::unordered_map<std::string, std::string> error_messages;
@@ -518,8 +518,8 @@ struct to_ir {
   to_ir(
       const Def& def,
       Resolver resolver_,
-      const Self& self,
-      Function& method) // method being constructed
+      const c10::optional<Self>& self,
+      Method& method) // method being constructed
       : method(method),
         graph(method.graph()),
         resolver(std::move(resolver_)),
@@ -541,7 +541,7 @@ struct to_ir {
   }
 
  private:
-  Function& method;
+  Method& method;
   std::shared_ptr<Graph> graph;
   Resolver resolver;
   std::unordered_map<int64_t, Value*> integral_constants;
@@ -577,7 +577,7 @@ struct to_ir {
 
   FunctionSchema emitDef(
       const Def& def,
-      const Self& self,
+      const c10::optional<Self>& self,
       Block* block) {
     auto schema = extractSchemaFromDef(def, self);
     // TODO need guards on init returning none
@@ -624,16 +624,15 @@ struct to_ir {
         blank_decl,
         List<Stmt>::create(r, {ret}));
     auto m = std::make_shared<Module>();
-    CompilationUnit cu;
-    cu.define({def}, {resolver}, nullptr);
+    defineMethodsInModule(m, {def}, {resolver}, c10::nullopt);
     Stack stack;
-    cu.get_function("defaults").run(stack);
+    m->get_method("defaults").run(stack);
     return stack.at(0).toTuple()->elements();
   }
 
   std::vector<Argument> parseArgsFromDecl(
       const Decl& decl,
-      const Self& self) {
+      const c10::optional<Self>& self) {
     auto params_begin = decl.params().begin();
     auto params_end = decl.params().end();
     if (self) {
@@ -707,7 +706,7 @@ struct to_ir {
   }
   FunctionSchema extractSchemaFromDef(
       const Def& def,
-      const Self& self) {
+      const c10::optional<Self>& self) {
     const auto name = def.name().name();
     std::vector<Argument> args = parseArgsFromDecl(def.decl(), self);
     std::vector<Argument> returns = parseReturnFromDecl(def.decl());
@@ -717,10 +716,9 @@ struct to_ir {
 
   std::vector<Argument> emitFormalArguments(
       const Def& def,
-      const Self& self,
+      const c10::optional<Self>& self,
       const FunctionSchema& schema,
       Block* block) {
-
     std::vector<Argument> arguments; // for schema
     // inputs
     auto it = def.decl().params().begin();
@@ -740,9 +738,14 @@ struct to_ir {
     if (self) {
       AT_ASSERT(it != end);
       const auto& name = (*it).ident().name();
-      Value* new_input = block->addInput()->setUniqueName(name);
-      environment_stack->setSugaredVar((*it).ident().range(), name, self(new_input));
-      arguments.emplace_back(name, new_input->type());
+      if (auto type = self->asFirstClass()) {
+        Value* new_input =
+            block->addInput()->setUniqueName(name)->setType(type);
+        environment_stack->setVar((*it).ident().range(), name, new_input);
+        arguments.emplace_back(name, type);
+      } else {
+        environment_stack->setSugaredVar(def.range(), name, self->asSugared());
+      }
       ++it;
     }
     size_t arg_annotation_idx = 0;
@@ -828,7 +831,7 @@ struct to_ir {
       pushFrame(block, /*starts_def=*/true);
       emitDef(
           def,
-          nullptr,
+          c10::nullopt,
           block); // ignore schema return, we just wont use it for now since we
                   // never create a Method for the closure
       popFrame(/*ends_def=*/true);
@@ -2260,6 +2263,7 @@ struct to_ir {
       node_output = fork_node->output()->setType(
           FutureType::create(fn_simple_output->type()));
     }
+
     // Lambda lift block(0) into attr::Subgraph
     lambdaLiftFork(fork_node);
 
@@ -2751,14 +2755,15 @@ struct to_ir {
   }
 };
 
-void CompilationUnit::define(
+void defineMethodsInModule(
+    const std::shared_ptr<Module>& m,
     const std::vector<Def>& definitions,
     const std::vector<Resolver>& resolvers,
-    const Self& self) {
+    const c10::optional<Self>& self) {
   AT_ASSERT(definitions.size() == resolvers.size());
   auto resolver_it = resolvers.begin();
-  std::vector<Function*> methods;
-  std::unordered_map<std::string, Function*> function_table;
+  std::vector<Method*> methods;
+  std::unordered_map<std::string, Method*> function_table;
   for (const Def& def : definitions) {
     const std::string& name = def.name().name();
     auto resolver = *resolver_it++;
@@ -2769,34 +2774,37 @@ void CompilationUnit::define(
       // the function table so the methods can see each other
       resolver = [resolver, &function_table](
                      const std::string& name,
-                     Function& m,
+                     Method& m,
                      const SourceRange& loc) -> std::shared_ptr<SugaredValue> {
         auto it = function_table.find(name);
         if (it != function_table.end()) {
-          return std::make_shared<MethodValue>(c10::nullopt, *it->second);
+          return std::make_shared<MethodValue>(nullptr, *it->second);
         }
         return resolver(name, m, loc);
       };
     }
-    auto creator = [def, resolver, self](Function& method) {
+    auto creator = [def, resolver, self](Method& method) {
       AT_ASSERT(resolver);
       to_ir(def, resolver, self, method);
     };
-    std::unique_ptr<Function> fn(
-        new Function(name, is_optimized(), std::make_shared<Graph>(), creator));
-    function_table[name] = fn.get();
-    methods.push_back(fn.get());
-    register_function(std::move(fn));
+    Method& method = m->create_method(name, creator);
+    function_table[name] = &method;
+    methods.push_back(&method);
   }
-  for (Function* method : methods) {
+  for (Method* method : methods) {
     method->ensure_defined();
   }
+  if (!self || !self->asFirstClass()) {
+    // Disable module hooks if the module is only used to store a class's code.
+    didFinishEmitModule(m);
+  }
 }
 
-void CompilationUnit::define(
+void defineMethodsInModule(
+    const std::shared_ptr<Module>& m,
     const std::string& source,
     const Resolver& resolver,
-    const Self& self) {
+    const c10::optional<Self>& self) {
   Parser p(source);
   std::vector<Def> definitions;
   std::vector<Resolver> resolvers;
@@ -2805,7 +2813,7 @@ void CompilationUnit::define(
     definitions.push_back(def);
     resolvers.push_back(resolver);
   }
-  define(definitions, resolvers, self);
+  defineMethodsInModule(m, definitions, resolvers, self);
 }
 
 void lambdaLiftFork(Node* fork_node) {
@@ -2830,7 +2838,6 @@ void lambdaLiftFork(Node* fork_node) {
   fork_node->g_(attr::Subgraph, forked_graph);
   fork_node->eraseBlock(0);
 }
-
 } // namespace script
 } // namespace jit
 } // namespace torch
index 3965467..3c2bb2d 100644 (file)
@@ -13,9 +13,12 @@ namespace torch {
 namespace jit {
 namespace script {
 
+using Resolver = std::function<std::shared_ptr<
+    SugaredValue>(const std::string& name, Method& m, const SourceRange& loc)>;
+
 inline std::shared_ptr<SugaredValue> nativeResolver(
     const std::string& name,
-    Function& m,
+    Method& m,
     const SourceRange& loc) {
   if (name == "torch") {
     return std::make_shared<BuiltinModule>("aten");
@@ -23,6 +26,47 @@ inline std::shared_ptr<SugaredValue> nativeResolver(
   return nullptr;
 }
 
+// Represents the `self` argument to a method. This wrapper class is necessary
+// because sometimes `self` sometimes is first class and sometimes not.
+//
+// `self` is first class when it refers to a ClassType. It will be bound as a
+// graph input argument.
+// `self` is sugared when it refers to a ModuleValue.
+class Self {
+ public:
+  explicit Self(std::shared_ptr<SugaredValue> sugared)
+      : sugared_(std::move(sugared)) {}
+  explicit Self(ClassTypePtr type) : firstClass_(std::move(type)) {}
+
+  ClassTypePtr asFirstClass() const {
+    return firstClass_;
+  }
+  std::shared_ptr<SugaredValue> asSugared() const {
+    return sugared_;
+  }
+
+ private:
+  // Used when `self` is not first-class and so we don't represent it in the
+  // graph. This is only ModuleValue.
+  std::shared_ptr<SugaredValue> sugared_ = nullptr;
+  // Used when `self` is a first-class type
+  ClassTypePtr firstClass_ = nullptr;
+};
+
+TORCH_API void defineMethodsInModule(
+    const std::shared_ptr<Module>& m,
+    const std::vector<Def>& definitions,
+    const std::vector<Resolver>& resolvers, /* determines how we handle free
+                                               variables in each definition*/
+    // if non-null, the first argument to each def, is bound to this value
+    const c10::optional<Self>& self);
+
+// same as above but parse the definitions from source
+TORCH_API void defineMethodsInModule(
+    const std::shared_ptr<Module>& m,
+    const std::string& source,
+    const Resolver& resolver,
+    const c10::optional<Self>& self);
 
 TORCH_API void lambdaLiftFork(Node* fork_node);
 
index 1aeadea..bff6d85 100644 (file)
@@ -64,9 +64,10 @@ inline std::shared_ptr<SugaredValue> toSimple(Value* v) {
 // type, *add it in this function's implementation*.
 std::shared_ptr<SugaredValue> toSugaredValue(
     py::object obj,
-    Function& m,
+    Method& m,
     SourceRange loc,
-    bool is_constant = false);
+    bool is_constant = false,
+    bool is_submodule = false);
 
 struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
   PythonValue(py::object self) : self(std::move(self)) {}
@@ -124,7 +125,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
   // call it like a function, e.g. `outputs = this(inputs)`
   std::shared_ptr<SugaredValue> call(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       at::ArrayRef<NamedValue> inputs_,
       at::ArrayRef<NamedValue> attributes,
       size_t n_binders) override {
@@ -181,7 +182,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
 
   std::vector<std::shared_ptr<SugaredValue>> asTuple(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const c10::optional<size_t>& size_hint = {}) override {
     const std::string type_str = typeString(self);
     std::stringstream ss;
@@ -192,7 +193,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
 
   std::shared_ptr<SugaredValue> attr(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const std::string& field) override {
     const std::string type_str = typeString(self);
     std::stringstream ss;
@@ -218,7 +219,7 @@ struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
 
   std::shared_ptr<SugaredValue> attr(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const std::string& field) override {
     py::object member = getattr(loc, field);
     // note: is_constant = true because we consider that global properties
@@ -233,7 +234,7 @@ struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
       : PythonValue(std::move(tup)) {}
   std::vector<std::shared_ptr<SugaredValue>> asTuple(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const c10::optional<size_t>& size_hint = {}) override {
     py::tuple tup = self;
     std::vector<std::shared_ptr<SugaredValue>> result;
@@ -245,7 +246,7 @@ struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
     return result;
   }
 
-  Value* asValue(const SourceRange& loc, Function& m) override {
+  Value* asValue(const SourceRange& loc, Method& m) override {
     std::vector<Value*> values;
     for (const auto& sugared_item : asTuple(loc, m)) {
       values.push_back(sugared_item->asValue(loc, m));
@@ -257,65 +258,33 @@ struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
 
 // Represents all the parameters of a module as a List[Tensor]
 struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
-  ConstantParameterList(Value* the_list) : the_list_(the_list) {}
-  std::string kind() const override {
-    return "constant parameter list";
-  }
-  std::shared_ptr<SugaredValue> call(
-      const SourceRange& loc,
-      Function& caller,
-      at::ArrayRef<NamedValue> inputs,
-      at::ArrayRef<NamedValue> attributes,
-      size_t n_binders) override {
-    return toSimple(the_list_);
-  }
-
- private:
-  Value* the_list_;
-};
-
-struct VISIBILITY_HIDDEN OverloadedFunctionValue : public SugaredValue {
-  OverloadedFunctionValue(Value* module, std::vector<std::string> method_names)
-      : module_(module), method_names_(std::move(method_names)) {}
+  ConstantParameterList(std::shared_ptr<Module> module)
+      : module_(std::move(module)) {}
 
   std::string kind() const override {
-    return "overloaded function";
+    return "constant parameter list";
   }
 
   std::shared_ptr<SugaredValue> call(
       const SourceRange& loc,
-      Function& caller,
+      Method& caller,
       at::ArrayRef<NamedValue> inputs,
       at::ArrayRef<NamedValue> attributes,
       size_t n_binders) override {
-    std::stringstream err;
-    std::vector<NamedValue> new_inputs = inputs.vec();
-    new_inputs.insert(new_inputs.begin(), module_);
-
-    for (const std::string& method_name : method_names_) {
-      auto cls = module_->type()->expect<ClassType>();
-      Function* fn = cls->getMethod(method_name);
-      auto match = tryMatchSchema(
-          fn->getSchema(),
-          loc,
-          *caller.graph().get(),
-          c10::nullopt,
-          new_inputs,
-          attributes,
-          err,
-          true);
-      if (match) {
-        return MethodValue(module_, *fn)
-            .call(loc, caller, inputs, attributes, n_binders);
-      }
+    // Add all module parameters as inputs to the graph
+    std::vector<Value*> params;
+    const auto& param_list = module_->get_parameters();
+    for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) {
+      auto& param = *it;
+      params.push_back(caller.get_or_add_parameter(param));
     }
-    throw ErrorReport(loc) << "Could not find any matching overloads\n"
-                           << err.str();
+    auto list = caller.graph()->createList(TensorType::get(), params);
+    caller.graph()->insertNode(list);
+    return toSimple(list->output());
   }
 
  private:
-  Value* module_;
-  std::vector<std::string> method_names_;
+  std::shared_ptr<Module> module_;
 };
 
 // defines how modules/methods behave inside the script subset.
@@ -326,8 +295,7 @@ struct VISIBILITY_HIDDEN OverloadedFunctionValue : public SugaredValue {
 // holding the actual nn.Module class.
 
 struct ModuleValue : public SugaredValue {
-  ModuleValue(Value* self, std::shared_ptr<Module> module)
-      : self_(self), module_(std::move(module)) {}
+  ModuleValue(std::shared_ptr<Module> module) : module(std::move(module)) {}
 
   std::string kind() const override {
     return "module";
@@ -336,60 +304,45 @@ struct ModuleValue : public SugaredValue {
   // select an attribute on it, e.g. `this.field`
   std::shared_ptr<SugaredValue> attr(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const std::string& field) override {
     // workaround to make self.training work
     // it adds a buffer 'training' to the model if one doesn't exist
     // and then loads that parameter, casting it to bool
     if (field == "training") {
-      Slot* v = module_->find_buffer(field);
+      Slot* v = module->find_buffer(field);
       if (!v) {
-        py::object py_module = py::cast(module_);
+        py::object py_module = py::cast(module);
         bool training = py::cast<bool>(py::getattr(py_module, "training"));
         auto t =
             autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong));
-        module_->register_buffer("training", std::move(t));
-        v = module_->find_buffer(field);
+        module->register_buffer("training", std::move(t));
+        v = module->find_buffer(field);
       }
-      Value* the_tensor = m.graph()->insertGetAttr(self_, "training");
+      Value* the_tensor = m.get_or_add_parameter(*v);
       Value* the_bool = m.graph()->insert(prim::Bool, {the_tensor});
       return std::make_shared<SimpleValue>(the_bool);
     }
 
-    if (std::shared_ptr<Module> v = module_->find_module(field)) {
-      return std::make_shared<ModuleValue>(
-          m.graph()->insertGetAttr(self_, field), v);
-    } else if (auto kind = module_->kind_of(field)) {
-      // methods, parameters, attributes, and buffers are all first class
-      return SimpleValue(self_).attr(loc, m, field);
+    if (std::shared_ptr<Module> v = module->find_module(field)) {
+      return std::make_shared<ModuleValue>(v);
+    } else if (Method* v = module->find_method(field)) {
+      return std::make_shared<MethodValue>(shared_from_this(), *v);
+    } else if (Slot* v = module->find_parameter(field)) {
+      return std::make_shared<SimpleValue>(m.get_or_add_parameter(*v));
+    } else if (Slot* v = module->find_attribute(field)) {
+      return std::make_shared<SimpleValue>(
+          m.get_or_add_attribute(*v));
     }
 
     // This can also be a call to a non-script module, or a plain
     // python method. If so return this as a python value.
-    py::object py_module = py::cast(module_);
-
-    py::object overloads =
-        py_module.attr("_overloads").attr("get")(field, py::none());
-    if (!overloads.is_none()) {
-      return std::make_shared<OverloadedFunctionValue>(
-          self_, py::cast<std::vector<std::string>>(overloads));
-    }
-
+    py::object py_module = py::cast(module);
     if (py::object attr = py::getattr(py_module, field.c_str(), py::none())) {
       if (py::isinstance<py::function>(attr) &&
           py::hasattr(attr, "_is_parameter_list") &&
           py::cast<bool>(py::getattr(attr, "_is_parameter_list"))) {
-        Graph& g = *m.graph();
-        // Add all module parameters as inputs to the graph
-        std::vector<Value*> params;
-        const auto& param_list = module_->get_parameters();
-        for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) {
-          auto& param = *it;
-          params.emplace_back(g.insertGetAttr(self_, param.name()));
-        }
-        auto list =
-            g.insertNode(g.createTuple(params))->output();
-        return std::make_shared<ConstantParameterList>(list);
+        return std::make_shared<ConstantParameterList>(module);
       }
       if (py::isinstance<py::function>(attr) ||
           py::isinstance(attr, py::module::import("torch.nn").attr("Module")) ||
@@ -411,7 +364,7 @@ struct ModuleValue : public SugaredValue {
   // call module.forward
   std::shared_ptr<SugaredValue> call(
       const SourceRange& loc,
-      Function& caller,
+      Method& caller,
       at::ArrayRef<NamedValue> inputs,
       at::ArrayRef<NamedValue> attributes,
       size_t n_binders) override {
@@ -421,35 +374,28 @@ struct ModuleValue : public SugaredValue {
 
   std::vector<std::shared_ptr<SugaredValue>> asTuple(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const c10::optional<size_t>& size_hint = {}) override {
-    py::object py_module = py::cast(module_);
+    py::object py_module = py::cast(module);
     if (!py::isinstance(
             py_module,
             py::module::import("torch.jit").attr("_ConstModuleList")))
       return SugaredValue::asTuple(loc, m, size_hint);
     std::vector<std::shared_ptr<SugaredValue>> result;
-    for (py::handle py_submodule : py_module) {
-      py::object obj = py::reinterpret_borrow<py::object>(py_submodule);
-      if (py::isinstance<Module>(obj)) {
-        auto sub_module = py::cast<std::shared_ptr<Module>>(obj);
-        Value* module_v = m.graph()->insertGetAttr(self_, sub_module->name());
-        result.emplace_back(
-            std::make_shared<ModuleValue>(module_v, sub_module));
-      } else {
-        result.push_back(toSugaredValue(
-            obj,
-            m,
-            loc,
-            /*is_constant =*/false));
-      }
+    for (py::handle module : py_module) {
+      py::object obj = py::reinterpret_borrow<py::object>(module);
+      result.push_back(toSugaredValue(
+          obj,
+          m,
+          loc,
+          /*is_constant =*/false,
+          /*is_submodule =*/true));
     }
     return result;
   }
 
  private:
-  Value* self_;
-  std::shared_ptr<Module> module_;
+  std::shared_ptr<Module> module;
 };
 
 struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
@@ -462,7 +408,7 @@ struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
 
   std::shared_ptr<SugaredValue> call(
       const SourceRange& loc,
-      Function& caller,
+      Method& caller,
       at::ArrayRef<NamedValue> inputs,
       at::ArrayRef<NamedValue> attributes,
       size_t n_binders) override {
@@ -500,31 +446,54 @@ struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
   py::dict dispatched_fn_;
 };
 
-std::shared_ptr<MethodValue> moduleToMethod(
-    const std::shared_ptr<Module>& mod) {
-  // this path only supports calling raw script functions
-  // but because they are not distinguished from models, we have to check
-  // that they are function-like here. They must not have state, and they
-  // must have a forward method. When we expose functions to python
-  //  this will be replaced with a direct py::isinstance<Function> call.
-
-  if (mod->get_parameters().size() != 0) {
-    throw ErrorReport()
-        << "Attempted to inline a Module with parameters. "
-           "Stateful modules to be inlined must be submodules of the callee.";
+struct VISIBILITY_HIDDEN OverloadedFunctionValue : public SugaredValue {
+  OverloadedFunctionValue(py::list functions)
+      : possible_functions_(std::move(functions)) {}
+
+  std::string kind() const override {
+    return "overloaded function";
   }
-  Method* forward = mod->find_method("forward");
-  if (!forward) {
-    throw ErrorReport() << " expected this module to have a forward function.";
+
+  std::shared_ptr<SugaredValue> call(
+      const SourceRange& loc,
+      Method& caller,
+      at::ArrayRef<NamedValue> inputs,
+      at::ArrayRef<NamedValue> attributes,
+      size_t n_binders) override {
+    std::stringstream err;
+    auto possible_functions =
+        py::cast<std::vector<py::object>>(possible_functions_);
+
+    for (const py::object& fn : possible_functions) {
+      auto& method = py::cast<Method&>(fn);
+      auto match = tryMatchSchema(
+          method.getSchema(),
+          loc,
+          *caller.graph().get(),
+          c10::nullopt,
+          inputs,
+          attributes,
+          err,
+          true);
+      if (match) {
+        return MethodValue(nullptr, method)
+            .call(loc, caller, inputs, attributes, n_binders);
+      }
+    }
+    throw ErrorReport(loc) << "Could not find any matching overloads\n"
+                           << err.str();
   }
-  return std::make_shared<MethodValue>(at::nullopt, forward->function());
-}
+
+ private:
+  py::list possible_functions_;
+};
 
 std::shared_ptr<SugaredValue> toSugaredValue(
     py::object obj,
-    Function& m,
+    Method& m,
     SourceRange loc,
-    bool is_constant) {
+    bool is_constant,
+    bool is_submodule) {
   // directly create SimpleValues when possible, because they are first-class
   // and can be re-assigned. Otherwise, this would be invalid:
   // f = python_constant
@@ -565,12 +534,17 @@ std::shared_ptr<SugaredValue> toSugaredValue(
     obj = weak_obj;
   }
   if (py::isinstance<Module>(obj)) {
+    auto mod = py::cast<std::shared_ptr<Module>>(obj);
     // In the case that this Python object is not a submodule, inline *ONLY
     // PURE* ScriptModules. This allows us to call arbitrary @script functions
     // within a scripting context while still enforcing that parameters from
     // stateful submodules are properly accounted for.
-    auto mod = py::cast<std::shared_ptr<Module>>(obj);
-    return moduleToMethod(mod);
+    if (!is_submodule && mod->get_parameters().size() != 0) {
+      throw ErrorReport()
+          << "Attempted to inline a Module with parameters. "
+             "Stateful modules to be inlined must be submodules of the callee.";
+    }
+    return std::make_shared<ModuleValue>(mod);
   } else if (py::isinstance<py::module>(obj)) {
     return std::make_shared<PythonModuleValue>(obj);
   } else if (obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr()) {
@@ -592,7 +566,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
         py::module::import("torch.jit").attr("_try_compile_weak_script")(obj);
     if (!compiled_fn.is(py::none())) {
       auto mod = py::cast<std::shared_ptr<Module>>(compiled_fn);
-      return moduleToMethod(mod);
+      return std::make_shared<ModuleValue>(mod);
     }
   }
 
@@ -602,6 +576,12 @@ std::shared_ptr<SugaredValue> toSugaredValue(
     return std::make_shared<BooleanDispatchValue>(std::move(dispatched_fn));
   }
 
+  py::object overloads =
+      py::module::import("torch.jit").attr("_try_get_overloaded_fn")(obj);
+  if (!overloads.is_none()) {
+    return std::make_shared<OverloadedFunctionValue>(std::move(overloads));
+  }
+
   return std::make_shared<PythonValue>(obj);
 }
 
@@ -641,7 +621,7 @@ static void gatherParametersAndBuffers(
 namespace {
 
 Resolver pythonResolver(const ResolutionCallback& rcb) {
-  return [rcb](const std::string& name, Function& m, const SourceRange& loc)
+  return [rcb](const std::string& name, Method& m, const SourceRange& loc)
              -> std::shared_ptr<SugaredValue> {
     AutoGIL ag;
     py::object obj = rcb(name);
@@ -693,13 +673,6 @@ FunctionSchema getSchemaWithNameAndDefaults(
       schema.is_varret());
 }
 
-static Self moduleSelf(const std::shared_ptr<Module>& m) {
-  return [m](Value* v) {
-    v->setType(m->module_object()->type());
-    return std::make_shared<ModuleValue>(v, m);
-  };
-}
-
 void initJitScriptBindings(PyObject* module) {
   auto m = py::handle(module).cast<py::module>();
 
@@ -739,12 +712,9 @@ void initJitScriptBindings(PyObject* module) {
              bool has_self) {
             c10::optional<Self> self;
             if (has_self) {
-              m->class_compilation_unit().define(
-                  script, pythonResolver(rcb), moduleSelf(m));
-            } else {
-              m->_define_lowered(script, pythonResolver(rcb));
+              self = Self(std::make_shared<ModuleValue>(m));
             }
-            didFinishEmitModule(m);
+            defineMethodsInModule(m, script, pythonResolver(rcb), self);
           })
       .def(
           "_create_methods",
@@ -757,13 +727,14 @@ void initJitScriptBindings(PyObject* module) {
             for (auto& callback : rcbs) {
               resolvers.push_back(pythonResolver(callback));
             }
-            m->class_compilation_unit().define(defs, resolvers, moduleSelf(m));
+            defineMethodsInModule(
+                m, defs, resolvers, Self(std::make_shared<ModuleValue>(m)));
+
             // Stitch in default arguments for each Def if provided
             auto defaults_it = defaults.begin();
             auto defs_it = defs.begin();
             while (defs_it != defs.end()) {
-              auto& method = m->class_compilation_unit().get_function(
-                  (*defs_it).name().name());
+              auto& method = m->get_method((*defs_it).name().name());
               method.setSchema(getSchemaWithNameAndDefaults(
                   defs_it->range(),
                   method.getSchema(),
@@ -813,7 +784,8 @@ void initJitScriptBindings(PyObject* module) {
               auto& p = parameters[i];
               py::tuple r(2);
               result[i] = std::make_tuple(
-                  p.name(), autograd::as_variable_ref(p.value().toTensor()));
+                  p.name(),
+                  autograd::as_variable_ref(p.value().toTensor()));
             }
             return result;
           })
@@ -869,7 +841,7 @@ void initJitScriptBindings(PyObject* module) {
           [](Module& self,
              const std::string& name,
              std::shared_ptr<Graph> graph) {
-            self._define_lowered(name, std::move(graph), {});
+            self.create_method(name, std::move(graph), {});
           })
       .def(
           "_create_method_from_trace",
@@ -893,8 +865,7 @@ void initJitScriptBindings(PyObject* module) {
                 var_lookup_fn,
                 force_outplace,
                 input_tuple.size());
-            self->_define_lowered(
-                name, std::move(graph), std::move(parameters));
+            self->create_method(name, std::move(graph), std::move(parameters));
             didFinishEmitModule(self);
           })
       .def(
@@ -919,7 +890,7 @@ void initJitScriptBindings(PyObject* module) {
           [](Module& self) {
             if (self.find_method("forward")) {
               Method& m = self.get_method("forward");
-              return m.get_executor().getDebugState();
+              return m.getDebugState();
             }
             throw std::runtime_error(
                 "Attempted to call get_debug_state on a Module without a compiled forward()");
@@ -929,7 +900,7 @@ void initJitScriptBindings(PyObject* module) {
           [](Module& self) {
             if (self.find_method("forward")) {
               Method& m = self.get_method("forward");
-              m.get_executor().debugDisableAutodiffSubgraphInlining();
+              m.debugDisableAutodiffSubgraphInlining();
             }
           })
       .def(
@@ -987,8 +958,7 @@ void initJitScriptBindings(PyObject* module) {
             }
 
             Method* orig_method = orig->find_method(name);
-            m->_define_lowered(
-                name, orig_method->graph()->copy(), std::move(member_inputs));
+            m->create_method(name, orig_method->graph()->copy(), member_inputs);
           });
 
   py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
@@ -1002,27 +972,10 @@ void initJitScriptBindings(PyObject* module) {
                 method, tuple_slice(std::move(args), 1), std::move(kwargs));
           })
       .def_property_readonly("graph", [](Method& m) { return m.graph(); })
-      .def(
-          "propagate_shapes",
-          [](Method& m, const std::vector<at::Tensor>& inputs, bool with_grad) {
-            return propagate_shapes(
-                *m.graph(), inputs, m.initial_ivalues(), with_grad);
-          })
+      .def("propagate_shapes", &Method::propagate_shapes)
       .def(
           "propagate_and_assign_input_and_output_shapes",
-          [](Method& m,
-             const std::vector<at::Tensor>& inputs,
-             std::vector<at::Tensor> outputs,
-             bool with_grad,
-             bool propagate) {
-            return propagate_and_assign_input_and_output_shapes(
-                *m.graph(),
-                inputs,
-                m.initial_ivalues(),
-                outputs,
-                with_grad,
-                propagate);
-          })
+          &Method::propagate_and_assign_input_and_output_shapes)
       .def(
           "initial_ivalues",
           [](Method& m) {
@@ -1042,18 +995,9 @@ void initJitScriptBindings(PyObject* module) {
           })
       .def(
           "debug_disable_autodiff_subgraph_inlining",
-          [](Method& m) {
-            return m.get_executor().debugDisableAutodiffSubgraphInlining();
-          })
+          &Method::debugDisableAutodiffSubgraphInlining)
       .def("schema", &Method::getSchema)
-      .def(
-          "pretty_print_schema",
-          [](Method& m) {
-            const FunctionSchema& schema = m.getSchema();
-            std::stringstream ss;
-            ss << schema;
-            return ss.str();
-          })
+      .def("pretty_print_schema", &Method::pretty_print_schema)
       .def(
           "python_print",
           [](Method& m) {
@@ -1078,30 +1022,29 @@ void initJitScriptBindings(PyObject* module) {
          ResolutionCallback rcb,
          FunctionDefaults defaults) {
         auto def_f = def.withName("forward");
-
-        mod->_define_lowered({def_f}, {pythonResolver(rcb)});
-        auto& func = mod->lowered_methods().get_function("forward");
-        func.setSchema(getSchemaWithNameAndDefaults(
-            def.range(), func.getSchema(), def.name().name(), defaults));
-        auto& func2 = mod->class_compilation_unit().get_function("forward");
-        func2.setSchema(getSchemaWithNameAndDefaults(
-            def.range(), func2.getSchema(), def.name().name(), defaults));
+        defineMethodsInModule(
+            mod, {def_f}, {pythonResolver(rcb)}, c10::nullopt);
+        auto& method = mod->get_method("forward");
+        method.setSchema(getSchemaWithNameAndDefaults(
+            def.range(), method.getSchema(), def.name().name(), defaults));
         didFinishEmitModule(mod);
         return mod;
       });
 
   m.def(
       "_jit_script_class_compile",
-      [](const ClassDef& classDef, ResolutionCallback rcb) {
-        auto cu = std::make_shared<CompilationUnit>();
-        auto classType = ClassType::create(classDef.name().name(), cu);
+      [](std::shared_ptr<Module> module,
+         const ClassDef& classDef,
+         ResolutionCallback rcb) {
+        auto classType = ClassType::create(classDef.name().name(), module);
         std::vector<Resolver> rcbs;
         std::vector<Def> methodDefs;
         for (const auto& def : classDef.defs()) {
           methodDefs.push_back(def);
           rcbs.push_back(pythonResolver(rcb));
         }
-        cu->define(methodDefs, rcbs, simpleSelf(classType));
+        defineMethodsInModule(module, methodDefs, rcbs, Self(classType));
+        return module;
       });
 
   m.def("parse_type_comment", [](const std::string& comment) {
index 7772ab1..33cb951 100644 (file)
@@ -2,7 +2,6 @@
 #include <c10/util/Exception.h>
 #include <torch/csrc/jit/export.h>
 #include <torch/csrc/jit/operator.h>
-#include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/script/compiler.h>
 #include <torch/csrc/jit/script/error_report.h>
 #include <torch/csrc/jit/script/schema_matching.h>
@@ -12,39 +11,32 @@ namespace jit {
 namespace script {
 
 struct RecursiveMethodCallError : public std::exception {};
-void placeholderCreator(Function&) {
+void placeholderCreator(Method&) {
   throw RecursiveMethodCallError();
 }
 
-void Function::ensure_defined() {
-  try {
-    if (function_creator_) {
-      auto creator = function_creator_;
-      function_creator_ = placeholderCreator;
-      creator(*this);
-      function_creator_ = nullptr;
-    }
-  } catch (RecursiveMethodCallError&) {
-    throw ErrorReport() // TODO: once lower_first_class methods is removed
-                        // re-establish callsite info for debugging
-        << " method '" << name() << "' is called recursively. "
-        << "Recursive calls are not supported";
-  }
-}
-
-Value* Function::try_emit_call(
+Value* try_emit_call_to(
     Graph& graph,
     const SourceRange& loc,
+    Method& callee,
     c10::optional<NamedValue> self,
     ArrayRef<NamedValue> args,
     ArrayRef<NamedValue> kwargs,
     std::stringstream& failure_messages,
+    Method* caller,
     bool conv_tensors_to_nums) {
-  ensure_defined();
-  auto fn = this->graph();
+  try {
+    callee.ensure_defined();
+  } catch (RecursiveMethodCallError&) {
+    throw ErrorReport(loc)
+        << " method '" << callee.name()
+        << "' is called recursively involving this call site. "
+        << "Recursive calls are not supported";
+  }
+  auto fn = callee.graph();
 
   auto matched_schema = tryMatchSchema(
-      getSchema(),
+      callee.getSchema(),
       loc,
       graph,
       std::move(self),
@@ -55,29 +47,52 @@ Value* Function::try_emit_call(
   if (!matched_schema)
     return nullptr;
 
-  check_single_output();
-  return inlineCallTo(graph, *fn, matched_schema->inputs).at(0);
+  // parameters to callee method (which become parameters to _this_ method
+  // if they were not already)
+  for (const auto& member : callee.initial_ivalues()) {
+    if (!caller) {
+      throw ErrorReport(loc)
+          << " attempting to call a method with parameters/attributes"
+             " from a raw graph. File a bug report";
+    }
+    matched_schema->inputs.push_back(
+        caller->get_or_add_attribute(member));
+  }
+  callee.check_single_output();
+  return inlineCallTo(graph, *callee.graph(), matched_schema->inputs).at(0);
 }
 
-Value* Function::emit_call(
-    Graph& graph,
+Value* Method::emit_call_to(
     const SourceRange& loc,
+    Method& callee,
     ArrayRef<NamedValue> args,
     ArrayRef<NamedValue> kwargs) {
+  AT_ASSERT(!executor);
   std::stringstream failure_messages;
-  if (auto result = try_emit_call(
-          graph,
+  if (auto result = try_emit_call_to(
+          *graph(),
           loc,
+          callee,
           c10::nullopt,
           args,
           kwargs,
           failure_messages,
+          this,
           /*conv_tensors_to_nums=*/true)) {
     return result;
   }
   throw ErrorReport(loc) << failure_messages.str();
 }
 
+void Method::ensure_defined() {
+  if (method_creator) {
+    auto creator = method_creator;
+    method_creator = placeholderCreator;
+    creator(*this);
+    method_creator = nullptr;
+  }
+}
+
 void Module::to(at::Device device, at::ScalarType dtype, bool non_blocking) {
   to_impl(device, dtype, non_blocking);
 }
@@ -122,229 +137,6 @@ void Module::to_impl(
   }
 }
 
-// lower_first_class_method and lift_lowered_method are transitionary functions
-// used to translate between module-as-first-class code generation,
-// and module-as-special execution. Once module-as-first-class execution is
-// debugged, then we can remove both and remove the lowered_functions_ table.
-
-// remove the first module argument, replacing any access of its
-// parameters/attributes with extra_ivalue input Slots that hold what value to
-// pass into the graph
-std::pair<std::shared_ptr<Graph>, std::vector<Slot>> lower_graph(
-    const ModulePtr& self,
-    Graph& g_,
-    size_t self_offset = 0) {
-  std::shared_ptr<Graph> g = g_.copy();
-  std::vector<Slot> extra_ivalues;
-  std::unordered_map<Slot, size_t> slot_to_offset;
-  struct ToScan {
-    ModulePtr mod;
-    Node* n;
-    size_t offset;
-  };
-  std::vector<ToScan> to_scan;
-  std::vector<Node*> to_clean; // nodes that should be dead at the end
-
-  auto getOrAddSlot = [&](const Slot& slot) -> Value* {
-    auto it = slot_to_offset.find(slot);
-    if (it != slot_to_offset.end()) {
-      size_t ivalues_start = g->inputs().size() - extra_ivalues.size();
-      return g->inputs().at(ivalues_start + it->second);
-    }
-    extra_ivalues.emplace_back(slot);
-    slot_to_offset[slot] = extra_ivalues.size() - 1;
-    return g->addInput()->setType(slot.type());
-  };
-
-  auto self_value = g->inputs().at(self_offset);
-
-  for (Use use : self_value->uses()) {
-    to_scan.emplace_back(ToScan{self, use.user, use.offset});
-  }
-  while (to_scan.size() > 0) {
-    auto e = to_scan.back();
-    to_scan.pop_back();
-
-    // when we lambda lift forks, first-class modules may be passed across
-    // forks. This code recursively lowers the module in the fork call.
-    if (e.n->kind() == prim::fork) {
-      auto subgraph = e.n->g(attr::Subgraph);
-      std::vector<Slot> new_slots;
-      std::tie(subgraph, new_slots) = lower_graph(e.mod, *subgraph, e.offset);
-      e.n->g_(attr::Subgraph, subgraph);
-      for (const Slot& slot : new_slots) {
-        e.n->addInput(getOrAddSlot(slot));
-      }
-      e.n->removeInput(e.offset);
-      continue;
-    }
-    if (e.n->kind() != prim::GetAttr) {
-      throw ErrorReport(e.n->getSourceLocation())
-          << "temporary: the only valid use of a module is looking up an attribute";
-    }
-    Slot slot(e.mod, e.mod->type()->getAttributeSlot(e.n->s(attr::name)));
-    if (ClassTypePtr c = e.n->output()->type()->cast<ClassType>()) {
-      if (c->name() == "Module") {
-        auto obj = slot.value().toObject();
-        for (Use use : e.n->output()->uses()) {
-          to_scan.emplace_back(ToScan{obj, use.user, use.offset});
-        }
-        to_clean.emplace_back(e.n);
-        continue;
-      }
-    }
-    e.n->output()->replaceAllUsesWith(getOrAddSlot(slot));
-    e.n->destroy();
-  }
-
-  while (to_clean.size() > 0) {
-    Node* n = to_clean.back();
-    AT_ASSERT(!n->hasUses());
-    n->destroy();
-    to_clean.pop_back();
-  }
-  AT_ASSERT(!self_value->hasUses());
-  g->eraseInput(self_offset);
-
-  return std::make_pair(std::move(g), std::move(extra_ivalues));
-}
-
-Method& Module::lower_first_class_method(Function* fn) {
-  fn->ensure_defined();
-  auto lowered = lower_graph(module_object(), *fn->graph());
-  Function& new_func =
-      lowered_methods_.create_function(fn->name(), lowered.first);
-
-  // generate the new schema
-  // slice away the self argument
-  std::vector<Argument> args(
-      fn->getSchema().arguments().begin() + 1,
-      fn->getSchema().arguments().end());
-  size_t id = 0;
-  for (const Slot& slot : lowered.second) {
-    std::ostringstream ss;
-    ss << "slot" << id++;
-    args.emplace_back(ss.str(), slot.type());
-  }
-  new_func.setSchema(fn->getSchema().cloneWithArguments(std::move(args)));
-  return _create_lowered_method(&new_func, std::move(lowered.second));
-}
-
-static void createFirstClassValues(
-    Module* module,
-    Value* self,
-    std::unordered_map<Slot, Value*>& result) {
-  auto& g = *self->owningGraph();
-
-  std::vector<Node*> created;
-  struct ToScan {
-    Module* mod;
-    Value* v; // value representing module in the graph
-  };
-  std::vector<ToScan> to_scan = {{module, self}};
-
-  while (!to_scan.empty()) {
-    auto s = to_scan.back();
-    to_scan.pop_back();
-    size_t offset = 0;
-    for (const std::string& name :
-         s.mod->module_object()->type()->attributeNames()) {
-      Value* v = g.insertGetAttr(s.v, name);
-      result[Slot(s.mod->module_object(), offset++)] = v;
-      if (std::shared_ptr<Module> sub = s.mod->find_module(name)) {
-        to_scan.emplace_back(ToScan{sub.get(), v});
-      }
-    }
-  }
-}
-
-void Module::lift_lowered_method(Method& m) {
-  auto graph = m.graph()->copy();
-  Value* self = graph->insertInput(0, "self")->setType(module_object()->type());
-  std::unordered_map<Slot, Value*> slot_to_value;
-  if (!m.initial_ivalues().empty()) {
-    WithInsertPoint guard(*graph->nodes().begin());
-    createFirstClassValues(this, self, slot_to_value);
-  }
-
-  size_t orig_graph_inputs_size = graph->inputs().size();
-  for (size_t i = 0; i < m.initial_ivalues().size(); ++i) {
-    size_t input_offset = orig_graph_inputs_size - i - 1;
-    size_t ivalue_offset = m.initial_ivalues().size() - i - 1;
-    graph->inputs()
-        .at(input_offset)
-        ->replaceAllUsesWith(
-            slot_to_value.at(m.initial_ivalues().at(ivalue_offset)));
-    graph->eraseInput(input_offset);
-  }
-
-  if (!m.initial_ivalues().empty()) {
-    // we added _all_ the submodules as first-class values but maybe did not use
-    // them. So remove any dead attribute lookups
-    EliminateDeadCode(graph);
-  }
-
-  Function& new_fn = class_cu().create_function(m.name(), std::move(graph));
-  // created lifted schema
-  // self argument is named '$self' to prevent accidental name collisions
-  // with another input that the user named 'self'
-  std::vector<Argument> new_args = {Argument("$self", module_object()->type())};
-  const auto& lowered_args = m.function().getSchema().arguments();
-  new_args.insert(
-      new_args.end(),
-      lowered_args.begin(),
-      lowered_args.begin() + m.num_inputs());
-  new_fn.setSchema(m.function().getSchema().cloneWithArguments(std::move(new_args)));
-}
-
-Method& Module::_create_lowered_method(
-    Function* func,
-    std::vector<Slot> member_inputs) {
-  std::unique_ptr<Method> m(new Method(this, func, std::move(member_inputs)));
-  return *insert(func->name(), methods_, EntityType::METHOD, std::move(m));
-}
-
-void Module::lift_lowered_methods(size_t start) {
-  for (size_t i = start; i < lowered_methods_.get_functions().size(); ++i) {
-    Method& m = _create_lowered_method(
-        lowered_methods_.get_functions().at(i).get(), {});
-    lift_lowered_method(m);
-  }
-}
-
-void Module::_define_lowered(
-    const std::vector<Def>& definitions,
-    const std::vector<Resolver>& resolvers) {
-  size_t start = lowered_methods_.get_functions().size();
-  lowered_methods_.define(definitions, resolvers, nullptr);
-  lift_lowered_methods(start);
-  // call lift_lowered_method for each definition
-}
-
-void Module::_define_lowered(const std::string& src, const Resolver& resolver) {
-  size_t start = lowered_methods_.get_functions().size();
-  lowered_methods_.define(src, resolver, nullptr);
-  lift_lowered_methods(start);
-}
-
-Method& Module::_define_lowered(
-    std::string name,
-    std::shared_ptr<Graph> graph,
-    std::vector<Slot> slots) {
-  Method& m = _create_lowered_method(
-      &lowered_methods_.create_function(std::move(name), std::move(graph)),
-      std::move(slots));
-  lift_lowered_method(m);
-  return m;
-}
-
-void Module::define(const std::string& src, const Resolver& resolver) {
-  class_cu().define(
-      src,
-      resolver ? resolver : nativeResolver,
-      simpleSelf(module_object()->type()));
-}
-
 } // namespace script
 } // namespace jit
 } // namespace torch
index 000ea9c..13233d9 100644 (file)
@@ -12,7 +12,6 @@
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/api/include/torch/ordered_dict.h>
-#include <torch/csrc/jit/script/compilation_unit.h>
 #include <torch/csrc/utils/memory.h>
 
 #include <ATen/core/function_schema.h>
@@ -40,7 +39,6 @@ using ::c10::FunctionSchema;
 // Map which stores filename to content.
 using ExtraFilesMap = std::unordered_map<std::string, std::string>;
 
-using ModulePtr = c10::intrusive_ptr<c10::ivalue::Object>;
 // A method in a module, e.g. f in:
 //
 // class M(ScriptModule):
@@ -55,110 +53,320 @@ struct Module;
 using ModuleLookup =
     std::function<std::shared_ptr<Module>(const std::vector<std::string>&)>;
 
-struct TORCH_API Method {
-  Method(Module* owner, Function* function, std::vector<Slot> initial_members)
+struct Method {
+  Method(
+      Module* owner,
+      std::string name,
+      bool optimize,
+      std::shared_ptr<Graph> graph,
+      std::vector<Slot> initial_members,
+      std::function<void(Method&)> method_creator)
       : owner_(owner),
-        function_(function),
-        initial_ivalues_(std::move(initial_members)) {
-    AT_ASSERT(function->num_inputs() >= initial_ivalues_.size());
-  }
-
-  // the module that contains this method.
-  Module& owner() const {
-    return *owner_;
+        name_(std::move(name)),
+        graph_(std::move(graph)),
+        optimize(optimize),
+        initial_ivalues_(std::move(initial_members)),
+        method_creator(std::move(method_creator)) {
+    AT_ASSERT(graph_->inputs().size() >= initial_ivalues_.size());
+    int i = graph_->inputs().size() - initial_ivalues_.size();
+    for (auto member : initial_ivalues_) {
+      initial_ivalue_index[member] = i++;
+    }
   }
 
   void run(Stack& stack) {
     for (auto input : initial_ivalues_) {
       push(stack, input.value());
     }
-    function_->run(stack);
+    get_executor().run(stack);
   }
+
   void run(Stack&& stack) {
     run(stack);
   }
 
   IValue operator()(std::vector<IValue> stack) {
-    getSchema().checkAndNormalizeInputs(stack);
-    for (auto input : initial_ivalues_) {
-      push(stack, input.value());
-    }
-    // use run rather than operator() to skip the second schema check.
-    function_->run(std::move(stack));
+    checkInputsAgainstSchema(stack);
+    run(stack);
     return stack.front();
   }
 
+  std::shared_ptr<Graph> graph_for(Stack inputs) {
+    for (auto tp : initial_ivalues_) {
+      inputs.emplace_back(tp.value());
+    }
+    return get_executor().graphFor(inputs);
+  }
+  TORCH_API std::shared_ptr<Graph> graph() const {
+    return graph_;
+  }
+
+  TORCH_API const std::string& name() const {
+    return name_;
+  }
+  // emit a function call by inlining the callees Graph into this one
+  // adding any extra parameters necessary to do this call
+
+  // defined here to keep details of member_input handling confined to this
+  // class
+  Value* emit_call_to(
+      const SourceRange& loc,
+      Method& callee,
+      ArrayRef<NamedValue> args,
+      ArrayRef<NamedValue> kwargs);
+
+  // if this isn't yet defined, run its method_creator function
+  TORCH_API void ensure_defined();
+
+  size_t num_inputs() const {
+    return graph()->inputs().size() - initial_ivalues_.size();
+  }
+  TORCH_API Value* get_or_add_parameter(Slot slot) {
+    AT_ASSERT(slot.value().isTensor());
+    return get_or_add_attribute(slot);
+  }
+  TORCH_API Value* get_or_add_attribute(Slot slot) {
+    auto it = initial_ivalue_index.find(slot);
+    if (it != initial_ivalue_index.end()) {
+      return graph()->inputs().at(it->second);
+    }
+    initial_ivalues_.push_back(slot);
+    initial_ivalue_index[slot] = graph()->inputs().size();
+    return graph()->addInput()->setType(slot.type());
+  }
+
+  static void setInputTensorTypes(Graph& g, const Stack& stack) {
+    AT_ASSERT(stack.size() == g.inputs().size());
+    for (size_t i = 0; i < stack.size(); ++i) {
+      g.inputs().at(i)->setType(
+          DimensionedTensorType::create(stack.at(i).toTensor()));
+    }
+  }
+
+  std::shared_ptr<Graph> propagate_shapes(
+      std::vector<at::Tensor> inputs,
+      bool with_grad = false) {
+    auto retval = graph_->copy();
+    Stack stack;
+    stack.reserve(inputs.size() + initial_ivalues_.size());
+    for (at::Tensor& i : inputs) {
+      stack.emplace_back(std::move(i));
+    }
+    for (const Slot& inp : initial_ivalues_) {
+      stack.push_back(inp.value());
+    }
+    setInputTensorTypes(*retval, stack);
+    PropagateInputShapes(retval);
+    return retval;
+  }
+
+  std::shared_ptr<Graph> propagate_and_assign_input_and_output_shapes(
+      std::vector<at::Tensor> inputs,
+      std::vector<at::Tensor> outputs,
+      bool with_grad = false,
+      bool propagate = true) {
+    auto retval = graph_->copy();
+    for (auto inp : initial_ivalues_) {
+      if (inp.value().isTensor()) {
+        inputs.push_back(inp.value().toTensor());
+      }
+    }
+    if (propagate) {
+      setInputTensorTypes(*retval, fmap<IValue>(inputs));
+      PropagateInputShapes(retval);
+    }
+    AT_ASSERT(retval->inputs().size() == inputs.size());
+    for (size_t i = 0; i < retval->inputs().size(); ++i) {
+      auto scalar_type = inputs[i].scalar_type();
+      auto sizes = inputs[i].sizes();
+      auto type =
+          torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
+      retval->inputs()[i]->setType(type);
+    }
+    at::ArrayRef<Value*> output_values = retval->outputs();
+    // patch this to still work if we are returning a tuple of multiple values
+    if (output_values.at(0)->type()->kind() == TupleType::Kind) {
+      AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct);
+      output_values = output_values.at(0)->node()->inputs();
+    }
+    AT_ASSERT(output_values.size() == outputs.size());
+    for (size_t i = 0; i < retval->outputs().size(); ++i) {
+      auto scalar_type = outputs[i].scalar_type();
+      auto sizes = outputs[i].sizes();
+      auto type =
+          torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
+      output_values[i]->setType(type);
+    }
+    return retval;
+  }
+
   const std::vector<Slot>& initial_ivalues() const {
     return initial_ivalues_;
   }
 
-  // proxies for underlying unbound Function
-  std::shared_ptr<Graph> graph_for(Stack inputs) {
-    for (auto tp : initial_ivalues_) {
-      inputs.emplace_back(tp.value());
+  Method& setSchema(FunctionSchema schema_) {
+    schema = make_unique<FunctionSchema>(std::move(schema_));
+    return *this;
+  }
+
+  TORCH_API const FunctionSchema& getSchema() const {
+    if (schema == nullptr) {
+      schema = make_unique<FunctionSchema>(defaultSchemaFor(*this));
     }
-    return function_->get_executor().graphFor(inputs);
+    return *schema;
   }
 
-  std::shared_ptr<Graph> graph() const {
-    return function_->graph();
+  std::string pretty_print_schema() const {
+    AT_ASSERT(schema);
+    std::stringstream ss;
+    ss << *schema;
+    return ss.str();
   }
 
-  const std::string& name() const {
-    return function_->name();
+  GraphExecutorState getDebugState() {
+    return get_executor().getDebugState();
   }
 
-  size_t num_inputs() const {
-    return function_->num_inputs() - initial_ivalues_.size();
+  void debugDisableAutodiffSubgraphInlining() {
+    return get_executor().debugDisableAutodiffSubgraphInlining();
   }
 
-  FunctionSchema getSchema() const {
-    // we are required to slice out the slot inputs from the schema
-    // we can't cache this because setSchema on the underlying function
-    // will change the underlying schema
-    auto sliced = ArrayRef<Argument>(function_->getSchema().arguments())
-                      .slice(0, num_inputs());
-    return function_->getSchema().cloneWithArguments(sliced.vec());
+  bool is_optimized() const {
+    return optimize;
   }
 
-  GraphExecutor& get_executor() {
-    return function_->get_executor();
+  // the module that contains this method.
+  Module& owner() const {
+    return *owner_;
   }
 
-  Function& function() const {
-    return *function_;
+  void check_single_output() {
+    AT_CHECK(
+        graph()->outputs().size() == 1,
+        "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
   }
 
  private:
+  static FunctionSchema defaultSchemaFor(const Method& method) {
+    std::vector<Argument> args;
+    std::vector<Argument> returns;
+    Graph& g = *method.graph();
+    size_t num_inputs = method.num_inputs();
+    for (size_t i = 0; i < num_inputs; ++i) {
+      const Value* v = g.inputs().at(i);
+      std::string name = v->hasUniqueName() ? v->uniqueNameBase()
+                                            : ("argument_" + std::to_string(i));
+      args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type()));
+    }
+    for (size_t i = 0; i < g.outputs().size(); ++i) {
+      returns.emplace_back("", unshapedType(g.outputs()[i]->type()));
+    }
+    return {method.name(), "", std::move(args), std::move(returns)};
+  }
+
+  GraphExecutor& get_executor() {
+    std::call_once(executor_init, [&] {
+      check_single_output();
+      executor = GraphExecutor(graph(), optimize);
+    });
+    return executor;
+  }
+
+  void checkInputsAgainstSchema(std::vector<IValue>& inputs) {
+    const auto& schema = getSchema();
+    // Do we have more inputs than the schema accepts?
+    AT_CHECK(
+        inputs.size() <= schema.arguments().size(),
+        "Expected at most ",
+        schema.arguments().size(),
+        " argument(s) for operator '",
+        schema.name(),
+        "', but received ",
+        inputs.size(),
+        " argument(s). Declaration: ",
+        schema);
+
+    for (size_t pos = 0; pos < schema.arguments().size(); ++pos) {
+      const auto& argument = schema.arguments()[pos];
+      if (pos < inputs.size()) {
+        if (!isSubvalueOf(inputs[pos], argument.type())) {
+          AT_ERROR(
+              "Expected value of type ",
+              *argument.type(),
+              " for argument '",
+              argument.name(),
+              "' in position ",
+              pos,
+              ", but instead got value of type ",
+              attemptToRecoverType(inputs[pos])->str(),
+              ". Declaration: ",
+              schema);
+        }
+      } else if (argument.default_value()) {
+        inputs.push_back(*argument.default_value());
+      } else {
+        AT_ERROR(
+            schema.name(),
+            "() is missing value for argument '",
+            argument.name(),
+            "'. Declaration: ",
+            schema);
+      }
+    }
+  }
+
   // Methods are uniqued onwed by a single module. This raw pointer allows
   // looking up the module.
   Module* owner_;
 
-  // Underlying unbound function
-  Function* function_;
-
-  // parameters and attributes loaded from the Module and appending
-  // before calling function_
+  std::string name_;
+  std::shared_ptr<Graph> graph_; // for debugging and for inlining
+  bool optimize;
+
+  GraphExecutor executor; // for execution
+  // initial_ivalues are a list of additional arguments appended to graph
+  // that are inputs that come from the members of the Module or its submodules.
+  // each is a pointer to a slot in the module that owns this parameter
+  // parameters and submodules can only be _added_ to script Modules to ensure
+  // these pointers always stay valid
   std::vector<Slot> initial_ivalues_;
+
+  // map from a IValue* in initial_ivalues to the offset it appears at
+  // in graph. used to accelerate get_or_add_parameter
+  std::unordered_map<Slot, size_t> initial_ivalue_index;
+
+  // TODO: support that case where we allow _writes_ to parameters from
+  // compiled functions.
+  // This requires more sophisticated tracking of ssa values in Graphs so that
+  // stores to all modules can be lifted to the end of a graph execution.
+  // It also adds more complexity to adding actual module invocations
+  // to the executor, so currently it is not done.
+  // std::vector<at::Tensor*> member_outputs;
+
+  std::once_flag executor_init;
+
+  // an optional function that actually creates the method when
+  // emit_call_to(this,...) is first called. this is used by the compiler so
+  // that it can construct methods out of order
+  std::function<void(Method&)> method_creator;
+
+  // if absent, then we generate a default schema based on the graph
+  // mutable because getSchema caches the default schema if one is requested
+  // before a call to setSchema
+  mutable std::unique_ptr<FunctionSchema> schema;
 };
 
 struct Module;
 
-struct TORCH_API Module {
+struct Module {
   TH_DISALLOW_COPY_AND_ASSIGN(Module);
   Module()
       : name_("__main__"),
         module_value_(c10::ivalue::Object::create(
-            ClassType::createModuleType(std::make_shared<CompilationUnit>()),
-            0)) {}
+            ClassType::createModuleType(),
+            0)),
+        optimize_(true) {}
 
-  ~Module() {
-    // ClassType own the compilation unit of their Functions, but each
-    // Function has a self argument which owns the ClassType, created a
-    // referernce cycle. By dropping all the methods of the module's class
-    // here we break the cycle.
-    class_cu().drop_all_functions();
-  }
   const std::string& name() const {
     return name_;
   }
@@ -166,11 +374,11 @@ struct TORCH_API Module {
   // note this doesn't change the flags of existing methods just ones
   // added afterward.
   void set_optimized(bool o) {
-    class_cu().set_optimized(o);
+    optimize_ = o;
   }
 
   bool is_optimized() const {
-    return class_cu().is_optimized();
+    return optimize_;
   }
 
   IValue forward(std::vector<IValue> inputs) {
@@ -187,7 +395,7 @@ struct TORCH_API Module {
         name,
         attributes_,
         EntityType::ATTRIBUTE,
-        appendSlot(name, TensorType::get(), std::move(v)));
+        appendSlot(name, TensorType::get(),std::move(v)));
   }
 
   void register_parameter(
@@ -212,11 +420,7 @@ struct TORCH_API Module {
       const std::string& name,
       const TypePtr type,
       IValue ivalue) {
-    insert(
-        name,
-        attributes_,
-        EntityType::ATTRIBUTE,
-        appendSlot(name, type, ivalue));
+    insert(name, attributes_, EntityType::ATTRIBUTE, appendSlot(name, type, ivalue));
   }
   void register_module(
       const std::string& name,
@@ -230,10 +434,9 @@ struct TORCH_API Module {
     //   AT_WARN(
     //       "Attempting to assign submodule '",
     //       name,
-    //       "' but it is already a submodule of another ScriptModule '",
-    //       module->parent_->name(), "'", " Modules of this form do not import
-    //       and export correctly. This use is deprecated and may be" " removed
-    //       in a future version.");
+    //       "' but it is already a submodule of another ScriptModule '", module->parent_->name(), "'",
+    //       " Modules of this form do not import and export correctly. This use is deprecated and may be"
+    //       " removed in a future version.");
     // }
     module->parent_ = this;
     module->name_ = name;
@@ -241,6 +444,34 @@ struct TORCH_API Module {
     insert(name, modules_, EntityType::MODULE, std::move(module));
   }
 
+  Method& create_method(
+      const std::string& name,
+      std::shared_ptr<Graph> graph,
+      std::vector<Slot> member_inputs) {
+    AT_ASSERT(graph);
+    std::unique_ptr<Method> method(new Method(
+        this,
+        name,
+        optimize_,
+        std::move(graph),
+        std::move(member_inputs),
+        nullptr));
+    return *insert(name, methods_, EntityType::METHOD, std::move(method));
+  }
+
+  Method& create_method(
+      const std::string& name,
+      std::function<void(Method&)> creator) {
+    std::unique_ptr<Method> method(new Method(
+        this,
+        name,
+        optimize_,
+        std::make_shared<Graph>(),
+        {},
+        std::move(creator)));
+    return *insert(name, methods_, EntityType::METHOD, std::move(method));
+  }
+
   Slot parameter_slot(const std::string& name) const {
     return parameters_[get_offset(name, EntityType::PARAMETER)];
   }
@@ -264,14 +495,7 @@ struct TORCH_API Module {
   // each module owns its method. The reference returned here
   // is guarenteed to stay valid until this module has been destroyed
   Method& get_method(const std::string& name) const {
-    if (Method* method = find_method(name)) {
-      return *method;
-    }
-    // temporary: force the error message
-    // once the on-demand creation of Method is removed, this code
-    // can be removed as well
-    get_offset(name, EntityType::METHOD);
-    AT_ERROR("unreachable");
+    return *methods_[get_offset(name, EntityType::METHOD)];
   }
 
   std::shared_ptr<Module> get_module(const std::string& name) const {
@@ -287,12 +511,7 @@ struct TORCH_API Module {
   c10::ArrayRef<Slot> get_attributes() const {
     return attributes_;
   }
-  const std::vector<std::unique_ptr<Method>>& get_methods() const {
-    // force methods_ to be up to date by querying all
-    // methods. This will go away when lowered_methods_ is deleted
-    for (const auto& m : class_cu().get_functions()) {
-      get_method(m->name());
-    }
+  c10::ArrayRef<std::unique_ptr<Method>> get_methods() const {
     return methods_;
   }
 
@@ -315,22 +534,9 @@ struct TORCH_API Module {
     auto offset = find_offset(name, EntityType::MODULE);
     return offset ? modules_[*offset] : nullptr;
   }
-  Method* find_method(const std::string& name) const {
+  Method* find_method(const std::string& name) {
     auto offset = find_offset(name, EntityType::METHOD);
-    if (offset) {
-      return methods_[*offset].get();
-    }
-
-    if (Function* fn = class_cu().find_function(name).get()) {
-      // temporary lock because technically this is marked const,
-      // but we have to update the internal Method cache.
-      // This can be removed when class_cu() is the source of truth for
-      // methods.
-      std::lock_guard<std::recursive_mutex> guard(find_method_guard_);
-      return &const_cast<Module*>(this)->lower_first_class_method(fn);
-    }
-
-    return nullptr;
+    return offset ? methods_[*offset].get() : nullptr;
   }
   void apply(std::function<void(Module&)> fn) {
     for (auto& submod : get_modules()) {
@@ -365,7 +571,10 @@ struct TORCH_API Module {
   /// destination is on the GPU or vice versa, the copy is performed
   /// asynchronously with respect to the host. Otherwise, the argument has no
   /// effect.
-  void to(at::Device device, at::ScalarType dtype, bool non_blocking = false);
+  TORCH_API void to(
+      at::Device device,
+      at::ScalarType dtype,
+      bool non_blocking = false);
 
   /// Recursively casts all parameters to the given dtype.
   ///
@@ -373,7 +582,7 @@ struct TORCH_API Module {
   /// destination is on the GPU or vice versa, the copy is performed
   /// asynchronously with respect to the host. Otherwise, the argument has no
   /// effect.
-  void to(at::ScalarType dtype, bool non_blocking = false);
+  TORCH_API void to(at::ScalarType dtype, bool non_blocking = false);
 
   /// Recursively moves all parameters to the given device.
   ///
@@ -381,7 +590,7 @@ struct TORCH_API Module {
   /// destination is on the GPU or vice versa, the copy is performed
   /// asynchronously with respect to the host. Otherwise, the argument has no
   /// effect.
-  void to(at::Device device, bool non_blocking = false);
+  TORCH_API void to(at::Device device, bool non_blocking = false);
 
   /// Run a method from this module.
   ///
@@ -437,57 +646,26 @@ struct TORCH_API Module {
       mod->copy_into(module_lookup, parameter_remap, names);
       names.pop_back();
     }
-
-    for (auto& fn : class_cu().get_functions()) {
-      curr->class_cu().clone_function(*fn);
+    for (auto& method : get_methods()) {
+      std::vector<Slot> initial_ivalues;
+      for (auto& p : method->initial_ivalues()) {
+        initial_ivalues.push_back(parameter_remap.at(p));
+      }
+      curr->create_method(
+          method->name(), method->graph()->copy(), initial_ivalues);
     }
   }
 
   enum class EntityType { MODULE, PARAMETER, ATTRIBUTE, METHOD };
 
   at::optional<EntityType> kind_of(const std::string& name) const {
-    // force lazy creation of Method if needed
-    // remove once lowered_methods_ is removed.
-    find_method(name);
-
     auto it = dict_.find(name);
     if (it == dict_.end())
       return at::nullopt;
     return it->second.type;
   }
 
-  ModulePtr module_object() const {
-    return module_value_;
-  }
-  CompilationUnit& class_compilation_unit() {
-    return module_object()->type()->compilation_unit();
-  }
-  CompilationUnit& lowered_methods() const {
-    return lowered_methods_;
-  }
-
-  // so that C++ users can easily add methods
-  void define(const std::string& src, const Resolver& resolver = nullptr);
-
-  void _define_lowered(
-      const std::vector<Def>& definitions,
-      const std::vector<Resolver>& resolvers);
-  void _define_lowered(const std::string& src, const Resolver& resolver);
-
-  Method& _define_lowered(
-      std::string name,
-      std::shared_ptr<Graph> graph,
-      std::vector<Slot> slots);
-
  private:
-  Method& _create_lowered_method(
-      Function* func,
-      std::vector<Slot> member_inputs);
-
-  Method& lower_first_class_method(Function* fn);
-  void lift_lowered_method(Method& fn);
-  void lift_lowered_methods(size_t start);
-
   void to_impl(
       const c10::optional<at::Device>& device,
       const c10::optional<at::ScalarType>& dtype,
@@ -575,13 +753,6 @@ struct TORCH_API Module {
     return Slot(module_value_, slot_index);
   }
 
-  CompilationUnit& class_cu() {
-    return module_value_->type()->compilation_unit();
-  }
-  const CompilationUnit& class_cu() const {
-    return module_value_->type()->compilation_unit();
-  }
-
   // modules have a single namespace, but spread over 4 different concepts:
   // parameters, attributes, methods, and sub-modules
   // we store individual lists of each concept, and a single map to
@@ -599,97 +770,29 @@ struct TORCH_API Module {
   std::unordered_map<std::string, Entry> dict_;
   std::string name_;
 
-  ModulePtr module_value_;
-
-  // back reference to parent of this Module if present
-  Module* parent_ = nullptr;
 
-  // Currently we are in a transitionary state
-  // where we construct such first class functions but we lower them
-  // to a form where the modules does not exist before execution.
+  c10::intrusive_ptr<at::ivalue::Object> module_value_;
 
-  // So each Method is actually stored twice once in first-class Module
-  // form and once in lowered form.
 
-  // first-class: module_value_->type().compilation_unit() holds Functions that
-  // treat modules as first class.
-
-  // lowered: In this lowered form, all the attributes/parameters are appended
-  // as additional inputs. lowered_methods_ holds this lowered form
-  // mutable because it is a cache for class_cu() methods
-  mutable CompilationUnit lowered_methods_;
-  mutable std::recursive_mutex find_method_guard_;
+  // back reference to parent of this Module if present
+  Module* parent_ = nullptr;
+  bool optimize_;
 };
 
-static void setInputTensorTypes(Graph& g, const Stack& stack) {
-  AT_ASSERT(stack.size() == g.inputs().size());
-  for (size_t i = 0; i < stack.size(); ++i) {
-    g.inputs().at(i)->setType(
-        DimensionedTensorType::create(stack.at(i).toTensor()));
-  }
-}
-
-inline std::shared_ptr<Graph> propagate_shapes(
-    Graph& graph,
-    const std::vector<at::Tensor>& inputs,
-    const std::vector<Slot>& initial_ivalues,
-    bool with_grad = false) {
-  auto retval = graph.copy();
-  Stack stack;
-  stack.reserve(inputs.size() + initial_ivalues.size());
-  for (const at::Tensor& i : inputs) {
-    stack.emplace_back(std::move(i));
-  }
-  for (const Slot& inp : initial_ivalues) {
-    stack.push_back(inp.value());
-  }
-  setInputTensorTypes(*retval, stack);
-  PropagateInputShapes(retval);
-  return retval;
-}
-
-inline std::shared_ptr<Graph> propagate_and_assign_input_and_output_shapes(
+// returns nullptr and fills in failure_messages if the callee does not
+// match the functions schema
+Value* try_emit_call_to(
     Graph& graph,
-    std::vector<at::Tensor> inputs,
-    const std::vector<Slot>& initial_ivalues,
-    std::vector<at::Tensor> outputs,
-    bool with_grad = false,
-    bool propagate = true) {
-  auto retval = graph.copy();
-  for (auto inp : initial_ivalues) {
-    if (inp.value().isTensor()) {
-      inputs.push_back(inp.value().toTensor());
-    }
-  }
-  if (propagate) {
-    setInputTensorTypes(*retval, fmap<IValue>(inputs));
-    PropagateInputShapes(retval);
-  }
-  AT_ASSERT(retval->inputs().size() == inputs.size());
-  for (size_t i = 0; i < retval->inputs().size(); ++i) {
-    auto scalar_type = inputs[i].scalar_type();
-    auto sizes = inputs[i].sizes();
-    auto type =
-        torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
-    retval->inputs()[i]->setType(type);
-  }
-  at::ArrayRef<Value*> output_values = retval->outputs();
-  // patch this to still work if we are returning a tuple of multiple values
-  if (output_values.at(0)->type()->kind() == TupleType::Kind) {
-    AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct);
-    output_values = output_values.at(0)->node()->inputs();
-  }
-  AT_ASSERT(output_values.size() == outputs.size());
-  for (size_t i = 0; i < retval->outputs().size(); ++i) {
-    auto scalar_type = outputs[i].scalar_type();
-    auto sizes = outputs[i].sizes();
-    auto type =
-        torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
-    output_values[i]->setType(type);
-  }
-  return retval;
-}
-
+    const SourceRange& loc,
+    Method& callee,
+    c10::optional<NamedValue> self,
+    ArrayRef<NamedValue> args,
+    ArrayRef<NamedValue> kwargs,
+    std::stringstream& failure_messages,
+    // when callee uses no parameters (e.g. it is a function in a compilation
+    // unit, and not a method), then nullptr can be passed as caller.
+    Method* caller,
+    bool conv_tensors_to_nums);
 } // namespace script
 } // namespace jit
 } // namespace torch
index 0c5676c..30c273d 100644 (file)
@@ -421,14 +421,16 @@ Value* emitBuiltinCall(
         return emitBuiltinNode(*matched_schema, loc, graph, name);
       }
     }
-    for (Function* method : builtin_functions) {
-      if (auto result = method->try_emit_call(
+    for (Method* method : builtin_functions) {
+      if (auto result = try_emit_call_to(
               graph,
               loc,
+              *method,
               self,
               inputs,
               attributes,
               failure_messages,
+              nullptr,
               allow_conversions)) {
         return result;
       }
index 0304c0e..3e01731 100644 (file)
@@ -33,7 +33,6 @@ private:
   c10::intrusive_ptr<c10::ivalue::Object> container_;
   size_t offset_;
   friend struct std::hash<Slot>;
-  friend struct Module;
 };
 
 }}}
index 6410282..48c5337 100644 (file)
@@ -16,7 +16,7 @@ struct NoneValue : SugaredValue {
 
 std::shared_ptr<SugaredValue> PrintValue::call(
     const SourceRange& loc,
-    Function& m,
+    Method& m,
     at::ArrayRef<NamedValue> inputs,
     at::ArrayRef<NamedValue> attributes,
     size_t n_binders) {
@@ -58,7 +58,7 @@ builtin_cast_methods() {
 
 std::shared_ptr<SugaredValue> BuiltinFunction::call(
     const SourceRange& loc,
-    Function& m,
+    Method& m,
     at::ArrayRef<NamedValue> inputs,
     at::ArrayRef<NamedValue> attributes,
     size_t n_binders) {
@@ -70,7 +70,7 @@ std::shared_ptr<SugaredValue> BuiltinFunction::call(
 // callable value that will resolve to foo(x, y, z) when called.
 std::shared_ptr<SugaredValue> SimpleValue::attr(
     const SourceRange& loc,
-    Function& m,
+    Method& m,
     const std::string& field) {
   // Allow method-style casts on Tensor types. e.g. x.int()
   if (value_->type()->isSubtypeOf(TensorType::get())) {
@@ -116,7 +116,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
   if (auto classType = value_->type()->cast<ClassType>()) {
     // This is a class, emit the proper attribute lookup
     if (auto method = classType->getMethod(field)) {
-      return std::make_shared<MethodValue>(getValue(), *method);
+      return std::make_shared<MethodValue>(shared_from_this(), *method);
     }
 
     if (!classType->hasAttribute(field)) {
@@ -135,7 +135,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
 
 std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
     const SourceRange& loc,
-    Function& m,
+    Method& m,
     const c10::optional<size_t>& size_hint) {
   static const auto make_simple_value =
       [](Value* v) -> std::shared_ptr<SugaredValue> {
@@ -161,7 +161,7 @@ std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
 
 void SimpleValue::setAttr(
     const SourceRange& loc,
-    Function& m,
+    Method& m,
     const std::string& field,
     Value* newValue) {
   const auto classType = value_->type()->cast<ClassType>();
@@ -217,7 +217,7 @@ void SimpleValue::setAttr(
 
 std::shared_ptr<SugaredValue> ClassValue::call(
     const SourceRange& loc,
-    Function& m,
+    Method& m,
     // note: names for args will be 'argument 0', 'argument 1', etc..
     at::ArrayRef<NamedValue> inputs,
     at::ArrayRef<NamedValue> attributes,
@@ -226,7 +226,8 @@ std::shared_ptr<SugaredValue> ClassValue::call(
 
   // Generate a new object of the right type, then call `__init__` on it
   auto& g = *m.graph();
-  auto self = g.insertNode(g.createObject(type_))->output();
+  auto createNode = g.insertNode(g.createObject(type_));
+  auto self = std::make_shared<SimpleValue>(createNode->output());
 
   auto initMethod = type_->getMethod("__init__");
   AT_ASSERT(initMethod);
@@ -234,12 +235,12 @@ std::shared_ptr<SugaredValue> ClassValue::call(
   // Call the init function
   MethodValue(self, *initMethod).call(loc, m, inputs, attributes, n_binders);
 
-  return std::make_shared<SimpleValue>(self);
+  return self;
 }
 
 std::shared_ptr<SugaredValue> ClassValue::attr(
     const SourceRange& loc,
-    Function& m,
+    Method& m,
     const std::string& field) {
   if (field != "__new__") {
     throw ErrorReport(loc) << "Tried to lookup unknown attribute on class";
index 6f4117d..e1fd725 100644 (file)
@@ -28,14 +28,14 @@ struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
 
   // what can we do with this thing?
   // use it as a value e.g.  `this + 4`
-  virtual Value* asValue(const SourceRange& loc, Function& m) {
+  virtual Value* asValue(const SourceRange& loc, Method& m) {
     throw ErrorReport(loc) << kind() << " cannot be used as a value";
   }
 
   // select an attribute on it, e.g. `this.field`
   virtual std::shared_ptr<SugaredValue> attr(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const std::string& field) {
     throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
   }
@@ -43,7 +43,7 @@ struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
   // assign an attribute on it, e.g. `this.field = newValue`
   virtual void setAttr(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const std::string& field,
       Value* newValue) {
     throw ErrorReport(loc) << "attribute assignment is not defined on "
@@ -57,7 +57,7 @@ struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
   // a method invocation
   virtual std::vector<std::shared_ptr<SugaredValue>> asTuple(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const c10::optional<size_t>& size_hint = {}) {
     throw ErrorReport(loc) << kind() << " cannot be used as a tuple";
   }
@@ -65,7 +65,7 @@ struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
   // call it like a function, e.g. `outputs = this(inputs)`
   virtual std::shared_ptr<SugaredValue> call(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       // note: names for args will be 'argument 0', 'argument 1', etc..
       at::ArrayRef<NamedValue> inputs_,
       at::ArrayRef<NamedValue> attributes,
@@ -97,7 +97,7 @@ struct TORCH_API SimpleValue : public SugaredValue {
   std::string kind() const override {
     return "value";
   }
-  Value* asValue(const SourceRange& range, Function& m) override {
+  Value* asValue(const SourceRange& range, Method& m) override {
     return value_;
   }
   NoneStatus isNone() override {
@@ -110,16 +110,16 @@ struct TORCH_API SimpleValue : public SugaredValue {
   }
   std::vector<std::shared_ptr<SugaredValue>> asTuple(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const c10::optional<size_t>& size_hint = {}) override;
   std::shared_ptr<SugaredValue> attr(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const std::string& field) override;
 
   void setAttr(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const std::string& field,
       Value* newValue) override;
 
@@ -146,7 +146,7 @@ struct TORCH_API BuiltinFunction : public SugaredValue {
   }
   std::shared_ptr<SugaredValue> call(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       at::ArrayRef<NamedValue> attributes,
       at::ArrayRef<NamedValue> inputs,
       size_t n_binders) override;
@@ -161,7 +161,7 @@ struct TORCH_API BuiltinModule : public SugaredValue {
   }
   std::shared_ptr<SugaredValue> attr(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const std::string& field) override {
     return std::make_shared<BuiltinFunction>(
         Symbol::fromQualString(name + "::" + field), c10::nullopt);
@@ -183,14 +183,14 @@ struct TORCH_API ClassValue : public SugaredValue {
   //    n = Foo(constructor_arg)
   std::shared_ptr<SugaredValue> call(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       at::ArrayRef<NamedValue> inputs,
       at::ArrayRef<NamedValue> attributes,
       size_t n_binders) override;
 
   std::shared_ptr<SugaredValue> attr(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const std::string& field) override;
 
   std::string kind() const override {
@@ -202,33 +202,34 @@ struct TORCH_API ClassValue : public SugaredValue {
 
 // defines how a method obtained from a module behaves in script
 struct MethodValue : public SugaredValue {
-  MethodValue(c10::optional<NamedValue> self, Function& method)
+  MethodValue(std::shared_ptr<SugaredValue> self, Method& method)
       : self_(std::move(self)), method(method) {}
   std::string kind() const override {
     return "method";
   }
   std::shared_ptr<SugaredValue> call(
       const SourceRange& loc,
-      Function& f,
+      Method& caller,
       at::ArrayRef<NamedValue> inputs,
       at::ArrayRef<NamedValue> attributes,
       size_t n_binders) override {
-    Graph& graph = *f.graph();
-    if (self_) {
+    if (auto classType = dynamic_cast<SimpleValue*>(self_.get())) {
+      // If self_ is a class, then it will be expected as part of
+      // the schema. Add it to the front of the inputs.
       std::vector<NamedValue> inputsWithSelf;
-      inputsWithSelf.emplace_back(loc, self_->value(graph));
+      inputsWithSelf.emplace_back(loc, classType->getValue());
       inputsWithSelf.insert(inputsWithSelf.end(), inputs.begin(), inputs.end());
       return std::make_shared<SimpleValue>(
-          method.emit_call(graph, loc, inputsWithSelf, attributes));
+          caller.emit_call_to(loc, method, inputsWithSelf, attributes));
     }
 
     return std::make_shared<SimpleValue>(
-        method.emit_call(graph, loc, inputs, attributes));
+        caller.emit_call_to(loc, method, inputs, attributes));
   }
 
  private:
-  c10::optional<NamedValue> self_;
-  Function& method;
+  std::shared_ptr<SugaredValue> self_;
+  Method& method;
 };
 
 struct TORCH_API PrintValue : public SugaredValue {
@@ -237,7 +238,7 @@ struct TORCH_API PrintValue : public SugaredValue {
   }
   std::shared_ptr<SugaredValue> call(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       at::ArrayRef<NamedValue> inputs,
       at::ArrayRef<NamedValue> attributes,
       size_t n_binders) override;
@@ -251,7 +252,7 @@ struct TORCH_API CastValue : public BuiltinFunction {
       : BuiltinFunction(method, c10::nullopt), type_(std::move(type)) {}
   std::shared_ptr<SugaredValue> call(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       at::ArrayRef<NamedValue> inputs,
       at::ArrayRef<NamedValue> attributes,
       size_t n_binders) override {
@@ -315,7 +316,7 @@ struct TORCH_API ClassNewMethod : public SugaredValue {
 
   std::shared_ptr<SugaredValue> createObject(
       const SourceRange& loc,
-      Function& m,
+      Method& m,
       const std::string& classname) {
     if (classname != type_->name()) {
       throw ErrorReport(loc)
@@ -337,13 +338,6 @@ static inline std::vector<Value*> toValues(
   return fmap(nvs, [&](const NamedValue& v) { return v.value(g); });
 }
 
-static inline Self simpleSelf(const TypePtr& typ) {
-  return [typ](Value* v) {
-    v->setType(typ);
-    return std::make_shared<SimpleValue>(v);
-  };
-}
-
 } // namespace script
 } // namespace jit
 } // namespace torch
index 3433a9d..2a99f72 100644 (file)
@@ -1303,8 +1303,8 @@ bool isHelperFunction(const std::string& method_name) {
   return method_name.compare(0, helper_prefix.length(), helper_prefix) == 0;
 }
 
-void loadModule(const script::CompilationUnit& module) {
-  for (const auto& method : module.get_functions()) {
+void loadModule(const std::shared_ptr<script::Module>& module) {
+  for (const auto& method : module->get_methods()) {
     if (isHelperFunction(method->name()))
       continue;
 
@@ -1356,8 +1356,9 @@ void loadModule(const script::CompilationUnit& module) {
 
 void loadFunctions() {
   for (const std::string& str : functions) {
-    script::CompilationUnit cu;
-    cu.define(str, script::nativeResolver, nullptr);
+    auto cu = std::make_shared<script::Module>();
+    script::defineMethodsInModule(
+        cu, str, script::nativeResolver, c10::nullopt);
     loadModule(cu);
   }
 }
index b4d33ce..593f508 100644 (file)
@@ -702,8 +702,14 @@ def _try_get_dispatched_fn(fn):
     return _jit_internal.boolean_dispatched.get(fn)
 
 
-def _try_get_overloaded_fn(mod, field):
-    return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None
+def _try_get_overloaded_fn(fn):
+    if not hasattr(fn, '__self__') or not isinstance(fn.__self__, ScriptModule):
+        # Only allow overloads for bound methods
+        return None
+    overloads = fn.__self__._overloads.get(fn.__name__, None)
+    if overloads is None:
+        return None
+    return [getattr(fn.__self__, overload) for overload in overloads]
 
 
 def _try_compile_weak_script(fn):
@@ -732,20 +738,20 @@ def script(obj, optimize=True, _frames_up=0, _rcb=None):
         return obj
     if _rcb is None:
         _rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
+    mod = ScriptModule()
     if inspect.isclass(obj):
         if not _is_new_style_class(obj):
             raise RuntimeError("TorchScript classes must be new-style classes. Please inherit from 'object'")
         ast = get_jit_class_def(obj)
-        _jit_script_class_compile(ast, _rcb)
+        _jit_script_class_compile(mod, ast, _rcb)
         _add_script_class(obj, obj.__name__)
         return obj
     else:
-        mod = ScriptModule()
         ast = get_jit_def(obj)
         _jit_script_compile(mod, ast, _rcb, get_default_args(obj))
-        # Forward docstrings
-        mod.__doc__ = obj.__doc__
-        return mod
+    # Forward docstrings
+    mod.__doc__ = obj.__doc__
+    return mod
 
 
 ScriptMethodStub = namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))