From ef406ee925b0cca35d227f458eab9af0b927d6ac Mon Sep 17 00:00:00 2001 From: Zachary DeVito Date: Thu, 11 Apr 2019 13:30:42 -0700 Subject: [PATCH] First class modules in the compiler, round 2 (#19167) Summary: This PR propagates where we use first-class modules objects into the compiler. This creates a transitionary state where: * compiler.cpp creates Graphs where `self` is a Module class and attributes/parameters/buffers/submodules are looked up with `prim::GetAttr` * GraphExecutor still runs "lowered graphs" where the self object has been removed by a compiler pass `lower_first_class_method`. * Tracing still creates "lowered graphs", and a pass "lift_lowered_method" creates a first-class method graph for things. * This PR separates out Method and Function. A script::Function is a pure Graph with no `self` bound. Similar to Python, a script::Method is just a bound `self` and its underlying `script::Function`. * This PR also separates CompilationUnit from Module. A CompilationUnit is just a list of named script::Functions. Class's have a CompilationUnit holding the class methods, and Modules also have a CompilationUnit holding their Methods. This avoids the weird circular case Module --has a-> Class -> has a -> Module ... Details: * In this transitionary state, we maintain two copies of a Graph, first-class module and lowered. Th first-class one has a self argument that is the module's class type. The lowered one is the lowered graph that uses the initial_ivalues inputs. * When defining lowered methods using `_defined_lowered` we immediately create the first-class equivalent. The reverse is done lazily, creating lowered_methods on demand from the class. * The two way conversions will be deleted in a future PR when the executor itself runs first-class objects. However this requires more changes to (1) the traces, (2) the python bindings, and (3) the onnx export pass and would make this PR way to large. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19167 Differential Revision: D14891966 Pulled By: zdevito fbshipit-source-id: 0b5f03118aa65448a15c7a7818e64089ec93d7ea --- aten/src/ATen/core/function_schema.h | 54 ++ aten/src/ATen/core/jit_type.h | 25 +- aten/src/ATen/core/type.cpp | 10 +- test/cpp/jit/test.cpp | 3 +- test/cpp/jit/test_misc.h | 60 ++- ...pt.test_onnx_export_script_inline_params.expect | 8 +- ...TestScript.test_onnx_export_speculate-f2.expect | 10 +- test/test_jit.py | 2 +- tools/build_variables.py | 1 + torch/CMakeLists.txt | 3 +- torch/csrc/api/include/torch/jit.h | 2 +- torch/csrc/api/src/jit.cpp | 7 +- torch/csrc/jit/import_source.cpp | 60 +-- torch/csrc/jit/ir.h | 4 + torch/csrc/jit/passes/graph_fuser.cpp | 7 +- torch/csrc/jit/passes/python_print.cpp | 13 +- torch/csrc/jit/python_ir.cpp | 1 + torch/csrc/jit/script/builtin_functions.cpp | 19 +- torch/csrc/jit/script/builtin_functions.h | 2 +- torch/csrc/jit/script/class_type.cpp | 17 +- torch/csrc/jit/script/compilation_unit.h | 285 +++++++++++ torch/csrc/jit/script/compiler.cpp | 77 ++- torch/csrc/jit/script/compiler.h | 46 +- torch/csrc/jit/script/init.cpp | 335 +++++++----- torch/csrc/jit/script/module.cpp | 294 +++++++++-- torch/csrc/jit/script/module.h | 564 +++++++++------------ torch/csrc/jit/script/schema_matching.cpp | 6 +- torch/csrc/jit/script/slot.h | 1 + torch/csrc/jit/script/sugared_value.cpp | 21 +- torch/csrc/jit/script/sugared_value.h | 58 ++- torch/csrc/jit/symbolic_script.cpp | 9 +- torch/jit/__init__.py | 20 +- 32 files changed, 1240 insertions(+), 784 deletions(-) create mode 100644 torch/csrc/jit/script/compilation_unit.h diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index c5185c1..888cfcb 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -164,6 +164,18 @@ public: } return c10::nullopt; } + FunctionSchema cloneWithArguments(std::vector new_arguments) const { + return FunctionSchema( + name(), + overload_name(), + std::move(new_arguments), + returns(), + is_vararg(), + is_varret()); + } + // Check that inputs have the correct types and appends any missing default + // values. + void checkAndNormalizeInputs(std::vector& inputs) const; }; inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) { @@ -227,4 +239,46 @@ inline std::string toString(const FunctionSchema& schema) { return str.str(); } +inline void FunctionSchema::checkAndNormalizeInputs(std::vector& inputs) const { + // Do we have more inputs than the schema accepts? + AT_CHECK( + inputs.size() <= arguments().size(), + "Expected at most ", + arguments().size(), + " argument(s) for operator '", + name(), + "', but received ", + inputs.size(), + " argument(s). Declaration: ", + *this); + + for (size_t pos = 0; pos < arguments().size(); ++pos) { + const auto& argument = arguments()[pos]; + if (pos < inputs.size()) { + if (!isSubvalueOf(inputs[pos], argument.type())) { + AT_ERROR( + "Expected value of type ", + *argument.type(), + " for argument '", + argument.name(), + "' in position ", + pos, + ", but instead got value of type ", + attemptToRecoverType(inputs[pos])->str(), + ". Declaration: ", + *this); + } + } else if (argument.default_value()) { + inputs.push_back(*argument.default_value()); + } else { + AT_ERROR( + name(), + "() is missing value for argument '", + argument.name(), + "'. Declaration: ", + *this); + } + } +} + } // namespace c10 diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 3656251..2399dcc 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -17,8 +17,8 @@ namespace torch { namespace jit { namespace script { -struct Module; -struct Method; +struct CompilationUnit; +struct Function; } } // namespace jit } // namespace torch @@ -1100,19 +1100,19 @@ CAFFE2_API TypePtr evalTypeVariables(TypePtr type, TypeEnv & type_env); struct ClassType; using ClassTypePtr = std::shared_ptr; -using ::torch::jit::script::Module; -using ::torch::jit::script::Method; +using ::torch::jit::script::CompilationUnit; +using ::torch::jit::script::Function; // This represents a class in TorchScript. struct CAFFE2_API ClassType : public Type { // Create a user type and register it globally. static ClassTypePtr create( const std::string& name, - std::shared_ptr module); + std::shared_ptr module); // Create a type representing a Module, // These do not have methods, and are not globally registered - static ClassTypePtr createModuleType(); + static ClassTypePtr createModuleType(std::shared_ptr module); // returns nullptr if there is no type with that name static ClassTypePtr get(const std::string& name); @@ -1168,8 +1168,11 @@ struct CAFFE2_API ClassType : public Type { return attributeNames_[slot]; } - Method* getMethod(const std::string& name) const; - std::vector methods() const; + Function* getMethod(const std::string& name) const; + CompilationUnit& compilation_unit(); + const CompilationUnit& compilation_unit() const; + std::vector methods() const; + const std::string& name() const { return typename_; @@ -1226,10 +1229,10 @@ struct CAFFE2_API ClassType : public Type { static const TypeKind Kind = TypeKind::ClassType; private: - ClassType(std::string name, std::shared_ptr module) + ClassType(std::string name, std::shared_ptr cu) : Type(TypeKind::ClassType), typename_(std::move(name)), - module_(std::move(module)) {} + compilation_unit_(std::move(cu)) {} // Name of type (note that this has to be globally unique). std::string typename_; @@ -1243,7 +1246,7 @@ struct CAFFE2_API ClassType : public Type { std::vector attributeNames_; std::vector attributeTypes_; // Holds method attributes - std::shared_ptr module_; + std::shared_ptr compilation_unit_; }; } // namespace c10 diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index e6c3628..534c9f3 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -472,18 +472,18 @@ ClassTypeRegistry& getRegistry() { ClassTypePtr ClassType::create( const std::string& name, - std::shared_ptr module) { - auto ptr = ClassTypePtr(new ClassType(name, std::move(module))); + std::shared_ptr cu) { + auto ptr = ClassTypePtr(new ClassType(name, std::move(cu))); getRegistry().registerType(name, ptr); return ptr; } -ClassTypePtr ClassType::createModuleType() { - return ClassTypePtr(new ClassType("Module", nullptr)); +ClassTypePtr ClassType::createModuleType(std::shared_ptr cu) { + return ClassTypePtr(new ClassType("Module", std::move(cu))); } ClassTypePtr ClassType::refine(at::ArrayRef refined_slots) const { - auto ptr = ClassTypePtr(new ClassType(typename_, module_)); + auto ptr = ClassTypePtr(new ClassType(typename_, compilation_unit_)); AT_ASSERT(numAttributes() == refined_slots.size()); for(size_t i = 0; i < attributeNames_.size(); ++i) { AT_ASSERT(refined_slots[i]->isSubtypeOf(attributeTypes_[i])); diff --git a/test/cpp/jit/test.cpp b/test/cpp/jit/test.cpp index 7145c1a..90b38f3 100644 --- a/test/cpp/jit/test.cpp +++ b/test/cpp/jit/test.cpp @@ -65,7 +65,8 @@ namespace jit { _(NoneSchemaMatch) \ _(ClassParser) \ _(PeepholeOptimize) \ - _(RecordFunction) + _(RecordFunction) \ + _(ModuleDefine) #define TH_FORALL_TESTS_CUDA(_) \ _(ArgumentSpec) \ diff --git a/test/cpp/jit/test_misc.h b/test/cpp/jit/test_misc.h index 8b93c21..ea20266 100644 --- a/test/cpp/jit/test_misc.h +++ b/test/cpp/jit/test_misc.h @@ -40,6 +40,7 @@ #include "ATen/core/ivalue.h" #include "torch/csrc/jit/script/compiler.h" #include "torch/csrc/jit/script/module.h" +#include "torch/jit.h" #include "onnx/onnx_pb.h" @@ -369,11 +370,10 @@ static const auto cf_examples = R"JIT( return a )JIT"; void testControlFlow() { - auto cu = std::make_shared(); - script::defineMethodsInModule( - cu, cf_examples, script::nativeResolver, c10::nullopt); + auto cu = compile(cf_examples); + auto run = [&](const std::string& name, std::vector stack) { - auto graph = cu->get_method(name).graph(); + auto graph = cu->get_function(name).graph(); Code code(graph); InterpreterState interp(code); interp.run(stack); @@ -576,12 +576,11 @@ void testTopologicalIndex() { } void invokeTestRecordFunction(at::Tensor& t) { - autograd::profiler::GetPackedInputsCallback inputs_cb = - [t]() { - Stack st; - pack(st, t); - return st; - }; + autograd::profiler::GetPackedInputsCallback inputs_cb = [t]() { + Stack st; + pack(st, t); + return st; + }; autograd::profiler::RecordFunction guard("test", inputs_cb); t.add_(torch::ones_like(t)); } @@ -605,15 +604,15 @@ void invokeTestRecordFunctionNested() { void testRecordFunction() { std::vector> input_sizes; - autograd::profiler::pushCallback([&input_sizes]( - const autograd::profiler::RecordFunction& fn) { - for (const auto& input : fn.inputs()) { - if (input.isTensor()) { - std::vector t = input.toTensor().sizes().vec(); - input_sizes.push_back(t); - } - } - }); + autograd::profiler::pushCallback( + [&input_sizes](const autograd::profiler::RecordFunction& fn) { + for (const auto& input : fn.inputs()) { + if (input.isTensor()) { + std::vector t = input.toTensor().sizes().vec(); + input_sizes.push_back(t); + } + } + }); auto t = torch::randn({1, 2, 3}, at::kCPU); invokeTestRecordFunction(t); @@ -625,14 +624,15 @@ void testRecordFunction() { // test nested RecordFunctions std::vector nested_names; - autograd::profiler::pushCallback([&nested_names]( - const autograd::profiler::RecordFunction& fn) { - nested_names.push_back(getFullName(&fn)); - }); + autograd::profiler::pushCallback( + [&nested_names](const autograd::profiler::RecordFunction& fn) { + nested_names.push_back(getFullName(&fn)); + }); { autograd::profiler::RecordFunction guard("outer"); - invokeTestRecordFunctionNested();; + invokeTestRecordFunctionNested(); + ; } autograd::profiler::popCallback(); @@ -709,6 +709,18 @@ void testNoneSchemaMatch() { // checking that constant propagation ran wo/failure AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1); } + +void testModuleDefine() { + auto m = std::make_shared(); + m->register_parameter("foo", torch::ones({}), false); + m->define(R"( + def add_it(self, x, b : int = 4): + return self.foo + x + b + )"); + auto result = m->run_method("add_it", torch::ones({})); + AT_ASSERT(result.toTensor().item() == 6) +} + } // namespace test } // namespace jit } // namespace torch diff --git a/test/expect/TestScript.test_onnx_export_script_inline_params.expect b/test/expect/TestScript.test_onnx_export_script_inline_params.expect index ffa284a..1cb1092 100644 --- a/test/expect/TestScript.test_onnx_export_script_inline_params.expect +++ b/test/expect/TestScript.test_onnx_export_script_inline_params.expect @@ -5,14 +5,14 @@ ModelProto { graph: GraphProto { name: "torch-jit-export" - inputs: [{name: "x", type:Tensor dims: 2 3},{name: "1", type:Tensor dims: 3 3},{name: "2", type:Tensor dims: 3 4}] + inputs: [{name: "x", type:Tensor dims: 2 3},{name: "1", type:Tensor dims: 3 4},{name: "2", type:Tensor dims: 3 3}] outputs: [{name: "6", type:Tensor dims: 2 4}] - initializers: [TensorProto shape: [3 3],TensorProto shape: [3 4]] + initializers: [TensorProto shape: [3 4],TensorProto shape: [3 3]] nodes: [ Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [1]}]}, - Node {type: "Gemm", inputs: [x,1,3], outputs: [4], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0}]}, + Node {type: "Gemm", inputs: [x,2,3], outputs: [4], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0}]}, Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [1]}]}, - Node {type: "Gemm", inputs: [4,2,5], outputs: [6], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0}]} + Node {type: "Gemm", inputs: [4,1,5], outputs: [6], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0}]} ] } opset_import: [OperatorSetIdProto { domain: }], diff --git a/test/expect/TestScript.test_onnx_export_speculate-f2.expect b/test/expect/TestScript.test_onnx_export_speculate-f2.expect index 3126f1d..29ce206 100644 --- a/test/expect/TestScript.test_onnx_export_speculate-f2.expect +++ b/test/expect/TestScript.test_onnx_export_speculate-f2.expect @@ -5,9 +5,9 @@ ModelProto { graph: GraphProto { name: "torch-jit-export" - inputs: [{name: "x.1", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20 10},{name: "2", type:Tensor dims: 20}] + inputs: [{name: "x.1", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20},{name: "2", type:Tensor dims: 20 10}] outputs: [{name: "8", type:Tensor dims: 1 20}] - initializers: [TensorProto shape: [20 10],TensorProto shape: [20]] + initializers: [TensorProto shape: [20],TensorProto shape: [20 10]] nodes: [ Node {type: "Add", inputs: [x.1,x.1], outputs: [3], attributes: []}, Node {type: "ReduceSum", inputs: [3], outputs: [4], attributes: [{ name: 'keepdims', type: int, value: 0}]}, @@ -28,7 +28,7 @@ ModelProto { outputs: [{name: "10", type:Tensor dims: 1 20}] initializers: [] nodes: [ - Node {type: "Gemm", inputs: [3,1,2], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} + Node {type: "Gemm", inputs: [3,2,1], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} ] } @@ -39,7 +39,7 @@ ModelProto { outputs: [{name: "11", type:Tensor dims: 1 20}] initializers: [] nodes: [ - Node {type: "Gemm", inputs: [3,1,2], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} + Node {type: "Gemm", inputs: [3,2,1], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} ] } @@ -54,7 +54,7 @@ ModelProto { outputs: [{name: "12", type:Tensor dims: 1 20}] initializers: [] nodes: [ - Node {type: "Gemm", inputs: [3,1,2], outputs: [12], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} + Node {type: "Gemm", inputs: [3,2,1], outputs: [12], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} ] } diff --git a/test/test_jit.py b/test/test_jit.py index 787dee4..18ca724 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -7478,7 +7478,7 @@ a") def foo(self, input): self.call_foo(input) - with self.assertRaisesRegex(RuntimeError, 'called recursively involving'): + with self.assertRaisesRegex(RuntimeError, 'called recursively'): M() def test_script_kwargs_fn_call(self): diff --git a/tools/build_variables.py b/tools/build_variables.py index 89a5ed8..ff27ce3 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -95,6 +95,7 @@ libtorch_sources = [ "torch/csrc/jit/register_quantized_ops.cpp", "torch/csrc/jit/scope.cpp", "torch/csrc/jit/script/compiler.cpp", + "torch/csrc/api/src/jit.cpp", "torch/csrc/jit/script/edit_distance.cpp", "torch/csrc/jit/script/logging.cpp", "torch/csrc/jit/script/final_returns.cpp", diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 60f883a..4b2281b 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -175,6 +175,7 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/jit/register_quantized_ops.cpp ${TORCH_SRC_DIR}/csrc/jit/scope.cpp ${TORCH_SRC_DIR}/csrc/jit/script/compiler.cpp + ${TORCH_SRC_DIR}/csrc/api/src/jit.cpp ${TORCH_SRC_DIR}/csrc/jit/testing/file_check.cpp ${TORCH_SRC_DIR}/csrc/jit/script/final_returns.cpp ${TORCH_SRC_DIR}/csrc/jit/script/schema_matching.cpp @@ -236,7 +237,6 @@ if (NOT NO_API) ${TORCH_SRC_DIR}/csrc/api/src/data/samplers/random.cpp ${TORCH_SRC_DIR}/csrc/api/src/data/samplers/sequential.cpp ${TORCH_SRC_DIR}/csrc/api/src/data/samplers/stream.cpp - ${TORCH_SRC_DIR}/csrc/api/src/jit.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/init.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/module.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/batchnorm.cpp @@ -528,7 +528,6 @@ if (BUILD_PYTHON) ${TORCH_SRC_DIR}/csrc/jit/python_tracer.cpp ${TORCH_SRC_DIR}/csrc/jit/script/init.cpp ${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp - ${TORCH_SRC_DIR}/csrc/jit/script/module.cpp ${TORCH_SRC_DIR}/csrc/jit/script/python_tree_views.cpp ${TORCH_SRC_DIR}/csrc/multiprocessing/init.cpp ${TORCH_SRC_DIR}/csrc/nn/THNN.cpp diff --git a/torch/csrc/api/include/torch/jit.h b/torch/csrc/api/include/torch/jit.h index 9814ead..7e2e4c9 100644 --- a/torch/csrc/api/include/torch/jit.h +++ b/torch/csrc/api/include/torch/jit.h @@ -32,7 +32,7 @@ namespace jit { /// )JIT"); /// IValue output = module->run_method("relu_script", a, b); /// \endrst -TORCH_API std::shared_ptr compile(const std::string& source); +TORCH_API std::shared_ptr compile(const std::string& source); } // namespace jit } // namespace torch diff --git a/torch/csrc/api/src/jit.cpp b/torch/csrc/api/src/jit.cpp index 29ea39f..a66e947 100644 --- a/torch/csrc/api/src/jit.cpp +++ b/torch/csrc/api/src/jit.cpp @@ -9,10 +9,9 @@ namespace torch { namespace jit { -std::shared_ptr compile(const std::string& source) { - auto module = std::make_shared(); - defineMethodsInModule( - module, source, script::nativeResolver, /*self=*/c10::nullopt); +std::shared_ptr compile(const std::string& source) { + auto module = std::make_shared(); + module->define(source, script::nativeResolver, nullptr); return module; } diff --git a/torch/csrc/jit/import_source.cpp b/torch/csrc/jit/import_source.cpp index 6a74810..e27988d 100644 --- a/torch/csrc/jit/import_source.cpp +++ b/torch/csrc/jit/import_source.cpp @@ -6,39 +6,6 @@ namespace torch { namespace jit { namespace script { -// this is a much simpler accessor that only handles modules, parameters, and -// and methods. It does not depend on python to work. -struct ModuleAccessorValue : public SugaredValue { - ModuleAccessorValue(std::shared_ptr module) - : module(std::move(module)) {} - std::string kind() const override { - return "module"; - } - // select an attribute on it, e.g. `this.field` - std::shared_ptr attr( - const SourceRange& loc, - Method& m, - const std::string& field) override { - if (std::shared_ptr v = module->find_module(field)) { - return std::make_shared(std::move(v)); - } else if (script::Slot* v = module->find_parameter(field)) { - return std::make_shared(m.get_or_add_parameter(*v)); - } else if (script::Slot* v = module->find_buffer(field)) { - return std::make_shared(m.get_or_add_parameter(*v)); - } else if (script::Slot* v = module->find_attribute(field)) { - return std::make_shared( - m.get_or_add_attribute(*v)); - } else if (Method* m = module->find_method(field)) { - return std::make_shared(shared_from_this(), *m); - } else { - throw ErrorReport(loc) << "unknown attr: " << field; - } - } - - private: - std::shared_ptr module; -}; - struct OpsValue : public SugaredValue { OpsValue(size_t version) : version_(version) {} std::string kind() const override { @@ -46,7 +13,7 @@ struct OpsValue : public SugaredValue { } std::shared_ptr attr( const SourceRange& loc, - Method& m, + Function& m, const std::string& field) override { return std::make_shared(field, version_); } @@ -59,7 +26,7 @@ struct ConstantValue : public SugaredValue { std::string kind() const override { return "constant"; } - Value* asValue(const SourceRange& loc, Method& m) override { + Value* asValue(const SourceRange& loc, Function& m) override { return m.graph()->insertConstant(value_); } }; @@ -75,7 +42,7 @@ struct ConstantTableValue : public SugaredValue { // select an attribute on it, e.g. `this.field` std::shared_ptr attr( const SourceRange& loc, - Method& m, + Function& m, const std::string& field) override { const char* field_s = field.c_str(); char* end; @@ -117,7 +84,7 @@ struct SourceImporter { }; resolver_ = [&](const std::string& name, - Method& m, + Function& m, const SourceRange& loc) -> std::shared_ptr { auto it = env_.find(name); if (it == env_.end()) { @@ -133,7 +100,7 @@ struct SourceImporter { const std::vector& constant_table_; std::unordered_map> env_; std::function(const std::string& name, Method& m, const SourceRange& loc)> + SugaredValue>(const std::string& name, Function& m, const SourceRange& loc)> resolver_; size_t parseVersionNumber() { @@ -167,8 +134,11 @@ void import_methods( definitions.emplace_back(def); resolvers.emplace_back(importer.resolver_); } - auto self = std::make_shared(mod); - defineMethodsInModule(mod, definitions, resolvers, Self(self)); + auto self = [&](Value* v) { + v->setType(mod->module_object()->type()); + return std::make_shared(v); + }; + mod->module_object()->type()->compilation_unit().define(definitions, resolvers, self); } void import_libs( @@ -186,9 +156,13 @@ void import_libs( resolvers.emplace_back(importer.resolver_); } - auto mod = std::make_shared(); - Self self(ClassType::create(class_def.name().name(), mod)); - defineMethodsInModule(mod, definitions, resolvers, self); + auto cu = std::make_shared(); + auto class_type = ClassType::create(class_def.name().name(), cu); + auto self = [&](Value* v) { + v->setType(class_type); + return std::make_shared(v); + }; + cu->define(definitions, resolvers, self); } } diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index ce43b5b..42e7267 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -1074,6 +1074,10 @@ struct Graph { const std::string& field, Value* newValue); TORCH_API Node* createGetAttr(Value* obj, const std::string& field); + TORCH_API Value* insertGetAttr(Value* obj, const std::string& field) { + return insertNode(createGetAttr(obj, field))->output(); + } + // Note: defined in python_ir.cpp and can be used only in python extension Node* createPythonOp( THPObjectPtr&& pyobj, diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index a8ae6d3..e7d5f94 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -267,10 +267,9 @@ struct GraphFuser { norm_invstd = 1 / (eps + torch.sqrt(norm_var)) return ((input - norm_mean) * norm_invstd) )SCRIPT"; - auto module = std::make_shared(); - defineMethodsInModule( - module, source, script::nativeResolver, /*self=*/c10::nullopt); - *graph_ptr = module->get_method("batch_norm").graph(); + script::CompilationUnit cu; + cu.define(source, script::nativeResolver, nullptr); + *graph_ptr = cu.get_function("batch_norm").graph(); }, &bn_graph); diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index b157cc6..f7bfd99 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -1133,6 +1133,16 @@ struct PythonPrintPass { [](const Argument& arg) { return arg.default_value(); }); printFunction(graph, name, is_class, defaults, ivalue_names); } + void printFunction( + script::Function& method, + bool is_class) { + const std::string& name = method.name(); + Graph& graph = *method.graph(); + auto defaults = fmap( + method.getSchema().arguments(), + [](const Argument& arg) { return arg.default_value(); }); + printFunction(graph, name, is_class, defaults, {}); + } void printModule(script::Module& module) { std::unordered_map extra_ivalue_names; createTensorToParameterNameMap( @@ -1153,9 +1163,8 @@ struct PythonPrintPass { out << "class " << classType->name() << ":\n"; { const auto guard = WithIndented(); - std::unordered_map extra_ivalue_names; for (auto& method : classType->methods()) { - printMethod(*method, /*is_class=*/true, extra_ivalue_names); + printFunction(*method, /*is_class=*/true); } } } diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 7652dc0..1a40aa4 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -137,6 +137,7 @@ void ConcretePythonOp::cloneFrom(Node* other_) { this->cconv = other->cconv; Py_INCREF(other->pyobj.get()); this->pyobj = THPObjectPtr(other->pyobj.get()); + this->ignore_on_export = other->ignore_on_export; for (auto& sa : other->scalar_args) { Py_INCREF(sa.get()); this->scalar_args.emplace_back(sa.get()); diff --git a/torch/csrc/jit/script/builtin_functions.cpp b/torch/csrc/jit/script/builtin_functions.cpp index a1ed46c..2ee7730 100644 --- a/torch/csrc/jit/script/builtin_functions.cpp +++ b/torch/csrc/jit/script/builtin_functions.cpp @@ -37,8 +37,8 @@ def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]: )SCRIPT"); struct BuiltinFunctionRegistry { - const std::vector& getAllBuiltinFunctionsFor(Symbol name) { - const static std::vector empty; + const std::vector& getAllBuiltinFunctionsFor(Symbol name) { + const static std::vector empty; // when initializing the builtin function library, we will re-enter // getAllBuiltinFunctionsFor since it is called in the compiler to // lookup builtins and initializing the builtin functions calls the @@ -62,11 +62,10 @@ struct BuiltinFunctionRegistry { private: void loadSource(const std::string& source) { - auto module = std::make_shared(); - defineMethodsInModule( - module, source, script::nativeResolver, /*self=*/c10::nullopt); - modules.push_back(module); - for (auto& method : module->get_methods()) { + std::shared_ptr cu = std::make_shared(); + modules.emplace_back(cu); + cu->define(source, script::nativeResolver, /*self=*/nullptr); + for (auto& method : cu->get_functions()) { builtins_by_name[Symbol::fromQualString("aten::" + method->name())] .push_back(method.get()); } @@ -97,11 +96,11 @@ struct BuiltinFunctionRegistry { } enum { UNINITIALIZED, INTIIALIZING, INITIALIZED } state = UNINITIALIZED; std::recursive_mutex mutex; - std::vector> modules; - std::unordered_map> builtins_by_name; + std::vector> modules; + std::unordered_map> builtins_by_name; }; -TORCH_API const std::vector& getAllBuiltinFunctionsFor(Symbol name) { +TORCH_API const std::vector& getAllBuiltinFunctionsFor(Symbol name) { static BuiltinFunctionRegistry registry; return registry.getAllBuiltinFunctionsFor(name); } diff --git a/torch/csrc/jit/script/builtin_functions.h b/torch/csrc/jit/script/builtin_functions.h index 42e15e7..f1a5f22 100644 --- a/torch/csrc/jit/script/builtin_functions.h +++ b/torch/csrc/jit/script/builtin_functions.h @@ -7,7 +7,7 @@ namespace torch { namespace jit { namespace script { -TORCH_API const std::vector& getAllBuiltinFunctionsFor(Symbol name); +TORCH_API const std::vector& getAllBuiltinFunctionsFor(Symbol name); } } // namespace jit diff --git a/torch/csrc/jit/script/class_type.cpp b/torch/csrc/jit/script/class_type.cpp index f669e03..0841e80 100644 --- a/torch/csrc/jit/script/class_type.cpp +++ b/torch/csrc/jit/script/class_type.cpp @@ -5,13 +5,20 @@ namespace c10 { // This file exists because we need to reference module.h, which we can't from // c10. Sigh... -Method* ClassType::getMethod(const std::string& name) const { - return module_? module_->find_method(name) : nullptr; +Function* ClassType::getMethod(const std::string& name) const { + return compilation_unit_->find_function(name).get(); } -std::vector ClassType::methods() const { - std::vector ret; - for (const auto& pr : module_->get_methods()) { +CompilationUnit& ClassType::compilation_unit() { + return *compilation_unit_; +} +const CompilationUnit& ClassType::compilation_unit() const { + return *compilation_unit_; +} + +std::vector ClassType::methods() const { + std::vector ret; + for (const auto& pr : compilation_unit().get_functions()) { ret.push_back(pr.get()); } return ret; diff --git a/torch/csrc/jit/script/compilation_unit.h b/torch/csrc/jit/script/compilation_unit.h new file mode 100644 index 0000000..790d061 --- /dev/null +++ b/torch/csrc/jit/script/compilation_unit.h @@ -0,0 +1,285 @@ +#pragma once +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +namespace script { + +struct Def; +struct SugaredValue; +struct Function; + +using Resolver = std::function( + const std::string& name, + Function& f, + const SourceRange& loc)>; +using Self = std::function(Value*)>; + +// A Function is a pure Graph with no implicit `self` object bound. +// It contains schema information, and the executor that manages the +// execution of the function. script::Method is a wrapper around a +// underlying Function that also provides a `self` object. +struct TORCH_API Function { + Function( + std::string name, + bool optimize, + std::shared_ptr graph, + std::function function_creator) + : name_(std::move(name)), + graph_(std::move(graph)), + optimize_(optimize), + function_creator_(std::move(function_creator)) {} + + void run(Stack& stack) { + get_executor().run(stack); + } + + void run(Stack&& stack) { + run(stack); + } + + IValue operator()(std::vector stack) { + getSchema().checkAndNormalizeInputs(stack); + run(stack); + return stack.front(); + } + + std::shared_ptr graph_for(Stack inputs) { + return get_executor().graphFor(inputs); + } + + std::shared_ptr graph() const { + return graph_; + } + + const std::string& name() const { + return name_; + } + + // if this isn't yet defined, run its method_creator function + void ensure_defined(); + + size_t num_inputs() const { + return graph()->inputs().size(); + } + + Function& setSchema(FunctionSchema schema) { + schema_ = make_unique(std::move(schema)); + return *this; + } + + const FunctionSchema& getSchema() const { + if (schema_ == nullptr) { + schema_ = make_unique(defaultSchemaFor(*this)); + } + return *schema_; + } + + std::string pretty_print_schema() const { + AT_ASSERT(schema_); + std::stringstream ss; + ss << *schema_; + return ss.str(); + } + + GraphExecutorState getDebugState() { + return get_executor().getDebugState(); + } + + void debugDisableAutodiffSubgraphInlining() { + return get_executor().debugDisableAutodiffSubgraphInlining(); + } + + bool is_optimized() const { + return optimize_; + } + + void check_single_output() { + AT_CHECK( + graph()->outputs().size() == 1, + "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs"); + } + + GraphExecutor& get_executor() { + std::call_once(executor_init_, [&] { + check_single_output(); + executor_ = GraphExecutor(graph(), optimize_); + }); + return executor_; + } + + // returns nullptr and fills in failure_messages if the callee does not + // match the functions schema + + // TODO: defined in module.cpp, move to compilation_unit.cpp + Value* try_emit_call( + Graph& graph, + const SourceRange& loc, + c10::optional self, + ArrayRef args, + ArrayRef kwargs, + std::stringstream& failure_messages, + bool conv_tensors_to_nums); + + Value* emit_call( + Graph& graph, + const SourceRange& loc, + ArrayRef args, + ArrayRef kwargs); + + private: + static FunctionSchema defaultSchemaFor(const Function& function) { + std::vector args; + std::vector returns; + Graph& g = *function.graph(); + size_t num_inputs = function.num_inputs(); + for (size_t i = 0; i < num_inputs; ++i) { + const Value* v = g.inputs().at(i); + std::string name = v->hasUniqueName() ? v->uniqueNameBase() + : ("argument_" + std::to_string(i)); + args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type())); + } + for (size_t i = 0; i < g.outputs().size(); ++i) { + returns.emplace_back("", unshapedType(g.outputs()[i]->type())); + } + return {function.name(), "", std::move(args), std::move(returns)}; + } + + std::string name_; + std::shared_ptr graph_; // for debugging and for inlining + bool optimize_; + + GraphExecutor executor_; // for execution + + std::once_flag executor_init_; + + // an optional function that actually creates the method when + // emit_call_to(this,...) is first called. this is used by the compiler so + // that it can construct methods out of order + std::function function_creator_; + + // if absent, then we generate a default schema based on the graph + // mutable because getSchema caches the default schema if one is requested + // before a call to setSchema + mutable std::unique_ptr schema_; +}; + + +// A CompilationUnit is a list of named script::Functions +// with helper methods to iterate the list, or invoke the function. +// Classes have a CompilationUnit holding the class methods +// and Modules also have a CompilationUnit holding the Functions that +// are used to implement their Methods + +struct TORCH_API CompilationUnit { + std::shared_ptr find_function(const std::string& name) const { + auto it = dict_.find(name); + if (it == dict_.end()) + return nullptr; + return functions_[it->second]; + } + + Function& get_function(const std::string& name) const { + if (auto r = find_function(name)) + return *r; + AT_ERROR("attempted to get undefined function ", name); + } + + void set_optimized(bool o) { + optimized_ = o; + } + + bool is_optimized() const { + return optimized_; + } + + // for historic reasons, these are defined in compiler.cpp + void define( + const std::vector& definitions, + const std::vector& resolvers, /* determines how we handle free + variables in each definition*/ + // if non-null, the first argument to each def, is bound to this value + const Self& self); + + // same as above but parse the definitions from source + void define( + const std::string& source, + const Resolver& resolver, + const Self& self); + + void clone_function(const Function& remote) { + create_function(remote.name(), remote.graph()->copy()); + } + + Function& create_function(std::string name, std::shared_ptr graph) { + auto fn = std::make_shared( + std::move(name), is_optimized(), std::move(graph), nullptr); + return register_function(std::move(fn)); + } + + const std::vector>& get_functions() const { + return functions_; + } + + /// Run a method from this compilation. + /// + /// For example: + /// @code + /// IValue output = module->run("relu_script", a, b); + /// @endcode + /// + /// To get a compile a module from a source string, see torch::jit::compile + /// + /// @param method_name The name of the method to run + /// @param args Arguments to be passed to the method + /// @return An IValue containing the return value (or values if it is a tuple) + /// from the method + template + IValue run_method(const std::string& method_name, Types&&... args) { + return get_function(method_name)({IValue(std::forward(args))...}); + } + + void drop_all_functions() { + dict_.clear(); + functions_.clear(); + } + + private: + Function& register_function(std::shared_ptr fn) { + AT_CHECK( + 0 == dict_.count(fn->name()), + "method '", + fn->name(), + "' already defined."); + functions_.emplace_back(std::move(fn)); + dict_[functions_.back()->name()] = functions_.size() - 1; + return *functions_.back(); + } + std::vector> functions_; + // for fast lookup + std::unordered_map dict_; + bool optimized_ = true; +}; + +} // namespace script +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 11358e2..6048ed6 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -25,7 +25,7 @@ namespace jit { namespace script { using SugaredValuePtr = std::shared_ptr; -using FunctionTable = std::unordered_map; +using FunctionTable = std::unordered_map; using ValueTable = std::unordered_map; using AttributeMap = std::unordered_map; using ListAttributeMap = std::unordered_map>; @@ -190,7 +190,7 @@ static bool meaningfulName(const std::string& name) { // delete unnecessary ones later with replaceAllusesWith(). struct Environment { Environment( - Method& method, + Function& method, Resolver resolver, Block* b, std::shared_ptr next = nullptr) @@ -199,7 +199,7 @@ struct Environment { b(b), next(std::move(next)) {} - Method& method; + Function& method; Resolver resolver; std::vector captured_inputs; std::unordered_map error_messages; @@ -518,8 +518,8 @@ struct to_ir { to_ir( const Def& def, Resolver resolver_, - const c10::optional& self, - Method& method) // method being constructed + const Self& self, + Function& method) // method being constructed : method(method), graph(method.graph()), resolver(std::move(resolver_)), @@ -541,7 +541,7 @@ struct to_ir { } private: - Method& method; + Function& method; std::shared_ptr graph; Resolver resolver; std::unordered_map integral_constants; @@ -577,7 +577,7 @@ struct to_ir { FunctionSchema emitDef( const Def& def, - const c10::optional& self, + const Self& self, Block* block) { auto schema = extractSchemaFromDef(def, self); // TODO need guards on init returning none @@ -624,15 +624,16 @@ struct to_ir { blank_decl, List::create(r, {ret})); auto m = std::make_shared(); - defineMethodsInModule(m, {def}, {resolver}, c10::nullopt); + CompilationUnit cu; + cu.define({def}, {resolver}, nullptr); Stack stack; - m->get_method("defaults").run(stack); + cu.get_function("defaults").run(stack); return stack.at(0).toTuple()->elements(); } std::vector parseArgsFromDecl( const Decl& decl, - const c10::optional& self) { + const Self& self) { auto params_begin = decl.params().begin(); auto params_end = decl.params().end(); if (self) { @@ -706,7 +707,7 @@ struct to_ir { } FunctionSchema extractSchemaFromDef( const Def& def, - const c10::optional& self) { + const Self& self) { const auto name = def.name().name(); std::vector args = parseArgsFromDecl(def.decl(), self); std::vector returns = parseReturnFromDecl(def.decl()); @@ -716,9 +717,10 @@ struct to_ir { std::vector emitFormalArguments( const Def& def, - const c10::optional& self, + const Self& self, const FunctionSchema& schema, Block* block) { + std::vector arguments; // for schema // inputs auto it = def.decl().params().begin(); @@ -738,14 +740,9 @@ struct to_ir { if (self) { AT_ASSERT(it != end); const auto& name = (*it).ident().name(); - if (auto type = self->asFirstClass()) { - Value* new_input = - block->addInput()->setUniqueName(name)->setType(type); - environment_stack->setVar((*it).ident().range(), name, new_input); - arguments.emplace_back(name, type); - } else { - environment_stack->setSugaredVar(def.range(), name, self->asSugared()); - } + Value* new_input = block->addInput()->setUniqueName(name); + environment_stack->setSugaredVar((*it).ident().range(), name, self(new_input)); + arguments.emplace_back(name, new_input->type()); ++it; } size_t arg_annotation_idx = 0; @@ -831,7 +828,7 @@ struct to_ir { pushFrame(block, /*starts_def=*/true); emitDef( def, - c10::nullopt, + nullptr, block); // ignore schema return, we just wont use it for now since we // never create a Method for the closure popFrame(/*ends_def=*/true); @@ -2263,7 +2260,6 @@ struct to_ir { node_output = fork_node->output()->setType( FutureType::create(fn_simple_output->type())); } - // Lambda lift block(0) into attr::Subgraph lambdaLiftFork(fork_node); @@ -2755,15 +2751,14 @@ struct to_ir { } }; -void defineMethodsInModule( - const std::shared_ptr& m, +void CompilationUnit::define( const std::vector& definitions, const std::vector& resolvers, - const c10::optional& self) { + const Self& self) { AT_ASSERT(definitions.size() == resolvers.size()); auto resolver_it = resolvers.begin(); - std::vector methods; - std::unordered_map function_table; + std::vector methods; + std::unordered_map function_table; for (const Def& def : definitions) { const std::string& name = def.name().name(); auto resolver = *resolver_it++; @@ -2774,37 +2769,34 @@ void defineMethodsInModule( // the function table so the methods can see each other resolver = [resolver, &function_table]( const std::string& name, - Method& m, + Function& m, const SourceRange& loc) -> std::shared_ptr { auto it = function_table.find(name); if (it != function_table.end()) { - return std::make_shared(nullptr, *it->second); + return std::make_shared(c10::nullopt, *it->second); } return resolver(name, m, loc); }; } - auto creator = [def, resolver, self](Method& method) { + auto creator = [def, resolver, self](Function& method) { AT_ASSERT(resolver); to_ir(def, resolver, self, method); }; - Method& method = m->create_method(name, creator); - function_table[name] = &method; - methods.push_back(&method); + std::unique_ptr fn( + new Function(name, is_optimized(), std::make_shared(), creator)); + function_table[name] = fn.get(); + methods.push_back(fn.get()); + register_function(std::move(fn)); } - for (Method* method : methods) { + for (Function* method : methods) { method->ensure_defined(); } - if (!self || !self->asFirstClass()) { - // Disable module hooks if the module is only used to store a class's code. - didFinishEmitModule(m); - } } -void defineMethodsInModule( - const std::shared_ptr& m, +void CompilationUnit::define( const std::string& source, const Resolver& resolver, - const c10::optional& self) { + const Self& self) { Parser p(source); std::vector definitions; std::vector resolvers; @@ -2813,7 +2805,7 @@ void defineMethodsInModule( definitions.push_back(def); resolvers.push_back(resolver); } - defineMethodsInModule(m, definitions, resolvers, self); + define(definitions, resolvers, self); } void lambdaLiftFork(Node* fork_node) { @@ -2838,6 +2830,7 @@ void lambdaLiftFork(Node* fork_node) { fork_node->g_(attr::Subgraph, forked_graph); fork_node->eraseBlock(0); } + } // namespace script } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/script/compiler.h b/torch/csrc/jit/script/compiler.h index 3c2bb2d..3965467 100644 --- a/torch/csrc/jit/script/compiler.h +++ b/torch/csrc/jit/script/compiler.h @@ -13,12 +13,9 @@ namespace torch { namespace jit { namespace script { -using Resolver = std::function(const std::string& name, Method& m, const SourceRange& loc)>; - inline std::shared_ptr nativeResolver( const std::string& name, - Method& m, + Function& m, const SourceRange& loc) { if (name == "torch") { return std::make_shared("aten"); @@ -26,47 +23,6 @@ inline std::shared_ptr nativeResolver( return nullptr; } -// Represents the `self` argument to a method. This wrapper class is necessary -// because sometimes `self` sometimes is first class and sometimes not. -// -// `self` is first class when it refers to a ClassType. It will be bound as a -// graph input argument. -// `self` is sugared when it refers to a ModuleValue. -class Self { - public: - explicit Self(std::shared_ptr sugared) - : sugared_(std::move(sugared)) {} - explicit Self(ClassTypePtr type) : firstClass_(std::move(type)) {} - - ClassTypePtr asFirstClass() const { - return firstClass_; - } - std::shared_ptr asSugared() const { - return sugared_; - } - - private: - // Used when `self` is not first-class and so we don't represent it in the - // graph. This is only ModuleValue. - std::shared_ptr sugared_ = nullptr; - // Used when `self` is a first-class type - ClassTypePtr firstClass_ = nullptr; -}; - -TORCH_API void defineMethodsInModule( - const std::shared_ptr& m, - const std::vector& definitions, - const std::vector& resolvers, /* determines how we handle free - variables in each definition*/ - // if non-null, the first argument to each def, is bound to this value - const c10::optional& self); - -// same as above but parse the definitions from source -TORCH_API void defineMethodsInModule( - const std::shared_ptr& m, - const std::string& source, - const Resolver& resolver, - const c10::optional& self); TORCH_API void lambdaLiftFork(Node* fork_node); diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index bff6d85..1aeadea 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -64,10 +64,9 @@ inline std::shared_ptr toSimple(Value* v) { // type, *add it in this function's implementation*. std::shared_ptr toSugaredValue( py::object obj, - Method& m, + Function& m, SourceRange loc, - bool is_constant = false, - bool is_submodule = false); + bool is_constant = false); struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { PythonValue(py::object self) : self(std::move(self)) {} @@ -125,7 +124,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { // call it like a function, e.g. `outputs = this(inputs)` std::shared_ptr call( const SourceRange& loc, - Method& m, + Function& m, at::ArrayRef inputs_, at::ArrayRef attributes, size_t n_binders) override { @@ -182,7 +181,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { std::vector> asTuple( const SourceRange& loc, - Method& m, + Function& m, const c10::optional& size_hint = {}) override { const std::string type_str = typeString(self); std::stringstream ss; @@ -193,7 +192,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { std::shared_ptr attr( const SourceRange& loc, - Method& m, + Function& m, const std::string& field) override { const std::string type_str = typeString(self); std::stringstream ss; @@ -219,7 +218,7 @@ struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue { std::shared_ptr attr( const SourceRange& loc, - Method& m, + Function& m, const std::string& field) override { py::object member = getattr(loc, field); // note: is_constant = true because we consider that global properties @@ -234,7 +233,7 @@ struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue { : PythonValue(std::move(tup)) {} std::vector> asTuple( const SourceRange& loc, - Method& m, + Function& m, const c10::optional& size_hint = {}) override { py::tuple tup = self; std::vector> result; @@ -246,7 +245,7 @@ struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue { return result; } - Value* asValue(const SourceRange& loc, Method& m) override { + Value* asValue(const SourceRange& loc, Function& m) override { std::vector values; for (const auto& sugared_item : asTuple(loc, m)) { values.push_back(sugared_item->asValue(loc, m)); @@ -258,33 +257,65 @@ struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue { // Represents all the parameters of a module as a List[Tensor] struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { - ConstantParameterList(std::shared_ptr module) - : module_(std::move(module)) {} - + ConstantParameterList(Value* the_list) : the_list_(the_list) {} std::string kind() const override { return "constant parameter list"; } + std::shared_ptr call( + const SourceRange& loc, + Function& caller, + at::ArrayRef inputs, + at::ArrayRef attributes, + size_t n_binders) override { + return toSimple(the_list_); + } + + private: + Value* the_list_; +}; + +struct VISIBILITY_HIDDEN OverloadedFunctionValue : public SugaredValue { + OverloadedFunctionValue(Value* module, std::vector method_names) + : module_(module), method_names_(std::move(method_names)) {} + + std::string kind() const override { + return "overloaded function"; + } std::shared_ptr call( const SourceRange& loc, - Method& caller, + Function& caller, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override { - // Add all module parameters as inputs to the graph - std::vector params; - const auto& param_list = module_->get_parameters(); - for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) { - auto& param = *it; - params.push_back(caller.get_or_add_parameter(param)); + std::stringstream err; + std::vector new_inputs = inputs.vec(); + new_inputs.insert(new_inputs.begin(), module_); + + for (const std::string& method_name : method_names_) { + auto cls = module_->type()->expect(); + Function* fn = cls->getMethod(method_name); + auto match = tryMatchSchema( + fn->getSchema(), + loc, + *caller.graph().get(), + c10::nullopt, + new_inputs, + attributes, + err, + true); + if (match) { + return MethodValue(module_, *fn) + .call(loc, caller, inputs, attributes, n_binders); + } } - auto list = caller.graph()->createList(TensorType::get(), params); - caller.graph()->insertNode(list); - return toSimple(list->output()); + throw ErrorReport(loc) << "Could not find any matching overloads\n" + << err.str(); } private: - std::shared_ptr module_; + Value* module_; + std::vector method_names_; }; // defines how modules/methods behave inside the script subset. @@ -295,7 +326,8 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { // holding the actual nn.Module class. struct ModuleValue : public SugaredValue { - ModuleValue(std::shared_ptr module) : module(std::move(module)) {} + ModuleValue(Value* self, std::shared_ptr module) + : self_(self), module_(std::move(module)) {} std::string kind() const override { return "module"; @@ -304,45 +336,60 @@ struct ModuleValue : public SugaredValue { // select an attribute on it, e.g. `this.field` std::shared_ptr attr( const SourceRange& loc, - Method& m, + Function& m, const std::string& field) override { // workaround to make self.training work // it adds a buffer 'training' to the model if one doesn't exist // and then loads that parameter, casting it to bool if (field == "training") { - Slot* v = module->find_buffer(field); + Slot* v = module_->find_buffer(field); if (!v) { - py::object py_module = py::cast(module); + py::object py_module = py::cast(module_); bool training = py::cast(py::getattr(py_module, "training")); auto t = autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong)); - module->register_buffer("training", std::move(t)); - v = module->find_buffer(field); + module_->register_buffer("training", std::move(t)); + v = module_->find_buffer(field); } - Value* the_tensor = m.get_or_add_parameter(*v); + Value* the_tensor = m.graph()->insertGetAttr(self_, "training"); Value* the_bool = m.graph()->insert(prim::Bool, {the_tensor}); return std::make_shared(the_bool); } - if (std::shared_ptr v = module->find_module(field)) { - return std::make_shared(v); - } else if (Method* v = module->find_method(field)) { - return std::make_shared(shared_from_this(), *v); - } else if (Slot* v = module->find_parameter(field)) { - return std::make_shared(m.get_or_add_parameter(*v)); - } else if (Slot* v = module->find_attribute(field)) { - return std::make_shared( - m.get_or_add_attribute(*v)); + if (std::shared_ptr v = module_->find_module(field)) { + return std::make_shared( + m.graph()->insertGetAttr(self_, field), v); + } else if (auto kind = module_->kind_of(field)) { + // methods, parameters, attributes, and buffers are all first class + return SimpleValue(self_).attr(loc, m, field); } // This can also be a call to a non-script module, or a plain // python method. If so return this as a python value. - py::object py_module = py::cast(module); + py::object py_module = py::cast(module_); + + py::object overloads = + py_module.attr("_overloads").attr("get")(field, py::none()); + if (!overloads.is_none()) { + return std::make_shared( + self_, py::cast>(overloads)); + } + if (py::object attr = py::getattr(py_module, field.c_str(), py::none())) { if (py::isinstance(attr) && py::hasattr(attr, "_is_parameter_list") && py::cast(py::getattr(attr, "_is_parameter_list"))) { - return std::make_shared(module); + Graph& g = *m.graph(); + // Add all module parameters as inputs to the graph + std::vector params; + const auto& param_list = module_->get_parameters(); + for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) { + auto& param = *it; + params.emplace_back(g.insertGetAttr(self_, param.name())); + } + auto list = + g.insertNode(g.createTuple(params))->output(); + return std::make_shared(list); } if (py::isinstance(attr) || py::isinstance(attr, py::module::import("torch.nn").attr("Module")) || @@ -364,7 +411,7 @@ struct ModuleValue : public SugaredValue { // call module.forward std::shared_ptr call( const SourceRange& loc, - Method& caller, + Function& caller, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override { @@ -374,28 +421,35 @@ struct ModuleValue : public SugaredValue { std::vector> asTuple( const SourceRange& loc, - Method& m, + Function& m, const c10::optional& size_hint = {}) override { - py::object py_module = py::cast(module); + py::object py_module = py::cast(module_); if (!py::isinstance( py_module, py::module::import("torch.jit").attr("_ConstModuleList"))) return SugaredValue::asTuple(loc, m, size_hint); std::vector> result; - for (py::handle module : py_module) { - py::object obj = py::reinterpret_borrow(module); - result.push_back(toSugaredValue( - obj, - m, - loc, - /*is_constant =*/false, - /*is_submodule =*/true)); + for (py::handle py_submodule : py_module) { + py::object obj = py::reinterpret_borrow(py_submodule); + if (py::isinstance(obj)) { + auto sub_module = py::cast>(obj); + Value* module_v = m.graph()->insertGetAttr(self_, sub_module->name()); + result.emplace_back( + std::make_shared(module_v, sub_module)); + } else { + result.push_back(toSugaredValue( + obj, + m, + loc, + /*is_constant =*/false)); + } } return result; } private: - std::shared_ptr module; + Value* self_; + std::shared_ptr module_; }; struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue { @@ -408,7 +462,7 @@ struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, - Method& caller, + Function& caller, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override { @@ -446,54 +500,31 @@ struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue { py::dict dispatched_fn_; }; -struct VISIBILITY_HIDDEN OverloadedFunctionValue : public SugaredValue { - OverloadedFunctionValue(py::list functions) - : possible_functions_(std::move(functions)) {} - - std::string kind() const override { - return "overloaded function"; +std::shared_ptr moduleToMethod( + const std::shared_ptr& mod) { + // this path only supports calling raw script functions + // but because they are not distinguished from models, we have to check + // that they are function-like here. They must not have state, and they + // must have a forward method. When we expose functions to python + // this will be replaced with a direct py::isinstance call. + + if (mod->get_parameters().size() != 0) { + throw ErrorReport() + << "Attempted to inline a Module with parameters. " + "Stateful modules to be inlined must be submodules of the callee."; } - - std::shared_ptr call( - const SourceRange& loc, - Method& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, - size_t n_binders) override { - std::stringstream err; - auto possible_functions = - py::cast>(possible_functions_); - - for (const py::object& fn : possible_functions) { - auto& method = py::cast(fn); - auto match = tryMatchSchema( - method.getSchema(), - loc, - *caller.graph().get(), - c10::nullopt, - inputs, - attributes, - err, - true); - if (match) { - return MethodValue(nullptr, method) - .call(loc, caller, inputs, attributes, n_binders); - } - } - throw ErrorReport(loc) << "Could not find any matching overloads\n" - << err.str(); + Method* forward = mod->find_method("forward"); + if (!forward) { + throw ErrorReport() << " expected this module to have a forward function."; } - - private: - py::list possible_functions_; -}; + return std::make_shared(at::nullopt, forward->function()); +} std::shared_ptr toSugaredValue( py::object obj, - Method& m, + Function& m, SourceRange loc, - bool is_constant, - bool is_submodule) { + bool is_constant) { // directly create SimpleValues when possible, because they are first-class // and can be re-assigned. Otherwise, this would be invalid: // f = python_constant @@ -534,17 +565,12 @@ std::shared_ptr toSugaredValue( obj = weak_obj; } if (py::isinstance(obj)) { - auto mod = py::cast>(obj); // In the case that this Python object is not a submodule, inline *ONLY // PURE* ScriptModules. This allows us to call arbitrary @script functions // within a scripting context while still enforcing that parameters from // stateful submodules are properly accounted for. - if (!is_submodule && mod->get_parameters().size() != 0) { - throw ErrorReport() - << "Attempted to inline a Module with parameters. " - "Stateful modules to be inlined must be submodules of the callee."; - } - return std::make_shared(mod); + auto mod = py::cast>(obj); + return moduleToMethod(mod); } else if (py::isinstance(obj)) { return std::make_shared(obj); } else if (obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr()) { @@ -566,7 +592,7 @@ std::shared_ptr toSugaredValue( py::module::import("torch.jit").attr("_try_compile_weak_script")(obj); if (!compiled_fn.is(py::none())) { auto mod = py::cast>(compiled_fn); - return std::make_shared(mod); + return moduleToMethod(mod); } } @@ -576,12 +602,6 @@ std::shared_ptr toSugaredValue( return std::make_shared(std::move(dispatched_fn)); } - py::object overloads = - py::module::import("torch.jit").attr("_try_get_overloaded_fn")(obj); - if (!overloads.is_none()) { - return std::make_shared(std::move(overloads)); - } - return std::make_shared(obj); } @@ -621,7 +641,7 @@ static void gatherParametersAndBuffers( namespace { Resolver pythonResolver(const ResolutionCallback& rcb) { - return [rcb](const std::string& name, Method& m, const SourceRange& loc) + return [rcb](const std::string& name, Function& m, const SourceRange& loc) -> std::shared_ptr { AutoGIL ag; py::object obj = rcb(name); @@ -673,6 +693,13 @@ FunctionSchema getSchemaWithNameAndDefaults( schema.is_varret()); } +static Self moduleSelf(const std::shared_ptr& m) { + return [m](Value* v) { + v->setType(m->module_object()->type()); + return std::make_shared(v, m); + }; +} + void initJitScriptBindings(PyObject* module) { auto m = py::handle(module).cast(); @@ -712,9 +739,12 @@ void initJitScriptBindings(PyObject* module) { bool has_self) { c10::optional self; if (has_self) { - self = Self(std::make_shared(m)); + m->class_compilation_unit().define( + script, pythonResolver(rcb), moduleSelf(m)); + } else { + m->_define_lowered(script, pythonResolver(rcb)); } - defineMethodsInModule(m, script, pythonResolver(rcb), self); + didFinishEmitModule(m); }) .def( "_create_methods", @@ -727,14 +757,13 @@ void initJitScriptBindings(PyObject* module) { for (auto& callback : rcbs) { resolvers.push_back(pythonResolver(callback)); } - defineMethodsInModule( - m, defs, resolvers, Self(std::make_shared(m))); - + m->class_compilation_unit().define(defs, resolvers, moduleSelf(m)); // Stitch in default arguments for each Def if provided auto defaults_it = defaults.begin(); auto defs_it = defs.begin(); while (defs_it != defs.end()) { - auto& method = m->get_method((*defs_it).name().name()); + auto& method = m->class_compilation_unit().get_function( + (*defs_it).name().name()); method.setSchema(getSchemaWithNameAndDefaults( defs_it->range(), method.getSchema(), @@ -784,8 +813,7 @@ void initJitScriptBindings(PyObject* module) { auto& p = parameters[i]; py::tuple r(2); result[i] = std::make_tuple( - p.name(), - autograd::as_variable_ref(p.value().toTensor())); + p.name(), autograd::as_variable_ref(p.value().toTensor())); } return result; }) @@ -841,7 +869,7 @@ void initJitScriptBindings(PyObject* module) { [](Module& self, const std::string& name, std::shared_ptr graph) { - self.create_method(name, std::move(graph), {}); + self._define_lowered(name, std::move(graph), {}); }) .def( "_create_method_from_trace", @@ -865,7 +893,8 @@ void initJitScriptBindings(PyObject* module) { var_lookup_fn, force_outplace, input_tuple.size()); - self->create_method(name, std::move(graph), std::move(parameters)); + self->_define_lowered( + name, std::move(graph), std::move(parameters)); didFinishEmitModule(self); }) .def( @@ -890,7 +919,7 @@ void initJitScriptBindings(PyObject* module) { [](Module& self) { if (self.find_method("forward")) { Method& m = self.get_method("forward"); - return m.getDebugState(); + return m.get_executor().getDebugState(); } throw std::runtime_error( "Attempted to call get_debug_state on a Module without a compiled forward()"); @@ -900,7 +929,7 @@ void initJitScriptBindings(PyObject* module) { [](Module& self) { if (self.find_method("forward")) { Method& m = self.get_method("forward"); - m.debugDisableAutodiffSubgraphInlining(); + m.get_executor().debugDisableAutodiffSubgraphInlining(); } }) .def( @@ -958,7 +987,8 @@ void initJitScriptBindings(PyObject* module) { } Method* orig_method = orig->find_method(name); - m->create_method(name, orig_method->graph()->copy(), member_inputs); + m->_define_lowered( + name, orig_method->graph()->copy(), std::move(member_inputs)); }); py::class_(m, "ScriptMethod", py::dynamic_attr()) @@ -972,10 +1002,27 @@ void initJitScriptBindings(PyObject* module) { method, tuple_slice(std::move(args), 1), std::move(kwargs)); }) .def_property_readonly("graph", [](Method& m) { return m.graph(); }) - .def("propagate_shapes", &Method::propagate_shapes) + .def( + "propagate_shapes", + [](Method& m, const std::vector& inputs, bool with_grad) { + return propagate_shapes( + *m.graph(), inputs, m.initial_ivalues(), with_grad); + }) .def( "propagate_and_assign_input_and_output_shapes", - &Method::propagate_and_assign_input_and_output_shapes) + [](Method& m, + const std::vector& inputs, + std::vector outputs, + bool with_grad, + bool propagate) { + return propagate_and_assign_input_and_output_shapes( + *m.graph(), + inputs, + m.initial_ivalues(), + outputs, + with_grad, + propagate); + }) .def( "initial_ivalues", [](Method& m) { @@ -995,9 +1042,18 @@ void initJitScriptBindings(PyObject* module) { }) .def( "debug_disable_autodiff_subgraph_inlining", - &Method::debugDisableAutodiffSubgraphInlining) + [](Method& m) { + return m.get_executor().debugDisableAutodiffSubgraphInlining(); + }) .def("schema", &Method::getSchema) - .def("pretty_print_schema", &Method::pretty_print_schema) + .def( + "pretty_print_schema", + [](Method& m) { + const FunctionSchema& schema = m.getSchema(); + std::stringstream ss; + ss << schema; + return ss.str(); + }) .def( "python_print", [](Method& m) { @@ -1022,29 +1078,30 @@ void initJitScriptBindings(PyObject* module) { ResolutionCallback rcb, FunctionDefaults defaults) { auto def_f = def.withName("forward"); - defineMethodsInModule( - mod, {def_f}, {pythonResolver(rcb)}, c10::nullopt); - auto& method = mod->get_method("forward"); - method.setSchema(getSchemaWithNameAndDefaults( - def.range(), method.getSchema(), def.name().name(), defaults)); + + mod->_define_lowered({def_f}, {pythonResolver(rcb)}); + auto& func = mod->lowered_methods().get_function("forward"); + func.setSchema(getSchemaWithNameAndDefaults( + def.range(), func.getSchema(), def.name().name(), defaults)); + auto& func2 = mod->class_compilation_unit().get_function("forward"); + func2.setSchema(getSchemaWithNameAndDefaults( + def.range(), func2.getSchema(), def.name().name(), defaults)); didFinishEmitModule(mod); return mod; }); m.def( "_jit_script_class_compile", - [](std::shared_ptr module, - const ClassDef& classDef, - ResolutionCallback rcb) { - auto classType = ClassType::create(classDef.name().name(), module); + [](const ClassDef& classDef, ResolutionCallback rcb) { + auto cu = std::make_shared(); + auto classType = ClassType::create(classDef.name().name(), cu); std::vector rcbs; std::vector methodDefs; for (const auto& def : classDef.defs()) { methodDefs.push_back(def); rcbs.push_back(pythonResolver(rcb)); } - defineMethodsInModule(module, methodDefs, rcbs, Self(classType)); - return module; + cu->define(methodDefs, rcbs, simpleSelf(classType)); }); m.def("parse_type_comment", [](const std::string& comment) { diff --git a/torch/csrc/jit/script/module.cpp b/torch/csrc/jit/script/module.cpp index 33cb951..7772ab1 100644 --- a/torch/csrc/jit/script/module.cpp +++ b/torch/csrc/jit/script/module.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -11,32 +12,39 @@ namespace jit { namespace script { struct RecursiveMethodCallError : public std::exception {}; -void placeholderCreator(Method&) { +void placeholderCreator(Function&) { throw RecursiveMethodCallError(); } -Value* try_emit_call_to( +void Function::ensure_defined() { + try { + if (function_creator_) { + auto creator = function_creator_; + function_creator_ = placeholderCreator; + creator(*this); + function_creator_ = nullptr; + } + } catch (RecursiveMethodCallError&) { + throw ErrorReport() // TODO: once lower_first_class methods is removed + // re-establish callsite info for debugging + << " method '" << name() << "' is called recursively. " + << "Recursive calls are not supported"; + } +} + +Value* Function::try_emit_call( Graph& graph, const SourceRange& loc, - Method& callee, c10::optional self, ArrayRef args, ArrayRef kwargs, std::stringstream& failure_messages, - Method* caller, bool conv_tensors_to_nums) { - try { - callee.ensure_defined(); - } catch (RecursiveMethodCallError&) { - throw ErrorReport(loc) - << " method '" << callee.name() - << "' is called recursively involving this call site. " - << "Recursive calls are not supported"; - } - auto fn = callee.graph(); + ensure_defined(); + auto fn = this->graph(); auto matched_schema = tryMatchSchema( - callee.getSchema(), + getSchema(), loc, graph, std::move(self), @@ -47,52 +55,29 @@ Value* try_emit_call_to( if (!matched_schema) return nullptr; - // parameters to callee method (which become parameters to _this_ method - // if they were not already) - for (const auto& member : callee.initial_ivalues()) { - if (!caller) { - throw ErrorReport(loc) - << " attempting to call a method with parameters/attributes" - " from a raw graph. File a bug report"; - } - matched_schema->inputs.push_back( - caller->get_or_add_attribute(member)); - } - callee.check_single_output(); - return inlineCallTo(graph, *callee.graph(), matched_schema->inputs).at(0); + check_single_output(); + return inlineCallTo(graph, *fn, matched_schema->inputs).at(0); } -Value* Method::emit_call_to( +Value* Function::emit_call( + Graph& graph, const SourceRange& loc, - Method& callee, ArrayRef args, ArrayRef kwargs) { - AT_ASSERT(!executor); std::stringstream failure_messages; - if (auto result = try_emit_call_to( - *graph(), + if (auto result = try_emit_call( + graph, loc, - callee, c10::nullopt, args, kwargs, failure_messages, - this, /*conv_tensors_to_nums=*/true)) { return result; } throw ErrorReport(loc) << failure_messages.str(); } -void Method::ensure_defined() { - if (method_creator) { - auto creator = method_creator; - method_creator = placeholderCreator; - creator(*this); - method_creator = nullptr; - } -} - void Module::to(at::Device device, at::ScalarType dtype, bool non_blocking) { to_impl(device, dtype, non_blocking); } @@ -137,6 +122,229 @@ void Module::to_impl( } } +// lower_first_class_method and lift_lowered_method are transitionary functions +// used to translate between module-as-first-class code generation, +// and module-as-special execution. Once module-as-first-class execution is +// debugged, then we can remove both and remove the lowered_functions_ table. + +// remove the first module argument, replacing any access of its +// parameters/attributes with extra_ivalue input Slots that hold what value to +// pass into the graph +std::pair, std::vector> lower_graph( + const ModulePtr& self, + Graph& g_, + size_t self_offset = 0) { + std::shared_ptr g = g_.copy(); + std::vector extra_ivalues; + std::unordered_map slot_to_offset; + struct ToScan { + ModulePtr mod; + Node* n; + size_t offset; + }; + std::vector to_scan; + std::vector to_clean; // nodes that should be dead at the end + + auto getOrAddSlot = [&](const Slot& slot) -> Value* { + auto it = slot_to_offset.find(slot); + if (it != slot_to_offset.end()) { + size_t ivalues_start = g->inputs().size() - extra_ivalues.size(); + return g->inputs().at(ivalues_start + it->second); + } + extra_ivalues.emplace_back(slot); + slot_to_offset[slot] = extra_ivalues.size() - 1; + return g->addInput()->setType(slot.type()); + }; + + auto self_value = g->inputs().at(self_offset); + + for (Use use : self_value->uses()) { + to_scan.emplace_back(ToScan{self, use.user, use.offset}); + } + while (to_scan.size() > 0) { + auto e = to_scan.back(); + to_scan.pop_back(); + + // when we lambda lift forks, first-class modules may be passed across + // forks. This code recursively lowers the module in the fork call. + if (e.n->kind() == prim::fork) { + auto subgraph = e.n->g(attr::Subgraph); + std::vector new_slots; + std::tie(subgraph, new_slots) = lower_graph(e.mod, *subgraph, e.offset); + e.n->g_(attr::Subgraph, subgraph); + for (const Slot& slot : new_slots) { + e.n->addInput(getOrAddSlot(slot)); + } + e.n->removeInput(e.offset); + continue; + } + if (e.n->kind() != prim::GetAttr) { + throw ErrorReport(e.n->getSourceLocation()) + << "temporary: the only valid use of a module is looking up an attribute"; + } + Slot slot(e.mod, e.mod->type()->getAttributeSlot(e.n->s(attr::name))); + if (ClassTypePtr c = e.n->output()->type()->cast()) { + if (c->name() == "Module") { + auto obj = slot.value().toObject(); + for (Use use : e.n->output()->uses()) { + to_scan.emplace_back(ToScan{obj, use.user, use.offset}); + } + to_clean.emplace_back(e.n); + continue; + } + } + e.n->output()->replaceAllUsesWith(getOrAddSlot(slot)); + e.n->destroy(); + } + + while (to_clean.size() > 0) { + Node* n = to_clean.back(); + AT_ASSERT(!n->hasUses()); + n->destroy(); + to_clean.pop_back(); + } + AT_ASSERT(!self_value->hasUses()); + g->eraseInput(self_offset); + + return std::make_pair(std::move(g), std::move(extra_ivalues)); +} + +Method& Module::lower_first_class_method(Function* fn) { + fn->ensure_defined(); + auto lowered = lower_graph(module_object(), *fn->graph()); + Function& new_func = + lowered_methods_.create_function(fn->name(), lowered.first); + + // generate the new schema + // slice away the self argument + std::vector args( + fn->getSchema().arguments().begin() + 1, + fn->getSchema().arguments().end()); + size_t id = 0; + for (const Slot& slot : lowered.second) { + std::ostringstream ss; + ss << "slot" << id++; + args.emplace_back(ss.str(), slot.type()); + } + new_func.setSchema(fn->getSchema().cloneWithArguments(std::move(args))); + return _create_lowered_method(&new_func, std::move(lowered.second)); +} + +static void createFirstClassValues( + Module* module, + Value* self, + std::unordered_map& result) { + auto& g = *self->owningGraph(); + + std::vector created; + struct ToScan { + Module* mod; + Value* v; // value representing module in the graph + }; + std::vector to_scan = {{module, self}}; + + while (!to_scan.empty()) { + auto s = to_scan.back(); + to_scan.pop_back(); + size_t offset = 0; + for (const std::string& name : + s.mod->module_object()->type()->attributeNames()) { + Value* v = g.insertGetAttr(s.v, name); + result[Slot(s.mod->module_object(), offset++)] = v; + if (std::shared_ptr sub = s.mod->find_module(name)) { + to_scan.emplace_back(ToScan{sub.get(), v}); + } + } + } +} + +void Module::lift_lowered_method(Method& m) { + auto graph = m.graph()->copy(); + Value* self = graph->insertInput(0, "self")->setType(module_object()->type()); + std::unordered_map slot_to_value; + if (!m.initial_ivalues().empty()) { + WithInsertPoint guard(*graph->nodes().begin()); + createFirstClassValues(this, self, slot_to_value); + } + + size_t orig_graph_inputs_size = graph->inputs().size(); + for (size_t i = 0; i < m.initial_ivalues().size(); ++i) { + size_t input_offset = orig_graph_inputs_size - i - 1; + size_t ivalue_offset = m.initial_ivalues().size() - i - 1; + graph->inputs() + .at(input_offset) + ->replaceAllUsesWith( + slot_to_value.at(m.initial_ivalues().at(ivalue_offset))); + graph->eraseInput(input_offset); + } + + if (!m.initial_ivalues().empty()) { + // we added _all_ the submodules as first-class values but maybe did not use + // them. So remove any dead attribute lookups + EliminateDeadCode(graph); + } + + Function& new_fn = class_cu().create_function(m.name(), std::move(graph)); + // created lifted schema + // self argument is named '$self' to prevent accidental name collisions + // with another input that the user named 'self' + std::vector new_args = {Argument("$self", module_object()->type())}; + const auto& lowered_args = m.function().getSchema().arguments(); + new_args.insert( + new_args.end(), + lowered_args.begin(), + lowered_args.begin() + m.num_inputs()); + new_fn.setSchema(m.function().getSchema().cloneWithArguments(std::move(new_args))); +} + +Method& Module::_create_lowered_method( + Function* func, + std::vector member_inputs) { + std::unique_ptr m(new Method(this, func, std::move(member_inputs))); + return *insert(func->name(), methods_, EntityType::METHOD, std::move(m)); +} + +void Module::lift_lowered_methods(size_t start) { + for (size_t i = start; i < lowered_methods_.get_functions().size(); ++i) { + Method& m = _create_lowered_method( + lowered_methods_.get_functions().at(i).get(), {}); + lift_lowered_method(m); + } +} + +void Module::_define_lowered( + const std::vector& definitions, + const std::vector& resolvers) { + size_t start = lowered_methods_.get_functions().size(); + lowered_methods_.define(definitions, resolvers, nullptr); + lift_lowered_methods(start); + // call lift_lowered_method for each definition +} + +void Module::_define_lowered(const std::string& src, const Resolver& resolver) { + size_t start = lowered_methods_.get_functions().size(); + lowered_methods_.define(src, resolver, nullptr); + lift_lowered_methods(start); +} + +Method& Module::_define_lowered( + std::string name, + std::shared_ptr graph, + std::vector slots) { + Method& m = _create_lowered_method( + &lowered_methods_.create_function(std::move(name), std::move(graph)), + std::move(slots)); + lift_lowered_method(m); + return m; +} + +void Module::define(const std::string& src, const Resolver& resolver) { + class_cu().define( + src, + resolver ? resolver : nativeResolver, + simpleSelf(module_object()->type())); +} + } // namespace script } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 13233d9..e5b3a65 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -39,6 +40,7 @@ using ::c10::FunctionSchema; // Map which stores filename to content. using ExtraFilesMap = std::unordered_map; +using ModulePtr = c10::intrusive_ptr; // A method in a module, e.g. f in: // // class M(ScriptModule): @@ -53,320 +55,110 @@ struct Module; using ModuleLookup = std::function(const std::vector&)>; -struct Method { - Method( - Module* owner, - std::string name, - bool optimize, - std::shared_ptr graph, - std::vector initial_members, - std::function method_creator) +struct TORCH_API Method { + Method(Module* owner, Function* function, std::vector initial_members) : owner_(owner), - name_(std::move(name)), - graph_(std::move(graph)), - optimize(optimize), - initial_ivalues_(std::move(initial_members)), - method_creator(std::move(method_creator)) { - AT_ASSERT(graph_->inputs().size() >= initial_ivalues_.size()); - int i = graph_->inputs().size() - initial_ivalues_.size(); - for (auto member : initial_ivalues_) { - initial_ivalue_index[member] = i++; - } + function_(function), + initial_ivalues_(std::move(initial_members)) { + AT_ASSERT(function->num_inputs() >= initial_ivalues_.size()); + } + + // the module that contains this method. + Module& owner() const { + return *owner_; } void run(Stack& stack) { for (auto input : initial_ivalues_) { push(stack, input.value()); } - get_executor().run(stack); + function_->run(stack); } - void run(Stack&& stack) { run(stack); } IValue operator()(std::vector stack) { - checkInputsAgainstSchema(stack); - run(stack); - return stack.front(); - } - - std::shared_ptr graph_for(Stack inputs) { - for (auto tp : initial_ivalues_) { - inputs.emplace_back(tp.value()); - } - return get_executor().graphFor(inputs); - } - TORCH_API std::shared_ptr graph() const { - return graph_; - } - - TORCH_API const std::string& name() const { - return name_; - } - // emit a function call by inlining the callees Graph into this one - // adding any extra parameters necessary to do this call - - // defined here to keep details of member_input handling confined to this - // class - Value* emit_call_to( - const SourceRange& loc, - Method& callee, - ArrayRef args, - ArrayRef kwargs); - - // if this isn't yet defined, run its method_creator function - TORCH_API void ensure_defined(); - - size_t num_inputs() const { - return graph()->inputs().size() - initial_ivalues_.size(); - } - TORCH_API Value* get_or_add_parameter(Slot slot) { - AT_ASSERT(slot.value().isTensor()); - return get_or_add_attribute(slot); - } - TORCH_API Value* get_or_add_attribute(Slot slot) { - auto it = initial_ivalue_index.find(slot); - if (it != initial_ivalue_index.end()) { - return graph()->inputs().at(it->second); - } - initial_ivalues_.push_back(slot); - initial_ivalue_index[slot] = graph()->inputs().size(); - return graph()->addInput()->setType(slot.type()); - } - - static void setInputTensorTypes(Graph& g, const Stack& stack) { - AT_ASSERT(stack.size() == g.inputs().size()); - for (size_t i = 0; i < stack.size(); ++i) { - g.inputs().at(i)->setType( - DimensionedTensorType::create(stack.at(i).toTensor())); - } - } - - std::shared_ptr propagate_shapes( - std::vector inputs, - bool with_grad = false) { - auto retval = graph_->copy(); - Stack stack; - stack.reserve(inputs.size() + initial_ivalues_.size()); - for (at::Tensor& i : inputs) { - stack.emplace_back(std::move(i)); - } - for (const Slot& inp : initial_ivalues_) { - stack.push_back(inp.value()); - } - setInputTensorTypes(*retval, stack); - PropagateInputShapes(retval); - return retval; - } - - std::shared_ptr propagate_and_assign_input_and_output_shapes( - std::vector inputs, - std::vector outputs, - bool with_grad = false, - bool propagate = true) { - auto retval = graph_->copy(); - for (auto inp : initial_ivalues_) { - if (inp.value().isTensor()) { - inputs.push_back(inp.value().toTensor()); - } - } - if (propagate) { - setInputTensorTypes(*retval, fmap(inputs)); - PropagateInputShapes(retval); - } - AT_ASSERT(retval->inputs().size() == inputs.size()); - for (size_t i = 0; i < retval->inputs().size(); ++i) { - auto scalar_type = inputs[i].scalar_type(); - auto sizes = inputs[i].sizes(); - auto type = - torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); - retval->inputs()[i]->setType(type); - } - at::ArrayRef output_values = retval->outputs(); - // patch this to still work if we are returning a tuple of multiple values - if (output_values.at(0)->type()->kind() == TupleType::Kind) { - AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct); - output_values = output_values.at(0)->node()->inputs(); - } - AT_ASSERT(output_values.size() == outputs.size()); - for (size_t i = 0; i < retval->outputs().size(); ++i) { - auto scalar_type = outputs[i].scalar_type(); - auto sizes = outputs[i].sizes(); - auto type = - torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); - output_values[i]->setType(type); + getSchema().checkAndNormalizeInputs(stack); + for (auto input : initial_ivalues_) { + push(stack, input.value()); } - return retval; + // use run rather than operator() to skip the second schema check. + function_->run(std::move(stack)); + return stack.front(); } const std::vector& initial_ivalues() const { return initial_ivalues_; } - Method& setSchema(FunctionSchema schema_) { - schema = make_unique(std::move(schema_)); - return *this; - } - - TORCH_API const FunctionSchema& getSchema() const { - if (schema == nullptr) { - schema = make_unique(defaultSchemaFor(*this)); + // proxies for underlying unbound Function + std::shared_ptr graph_for(Stack inputs) { + for (auto tp : initial_ivalues_) { + inputs.emplace_back(tp.value()); } - return *schema; + return function_->get_executor().graphFor(inputs); } - std::string pretty_print_schema() const { - AT_ASSERT(schema); - std::stringstream ss; - ss << *schema; - return ss.str(); + std::shared_ptr graph() const { + return function_->graph(); } - GraphExecutorState getDebugState() { - return get_executor().getDebugState(); + const std::string& name() const { + return function_->name(); } - void debugDisableAutodiffSubgraphInlining() { - return get_executor().debugDisableAutodiffSubgraphInlining(); + size_t num_inputs() const { + return function_->num_inputs() - initial_ivalues_.size(); } - bool is_optimized() const { - return optimize; + FunctionSchema getSchema() const { + // we are required to slice out the slot inputs from the schema + // we can't cache this because setSchema on the underlying function + // will change the underlying schema + auto sliced = ArrayRef(function_->getSchema().arguments()) + .slice(0, num_inputs()); + return function_->getSchema().cloneWithArguments(sliced.vec()); } - // the module that contains this method. - Module& owner() const { - return *owner_; + GraphExecutor& get_executor() { + return function_->get_executor(); } - void check_single_output() { - AT_CHECK( - graph()->outputs().size() == 1, - "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs"); + Function& function() const { + return *function_; } private: - static FunctionSchema defaultSchemaFor(const Method& method) { - std::vector args; - std::vector returns; - Graph& g = *method.graph(); - size_t num_inputs = method.num_inputs(); - for (size_t i = 0; i < num_inputs; ++i) { - const Value* v = g.inputs().at(i); - std::string name = v->hasUniqueName() ? v->uniqueNameBase() - : ("argument_" + std::to_string(i)); - args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type())); - } - for (size_t i = 0; i < g.outputs().size(); ++i) { - returns.emplace_back("", unshapedType(g.outputs()[i]->type())); - } - return {method.name(), "", std::move(args), std::move(returns)}; - } - - GraphExecutor& get_executor() { - std::call_once(executor_init, [&] { - check_single_output(); - executor = GraphExecutor(graph(), optimize); - }); - return executor; - } - - void checkInputsAgainstSchema(std::vector& inputs) { - const auto& schema = getSchema(); - // Do we have more inputs than the schema accepts? - AT_CHECK( - inputs.size() <= schema.arguments().size(), - "Expected at most ", - schema.arguments().size(), - " argument(s) for operator '", - schema.name(), - "', but received ", - inputs.size(), - " argument(s). Declaration: ", - schema); - - for (size_t pos = 0; pos < schema.arguments().size(); ++pos) { - const auto& argument = schema.arguments()[pos]; - if (pos < inputs.size()) { - if (!isSubvalueOf(inputs[pos], argument.type())) { - AT_ERROR( - "Expected value of type ", - *argument.type(), - " for argument '", - argument.name(), - "' in position ", - pos, - ", but instead got value of type ", - attemptToRecoverType(inputs[pos])->str(), - ". Declaration: ", - schema); - } - } else if (argument.default_value()) { - inputs.push_back(*argument.default_value()); - } else { - AT_ERROR( - schema.name(), - "() is missing value for argument '", - argument.name(), - "'. Declaration: ", - schema); - } - } - } - // Methods are uniqued onwed by a single module. This raw pointer allows // looking up the module. Module* owner_; - std::string name_; - std::shared_ptr graph_; // for debugging and for inlining - bool optimize; - - GraphExecutor executor; // for execution - // initial_ivalues are a list of additional arguments appended to graph - // that are inputs that come from the members of the Module or its submodules. - // each is a pointer to a slot in the module that owns this parameter - // parameters and submodules can only be _added_ to script Modules to ensure - // these pointers always stay valid - std::vector initial_ivalues_; + // Underlying unbound function + Function* function_; - // map from a IValue* in initial_ivalues to the offset it appears at - // in graph. used to accelerate get_or_add_parameter - std::unordered_map initial_ivalue_index; - - // TODO: support that case where we allow _writes_ to parameters from - // compiled functions. - // This requires more sophisticated tracking of ssa values in Graphs so that - // stores to all modules can be lifted to the end of a graph execution. - // It also adds more complexity to adding actual module invocations - // to the executor, so currently it is not done. - // std::vector member_outputs; - - std::once_flag executor_init; - - // an optional function that actually creates the method when - // emit_call_to(this,...) is first called. this is used by the compiler so - // that it can construct methods out of order - std::function method_creator; - - // if absent, then we generate a default schema based on the graph - // mutable because getSchema caches the default schema if one is requested - // before a call to setSchema - mutable std::unique_ptr schema; + // parameters and attributes loaded from the Module and appending + // before calling function_ + std::vector initial_ivalues_; }; struct Module; -struct Module { +struct TORCH_API Module { TH_DISALLOW_COPY_AND_ASSIGN(Module); Module() : name_("__main__"), module_value_(c10::ivalue::Object::create( - ClassType::createModuleType(), - 0)), - optimize_(true) {} + ClassType::createModuleType(std::make_shared()), + 0)) {} + ~Module() { + // ClassType own the compilation unit of their Functions, but each + // Function has a self argument which owns the ClassType, created a + // referernce cycle. By dropping all the methods of the module's class + // here we break the cycle. + class_cu().drop_all_functions(); + } const std::string& name() const { return name_; } @@ -374,11 +166,12 @@ struct Module { // note this doesn't change the flags of existing methods just ones // added afterward. void set_optimized(bool o) { - optimize_ = o; + lowered_methods().set_optimized(o); + class_cu().set_optimized(o); } bool is_optimized() const { - return optimize_; + return class_cu().is_optimized(); } IValue forward(std::vector inputs) { @@ -395,7 +188,7 @@ struct Module { name, attributes_, EntityType::ATTRIBUTE, - appendSlot(name, TensorType::get(),std::move(v))); + appendSlot(name, TensorType::get(), std::move(v))); } void register_parameter( @@ -420,7 +213,11 @@ struct Module { const std::string& name, const TypePtr type, IValue ivalue) { - insert(name, attributes_, EntityType::ATTRIBUTE, appendSlot(name, type, ivalue)); + insert( + name, + attributes_, + EntityType::ATTRIBUTE, + appendSlot(name, type, ivalue)); } void register_module( const std::string& name, @@ -434,9 +231,10 @@ struct Module { // AT_WARN( // "Attempting to assign submodule '", // name, - // "' but it is already a submodule of another ScriptModule '", module->parent_->name(), "'", - // " Modules of this form do not import and export correctly. This use is deprecated and may be" - // " removed in a future version."); + // "' but it is already a submodule of another ScriptModule '", + // module->parent_->name(), "'", " Modules of this form do not import + // and export correctly. This use is deprecated and may be" " removed + // in a future version."); // } module->parent_ = this; module->name_ = name; @@ -444,34 +242,6 @@ struct Module { insert(name, modules_, EntityType::MODULE, std::move(module)); } - Method& create_method( - const std::string& name, - std::shared_ptr graph, - std::vector member_inputs) { - AT_ASSERT(graph); - std::unique_ptr method(new Method( - this, - name, - optimize_, - std::move(graph), - std::move(member_inputs), - nullptr)); - return *insert(name, methods_, EntityType::METHOD, std::move(method)); - } - - Method& create_method( - const std::string& name, - std::function creator) { - std::unique_ptr method(new Method( - this, - name, - optimize_, - std::make_shared(), - {}, - std::move(creator))); - return *insert(name, methods_, EntityType::METHOD, std::move(method)); - } - Slot parameter_slot(const std::string& name) const { return parameters_[get_offset(name, EntityType::PARAMETER)]; } @@ -495,7 +265,14 @@ struct Module { // each module owns its method. The reference returned here // is guarenteed to stay valid until this module has been destroyed Method& get_method(const std::string& name) const { - return *methods_[get_offset(name, EntityType::METHOD)]; + if (Method* method = find_method(name)) { + return *method; + } + // temporary: force the error message + // once the on-demand creation of Method is removed, this code + // can be removed as well + get_offset(name, EntityType::METHOD); + AT_ERROR("unreachable"); } std::shared_ptr get_module(const std::string& name) const { @@ -511,7 +288,12 @@ struct Module { c10::ArrayRef get_attributes() const { return attributes_; } - c10::ArrayRef> get_methods() const { + const std::vector>& get_methods() const { + // force methods_ to be up to date by querying all + // methods. This will go away when lowered_methods_ is deleted + for (const auto& m : class_cu().get_functions()) { + get_method(m->name()); + } return methods_; } @@ -534,9 +316,22 @@ struct Module { auto offset = find_offset(name, EntityType::MODULE); return offset ? modules_[*offset] : nullptr; } - Method* find_method(const std::string& name) { + Method* find_method(const std::string& name) const { auto offset = find_offset(name, EntityType::METHOD); - return offset ? methods_[*offset].get() : nullptr; + if (offset) { + return methods_[*offset].get(); + } + + if (Function* fn = class_cu().find_function(name).get()) { + // temporary lock because technically this is marked const, + // but we have to update the internal Method cache. + // This can be removed when class_cu() is the source of truth for + // methods. + std::lock_guard guard(find_method_guard_); + return &const_cast(this)->lower_first_class_method(fn); + } + + return nullptr; } void apply(std::function fn) { for (auto& submod : get_modules()) { @@ -571,10 +366,7 @@ struct Module { /// destination is on the GPU or vice versa, the copy is performed /// asynchronously with respect to the host. Otherwise, the argument has no /// effect. - TORCH_API void to( - at::Device device, - at::ScalarType dtype, - bool non_blocking = false); + void to(at::Device device, at::ScalarType dtype, bool non_blocking = false); /// Recursively casts all parameters to the given dtype. /// @@ -582,7 +374,7 @@ struct Module { /// destination is on the GPU or vice versa, the copy is performed /// asynchronously with respect to the host. Otherwise, the argument has no /// effect. - TORCH_API void to(at::ScalarType dtype, bool non_blocking = false); + void to(at::ScalarType dtype, bool non_blocking = false); /// Recursively moves all parameters to the given device. /// @@ -590,7 +382,7 @@ struct Module { /// destination is on the GPU or vice versa, the copy is performed /// asynchronously with respect to the host. Otherwise, the argument has no /// effect. - TORCH_API void to(at::Device device, bool non_blocking = false); + void to(at::Device device, bool non_blocking = false); /// Run a method from this module. /// @@ -646,26 +438,57 @@ struct Module { mod->copy_into(module_lookup, parameter_remap, names); names.pop_back(); } - for (auto& method : get_methods()) { - std::vector initial_ivalues; - for (auto& p : method->initial_ivalues()) { - initial_ivalues.push_back(parameter_remap.at(p)); - } - curr->create_method( - method->name(), method->graph()->copy(), initial_ivalues); + + for (auto& fn : class_cu().get_functions()) { + curr->class_cu().clone_function(*fn); } } enum class EntityType { MODULE, PARAMETER, ATTRIBUTE, METHOD }; at::optional kind_of(const std::string& name) const { + // force lazy creation of Method if needed + // remove once lowered_methods_ is removed. + find_method(name); + auto it = dict_.find(name); if (it == dict_.end()) return at::nullopt; return it->second.type; } + ModulePtr module_object() const { + return module_value_; + } + CompilationUnit& class_compilation_unit() { + return module_object()->type()->compilation_unit(); + } + CompilationUnit& lowered_methods() const { + return lowered_methods_; + } + + // so that C++ users can easily add methods + void define(const std::string& src, const Resolver& resolver = nullptr); + + void _define_lowered( + const std::vector& definitions, + const std::vector& resolvers); + void _define_lowered(const std::string& src, const Resolver& resolver); + + Method& _define_lowered( + std::string name, + std::shared_ptr graph, + std::vector slots); + private: + Method& _create_lowered_method( + Function* func, + std::vector member_inputs); + + Method& lower_first_class_method(Function* fn); + void lift_lowered_method(Method& fn); + void lift_lowered_methods(size_t start); + void to_impl( const c10::optional& device, const c10::optional& dtype, @@ -753,6 +576,13 @@ struct Module { return Slot(module_value_, slot_index); } + CompilationUnit& class_cu() { + return module_value_->type()->compilation_unit(); + } + const CompilationUnit& class_cu() const { + return module_value_->type()->compilation_unit(); + } + // modules have a single namespace, but spread over 4 different concepts: // parameters, attributes, methods, and sub-modules // we store individual lists of each concept, and a single map to @@ -770,29 +600,97 @@ struct Module { std::unordered_map dict_; std::string name_; - - c10::intrusive_ptr module_value_; - + ModulePtr module_value_; // back reference to parent of this Module if present Module* parent_ = nullptr; - bool optimize_; + + // Currently we are in a transitionary state + // where we construct such first class functions but we lower them + // to a form where the modules does not exist before execution. + + // So each Method is actually stored twice once in first-class Module + // form and once in lowered form. + + // first-class: module_value_->type().compilation_unit() holds Functions that + // treat modules as first class. + + // lowered: In this lowered form, all the attributes/parameters are appended + // as additional inputs. lowered_methods_ holds this lowered form + // mutable because it is a cache for class_cu() methods + mutable CompilationUnit lowered_methods_; + mutable std::recursive_mutex find_method_guard_; }; -// returns nullptr and fills in failure_messages if the callee does not -// match the functions schema -Value* try_emit_call_to( +static void setInputTensorTypes(Graph& g, const Stack& stack) { + AT_ASSERT(stack.size() == g.inputs().size()); + for (size_t i = 0; i < stack.size(); ++i) { + g.inputs().at(i)->setType( + DimensionedTensorType::create(stack.at(i).toTensor())); + } +} + +inline std::shared_ptr propagate_shapes( + Graph& graph, + const std::vector& inputs, + const std::vector& initial_ivalues, + bool with_grad = false) { + auto retval = graph.copy(); + Stack stack; + stack.reserve(inputs.size() + initial_ivalues.size()); + for (const at::Tensor& i : inputs) { + stack.emplace_back(std::move(i)); + } + for (const Slot& inp : initial_ivalues) { + stack.push_back(inp.value()); + } + setInputTensorTypes(*retval, stack); + PropagateInputShapes(retval); + return retval; +} + +inline std::shared_ptr propagate_and_assign_input_and_output_shapes( Graph& graph, - const SourceRange& loc, - Method& callee, - c10::optional self, - ArrayRef args, - ArrayRef kwargs, - std::stringstream& failure_messages, - // when callee uses no parameters (e.g. it is a function in a compilation - // unit, and not a method), then nullptr can be passed as caller. - Method* caller, - bool conv_tensors_to_nums); + std::vector inputs, + const std::vector& initial_ivalues, + std::vector outputs, + bool with_grad = false, + bool propagate = true) { + auto retval = graph.copy(); + for (auto inp : initial_ivalues) { + if (inp.value().isTensor()) { + inputs.push_back(inp.value().toTensor()); + } + } + if (propagate) { + setInputTensorTypes(*retval, fmap(inputs)); + PropagateInputShapes(retval); + } + AT_ASSERT(retval->inputs().size() == inputs.size()); + for (size_t i = 0; i < retval->inputs().size(); ++i) { + auto scalar_type = inputs[i].scalar_type(); + auto sizes = inputs[i].sizes(); + auto type = + torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); + retval->inputs()[i]->setType(type); + } + at::ArrayRef output_values = retval->outputs(); + // patch this to still work if we are returning a tuple of multiple values + if (output_values.at(0)->type()->kind() == TupleType::Kind) { + AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct); + output_values = output_values.at(0)->node()->inputs(); + } + AT_ASSERT(output_values.size() == outputs.size()); + for (size_t i = 0; i < retval->outputs().size(); ++i) { + auto scalar_type = outputs[i].scalar_type(); + auto sizes = outputs[i].sizes(); + auto type = + torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); + output_values[i]->setType(type); + } + return retval; +} + } // namespace script } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/script/schema_matching.cpp b/torch/csrc/jit/script/schema_matching.cpp index 30c273d..0c5676c 100644 --- a/torch/csrc/jit/script/schema_matching.cpp +++ b/torch/csrc/jit/script/schema_matching.cpp @@ -421,16 +421,14 @@ Value* emitBuiltinCall( return emitBuiltinNode(*matched_schema, loc, graph, name); } } - for (Method* method : builtin_functions) { - if (auto result = try_emit_call_to( + for (Function* method : builtin_functions) { + if (auto result = method->try_emit_call( graph, loc, - *method, self, inputs, attributes, failure_messages, - nullptr, allow_conversions)) { return result; } diff --git a/torch/csrc/jit/script/slot.h b/torch/csrc/jit/script/slot.h index 3e01731..0304c0e 100644 --- a/torch/csrc/jit/script/slot.h +++ b/torch/csrc/jit/script/slot.h @@ -33,6 +33,7 @@ private: c10::intrusive_ptr container_; size_t offset_; friend struct std::hash; + friend struct Module; }; }}} diff --git a/torch/csrc/jit/script/sugared_value.cpp b/torch/csrc/jit/script/sugared_value.cpp index 48c5337..6410282 100644 --- a/torch/csrc/jit/script/sugared_value.cpp +++ b/torch/csrc/jit/script/sugared_value.cpp @@ -16,7 +16,7 @@ struct NoneValue : SugaredValue { std::shared_ptr PrintValue::call( const SourceRange& loc, - Method& m, + Function& m, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) { @@ -58,7 +58,7 @@ builtin_cast_methods() { std::shared_ptr BuiltinFunction::call( const SourceRange& loc, - Method& m, + Function& m, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) { @@ -70,7 +70,7 @@ std::shared_ptr BuiltinFunction::call( // callable value that will resolve to foo(x, y, z) when called. std::shared_ptr SimpleValue::attr( const SourceRange& loc, - Method& m, + Function& m, const std::string& field) { // Allow method-style casts on Tensor types. e.g. x.int() if (value_->type()->isSubtypeOf(TensorType::get())) { @@ -116,7 +116,7 @@ std::shared_ptr SimpleValue::attr( if (auto classType = value_->type()->cast()) { // This is a class, emit the proper attribute lookup if (auto method = classType->getMethod(field)) { - return std::make_shared(shared_from_this(), *method); + return std::make_shared(getValue(), *method); } if (!classType->hasAttribute(field)) { @@ -135,7 +135,7 @@ std::shared_ptr SimpleValue::attr( std::vector> SimpleValue::asTuple( const SourceRange& loc, - Method& m, + Function& m, const c10::optional& size_hint) { static const auto make_simple_value = [](Value* v) -> std::shared_ptr { @@ -161,7 +161,7 @@ std::vector> SimpleValue::asTuple( void SimpleValue::setAttr( const SourceRange& loc, - Method& m, + Function& m, const std::string& field, Value* newValue) { const auto classType = value_->type()->cast(); @@ -217,7 +217,7 @@ void SimpleValue::setAttr( std::shared_ptr ClassValue::call( const SourceRange& loc, - Method& m, + Function& m, // note: names for args will be 'argument 0', 'argument 1', etc.. at::ArrayRef inputs, at::ArrayRef attributes, @@ -226,8 +226,7 @@ std::shared_ptr ClassValue::call( // Generate a new object of the right type, then call `__init__` on it auto& g = *m.graph(); - auto createNode = g.insertNode(g.createObject(type_)); - auto self = std::make_shared(createNode->output()); + auto self = g.insertNode(g.createObject(type_))->output(); auto initMethod = type_->getMethod("__init__"); AT_ASSERT(initMethod); @@ -235,12 +234,12 @@ std::shared_ptr ClassValue::call( // Call the init function MethodValue(self, *initMethod).call(loc, m, inputs, attributes, n_binders); - return self; + return std::make_shared(self); } std::shared_ptr ClassValue::attr( const SourceRange& loc, - Method& m, + Function& m, const std::string& field) { if (field != "__new__") { throw ErrorReport(loc) << "Tried to lookup unknown attribute on class"; diff --git a/torch/csrc/jit/script/sugared_value.h b/torch/csrc/jit/script/sugared_value.h index e1fd725..6f4117d 100644 --- a/torch/csrc/jit/script/sugared_value.h +++ b/torch/csrc/jit/script/sugared_value.h @@ -28,14 +28,14 @@ struct SugaredValue : public std::enable_shared_from_this { // what can we do with this thing? // use it as a value e.g. `this + 4` - virtual Value* asValue(const SourceRange& loc, Method& m) { + virtual Value* asValue(const SourceRange& loc, Function& m) { throw ErrorReport(loc) << kind() << " cannot be used as a value"; } // select an attribute on it, e.g. `this.field` virtual std::shared_ptr attr( const SourceRange& loc, - Method& m, + Function& m, const std::string& field) { throw ErrorReport(loc) << "attribute lookup is not defined on " << kind(); } @@ -43,7 +43,7 @@ struct SugaredValue : public std::enable_shared_from_this { // assign an attribute on it, e.g. `this.field = newValue` virtual void setAttr( const SourceRange& loc, - Method& m, + Function& m, const std::string& field, Value* newValue) { throw ErrorReport(loc) << "attribute assignment is not defined on " @@ -57,7 +57,7 @@ struct SugaredValue : public std::enable_shared_from_this { // a method invocation virtual std::vector> asTuple( const SourceRange& loc, - Method& m, + Function& m, const c10::optional& size_hint = {}) { throw ErrorReport(loc) << kind() << " cannot be used as a tuple"; } @@ -65,7 +65,7 @@ struct SugaredValue : public std::enable_shared_from_this { // call it like a function, e.g. `outputs = this(inputs)` virtual std::shared_ptr call( const SourceRange& loc, - Method& m, + Function& m, // note: names for args will be 'argument 0', 'argument 1', etc.. at::ArrayRef inputs_, at::ArrayRef attributes, @@ -97,7 +97,7 @@ struct TORCH_API SimpleValue : public SugaredValue { std::string kind() const override { return "value"; } - Value* asValue(const SourceRange& range, Method& m) override { + Value* asValue(const SourceRange& range, Function& m) override { return value_; } NoneStatus isNone() override { @@ -110,16 +110,16 @@ struct TORCH_API SimpleValue : public SugaredValue { } std::vector> asTuple( const SourceRange& loc, - Method& m, + Function& m, const c10::optional& size_hint = {}) override; std::shared_ptr attr( const SourceRange& loc, - Method& m, + Function& m, const std::string& field) override; void setAttr( const SourceRange& loc, - Method& m, + Function& m, const std::string& field, Value* newValue) override; @@ -146,7 +146,7 @@ struct TORCH_API BuiltinFunction : public SugaredValue { } std::shared_ptr call( const SourceRange& loc, - Method& m, + Function& m, at::ArrayRef attributes, at::ArrayRef inputs, size_t n_binders) override; @@ -161,7 +161,7 @@ struct TORCH_API BuiltinModule : public SugaredValue { } std::shared_ptr attr( const SourceRange& loc, - Method& m, + Function& m, const std::string& field) override { return std::make_shared( Symbol::fromQualString(name + "::" + field), c10::nullopt); @@ -183,14 +183,14 @@ struct TORCH_API ClassValue : public SugaredValue { // n = Foo(constructor_arg) std::shared_ptr call( const SourceRange& loc, - Method& m, + Function& m, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override; std::shared_ptr attr( const SourceRange& loc, - Method& m, + Function& m, const std::string& field) override; std::string kind() const override { @@ -202,34 +202,33 @@ struct TORCH_API ClassValue : public SugaredValue { // defines how a method obtained from a module behaves in script struct MethodValue : public SugaredValue { - MethodValue(std::shared_ptr self, Method& method) + MethodValue(c10::optional self, Function& method) : self_(std::move(self)), method(method) {} std::string kind() const override { return "method"; } std::shared_ptr call( const SourceRange& loc, - Method& caller, + Function& f, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override { - if (auto classType = dynamic_cast(self_.get())) { - // If self_ is a class, then it will be expected as part of - // the schema. Add it to the front of the inputs. + Graph& graph = *f.graph(); + if (self_) { std::vector inputsWithSelf; - inputsWithSelf.emplace_back(loc, classType->getValue()); + inputsWithSelf.emplace_back(loc, self_->value(graph)); inputsWithSelf.insert(inputsWithSelf.end(), inputs.begin(), inputs.end()); return std::make_shared( - caller.emit_call_to(loc, method, inputsWithSelf, attributes)); + method.emit_call(graph, loc, inputsWithSelf, attributes)); } return std::make_shared( - caller.emit_call_to(loc, method, inputs, attributes)); + method.emit_call(graph, loc, inputs, attributes)); } private: - std::shared_ptr self_; - Method& method; + c10::optional self_; + Function& method; }; struct TORCH_API PrintValue : public SugaredValue { @@ -238,7 +237,7 @@ struct TORCH_API PrintValue : public SugaredValue { } std::shared_ptr call( const SourceRange& loc, - Method& m, + Function& m, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override; @@ -252,7 +251,7 @@ struct TORCH_API CastValue : public BuiltinFunction { : BuiltinFunction(method, c10::nullopt), type_(std::move(type)) {} std::shared_ptr call( const SourceRange& loc, - Method& m, + Function& m, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override { @@ -316,7 +315,7 @@ struct TORCH_API ClassNewMethod : public SugaredValue { std::shared_ptr createObject( const SourceRange& loc, - Method& m, + Function& m, const std::string& classname) { if (classname != type_->name()) { throw ErrorReport(loc) @@ -338,6 +337,13 @@ static inline std::vector toValues( return fmap(nvs, [&](const NamedValue& v) { return v.value(g); }); } +static inline Self simpleSelf(const TypePtr& typ) { + return [typ](Value* v) { + v->setType(typ); + return std::make_shared(v); + }; +} + } // namespace script } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp index 2a99f72..3433a9d 100644 --- a/torch/csrc/jit/symbolic_script.cpp +++ b/torch/csrc/jit/symbolic_script.cpp @@ -1303,8 +1303,8 @@ bool isHelperFunction(const std::string& method_name) { return method_name.compare(0, helper_prefix.length(), helper_prefix) == 0; } -void loadModule(const std::shared_ptr& module) { - for (const auto& method : module->get_methods()) { +void loadModule(const script::CompilationUnit& module) { + for (const auto& method : module.get_functions()) { if (isHelperFunction(method->name())) continue; @@ -1356,9 +1356,8 @@ void loadModule(const std::shared_ptr& module) { void loadFunctions() { for (const std::string& str : functions) { - auto cu = std::make_shared(); - script::defineMethodsInModule( - cu, str, script::nativeResolver, c10::nullopt); + script::CompilationUnit cu; + cu.define(str, script::nativeResolver, nullptr); loadModule(cu); } } diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 593f508..b4d33ced 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -702,14 +702,8 @@ def _try_get_dispatched_fn(fn): return _jit_internal.boolean_dispatched.get(fn) -def _try_get_overloaded_fn(fn): - if not hasattr(fn, '__self__') or not isinstance(fn.__self__, ScriptModule): - # Only allow overloads for bound methods - return None - overloads = fn.__self__._overloads.get(fn.__name__, None) - if overloads is None: - return None - return [getattr(fn.__self__, overload) for overload in overloads] +def _try_get_overloaded_fn(mod, field): + return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None def _try_compile_weak_script(fn): @@ -738,20 +732,20 @@ def script(obj, optimize=True, _frames_up=0, _rcb=None): return obj if _rcb is None: _rcb = _jit_internal.createResolutionCallback(_frames_up + 1) - mod = ScriptModule() if inspect.isclass(obj): if not _is_new_style_class(obj): raise RuntimeError("TorchScript classes must be new-style classes. Please inherit from 'object'") ast = get_jit_class_def(obj) - _jit_script_class_compile(mod, ast, _rcb) + _jit_script_class_compile(ast, _rcb) _add_script_class(obj, obj.__name__) return obj else: + mod = ScriptModule() ast = get_jit_def(obj) _jit_script_compile(mod, ast, _rcb, get_default_args(obj)) - # Forward docstrings - mod.__doc__ = obj.__doc__ - return mod + # Forward docstrings + mod.__doc__ = obj.__doc__ + return mod ScriptMethodStub = namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method')) -- 2.7.4