From: Zachary DeVito Date: Thu, 11 Apr 2019 13:14:21 +0000 (-0700) Subject: Revert D14842057: Compiler uses first-class modules** X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~273 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f5165ade5b855d25316464455171d8d464fc8b3d;p=platform%2Fupstream%2Fpytorch.git Revert D14842057: Compiler uses first-class modules** Differential Revision: D14842057 Original commit changeset: ca6e7b5a4380 fbshipit-source-id: e8f1862a59bf20d5f78648b2fdc53a8b3750ead3 --- diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index 888cfcb..c5185c1 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -164,18 +164,6 @@ public: } return c10::nullopt; } - FunctionSchema cloneWithArguments(std::vector new_arguments) const { - return FunctionSchema( - name(), - overload_name(), - std::move(new_arguments), - returns(), - is_vararg(), - is_varret()); - } - // Check that inputs have the correct types and appends any missing default - // values. - void checkAndNormalizeInputs(std::vector& inputs) const; }; inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) { @@ -239,46 +227,4 @@ inline std::string toString(const FunctionSchema& schema) { return str.str(); } -inline void FunctionSchema::checkAndNormalizeInputs(std::vector& inputs) const { - // Do we have more inputs than the schema accepts? - AT_CHECK( - inputs.size() <= arguments().size(), - "Expected at most ", - arguments().size(), - " argument(s) for operator '", - name(), - "', but received ", - inputs.size(), - " argument(s). Declaration: ", - *this); - - for (size_t pos = 0; pos < arguments().size(); ++pos) { - const auto& argument = arguments()[pos]; - if (pos < inputs.size()) { - if (!isSubvalueOf(inputs[pos], argument.type())) { - AT_ERROR( - "Expected value of type ", - *argument.type(), - " for argument '", - argument.name(), - "' in position ", - pos, - ", but instead got value of type ", - attemptToRecoverType(inputs[pos])->str(), - ". Declaration: ", - *this); - } - } else if (argument.default_value()) { - inputs.push_back(*argument.default_value()); - } else { - AT_ERROR( - name(), - "() is missing value for argument '", - argument.name(), - "'. Declaration: ", - *this); - } - } -} - } // namespace c10 diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 2399dcc..3656251 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -17,8 +17,8 @@ namespace torch { namespace jit { namespace script { -struct CompilationUnit; -struct Function; +struct Module; +struct Method; } } // namespace jit } // namespace torch @@ -1100,19 +1100,19 @@ CAFFE2_API TypePtr evalTypeVariables(TypePtr type, TypeEnv & type_env); struct ClassType; using ClassTypePtr = std::shared_ptr; -using ::torch::jit::script::CompilationUnit; -using ::torch::jit::script::Function; +using ::torch::jit::script::Module; +using ::torch::jit::script::Method; // This represents a class in TorchScript. struct CAFFE2_API ClassType : public Type { // Create a user type and register it globally. static ClassTypePtr create( const std::string& name, - std::shared_ptr module); + std::shared_ptr module); // Create a type representing a Module, // These do not have methods, and are not globally registered - static ClassTypePtr createModuleType(std::shared_ptr module); + static ClassTypePtr createModuleType(); // returns nullptr if there is no type with that name static ClassTypePtr get(const std::string& name); @@ -1168,11 +1168,8 @@ struct CAFFE2_API ClassType : public Type { return attributeNames_[slot]; } - Function* getMethod(const std::string& name) const; - CompilationUnit& compilation_unit(); - const CompilationUnit& compilation_unit() const; - std::vector methods() const; - + Method* getMethod(const std::string& name) const; + std::vector methods() const; const std::string& name() const { return typename_; @@ -1229,10 +1226,10 @@ struct CAFFE2_API ClassType : public Type { static const TypeKind Kind = TypeKind::ClassType; private: - ClassType(std::string name, std::shared_ptr cu) + ClassType(std::string name, std::shared_ptr module) : Type(TypeKind::ClassType), typename_(std::move(name)), - compilation_unit_(std::move(cu)) {} + module_(std::move(module)) {} // Name of type (note that this has to be globally unique). std::string typename_; @@ -1246,7 +1243,7 @@ struct CAFFE2_API ClassType : public Type { std::vector attributeNames_; std::vector attributeTypes_; // Holds method attributes - std::shared_ptr compilation_unit_; + std::shared_ptr module_; }; } // namespace c10 diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 534c9f3..e6c3628 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -472,18 +472,18 @@ ClassTypeRegistry& getRegistry() { ClassTypePtr ClassType::create( const std::string& name, - std::shared_ptr cu) { - auto ptr = ClassTypePtr(new ClassType(name, std::move(cu))); + std::shared_ptr module) { + auto ptr = ClassTypePtr(new ClassType(name, std::move(module))); getRegistry().registerType(name, ptr); return ptr; } -ClassTypePtr ClassType::createModuleType(std::shared_ptr cu) { - return ClassTypePtr(new ClassType("Module", std::move(cu))); +ClassTypePtr ClassType::createModuleType() { + return ClassTypePtr(new ClassType("Module", nullptr)); } ClassTypePtr ClassType::refine(at::ArrayRef refined_slots) const { - auto ptr = ClassTypePtr(new ClassType(typename_, compilation_unit_)); + auto ptr = ClassTypePtr(new ClassType(typename_, module_)); AT_ASSERT(numAttributes() == refined_slots.size()); for(size_t i = 0; i < attributeNames_.size(); ++i) { AT_ASSERT(refined_slots[i]->isSubtypeOf(attributeTypes_[i])); diff --git a/test/cpp/jit/test.cpp b/test/cpp/jit/test.cpp index 90b38f3..7145c1a 100644 --- a/test/cpp/jit/test.cpp +++ b/test/cpp/jit/test.cpp @@ -65,8 +65,7 @@ namespace jit { _(NoneSchemaMatch) \ _(ClassParser) \ _(PeepholeOptimize) \ - _(RecordFunction) \ - _(ModuleDefine) + _(RecordFunction) #define TH_FORALL_TESTS_CUDA(_) \ _(ArgumentSpec) \ diff --git a/test/cpp/jit/test_misc.h b/test/cpp/jit/test_misc.h index ea20266..8b93c21 100644 --- a/test/cpp/jit/test_misc.h +++ b/test/cpp/jit/test_misc.h @@ -40,7 +40,6 @@ #include "ATen/core/ivalue.h" #include "torch/csrc/jit/script/compiler.h" #include "torch/csrc/jit/script/module.h" -#include "torch/jit.h" #include "onnx/onnx_pb.h" @@ -370,10 +369,11 @@ static const auto cf_examples = R"JIT( return a )JIT"; void testControlFlow() { - auto cu = compile(cf_examples); - + auto cu = std::make_shared(); + script::defineMethodsInModule( + cu, cf_examples, script::nativeResolver, c10::nullopt); auto run = [&](const std::string& name, std::vector stack) { - auto graph = cu->get_function(name).graph(); + auto graph = cu->get_method(name).graph(); Code code(graph); InterpreterState interp(code); interp.run(stack); @@ -576,11 +576,12 @@ void testTopologicalIndex() { } void invokeTestRecordFunction(at::Tensor& t) { - autograd::profiler::GetPackedInputsCallback inputs_cb = [t]() { - Stack st; - pack(st, t); - return st; - }; + autograd::profiler::GetPackedInputsCallback inputs_cb = + [t]() { + Stack st; + pack(st, t); + return st; + }; autograd::profiler::RecordFunction guard("test", inputs_cb); t.add_(torch::ones_like(t)); } @@ -604,15 +605,15 @@ void invokeTestRecordFunctionNested() { void testRecordFunction() { std::vector> input_sizes; - autograd::profiler::pushCallback( - [&input_sizes](const autograd::profiler::RecordFunction& fn) { - for (const auto& input : fn.inputs()) { - if (input.isTensor()) { - std::vector t = input.toTensor().sizes().vec(); - input_sizes.push_back(t); - } - } - }); + autograd::profiler::pushCallback([&input_sizes]( + const autograd::profiler::RecordFunction& fn) { + for (const auto& input : fn.inputs()) { + if (input.isTensor()) { + std::vector t = input.toTensor().sizes().vec(); + input_sizes.push_back(t); + } + } + }); auto t = torch::randn({1, 2, 3}, at::kCPU); invokeTestRecordFunction(t); @@ -624,15 +625,14 @@ void testRecordFunction() { // test nested RecordFunctions std::vector nested_names; - autograd::profiler::pushCallback( - [&nested_names](const autograd::profiler::RecordFunction& fn) { - nested_names.push_back(getFullName(&fn)); - }); + autograd::profiler::pushCallback([&nested_names]( + const autograd::profiler::RecordFunction& fn) { + nested_names.push_back(getFullName(&fn)); + }); { autograd::profiler::RecordFunction guard("outer"); - invokeTestRecordFunctionNested(); - ; + invokeTestRecordFunctionNested();; } autograd::profiler::popCallback(); @@ -709,18 +709,6 @@ void testNoneSchemaMatch() { // checking that constant propagation ran wo/failure AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1); } - -void testModuleDefine() { - auto m = std::make_shared(); - m->register_parameter("foo", torch::ones({}), false); - m->define(R"( - def add_it(self, x, b : int = 4): - return self.foo + x + b - )"); - auto result = m->run_method("add_it", torch::ones({})); - AT_ASSERT(result.toTensor().item() == 6) -} - } // namespace test } // namespace jit } // namespace torch diff --git a/test/expect/TestScript.test_onnx_export_script_inline_params.expect b/test/expect/TestScript.test_onnx_export_script_inline_params.expect index 1cb1092..ffa284a 100644 --- a/test/expect/TestScript.test_onnx_export_script_inline_params.expect +++ b/test/expect/TestScript.test_onnx_export_script_inline_params.expect @@ -5,14 +5,14 @@ ModelProto { graph: GraphProto { name: "torch-jit-export" - inputs: [{name: "x", type:Tensor dims: 2 3},{name: "1", type:Tensor dims: 3 4},{name: "2", type:Tensor dims: 3 3}] + inputs: [{name: "x", type:Tensor dims: 2 3},{name: "1", type:Tensor dims: 3 3},{name: "2", type:Tensor dims: 3 4}] outputs: [{name: "6", type:Tensor dims: 2 4}] - initializers: [TensorProto shape: [3 4],TensorProto shape: [3 3]] + initializers: [TensorProto shape: [3 3],TensorProto shape: [3 4]] nodes: [ Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [1]}]}, - Node {type: "Gemm", inputs: [x,2,3], outputs: [4], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0}]}, + Node {type: "Gemm", inputs: [x,1,3], outputs: [4], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0}]}, Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [1]}]}, - Node {type: "Gemm", inputs: [4,1,5], outputs: [6], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0}]} + Node {type: "Gemm", inputs: [4,2,5], outputs: [6], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0}]} ] } opset_import: [OperatorSetIdProto { domain: }], diff --git a/test/expect/TestScript.test_onnx_export_speculate-f2.expect b/test/expect/TestScript.test_onnx_export_speculate-f2.expect index 29ce206..3126f1d 100644 --- a/test/expect/TestScript.test_onnx_export_speculate-f2.expect +++ b/test/expect/TestScript.test_onnx_export_speculate-f2.expect @@ -5,9 +5,9 @@ ModelProto { graph: GraphProto { name: "torch-jit-export" - inputs: [{name: "x.1", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20},{name: "2", type:Tensor dims: 20 10}] + inputs: [{name: "x.1", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20 10},{name: "2", type:Tensor dims: 20}] outputs: [{name: "8", type:Tensor dims: 1 20}] - initializers: [TensorProto shape: [20],TensorProto shape: [20 10]] + initializers: [TensorProto shape: [20 10],TensorProto shape: [20]] nodes: [ Node {type: "Add", inputs: [x.1,x.1], outputs: [3], attributes: []}, Node {type: "ReduceSum", inputs: [3], outputs: [4], attributes: [{ name: 'keepdims', type: int, value: 0}]}, @@ -28,7 +28,7 @@ ModelProto { outputs: [{name: "10", type:Tensor dims: 1 20}] initializers: [] nodes: [ - Node {type: "Gemm", inputs: [3,2,1], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} + Node {type: "Gemm", inputs: [3,1,2], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} ] } @@ -39,7 +39,7 @@ ModelProto { outputs: [{name: "11", type:Tensor dims: 1 20}] initializers: [] nodes: [ - Node {type: "Gemm", inputs: [3,2,1], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} + Node {type: "Gemm", inputs: [3,1,2], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} ] } @@ -54,7 +54,7 @@ ModelProto { outputs: [{name: "12", type:Tensor dims: 1 20}] initializers: [] nodes: [ - Node {type: "Gemm", inputs: [3,2,1], outputs: [12], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} + Node {type: "Gemm", inputs: [3,1,2], outputs: [12], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} ] } diff --git a/test/test_jit.py b/test/test_jit.py index 18ca724..787dee4 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -7478,7 +7478,7 @@ a") def foo(self, input): self.call_foo(input) - with self.assertRaisesRegex(RuntimeError, 'called recursively'): + with self.assertRaisesRegex(RuntimeError, 'called recursively involving'): M() def test_script_kwargs_fn_call(self): diff --git a/tools/build_variables.py b/tools/build_variables.py index ff27ce3..89a5ed8 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -95,7 +95,6 @@ libtorch_sources = [ "torch/csrc/jit/register_quantized_ops.cpp", "torch/csrc/jit/scope.cpp", "torch/csrc/jit/script/compiler.cpp", - "torch/csrc/api/src/jit.cpp", "torch/csrc/jit/script/edit_distance.cpp", "torch/csrc/jit/script/logging.cpp", "torch/csrc/jit/script/final_returns.cpp", diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 4b2281b..60f883a 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -175,7 +175,6 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/jit/register_quantized_ops.cpp ${TORCH_SRC_DIR}/csrc/jit/scope.cpp ${TORCH_SRC_DIR}/csrc/jit/script/compiler.cpp - ${TORCH_SRC_DIR}/csrc/api/src/jit.cpp ${TORCH_SRC_DIR}/csrc/jit/testing/file_check.cpp ${TORCH_SRC_DIR}/csrc/jit/script/final_returns.cpp ${TORCH_SRC_DIR}/csrc/jit/script/schema_matching.cpp @@ -237,6 +236,7 @@ if (NOT NO_API) ${TORCH_SRC_DIR}/csrc/api/src/data/samplers/random.cpp ${TORCH_SRC_DIR}/csrc/api/src/data/samplers/sequential.cpp ${TORCH_SRC_DIR}/csrc/api/src/data/samplers/stream.cpp + ${TORCH_SRC_DIR}/csrc/api/src/jit.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/init.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/module.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/batchnorm.cpp @@ -528,6 +528,7 @@ if (BUILD_PYTHON) ${TORCH_SRC_DIR}/csrc/jit/python_tracer.cpp ${TORCH_SRC_DIR}/csrc/jit/script/init.cpp ${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp + ${TORCH_SRC_DIR}/csrc/jit/script/module.cpp ${TORCH_SRC_DIR}/csrc/jit/script/python_tree_views.cpp ${TORCH_SRC_DIR}/csrc/multiprocessing/init.cpp ${TORCH_SRC_DIR}/csrc/nn/THNN.cpp diff --git a/torch/csrc/api/include/torch/jit.h b/torch/csrc/api/include/torch/jit.h index 7e2e4c9..9814ead 100644 --- a/torch/csrc/api/include/torch/jit.h +++ b/torch/csrc/api/include/torch/jit.h @@ -32,7 +32,7 @@ namespace jit { /// )JIT"); /// IValue output = module->run_method("relu_script", a, b); /// \endrst -TORCH_API std::shared_ptr compile(const std::string& source); +TORCH_API std::shared_ptr compile(const std::string& source); } // namespace jit } // namespace torch diff --git a/torch/csrc/api/src/jit.cpp b/torch/csrc/api/src/jit.cpp index a66e947..29ea39f 100644 --- a/torch/csrc/api/src/jit.cpp +++ b/torch/csrc/api/src/jit.cpp @@ -9,9 +9,10 @@ namespace torch { namespace jit { -std::shared_ptr compile(const std::string& source) { - auto module = std::make_shared(); - module->define(source, script::nativeResolver, nullptr); +std::shared_ptr compile(const std::string& source) { + auto module = std::make_shared(); + defineMethodsInModule( + module, source, script::nativeResolver, /*self=*/c10::nullopt); return module; } diff --git a/torch/csrc/jit/import_source.cpp b/torch/csrc/jit/import_source.cpp index e27988d..6a74810 100644 --- a/torch/csrc/jit/import_source.cpp +++ b/torch/csrc/jit/import_source.cpp @@ -6,6 +6,39 @@ namespace torch { namespace jit { namespace script { +// this is a much simpler accessor that only handles modules, parameters, and +// and methods. It does not depend on python to work. +struct ModuleAccessorValue : public SugaredValue { + ModuleAccessorValue(std::shared_ptr module) + : module(std::move(module)) {} + std::string kind() const override { + return "module"; + } + // select an attribute on it, e.g. `this.field` + std::shared_ptr attr( + const SourceRange& loc, + Method& m, + const std::string& field) override { + if (std::shared_ptr v = module->find_module(field)) { + return std::make_shared(std::move(v)); + } else if (script::Slot* v = module->find_parameter(field)) { + return std::make_shared(m.get_or_add_parameter(*v)); + } else if (script::Slot* v = module->find_buffer(field)) { + return std::make_shared(m.get_or_add_parameter(*v)); + } else if (script::Slot* v = module->find_attribute(field)) { + return std::make_shared( + m.get_or_add_attribute(*v)); + } else if (Method* m = module->find_method(field)) { + return std::make_shared(shared_from_this(), *m); + } else { + throw ErrorReport(loc) << "unknown attr: " << field; + } + } + + private: + std::shared_ptr module; +}; + struct OpsValue : public SugaredValue { OpsValue(size_t version) : version_(version) {} std::string kind() const override { @@ -13,7 +46,7 @@ struct OpsValue : public SugaredValue { } std::shared_ptr attr( const SourceRange& loc, - Function& m, + Method& m, const std::string& field) override { return std::make_shared(field, version_); } @@ -26,7 +59,7 @@ struct ConstantValue : public SugaredValue { std::string kind() const override { return "constant"; } - Value* asValue(const SourceRange& loc, Function& m) override { + Value* asValue(const SourceRange& loc, Method& m) override { return m.graph()->insertConstant(value_); } }; @@ -42,7 +75,7 @@ struct ConstantTableValue : public SugaredValue { // select an attribute on it, e.g. `this.field` std::shared_ptr attr( const SourceRange& loc, - Function& m, + Method& m, const std::string& field) override { const char* field_s = field.c_str(); char* end; @@ -84,7 +117,7 @@ struct SourceImporter { }; resolver_ = [&](const std::string& name, - Function& m, + Method& m, const SourceRange& loc) -> std::shared_ptr { auto it = env_.find(name); if (it == env_.end()) { @@ -100,7 +133,7 @@ struct SourceImporter { const std::vector& constant_table_; std::unordered_map> env_; std::function(const std::string& name, Function& m, const SourceRange& loc)> + SugaredValue>(const std::string& name, Method& m, const SourceRange& loc)> resolver_; size_t parseVersionNumber() { @@ -134,11 +167,8 @@ void import_methods( definitions.emplace_back(def); resolvers.emplace_back(importer.resolver_); } - auto self = [&](Value* v) { - v->setType(mod->module_object()->type()); - return std::make_shared(v); - }; - mod->module_object()->type()->compilation_unit().define(definitions, resolvers, self); + auto self = std::make_shared(mod); + defineMethodsInModule(mod, definitions, resolvers, Self(self)); } void import_libs( @@ -156,13 +186,9 @@ void import_libs( resolvers.emplace_back(importer.resolver_); } - auto cu = std::make_shared(); - auto class_type = ClassType::create(class_def.name().name(), cu); - auto self = [&](Value* v) { - v->setType(class_type); - return std::make_shared(v); - }; - cu->define(definitions, resolvers, self); + auto mod = std::make_shared(); + Self self(ClassType::create(class_def.name().name(), mod)); + defineMethodsInModule(mod, definitions, resolvers, self); } } diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 42e7267..ce43b5b 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -1074,10 +1074,6 @@ struct Graph { const std::string& field, Value* newValue); TORCH_API Node* createGetAttr(Value* obj, const std::string& field); - TORCH_API Value* insertGetAttr(Value* obj, const std::string& field) { - return insertNode(createGetAttr(obj, field))->output(); - } - // Note: defined in python_ir.cpp and can be used only in python extension Node* createPythonOp( THPObjectPtr&& pyobj, diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index e7d5f94..a8ae6d3 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -267,9 +267,10 @@ struct GraphFuser { norm_invstd = 1 / (eps + torch.sqrt(norm_var)) return ((input - norm_mean) * norm_invstd) )SCRIPT"; - script::CompilationUnit cu; - cu.define(source, script::nativeResolver, nullptr); - *graph_ptr = cu.get_function("batch_norm").graph(); + auto module = std::make_shared(); + defineMethodsInModule( + module, source, script::nativeResolver, /*self=*/c10::nullopt); + *graph_ptr = module->get_method("batch_norm").graph(); }, &bn_graph); diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index f7bfd99..b157cc6 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -1133,16 +1133,6 @@ struct PythonPrintPass { [](const Argument& arg) { return arg.default_value(); }); printFunction(graph, name, is_class, defaults, ivalue_names); } - void printFunction( - script::Function& method, - bool is_class) { - const std::string& name = method.name(); - Graph& graph = *method.graph(); - auto defaults = fmap( - method.getSchema().arguments(), - [](const Argument& arg) { return arg.default_value(); }); - printFunction(graph, name, is_class, defaults, {}); - } void printModule(script::Module& module) { std::unordered_map extra_ivalue_names; createTensorToParameterNameMap( @@ -1163,8 +1153,9 @@ struct PythonPrintPass { out << "class " << classType->name() << ":\n"; { const auto guard = WithIndented(); + std::unordered_map extra_ivalue_names; for (auto& method : classType->methods()) { - printFunction(*method, /*is_class=*/true); + printMethod(*method, /*is_class=*/true, extra_ivalue_names); } } } diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 1a40aa4..7652dc0 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -137,7 +137,6 @@ void ConcretePythonOp::cloneFrom(Node* other_) { this->cconv = other->cconv; Py_INCREF(other->pyobj.get()); this->pyobj = THPObjectPtr(other->pyobj.get()); - this->ignore_on_export = other->ignore_on_export; for (auto& sa : other->scalar_args) { Py_INCREF(sa.get()); this->scalar_args.emplace_back(sa.get()); diff --git a/torch/csrc/jit/script/builtin_functions.cpp b/torch/csrc/jit/script/builtin_functions.cpp index 2ee7730..a1ed46c 100644 --- a/torch/csrc/jit/script/builtin_functions.cpp +++ b/torch/csrc/jit/script/builtin_functions.cpp @@ -37,8 +37,8 @@ def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]: )SCRIPT"); struct BuiltinFunctionRegistry { - const std::vector& getAllBuiltinFunctionsFor(Symbol name) { - const static std::vector empty; + const std::vector& getAllBuiltinFunctionsFor(Symbol name) { + const static std::vector empty; // when initializing the builtin function library, we will re-enter // getAllBuiltinFunctionsFor since it is called in the compiler to // lookup builtins and initializing the builtin functions calls the @@ -62,10 +62,11 @@ struct BuiltinFunctionRegistry { private: void loadSource(const std::string& source) { - std::shared_ptr cu = std::make_shared(); - modules.emplace_back(cu); - cu->define(source, script::nativeResolver, /*self=*/nullptr); - for (auto& method : cu->get_functions()) { + auto module = std::make_shared(); + defineMethodsInModule( + module, source, script::nativeResolver, /*self=*/c10::nullopt); + modules.push_back(module); + for (auto& method : module->get_methods()) { builtins_by_name[Symbol::fromQualString("aten::" + method->name())] .push_back(method.get()); } @@ -96,11 +97,11 @@ struct BuiltinFunctionRegistry { } enum { UNINITIALIZED, INTIIALIZING, INITIALIZED } state = UNINITIALIZED; std::recursive_mutex mutex; - std::vector> modules; - std::unordered_map> builtins_by_name; + std::vector> modules; + std::unordered_map> builtins_by_name; }; -TORCH_API const std::vector& getAllBuiltinFunctionsFor(Symbol name) { +TORCH_API const std::vector& getAllBuiltinFunctionsFor(Symbol name) { static BuiltinFunctionRegistry registry; return registry.getAllBuiltinFunctionsFor(name); } diff --git a/torch/csrc/jit/script/builtin_functions.h b/torch/csrc/jit/script/builtin_functions.h index f1a5f22..42e15e7 100644 --- a/torch/csrc/jit/script/builtin_functions.h +++ b/torch/csrc/jit/script/builtin_functions.h @@ -7,7 +7,7 @@ namespace torch { namespace jit { namespace script { -TORCH_API const std::vector& getAllBuiltinFunctionsFor(Symbol name); +TORCH_API const std::vector& getAllBuiltinFunctionsFor(Symbol name); } } // namespace jit diff --git a/torch/csrc/jit/script/class_type.cpp b/torch/csrc/jit/script/class_type.cpp index 0841e80..f669e03 100644 --- a/torch/csrc/jit/script/class_type.cpp +++ b/torch/csrc/jit/script/class_type.cpp @@ -5,20 +5,13 @@ namespace c10 { // This file exists because we need to reference module.h, which we can't from // c10. Sigh... -Function* ClassType::getMethod(const std::string& name) const { - return compilation_unit_->find_function(name).get(); +Method* ClassType::getMethod(const std::string& name) const { + return module_? module_->find_method(name) : nullptr; } -CompilationUnit& ClassType::compilation_unit() { - return *compilation_unit_; -} -const CompilationUnit& ClassType::compilation_unit() const { - return *compilation_unit_; -} - -std::vector ClassType::methods() const { - std::vector ret; - for (const auto& pr : compilation_unit().get_functions()) { +std::vector ClassType::methods() const { + std::vector ret; + for (const auto& pr : module_->get_methods()) { ret.push_back(pr.get()); } return ret; diff --git a/torch/csrc/jit/script/compilation_unit.h b/torch/csrc/jit/script/compilation_unit.h deleted file mode 100644 index 790d061..0000000 --- a/torch/csrc/jit/script/compilation_unit.h +++ /dev/null @@ -1,285 +0,0 @@ -#pragma once -#include -#include -#include -#include - -#include -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -namespace script { - -struct Def; -struct SugaredValue; -struct Function; - -using Resolver = std::function( - const std::string& name, - Function& f, - const SourceRange& loc)>; -using Self = std::function(Value*)>; - -// A Function is a pure Graph with no implicit `self` object bound. -// It contains schema information, and the executor that manages the -// execution of the function. script::Method is a wrapper around a -// underlying Function that also provides a `self` object. -struct TORCH_API Function { - Function( - std::string name, - bool optimize, - std::shared_ptr graph, - std::function function_creator) - : name_(std::move(name)), - graph_(std::move(graph)), - optimize_(optimize), - function_creator_(std::move(function_creator)) {} - - void run(Stack& stack) { - get_executor().run(stack); - } - - void run(Stack&& stack) { - run(stack); - } - - IValue operator()(std::vector stack) { - getSchema().checkAndNormalizeInputs(stack); - run(stack); - return stack.front(); - } - - std::shared_ptr graph_for(Stack inputs) { - return get_executor().graphFor(inputs); - } - - std::shared_ptr graph() const { - return graph_; - } - - const std::string& name() const { - return name_; - } - - // if this isn't yet defined, run its method_creator function - void ensure_defined(); - - size_t num_inputs() const { - return graph()->inputs().size(); - } - - Function& setSchema(FunctionSchema schema) { - schema_ = make_unique(std::move(schema)); - return *this; - } - - const FunctionSchema& getSchema() const { - if (schema_ == nullptr) { - schema_ = make_unique(defaultSchemaFor(*this)); - } - return *schema_; - } - - std::string pretty_print_schema() const { - AT_ASSERT(schema_); - std::stringstream ss; - ss << *schema_; - return ss.str(); - } - - GraphExecutorState getDebugState() { - return get_executor().getDebugState(); - } - - void debugDisableAutodiffSubgraphInlining() { - return get_executor().debugDisableAutodiffSubgraphInlining(); - } - - bool is_optimized() const { - return optimize_; - } - - void check_single_output() { - AT_CHECK( - graph()->outputs().size() == 1, - "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs"); - } - - GraphExecutor& get_executor() { - std::call_once(executor_init_, [&] { - check_single_output(); - executor_ = GraphExecutor(graph(), optimize_); - }); - return executor_; - } - - // returns nullptr and fills in failure_messages if the callee does not - // match the functions schema - - // TODO: defined in module.cpp, move to compilation_unit.cpp - Value* try_emit_call( - Graph& graph, - const SourceRange& loc, - c10::optional self, - ArrayRef args, - ArrayRef kwargs, - std::stringstream& failure_messages, - bool conv_tensors_to_nums); - - Value* emit_call( - Graph& graph, - const SourceRange& loc, - ArrayRef args, - ArrayRef kwargs); - - private: - static FunctionSchema defaultSchemaFor(const Function& function) { - std::vector args; - std::vector returns; - Graph& g = *function.graph(); - size_t num_inputs = function.num_inputs(); - for (size_t i = 0; i < num_inputs; ++i) { - const Value* v = g.inputs().at(i); - std::string name = v->hasUniqueName() ? v->uniqueNameBase() - : ("argument_" + std::to_string(i)); - args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type())); - } - for (size_t i = 0; i < g.outputs().size(); ++i) { - returns.emplace_back("", unshapedType(g.outputs()[i]->type())); - } - return {function.name(), "", std::move(args), std::move(returns)}; - } - - std::string name_; - std::shared_ptr graph_; // for debugging and for inlining - bool optimize_; - - GraphExecutor executor_; // for execution - - std::once_flag executor_init_; - - // an optional function that actually creates the method when - // emit_call_to(this,...) is first called. this is used by the compiler so - // that it can construct methods out of order - std::function function_creator_; - - // if absent, then we generate a default schema based on the graph - // mutable because getSchema caches the default schema if one is requested - // before a call to setSchema - mutable std::unique_ptr schema_; -}; - - -// A CompilationUnit is a list of named script::Functions -// with helper methods to iterate the list, or invoke the function. -// Classes have a CompilationUnit holding the class methods -// and Modules also have a CompilationUnit holding the Functions that -// are used to implement their Methods - -struct TORCH_API CompilationUnit { - std::shared_ptr find_function(const std::string& name) const { - auto it = dict_.find(name); - if (it == dict_.end()) - return nullptr; - return functions_[it->second]; - } - - Function& get_function(const std::string& name) const { - if (auto r = find_function(name)) - return *r; - AT_ERROR("attempted to get undefined function ", name); - } - - void set_optimized(bool o) { - optimized_ = o; - } - - bool is_optimized() const { - return optimized_; - } - - // for historic reasons, these are defined in compiler.cpp - void define( - const std::vector& definitions, - const std::vector& resolvers, /* determines how we handle free - variables in each definition*/ - // if non-null, the first argument to each def, is bound to this value - const Self& self); - - // same as above but parse the definitions from source - void define( - const std::string& source, - const Resolver& resolver, - const Self& self); - - void clone_function(const Function& remote) { - create_function(remote.name(), remote.graph()->copy()); - } - - Function& create_function(std::string name, std::shared_ptr graph) { - auto fn = std::make_shared( - std::move(name), is_optimized(), std::move(graph), nullptr); - return register_function(std::move(fn)); - } - - const std::vector>& get_functions() const { - return functions_; - } - - /// Run a method from this compilation. - /// - /// For example: - /// @code - /// IValue output = module->run("relu_script", a, b); - /// @endcode - /// - /// To get a compile a module from a source string, see torch::jit::compile - /// - /// @param method_name The name of the method to run - /// @param args Arguments to be passed to the method - /// @return An IValue containing the return value (or values if it is a tuple) - /// from the method - template - IValue run_method(const std::string& method_name, Types&&... args) { - return get_function(method_name)({IValue(std::forward(args))...}); - } - - void drop_all_functions() { - dict_.clear(); - functions_.clear(); - } - - private: - Function& register_function(std::shared_ptr fn) { - AT_CHECK( - 0 == dict_.count(fn->name()), - "method '", - fn->name(), - "' already defined."); - functions_.emplace_back(std::move(fn)); - dict_[functions_.back()->name()] = functions_.size() - 1; - return *functions_.back(); - } - std::vector> functions_; - // for fast lookup - std::unordered_map dict_; - bool optimized_ = true; -}; - -} // namespace script -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 6048ed6..11358e2 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -25,7 +25,7 @@ namespace jit { namespace script { using SugaredValuePtr = std::shared_ptr; -using FunctionTable = std::unordered_map; +using FunctionTable = std::unordered_map; using ValueTable = std::unordered_map; using AttributeMap = std::unordered_map; using ListAttributeMap = std::unordered_map>; @@ -190,7 +190,7 @@ static bool meaningfulName(const std::string& name) { // delete unnecessary ones later with replaceAllusesWith(). struct Environment { Environment( - Function& method, + Method& method, Resolver resolver, Block* b, std::shared_ptr next = nullptr) @@ -199,7 +199,7 @@ struct Environment { b(b), next(std::move(next)) {} - Function& method; + Method& method; Resolver resolver; std::vector captured_inputs; std::unordered_map error_messages; @@ -518,8 +518,8 @@ struct to_ir { to_ir( const Def& def, Resolver resolver_, - const Self& self, - Function& method) // method being constructed + const c10::optional& self, + Method& method) // method being constructed : method(method), graph(method.graph()), resolver(std::move(resolver_)), @@ -541,7 +541,7 @@ struct to_ir { } private: - Function& method; + Method& method; std::shared_ptr graph; Resolver resolver; std::unordered_map integral_constants; @@ -577,7 +577,7 @@ struct to_ir { FunctionSchema emitDef( const Def& def, - const Self& self, + const c10::optional& self, Block* block) { auto schema = extractSchemaFromDef(def, self); // TODO need guards on init returning none @@ -624,16 +624,15 @@ struct to_ir { blank_decl, List::create(r, {ret})); auto m = std::make_shared(); - CompilationUnit cu; - cu.define({def}, {resolver}, nullptr); + defineMethodsInModule(m, {def}, {resolver}, c10::nullopt); Stack stack; - cu.get_function("defaults").run(stack); + m->get_method("defaults").run(stack); return stack.at(0).toTuple()->elements(); } std::vector parseArgsFromDecl( const Decl& decl, - const Self& self) { + const c10::optional& self) { auto params_begin = decl.params().begin(); auto params_end = decl.params().end(); if (self) { @@ -707,7 +706,7 @@ struct to_ir { } FunctionSchema extractSchemaFromDef( const Def& def, - const Self& self) { + const c10::optional& self) { const auto name = def.name().name(); std::vector args = parseArgsFromDecl(def.decl(), self); std::vector returns = parseReturnFromDecl(def.decl()); @@ -717,10 +716,9 @@ struct to_ir { std::vector emitFormalArguments( const Def& def, - const Self& self, + const c10::optional& self, const FunctionSchema& schema, Block* block) { - std::vector arguments; // for schema // inputs auto it = def.decl().params().begin(); @@ -740,9 +738,14 @@ struct to_ir { if (self) { AT_ASSERT(it != end); const auto& name = (*it).ident().name(); - Value* new_input = block->addInput()->setUniqueName(name); - environment_stack->setSugaredVar((*it).ident().range(), name, self(new_input)); - arguments.emplace_back(name, new_input->type()); + if (auto type = self->asFirstClass()) { + Value* new_input = + block->addInput()->setUniqueName(name)->setType(type); + environment_stack->setVar((*it).ident().range(), name, new_input); + arguments.emplace_back(name, type); + } else { + environment_stack->setSugaredVar(def.range(), name, self->asSugared()); + } ++it; } size_t arg_annotation_idx = 0; @@ -828,7 +831,7 @@ struct to_ir { pushFrame(block, /*starts_def=*/true); emitDef( def, - nullptr, + c10::nullopt, block); // ignore schema return, we just wont use it for now since we // never create a Method for the closure popFrame(/*ends_def=*/true); @@ -2260,6 +2263,7 @@ struct to_ir { node_output = fork_node->output()->setType( FutureType::create(fn_simple_output->type())); } + // Lambda lift block(0) into attr::Subgraph lambdaLiftFork(fork_node); @@ -2751,14 +2755,15 @@ struct to_ir { } }; -void CompilationUnit::define( +void defineMethodsInModule( + const std::shared_ptr& m, const std::vector& definitions, const std::vector& resolvers, - const Self& self) { + const c10::optional& self) { AT_ASSERT(definitions.size() == resolvers.size()); auto resolver_it = resolvers.begin(); - std::vector methods; - std::unordered_map function_table; + std::vector methods; + std::unordered_map function_table; for (const Def& def : definitions) { const std::string& name = def.name().name(); auto resolver = *resolver_it++; @@ -2769,34 +2774,37 @@ void CompilationUnit::define( // the function table so the methods can see each other resolver = [resolver, &function_table]( const std::string& name, - Function& m, + Method& m, const SourceRange& loc) -> std::shared_ptr { auto it = function_table.find(name); if (it != function_table.end()) { - return std::make_shared(c10::nullopt, *it->second); + return std::make_shared(nullptr, *it->second); } return resolver(name, m, loc); }; } - auto creator = [def, resolver, self](Function& method) { + auto creator = [def, resolver, self](Method& method) { AT_ASSERT(resolver); to_ir(def, resolver, self, method); }; - std::unique_ptr fn( - new Function(name, is_optimized(), std::make_shared(), creator)); - function_table[name] = fn.get(); - methods.push_back(fn.get()); - register_function(std::move(fn)); + Method& method = m->create_method(name, creator); + function_table[name] = &method; + methods.push_back(&method); } - for (Function* method : methods) { + for (Method* method : methods) { method->ensure_defined(); } + if (!self || !self->asFirstClass()) { + // Disable module hooks if the module is only used to store a class's code. + didFinishEmitModule(m); + } } -void CompilationUnit::define( +void defineMethodsInModule( + const std::shared_ptr& m, const std::string& source, const Resolver& resolver, - const Self& self) { + const c10::optional& self) { Parser p(source); std::vector definitions; std::vector resolvers; @@ -2805,7 +2813,7 @@ void CompilationUnit::define( definitions.push_back(def); resolvers.push_back(resolver); } - define(definitions, resolvers, self); + defineMethodsInModule(m, definitions, resolvers, self); } void lambdaLiftFork(Node* fork_node) { @@ -2830,7 +2838,6 @@ void lambdaLiftFork(Node* fork_node) { fork_node->g_(attr::Subgraph, forked_graph); fork_node->eraseBlock(0); } - } // namespace script } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/script/compiler.h b/torch/csrc/jit/script/compiler.h index 3965467..3c2bb2d 100644 --- a/torch/csrc/jit/script/compiler.h +++ b/torch/csrc/jit/script/compiler.h @@ -13,9 +13,12 @@ namespace torch { namespace jit { namespace script { +using Resolver = std::function(const std::string& name, Method& m, const SourceRange& loc)>; + inline std::shared_ptr nativeResolver( const std::string& name, - Function& m, + Method& m, const SourceRange& loc) { if (name == "torch") { return std::make_shared("aten"); @@ -23,6 +26,47 @@ inline std::shared_ptr nativeResolver( return nullptr; } +// Represents the `self` argument to a method. This wrapper class is necessary +// because sometimes `self` sometimes is first class and sometimes not. +// +// `self` is first class when it refers to a ClassType. It will be bound as a +// graph input argument. +// `self` is sugared when it refers to a ModuleValue. +class Self { + public: + explicit Self(std::shared_ptr sugared) + : sugared_(std::move(sugared)) {} + explicit Self(ClassTypePtr type) : firstClass_(std::move(type)) {} + + ClassTypePtr asFirstClass() const { + return firstClass_; + } + std::shared_ptr asSugared() const { + return sugared_; + } + + private: + // Used when `self` is not first-class and so we don't represent it in the + // graph. This is only ModuleValue. + std::shared_ptr sugared_ = nullptr; + // Used when `self` is a first-class type + ClassTypePtr firstClass_ = nullptr; +}; + +TORCH_API void defineMethodsInModule( + const std::shared_ptr& m, + const std::vector& definitions, + const std::vector& resolvers, /* determines how we handle free + variables in each definition*/ + // if non-null, the first argument to each def, is bound to this value + const c10::optional& self); + +// same as above but parse the definitions from source +TORCH_API void defineMethodsInModule( + const std::shared_ptr& m, + const std::string& source, + const Resolver& resolver, + const c10::optional& self); TORCH_API void lambdaLiftFork(Node* fork_node); diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 1aeadea..bff6d85 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -64,9 +64,10 @@ inline std::shared_ptr toSimple(Value* v) { // type, *add it in this function's implementation*. std::shared_ptr toSugaredValue( py::object obj, - Function& m, + Method& m, SourceRange loc, - bool is_constant = false); + bool is_constant = false, + bool is_submodule = false); struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { PythonValue(py::object self) : self(std::move(self)) {} @@ -124,7 +125,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { // call it like a function, e.g. `outputs = this(inputs)` std::shared_ptr call( const SourceRange& loc, - Function& m, + Method& m, at::ArrayRef inputs_, at::ArrayRef attributes, size_t n_binders) override { @@ -181,7 +182,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { std::vector> asTuple( const SourceRange& loc, - Function& m, + Method& m, const c10::optional& size_hint = {}) override { const std::string type_str = typeString(self); std::stringstream ss; @@ -192,7 +193,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { std::shared_ptr attr( const SourceRange& loc, - Function& m, + Method& m, const std::string& field) override { const std::string type_str = typeString(self); std::stringstream ss; @@ -218,7 +219,7 @@ struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue { std::shared_ptr attr( const SourceRange& loc, - Function& m, + Method& m, const std::string& field) override { py::object member = getattr(loc, field); // note: is_constant = true because we consider that global properties @@ -233,7 +234,7 @@ struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue { : PythonValue(std::move(tup)) {} std::vector> asTuple( const SourceRange& loc, - Function& m, + Method& m, const c10::optional& size_hint = {}) override { py::tuple tup = self; std::vector> result; @@ -245,7 +246,7 @@ struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue { return result; } - Value* asValue(const SourceRange& loc, Function& m) override { + Value* asValue(const SourceRange& loc, Method& m) override { std::vector values; for (const auto& sugared_item : asTuple(loc, m)) { values.push_back(sugared_item->asValue(loc, m)); @@ -257,65 +258,33 @@ struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue { // Represents all the parameters of a module as a List[Tensor] struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { - ConstantParameterList(Value* the_list) : the_list_(the_list) {} - std::string kind() const override { - return "constant parameter list"; - } - std::shared_ptr call( - const SourceRange& loc, - Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, - size_t n_binders) override { - return toSimple(the_list_); - } - - private: - Value* the_list_; -}; - -struct VISIBILITY_HIDDEN OverloadedFunctionValue : public SugaredValue { - OverloadedFunctionValue(Value* module, std::vector method_names) - : module_(module), method_names_(std::move(method_names)) {} + ConstantParameterList(std::shared_ptr module) + : module_(std::move(module)) {} std::string kind() const override { - return "overloaded function"; + return "constant parameter list"; } std::shared_ptr call( const SourceRange& loc, - Function& caller, + Method& caller, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override { - std::stringstream err; - std::vector new_inputs = inputs.vec(); - new_inputs.insert(new_inputs.begin(), module_); - - for (const std::string& method_name : method_names_) { - auto cls = module_->type()->expect(); - Function* fn = cls->getMethod(method_name); - auto match = tryMatchSchema( - fn->getSchema(), - loc, - *caller.graph().get(), - c10::nullopt, - new_inputs, - attributes, - err, - true); - if (match) { - return MethodValue(module_, *fn) - .call(loc, caller, inputs, attributes, n_binders); - } + // Add all module parameters as inputs to the graph + std::vector params; + const auto& param_list = module_->get_parameters(); + for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) { + auto& param = *it; + params.push_back(caller.get_or_add_parameter(param)); } - throw ErrorReport(loc) << "Could not find any matching overloads\n" - << err.str(); + auto list = caller.graph()->createList(TensorType::get(), params); + caller.graph()->insertNode(list); + return toSimple(list->output()); } private: - Value* module_; - std::vector method_names_; + std::shared_ptr module_; }; // defines how modules/methods behave inside the script subset. @@ -326,8 +295,7 @@ struct VISIBILITY_HIDDEN OverloadedFunctionValue : public SugaredValue { // holding the actual nn.Module class. struct ModuleValue : public SugaredValue { - ModuleValue(Value* self, std::shared_ptr module) - : self_(self), module_(std::move(module)) {} + ModuleValue(std::shared_ptr module) : module(std::move(module)) {} std::string kind() const override { return "module"; @@ -336,60 +304,45 @@ struct ModuleValue : public SugaredValue { // select an attribute on it, e.g. `this.field` std::shared_ptr attr( const SourceRange& loc, - Function& m, + Method& m, const std::string& field) override { // workaround to make self.training work // it adds a buffer 'training' to the model if one doesn't exist // and then loads that parameter, casting it to bool if (field == "training") { - Slot* v = module_->find_buffer(field); + Slot* v = module->find_buffer(field); if (!v) { - py::object py_module = py::cast(module_); + py::object py_module = py::cast(module); bool training = py::cast(py::getattr(py_module, "training")); auto t = autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong)); - module_->register_buffer("training", std::move(t)); - v = module_->find_buffer(field); + module->register_buffer("training", std::move(t)); + v = module->find_buffer(field); } - Value* the_tensor = m.graph()->insertGetAttr(self_, "training"); + Value* the_tensor = m.get_or_add_parameter(*v); Value* the_bool = m.graph()->insert(prim::Bool, {the_tensor}); return std::make_shared(the_bool); } - if (std::shared_ptr v = module_->find_module(field)) { - return std::make_shared( - m.graph()->insertGetAttr(self_, field), v); - } else if (auto kind = module_->kind_of(field)) { - // methods, parameters, attributes, and buffers are all first class - return SimpleValue(self_).attr(loc, m, field); + if (std::shared_ptr v = module->find_module(field)) { + return std::make_shared(v); + } else if (Method* v = module->find_method(field)) { + return std::make_shared(shared_from_this(), *v); + } else if (Slot* v = module->find_parameter(field)) { + return std::make_shared(m.get_or_add_parameter(*v)); + } else if (Slot* v = module->find_attribute(field)) { + return std::make_shared( + m.get_or_add_attribute(*v)); } // This can also be a call to a non-script module, or a plain // python method. If so return this as a python value. - py::object py_module = py::cast(module_); - - py::object overloads = - py_module.attr("_overloads").attr("get")(field, py::none()); - if (!overloads.is_none()) { - return std::make_shared( - self_, py::cast>(overloads)); - } - + py::object py_module = py::cast(module); if (py::object attr = py::getattr(py_module, field.c_str(), py::none())) { if (py::isinstance(attr) && py::hasattr(attr, "_is_parameter_list") && py::cast(py::getattr(attr, "_is_parameter_list"))) { - Graph& g = *m.graph(); - // Add all module parameters as inputs to the graph - std::vector params; - const auto& param_list = module_->get_parameters(); - for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) { - auto& param = *it; - params.emplace_back(g.insertGetAttr(self_, param.name())); - } - auto list = - g.insertNode(g.createTuple(params))->output(); - return std::make_shared(list); + return std::make_shared(module); } if (py::isinstance(attr) || py::isinstance(attr, py::module::import("torch.nn").attr("Module")) || @@ -411,7 +364,7 @@ struct ModuleValue : public SugaredValue { // call module.forward std::shared_ptr call( const SourceRange& loc, - Function& caller, + Method& caller, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override { @@ -421,35 +374,28 @@ struct ModuleValue : public SugaredValue { std::vector> asTuple( const SourceRange& loc, - Function& m, + Method& m, const c10::optional& size_hint = {}) override { - py::object py_module = py::cast(module_); + py::object py_module = py::cast(module); if (!py::isinstance( py_module, py::module::import("torch.jit").attr("_ConstModuleList"))) return SugaredValue::asTuple(loc, m, size_hint); std::vector> result; - for (py::handle py_submodule : py_module) { - py::object obj = py::reinterpret_borrow(py_submodule); - if (py::isinstance(obj)) { - auto sub_module = py::cast>(obj); - Value* module_v = m.graph()->insertGetAttr(self_, sub_module->name()); - result.emplace_back( - std::make_shared(module_v, sub_module)); - } else { - result.push_back(toSugaredValue( - obj, - m, - loc, - /*is_constant =*/false)); - } + for (py::handle module : py_module) { + py::object obj = py::reinterpret_borrow(module); + result.push_back(toSugaredValue( + obj, + m, + loc, + /*is_constant =*/false, + /*is_submodule =*/true)); } return result; } private: - Value* self_; - std::shared_ptr module_; + std::shared_ptr module; }; struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue { @@ -462,7 +408,7 @@ struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, - Function& caller, + Method& caller, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override { @@ -500,31 +446,54 @@ struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue { py::dict dispatched_fn_; }; -std::shared_ptr moduleToMethod( - const std::shared_ptr& mod) { - // this path only supports calling raw script functions - // but because they are not distinguished from models, we have to check - // that they are function-like here. They must not have state, and they - // must have a forward method. When we expose functions to python - // this will be replaced with a direct py::isinstance call. - - if (mod->get_parameters().size() != 0) { - throw ErrorReport() - << "Attempted to inline a Module with parameters. " - "Stateful modules to be inlined must be submodules of the callee."; +struct VISIBILITY_HIDDEN OverloadedFunctionValue : public SugaredValue { + OverloadedFunctionValue(py::list functions) + : possible_functions_(std::move(functions)) {} + + std::string kind() const override { + return "overloaded function"; } - Method* forward = mod->find_method("forward"); - if (!forward) { - throw ErrorReport() << " expected this module to have a forward function."; + + std::shared_ptr call( + const SourceRange& loc, + Method& caller, + at::ArrayRef inputs, + at::ArrayRef attributes, + size_t n_binders) override { + std::stringstream err; + auto possible_functions = + py::cast>(possible_functions_); + + for (const py::object& fn : possible_functions) { + auto& method = py::cast(fn); + auto match = tryMatchSchema( + method.getSchema(), + loc, + *caller.graph().get(), + c10::nullopt, + inputs, + attributes, + err, + true); + if (match) { + return MethodValue(nullptr, method) + .call(loc, caller, inputs, attributes, n_binders); + } + } + throw ErrorReport(loc) << "Could not find any matching overloads\n" + << err.str(); } - return std::make_shared(at::nullopt, forward->function()); -} + + private: + py::list possible_functions_; +}; std::shared_ptr toSugaredValue( py::object obj, - Function& m, + Method& m, SourceRange loc, - bool is_constant) { + bool is_constant, + bool is_submodule) { // directly create SimpleValues when possible, because they are first-class // and can be re-assigned. Otherwise, this would be invalid: // f = python_constant @@ -565,12 +534,17 @@ std::shared_ptr toSugaredValue( obj = weak_obj; } if (py::isinstance(obj)) { + auto mod = py::cast>(obj); // In the case that this Python object is not a submodule, inline *ONLY // PURE* ScriptModules. This allows us to call arbitrary @script functions // within a scripting context while still enforcing that parameters from // stateful submodules are properly accounted for. - auto mod = py::cast>(obj); - return moduleToMethod(mod); + if (!is_submodule && mod->get_parameters().size() != 0) { + throw ErrorReport() + << "Attempted to inline a Module with parameters. " + "Stateful modules to be inlined must be submodules of the callee."; + } + return std::make_shared(mod); } else if (py::isinstance(obj)) { return std::make_shared(obj); } else if (obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr()) { @@ -592,7 +566,7 @@ std::shared_ptr toSugaredValue( py::module::import("torch.jit").attr("_try_compile_weak_script")(obj); if (!compiled_fn.is(py::none())) { auto mod = py::cast>(compiled_fn); - return moduleToMethod(mod); + return std::make_shared(mod); } } @@ -602,6 +576,12 @@ std::shared_ptr toSugaredValue( return std::make_shared(std::move(dispatched_fn)); } + py::object overloads = + py::module::import("torch.jit").attr("_try_get_overloaded_fn")(obj); + if (!overloads.is_none()) { + return std::make_shared(std::move(overloads)); + } + return std::make_shared(obj); } @@ -641,7 +621,7 @@ static void gatherParametersAndBuffers( namespace { Resolver pythonResolver(const ResolutionCallback& rcb) { - return [rcb](const std::string& name, Function& m, const SourceRange& loc) + return [rcb](const std::string& name, Method& m, const SourceRange& loc) -> std::shared_ptr { AutoGIL ag; py::object obj = rcb(name); @@ -693,13 +673,6 @@ FunctionSchema getSchemaWithNameAndDefaults( schema.is_varret()); } -static Self moduleSelf(const std::shared_ptr& m) { - return [m](Value* v) { - v->setType(m->module_object()->type()); - return std::make_shared(v, m); - }; -} - void initJitScriptBindings(PyObject* module) { auto m = py::handle(module).cast(); @@ -739,12 +712,9 @@ void initJitScriptBindings(PyObject* module) { bool has_self) { c10::optional self; if (has_self) { - m->class_compilation_unit().define( - script, pythonResolver(rcb), moduleSelf(m)); - } else { - m->_define_lowered(script, pythonResolver(rcb)); + self = Self(std::make_shared(m)); } - didFinishEmitModule(m); + defineMethodsInModule(m, script, pythonResolver(rcb), self); }) .def( "_create_methods", @@ -757,13 +727,14 @@ void initJitScriptBindings(PyObject* module) { for (auto& callback : rcbs) { resolvers.push_back(pythonResolver(callback)); } - m->class_compilation_unit().define(defs, resolvers, moduleSelf(m)); + defineMethodsInModule( + m, defs, resolvers, Self(std::make_shared(m))); + // Stitch in default arguments for each Def if provided auto defaults_it = defaults.begin(); auto defs_it = defs.begin(); while (defs_it != defs.end()) { - auto& method = m->class_compilation_unit().get_function( - (*defs_it).name().name()); + auto& method = m->get_method((*defs_it).name().name()); method.setSchema(getSchemaWithNameAndDefaults( defs_it->range(), method.getSchema(), @@ -813,7 +784,8 @@ void initJitScriptBindings(PyObject* module) { auto& p = parameters[i]; py::tuple r(2); result[i] = std::make_tuple( - p.name(), autograd::as_variable_ref(p.value().toTensor())); + p.name(), + autograd::as_variable_ref(p.value().toTensor())); } return result; }) @@ -869,7 +841,7 @@ void initJitScriptBindings(PyObject* module) { [](Module& self, const std::string& name, std::shared_ptr graph) { - self._define_lowered(name, std::move(graph), {}); + self.create_method(name, std::move(graph), {}); }) .def( "_create_method_from_trace", @@ -893,8 +865,7 @@ void initJitScriptBindings(PyObject* module) { var_lookup_fn, force_outplace, input_tuple.size()); - self->_define_lowered( - name, std::move(graph), std::move(parameters)); + self->create_method(name, std::move(graph), std::move(parameters)); didFinishEmitModule(self); }) .def( @@ -919,7 +890,7 @@ void initJitScriptBindings(PyObject* module) { [](Module& self) { if (self.find_method("forward")) { Method& m = self.get_method("forward"); - return m.get_executor().getDebugState(); + return m.getDebugState(); } throw std::runtime_error( "Attempted to call get_debug_state on a Module without a compiled forward()"); @@ -929,7 +900,7 @@ void initJitScriptBindings(PyObject* module) { [](Module& self) { if (self.find_method("forward")) { Method& m = self.get_method("forward"); - m.get_executor().debugDisableAutodiffSubgraphInlining(); + m.debugDisableAutodiffSubgraphInlining(); } }) .def( @@ -987,8 +958,7 @@ void initJitScriptBindings(PyObject* module) { } Method* orig_method = orig->find_method(name); - m->_define_lowered( - name, orig_method->graph()->copy(), std::move(member_inputs)); + m->create_method(name, orig_method->graph()->copy(), member_inputs); }); py::class_(m, "ScriptMethod", py::dynamic_attr()) @@ -1002,27 +972,10 @@ void initJitScriptBindings(PyObject* module) { method, tuple_slice(std::move(args), 1), std::move(kwargs)); }) .def_property_readonly("graph", [](Method& m) { return m.graph(); }) - .def( - "propagate_shapes", - [](Method& m, const std::vector& inputs, bool with_grad) { - return propagate_shapes( - *m.graph(), inputs, m.initial_ivalues(), with_grad); - }) + .def("propagate_shapes", &Method::propagate_shapes) .def( "propagate_and_assign_input_and_output_shapes", - [](Method& m, - const std::vector& inputs, - std::vector outputs, - bool with_grad, - bool propagate) { - return propagate_and_assign_input_and_output_shapes( - *m.graph(), - inputs, - m.initial_ivalues(), - outputs, - with_grad, - propagate); - }) + &Method::propagate_and_assign_input_and_output_shapes) .def( "initial_ivalues", [](Method& m) { @@ -1042,18 +995,9 @@ void initJitScriptBindings(PyObject* module) { }) .def( "debug_disable_autodiff_subgraph_inlining", - [](Method& m) { - return m.get_executor().debugDisableAutodiffSubgraphInlining(); - }) + &Method::debugDisableAutodiffSubgraphInlining) .def("schema", &Method::getSchema) - .def( - "pretty_print_schema", - [](Method& m) { - const FunctionSchema& schema = m.getSchema(); - std::stringstream ss; - ss << schema; - return ss.str(); - }) + .def("pretty_print_schema", &Method::pretty_print_schema) .def( "python_print", [](Method& m) { @@ -1078,30 +1022,29 @@ void initJitScriptBindings(PyObject* module) { ResolutionCallback rcb, FunctionDefaults defaults) { auto def_f = def.withName("forward"); - - mod->_define_lowered({def_f}, {pythonResolver(rcb)}); - auto& func = mod->lowered_methods().get_function("forward"); - func.setSchema(getSchemaWithNameAndDefaults( - def.range(), func.getSchema(), def.name().name(), defaults)); - auto& func2 = mod->class_compilation_unit().get_function("forward"); - func2.setSchema(getSchemaWithNameAndDefaults( - def.range(), func2.getSchema(), def.name().name(), defaults)); + defineMethodsInModule( + mod, {def_f}, {pythonResolver(rcb)}, c10::nullopt); + auto& method = mod->get_method("forward"); + method.setSchema(getSchemaWithNameAndDefaults( + def.range(), method.getSchema(), def.name().name(), defaults)); didFinishEmitModule(mod); return mod; }); m.def( "_jit_script_class_compile", - [](const ClassDef& classDef, ResolutionCallback rcb) { - auto cu = std::make_shared(); - auto classType = ClassType::create(classDef.name().name(), cu); + [](std::shared_ptr module, + const ClassDef& classDef, + ResolutionCallback rcb) { + auto classType = ClassType::create(classDef.name().name(), module); std::vector rcbs; std::vector methodDefs; for (const auto& def : classDef.defs()) { methodDefs.push_back(def); rcbs.push_back(pythonResolver(rcb)); } - cu->define(methodDefs, rcbs, simpleSelf(classType)); + defineMethodsInModule(module, methodDefs, rcbs, Self(classType)); + return module; }); m.def("parse_type_comment", [](const std::string& comment) { diff --git a/torch/csrc/jit/script/module.cpp b/torch/csrc/jit/script/module.cpp index 7772ab1..33cb951 100644 --- a/torch/csrc/jit/script/module.cpp +++ b/torch/csrc/jit/script/module.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include #include #include @@ -12,39 +11,32 @@ namespace jit { namespace script { struct RecursiveMethodCallError : public std::exception {}; -void placeholderCreator(Function&) { +void placeholderCreator(Method&) { throw RecursiveMethodCallError(); } -void Function::ensure_defined() { - try { - if (function_creator_) { - auto creator = function_creator_; - function_creator_ = placeholderCreator; - creator(*this); - function_creator_ = nullptr; - } - } catch (RecursiveMethodCallError&) { - throw ErrorReport() // TODO: once lower_first_class methods is removed - // re-establish callsite info for debugging - << " method '" << name() << "' is called recursively. " - << "Recursive calls are not supported"; - } -} - -Value* Function::try_emit_call( +Value* try_emit_call_to( Graph& graph, const SourceRange& loc, + Method& callee, c10::optional self, ArrayRef args, ArrayRef kwargs, std::stringstream& failure_messages, + Method* caller, bool conv_tensors_to_nums) { - ensure_defined(); - auto fn = this->graph(); + try { + callee.ensure_defined(); + } catch (RecursiveMethodCallError&) { + throw ErrorReport(loc) + << " method '" << callee.name() + << "' is called recursively involving this call site. " + << "Recursive calls are not supported"; + } + auto fn = callee.graph(); auto matched_schema = tryMatchSchema( - getSchema(), + callee.getSchema(), loc, graph, std::move(self), @@ -55,29 +47,52 @@ Value* Function::try_emit_call( if (!matched_schema) return nullptr; - check_single_output(); - return inlineCallTo(graph, *fn, matched_schema->inputs).at(0); + // parameters to callee method (which become parameters to _this_ method + // if they were not already) + for (const auto& member : callee.initial_ivalues()) { + if (!caller) { + throw ErrorReport(loc) + << " attempting to call a method with parameters/attributes" + " from a raw graph. File a bug report"; + } + matched_schema->inputs.push_back( + caller->get_or_add_attribute(member)); + } + callee.check_single_output(); + return inlineCallTo(graph, *callee.graph(), matched_schema->inputs).at(0); } -Value* Function::emit_call( - Graph& graph, +Value* Method::emit_call_to( const SourceRange& loc, + Method& callee, ArrayRef args, ArrayRef kwargs) { + AT_ASSERT(!executor); std::stringstream failure_messages; - if (auto result = try_emit_call( - graph, + if (auto result = try_emit_call_to( + *graph(), loc, + callee, c10::nullopt, args, kwargs, failure_messages, + this, /*conv_tensors_to_nums=*/true)) { return result; } throw ErrorReport(loc) << failure_messages.str(); } +void Method::ensure_defined() { + if (method_creator) { + auto creator = method_creator; + method_creator = placeholderCreator; + creator(*this); + method_creator = nullptr; + } +} + void Module::to(at::Device device, at::ScalarType dtype, bool non_blocking) { to_impl(device, dtype, non_blocking); } @@ -122,229 +137,6 @@ void Module::to_impl( } } -// lower_first_class_method and lift_lowered_method are transitionary functions -// used to translate between module-as-first-class code generation, -// and module-as-special execution. Once module-as-first-class execution is -// debugged, then we can remove both and remove the lowered_functions_ table. - -// remove the first module argument, replacing any access of its -// parameters/attributes with extra_ivalue input Slots that hold what value to -// pass into the graph -std::pair, std::vector> lower_graph( - const ModulePtr& self, - Graph& g_, - size_t self_offset = 0) { - std::shared_ptr g = g_.copy(); - std::vector extra_ivalues; - std::unordered_map slot_to_offset; - struct ToScan { - ModulePtr mod; - Node* n; - size_t offset; - }; - std::vector to_scan; - std::vector to_clean; // nodes that should be dead at the end - - auto getOrAddSlot = [&](const Slot& slot) -> Value* { - auto it = slot_to_offset.find(slot); - if (it != slot_to_offset.end()) { - size_t ivalues_start = g->inputs().size() - extra_ivalues.size(); - return g->inputs().at(ivalues_start + it->second); - } - extra_ivalues.emplace_back(slot); - slot_to_offset[slot] = extra_ivalues.size() - 1; - return g->addInput()->setType(slot.type()); - }; - - auto self_value = g->inputs().at(self_offset); - - for (Use use : self_value->uses()) { - to_scan.emplace_back(ToScan{self, use.user, use.offset}); - } - while (to_scan.size() > 0) { - auto e = to_scan.back(); - to_scan.pop_back(); - - // when we lambda lift forks, first-class modules may be passed across - // forks. This code recursively lowers the module in the fork call. - if (e.n->kind() == prim::fork) { - auto subgraph = e.n->g(attr::Subgraph); - std::vector new_slots; - std::tie(subgraph, new_slots) = lower_graph(e.mod, *subgraph, e.offset); - e.n->g_(attr::Subgraph, subgraph); - for (const Slot& slot : new_slots) { - e.n->addInput(getOrAddSlot(slot)); - } - e.n->removeInput(e.offset); - continue; - } - if (e.n->kind() != prim::GetAttr) { - throw ErrorReport(e.n->getSourceLocation()) - << "temporary: the only valid use of a module is looking up an attribute"; - } - Slot slot(e.mod, e.mod->type()->getAttributeSlot(e.n->s(attr::name))); - if (ClassTypePtr c = e.n->output()->type()->cast()) { - if (c->name() == "Module") { - auto obj = slot.value().toObject(); - for (Use use : e.n->output()->uses()) { - to_scan.emplace_back(ToScan{obj, use.user, use.offset}); - } - to_clean.emplace_back(e.n); - continue; - } - } - e.n->output()->replaceAllUsesWith(getOrAddSlot(slot)); - e.n->destroy(); - } - - while (to_clean.size() > 0) { - Node* n = to_clean.back(); - AT_ASSERT(!n->hasUses()); - n->destroy(); - to_clean.pop_back(); - } - AT_ASSERT(!self_value->hasUses()); - g->eraseInput(self_offset); - - return std::make_pair(std::move(g), std::move(extra_ivalues)); -} - -Method& Module::lower_first_class_method(Function* fn) { - fn->ensure_defined(); - auto lowered = lower_graph(module_object(), *fn->graph()); - Function& new_func = - lowered_methods_.create_function(fn->name(), lowered.first); - - // generate the new schema - // slice away the self argument - std::vector args( - fn->getSchema().arguments().begin() + 1, - fn->getSchema().arguments().end()); - size_t id = 0; - for (const Slot& slot : lowered.second) { - std::ostringstream ss; - ss << "slot" << id++; - args.emplace_back(ss.str(), slot.type()); - } - new_func.setSchema(fn->getSchema().cloneWithArguments(std::move(args))); - return _create_lowered_method(&new_func, std::move(lowered.second)); -} - -static void createFirstClassValues( - Module* module, - Value* self, - std::unordered_map& result) { - auto& g = *self->owningGraph(); - - std::vector created; - struct ToScan { - Module* mod; - Value* v; // value representing module in the graph - }; - std::vector to_scan = {{module, self}}; - - while (!to_scan.empty()) { - auto s = to_scan.back(); - to_scan.pop_back(); - size_t offset = 0; - for (const std::string& name : - s.mod->module_object()->type()->attributeNames()) { - Value* v = g.insertGetAttr(s.v, name); - result[Slot(s.mod->module_object(), offset++)] = v; - if (std::shared_ptr sub = s.mod->find_module(name)) { - to_scan.emplace_back(ToScan{sub.get(), v}); - } - } - } -} - -void Module::lift_lowered_method(Method& m) { - auto graph = m.graph()->copy(); - Value* self = graph->insertInput(0, "self")->setType(module_object()->type()); - std::unordered_map slot_to_value; - if (!m.initial_ivalues().empty()) { - WithInsertPoint guard(*graph->nodes().begin()); - createFirstClassValues(this, self, slot_to_value); - } - - size_t orig_graph_inputs_size = graph->inputs().size(); - for (size_t i = 0; i < m.initial_ivalues().size(); ++i) { - size_t input_offset = orig_graph_inputs_size - i - 1; - size_t ivalue_offset = m.initial_ivalues().size() - i - 1; - graph->inputs() - .at(input_offset) - ->replaceAllUsesWith( - slot_to_value.at(m.initial_ivalues().at(ivalue_offset))); - graph->eraseInput(input_offset); - } - - if (!m.initial_ivalues().empty()) { - // we added _all_ the submodules as first-class values but maybe did not use - // them. So remove any dead attribute lookups - EliminateDeadCode(graph); - } - - Function& new_fn = class_cu().create_function(m.name(), std::move(graph)); - // created lifted schema - // self argument is named '$self' to prevent accidental name collisions - // with another input that the user named 'self' - std::vector new_args = {Argument("$self", module_object()->type())}; - const auto& lowered_args = m.function().getSchema().arguments(); - new_args.insert( - new_args.end(), - lowered_args.begin(), - lowered_args.begin() + m.num_inputs()); - new_fn.setSchema(m.function().getSchema().cloneWithArguments(std::move(new_args))); -} - -Method& Module::_create_lowered_method( - Function* func, - std::vector member_inputs) { - std::unique_ptr m(new Method(this, func, std::move(member_inputs))); - return *insert(func->name(), methods_, EntityType::METHOD, std::move(m)); -} - -void Module::lift_lowered_methods(size_t start) { - for (size_t i = start; i < lowered_methods_.get_functions().size(); ++i) { - Method& m = _create_lowered_method( - lowered_methods_.get_functions().at(i).get(), {}); - lift_lowered_method(m); - } -} - -void Module::_define_lowered( - const std::vector& definitions, - const std::vector& resolvers) { - size_t start = lowered_methods_.get_functions().size(); - lowered_methods_.define(definitions, resolvers, nullptr); - lift_lowered_methods(start); - // call lift_lowered_method for each definition -} - -void Module::_define_lowered(const std::string& src, const Resolver& resolver) { - size_t start = lowered_methods_.get_functions().size(); - lowered_methods_.define(src, resolver, nullptr); - lift_lowered_methods(start); -} - -Method& Module::_define_lowered( - std::string name, - std::shared_ptr graph, - std::vector slots) { - Method& m = _create_lowered_method( - &lowered_methods_.create_function(std::move(name), std::move(graph)), - std::move(slots)); - lift_lowered_method(m); - return m; -} - -void Module::define(const std::string& src, const Resolver& resolver) { - class_cu().define( - src, - resolver ? resolver : nativeResolver, - simpleSelf(module_object()->type())); -} - } // namespace script } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 000ea9c..13233d9 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -12,7 +12,6 @@ #include #include -#include #include #include @@ -40,7 +39,6 @@ using ::c10::FunctionSchema; // Map which stores filename to content. using ExtraFilesMap = std::unordered_map; -using ModulePtr = c10::intrusive_ptr; // A method in a module, e.g. f in: // // class M(ScriptModule): @@ -55,110 +53,320 @@ struct Module; using ModuleLookup = std::function(const std::vector&)>; -struct TORCH_API Method { - Method(Module* owner, Function* function, std::vector initial_members) +struct Method { + Method( + Module* owner, + std::string name, + bool optimize, + std::shared_ptr graph, + std::vector initial_members, + std::function method_creator) : owner_(owner), - function_(function), - initial_ivalues_(std::move(initial_members)) { - AT_ASSERT(function->num_inputs() >= initial_ivalues_.size()); - } - - // the module that contains this method. - Module& owner() const { - return *owner_; + name_(std::move(name)), + graph_(std::move(graph)), + optimize(optimize), + initial_ivalues_(std::move(initial_members)), + method_creator(std::move(method_creator)) { + AT_ASSERT(graph_->inputs().size() >= initial_ivalues_.size()); + int i = graph_->inputs().size() - initial_ivalues_.size(); + for (auto member : initial_ivalues_) { + initial_ivalue_index[member] = i++; + } } void run(Stack& stack) { for (auto input : initial_ivalues_) { push(stack, input.value()); } - function_->run(stack); + get_executor().run(stack); } + void run(Stack&& stack) { run(stack); } IValue operator()(std::vector stack) { - getSchema().checkAndNormalizeInputs(stack); - for (auto input : initial_ivalues_) { - push(stack, input.value()); - } - // use run rather than operator() to skip the second schema check. - function_->run(std::move(stack)); + checkInputsAgainstSchema(stack); + run(stack); return stack.front(); } + std::shared_ptr graph_for(Stack inputs) { + for (auto tp : initial_ivalues_) { + inputs.emplace_back(tp.value()); + } + return get_executor().graphFor(inputs); + } + TORCH_API std::shared_ptr graph() const { + return graph_; + } + + TORCH_API const std::string& name() const { + return name_; + } + // emit a function call by inlining the callees Graph into this one + // adding any extra parameters necessary to do this call + + // defined here to keep details of member_input handling confined to this + // class + Value* emit_call_to( + const SourceRange& loc, + Method& callee, + ArrayRef args, + ArrayRef kwargs); + + // if this isn't yet defined, run its method_creator function + TORCH_API void ensure_defined(); + + size_t num_inputs() const { + return graph()->inputs().size() - initial_ivalues_.size(); + } + TORCH_API Value* get_or_add_parameter(Slot slot) { + AT_ASSERT(slot.value().isTensor()); + return get_or_add_attribute(slot); + } + TORCH_API Value* get_or_add_attribute(Slot slot) { + auto it = initial_ivalue_index.find(slot); + if (it != initial_ivalue_index.end()) { + return graph()->inputs().at(it->second); + } + initial_ivalues_.push_back(slot); + initial_ivalue_index[slot] = graph()->inputs().size(); + return graph()->addInput()->setType(slot.type()); + } + + static void setInputTensorTypes(Graph& g, const Stack& stack) { + AT_ASSERT(stack.size() == g.inputs().size()); + for (size_t i = 0; i < stack.size(); ++i) { + g.inputs().at(i)->setType( + DimensionedTensorType::create(stack.at(i).toTensor())); + } + } + + std::shared_ptr propagate_shapes( + std::vector inputs, + bool with_grad = false) { + auto retval = graph_->copy(); + Stack stack; + stack.reserve(inputs.size() + initial_ivalues_.size()); + for (at::Tensor& i : inputs) { + stack.emplace_back(std::move(i)); + } + for (const Slot& inp : initial_ivalues_) { + stack.push_back(inp.value()); + } + setInputTensorTypes(*retval, stack); + PropagateInputShapes(retval); + return retval; + } + + std::shared_ptr propagate_and_assign_input_and_output_shapes( + std::vector inputs, + std::vector outputs, + bool with_grad = false, + bool propagate = true) { + auto retval = graph_->copy(); + for (auto inp : initial_ivalues_) { + if (inp.value().isTensor()) { + inputs.push_back(inp.value().toTensor()); + } + } + if (propagate) { + setInputTensorTypes(*retval, fmap(inputs)); + PropagateInputShapes(retval); + } + AT_ASSERT(retval->inputs().size() == inputs.size()); + for (size_t i = 0; i < retval->inputs().size(); ++i) { + auto scalar_type = inputs[i].scalar_type(); + auto sizes = inputs[i].sizes(); + auto type = + torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); + retval->inputs()[i]->setType(type); + } + at::ArrayRef output_values = retval->outputs(); + // patch this to still work if we are returning a tuple of multiple values + if (output_values.at(0)->type()->kind() == TupleType::Kind) { + AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct); + output_values = output_values.at(0)->node()->inputs(); + } + AT_ASSERT(output_values.size() == outputs.size()); + for (size_t i = 0; i < retval->outputs().size(); ++i) { + auto scalar_type = outputs[i].scalar_type(); + auto sizes = outputs[i].sizes(); + auto type = + torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); + output_values[i]->setType(type); + } + return retval; + } + const std::vector& initial_ivalues() const { return initial_ivalues_; } - // proxies for underlying unbound Function - std::shared_ptr graph_for(Stack inputs) { - for (auto tp : initial_ivalues_) { - inputs.emplace_back(tp.value()); + Method& setSchema(FunctionSchema schema_) { + schema = make_unique(std::move(schema_)); + return *this; + } + + TORCH_API const FunctionSchema& getSchema() const { + if (schema == nullptr) { + schema = make_unique(defaultSchemaFor(*this)); } - return function_->get_executor().graphFor(inputs); + return *schema; } - std::shared_ptr graph() const { - return function_->graph(); + std::string pretty_print_schema() const { + AT_ASSERT(schema); + std::stringstream ss; + ss << *schema; + return ss.str(); } - const std::string& name() const { - return function_->name(); + GraphExecutorState getDebugState() { + return get_executor().getDebugState(); } - size_t num_inputs() const { - return function_->num_inputs() - initial_ivalues_.size(); + void debugDisableAutodiffSubgraphInlining() { + return get_executor().debugDisableAutodiffSubgraphInlining(); } - FunctionSchema getSchema() const { - // we are required to slice out the slot inputs from the schema - // we can't cache this because setSchema on the underlying function - // will change the underlying schema - auto sliced = ArrayRef(function_->getSchema().arguments()) - .slice(0, num_inputs()); - return function_->getSchema().cloneWithArguments(sliced.vec()); + bool is_optimized() const { + return optimize; } - GraphExecutor& get_executor() { - return function_->get_executor(); + // the module that contains this method. + Module& owner() const { + return *owner_; } - Function& function() const { - return *function_; + void check_single_output() { + AT_CHECK( + graph()->outputs().size() == 1, + "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs"); } private: + static FunctionSchema defaultSchemaFor(const Method& method) { + std::vector args; + std::vector returns; + Graph& g = *method.graph(); + size_t num_inputs = method.num_inputs(); + for (size_t i = 0; i < num_inputs; ++i) { + const Value* v = g.inputs().at(i); + std::string name = v->hasUniqueName() ? v->uniqueNameBase() + : ("argument_" + std::to_string(i)); + args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type())); + } + for (size_t i = 0; i < g.outputs().size(); ++i) { + returns.emplace_back("", unshapedType(g.outputs()[i]->type())); + } + return {method.name(), "", std::move(args), std::move(returns)}; + } + + GraphExecutor& get_executor() { + std::call_once(executor_init, [&] { + check_single_output(); + executor = GraphExecutor(graph(), optimize); + }); + return executor; + } + + void checkInputsAgainstSchema(std::vector& inputs) { + const auto& schema = getSchema(); + // Do we have more inputs than the schema accepts? + AT_CHECK( + inputs.size() <= schema.arguments().size(), + "Expected at most ", + schema.arguments().size(), + " argument(s) for operator '", + schema.name(), + "', but received ", + inputs.size(), + " argument(s). Declaration: ", + schema); + + for (size_t pos = 0; pos < schema.arguments().size(); ++pos) { + const auto& argument = schema.arguments()[pos]; + if (pos < inputs.size()) { + if (!isSubvalueOf(inputs[pos], argument.type())) { + AT_ERROR( + "Expected value of type ", + *argument.type(), + " for argument '", + argument.name(), + "' in position ", + pos, + ", but instead got value of type ", + attemptToRecoverType(inputs[pos])->str(), + ". Declaration: ", + schema); + } + } else if (argument.default_value()) { + inputs.push_back(*argument.default_value()); + } else { + AT_ERROR( + schema.name(), + "() is missing value for argument '", + argument.name(), + "'. Declaration: ", + schema); + } + } + } + // Methods are uniqued onwed by a single module. This raw pointer allows // looking up the module. Module* owner_; - // Underlying unbound function - Function* function_; - - // parameters and attributes loaded from the Module and appending - // before calling function_ + std::string name_; + std::shared_ptr graph_; // for debugging and for inlining + bool optimize; + + GraphExecutor executor; // for execution + // initial_ivalues are a list of additional arguments appended to graph + // that are inputs that come from the members of the Module or its submodules. + // each is a pointer to a slot in the module that owns this parameter + // parameters and submodules can only be _added_ to script Modules to ensure + // these pointers always stay valid std::vector initial_ivalues_; + + // map from a IValue* in initial_ivalues to the offset it appears at + // in graph. used to accelerate get_or_add_parameter + std::unordered_map initial_ivalue_index; + + // TODO: support that case where we allow _writes_ to parameters from + // compiled functions. + // This requires more sophisticated tracking of ssa values in Graphs so that + // stores to all modules can be lifted to the end of a graph execution. + // It also adds more complexity to adding actual module invocations + // to the executor, so currently it is not done. + // std::vector member_outputs; + + std::once_flag executor_init; + + // an optional function that actually creates the method when + // emit_call_to(this,...) is first called. this is used by the compiler so + // that it can construct methods out of order + std::function method_creator; + + // if absent, then we generate a default schema based on the graph + // mutable because getSchema caches the default schema if one is requested + // before a call to setSchema + mutable std::unique_ptr schema; }; struct Module; -struct TORCH_API Module { +struct Module { TH_DISALLOW_COPY_AND_ASSIGN(Module); Module() : name_("__main__"), module_value_(c10::ivalue::Object::create( - ClassType::createModuleType(std::make_shared()), - 0)) {} + ClassType::createModuleType(), + 0)), + optimize_(true) {} - ~Module() { - // ClassType own the compilation unit of their Functions, but each - // Function has a self argument which owns the ClassType, created a - // referernce cycle. By dropping all the methods of the module's class - // here we break the cycle. - class_cu().drop_all_functions(); - } const std::string& name() const { return name_; } @@ -166,11 +374,11 @@ struct TORCH_API Module { // note this doesn't change the flags of existing methods just ones // added afterward. void set_optimized(bool o) { - class_cu().set_optimized(o); + optimize_ = o; } bool is_optimized() const { - return class_cu().is_optimized(); + return optimize_; } IValue forward(std::vector inputs) { @@ -187,7 +395,7 @@ struct TORCH_API Module { name, attributes_, EntityType::ATTRIBUTE, - appendSlot(name, TensorType::get(), std::move(v))); + appendSlot(name, TensorType::get(),std::move(v))); } void register_parameter( @@ -212,11 +420,7 @@ struct TORCH_API Module { const std::string& name, const TypePtr type, IValue ivalue) { - insert( - name, - attributes_, - EntityType::ATTRIBUTE, - appendSlot(name, type, ivalue)); + insert(name, attributes_, EntityType::ATTRIBUTE, appendSlot(name, type, ivalue)); } void register_module( const std::string& name, @@ -230,10 +434,9 @@ struct TORCH_API Module { // AT_WARN( // "Attempting to assign submodule '", // name, - // "' but it is already a submodule of another ScriptModule '", - // module->parent_->name(), "'", " Modules of this form do not import - // and export correctly. This use is deprecated and may be" " removed - // in a future version."); + // "' but it is already a submodule of another ScriptModule '", module->parent_->name(), "'", + // " Modules of this form do not import and export correctly. This use is deprecated and may be" + // " removed in a future version."); // } module->parent_ = this; module->name_ = name; @@ -241,6 +444,34 @@ struct TORCH_API Module { insert(name, modules_, EntityType::MODULE, std::move(module)); } + Method& create_method( + const std::string& name, + std::shared_ptr graph, + std::vector member_inputs) { + AT_ASSERT(graph); + std::unique_ptr method(new Method( + this, + name, + optimize_, + std::move(graph), + std::move(member_inputs), + nullptr)); + return *insert(name, methods_, EntityType::METHOD, std::move(method)); + } + + Method& create_method( + const std::string& name, + std::function creator) { + std::unique_ptr method(new Method( + this, + name, + optimize_, + std::make_shared(), + {}, + std::move(creator))); + return *insert(name, methods_, EntityType::METHOD, std::move(method)); + } + Slot parameter_slot(const std::string& name) const { return parameters_[get_offset(name, EntityType::PARAMETER)]; } @@ -264,14 +495,7 @@ struct TORCH_API Module { // each module owns its method. The reference returned here // is guarenteed to stay valid until this module has been destroyed Method& get_method(const std::string& name) const { - if (Method* method = find_method(name)) { - return *method; - } - // temporary: force the error message - // once the on-demand creation of Method is removed, this code - // can be removed as well - get_offset(name, EntityType::METHOD); - AT_ERROR("unreachable"); + return *methods_[get_offset(name, EntityType::METHOD)]; } std::shared_ptr get_module(const std::string& name) const { @@ -287,12 +511,7 @@ struct TORCH_API Module { c10::ArrayRef get_attributes() const { return attributes_; } - const std::vector>& get_methods() const { - // force methods_ to be up to date by querying all - // methods. This will go away when lowered_methods_ is deleted - for (const auto& m : class_cu().get_functions()) { - get_method(m->name()); - } + c10::ArrayRef> get_methods() const { return methods_; } @@ -315,22 +534,9 @@ struct TORCH_API Module { auto offset = find_offset(name, EntityType::MODULE); return offset ? modules_[*offset] : nullptr; } - Method* find_method(const std::string& name) const { + Method* find_method(const std::string& name) { auto offset = find_offset(name, EntityType::METHOD); - if (offset) { - return methods_[*offset].get(); - } - - if (Function* fn = class_cu().find_function(name).get()) { - // temporary lock because technically this is marked const, - // but we have to update the internal Method cache. - // This can be removed when class_cu() is the source of truth for - // methods. - std::lock_guard guard(find_method_guard_); - return &const_cast(this)->lower_first_class_method(fn); - } - - return nullptr; + return offset ? methods_[*offset].get() : nullptr; } void apply(std::function fn) { for (auto& submod : get_modules()) { @@ -365,7 +571,10 @@ struct TORCH_API Module { /// destination is on the GPU or vice versa, the copy is performed /// asynchronously with respect to the host. Otherwise, the argument has no /// effect. - void to(at::Device device, at::ScalarType dtype, bool non_blocking = false); + TORCH_API void to( + at::Device device, + at::ScalarType dtype, + bool non_blocking = false); /// Recursively casts all parameters to the given dtype. /// @@ -373,7 +582,7 @@ struct TORCH_API Module { /// destination is on the GPU or vice versa, the copy is performed /// asynchronously with respect to the host. Otherwise, the argument has no /// effect. - void to(at::ScalarType dtype, bool non_blocking = false); + TORCH_API void to(at::ScalarType dtype, bool non_blocking = false); /// Recursively moves all parameters to the given device. /// @@ -381,7 +590,7 @@ struct TORCH_API Module { /// destination is on the GPU or vice versa, the copy is performed /// asynchronously with respect to the host. Otherwise, the argument has no /// effect. - void to(at::Device device, bool non_blocking = false); + TORCH_API void to(at::Device device, bool non_blocking = false); /// Run a method from this module. /// @@ -437,57 +646,26 @@ struct TORCH_API Module { mod->copy_into(module_lookup, parameter_remap, names); names.pop_back(); } - - for (auto& fn : class_cu().get_functions()) { - curr->class_cu().clone_function(*fn); + for (auto& method : get_methods()) { + std::vector initial_ivalues; + for (auto& p : method->initial_ivalues()) { + initial_ivalues.push_back(parameter_remap.at(p)); + } + curr->create_method( + method->name(), method->graph()->copy(), initial_ivalues); } } enum class EntityType { MODULE, PARAMETER, ATTRIBUTE, METHOD }; at::optional kind_of(const std::string& name) const { - // force lazy creation of Method if needed - // remove once lowered_methods_ is removed. - find_method(name); - auto it = dict_.find(name); if (it == dict_.end()) return at::nullopt; return it->second.type; } - ModulePtr module_object() const { - return module_value_; - } - CompilationUnit& class_compilation_unit() { - return module_object()->type()->compilation_unit(); - } - CompilationUnit& lowered_methods() const { - return lowered_methods_; - } - - // so that C++ users can easily add methods - void define(const std::string& src, const Resolver& resolver = nullptr); - - void _define_lowered( - const std::vector& definitions, - const std::vector& resolvers); - void _define_lowered(const std::string& src, const Resolver& resolver); - - Method& _define_lowered( - std::string name, - std::shared_ptr graph, - std::vector slots); - private: - Method& _create_lowered_method( - Function* func, - std::vector member_inputs); - - Method& lower_first_class_method(Function* fn); - void lift_lowered_method(Method& fn); - void lift_lowered_methods(size_t start); - void to_impl( const c10::optional& device, const c10::optional& dtype, @@ -575,13 +753,6 @@ struct TORCH_API Module { return Slot(module_value_, slot_index); } - CompilationUnit& class_cu() { - return module_value_->type()->compilation_unit(); - } - const CompilationUnit& class_cu() const { - return module_value_->type()->compilation_unit(); - } - // modules have a single namespace, but spread over 4 different concepts: // parameters, attributes, methods, and sub-modules // we store individual lists of each concept, and a single map to @@ -599,97 +770,29 @@ struct TORCH_API Module { std::unordered_map dict_; std::string name_; - ModulePtr module_value_; - - // back reference to parent of this Module if present - Module* parent_ = nullptr; - // Currently we are in a transitionary state - // where we construct such first class functions but we lower them - // to a form where the modules does not exist before execution. + c10::intrusive_ptr module_value_; - // So each Method is actually stored twice once in first-class Module - // form and once in lowered form. - // first-class: module_value_->type().compilation_unit() holds Functions that - // treat modules as first class. - - // lowered: In this lowered form, all the attributes/parameters are appended - // as additional inputs. lowered_methods_ holds this lowered form - // mutable because it is a cache for class_cu() methods - mutable CompilationUnit lowered_methods_; - mutable std::recursive_mutex find_method_guard_; + // back reference to parent of this Module if present + Module* parent_ = nullptr; + bool optimize_; }; -static void setInputTensorTypes(Graph& g, const Stack& stack) { - AT_ASSERT(stack.size() == g.inputs().size()); - for (size_t i = 0; i < stack.size(); ++i) { - g.inputs().at(i)->setType( - DimensionedTensorType::create(stack.at(i).toTensor())); - } -} - -inline std::shared_ptr propagate_shapes( - Graph& graph, - const std::vector& inputs, - const std::vector& initial_ivalues, - bool with_grad = false) { - auto retval = graph.copy(); - Stack stack; - stack.reserve(inputs.size() + initial_ivalues.size()); - for (const at::Tensor& i : inputs) { - stack.emplace_back(std::move(i)); - } - for (const Slot& inp : initial_ivalues) { - stack.push_back(inp.value()); - } - setInputTensorTypes(*retval, stack); - PropagateInputShapes(retval); - return retval; -} - -inline std::shared_ptr propagate_and_assign_input_and_output_shapes( +// returns nullptr and fills in failure_messages if the callee does not +// match the functions schema +Value* try_emit_call_to( Graph& graph, - std::vector inputs, - const std::vector& initial_ivalues, - std::vector outputs, - bool with_grad = false, - bool propagate = true) { - auto retval = graph.copy(); - for (auto inp : initial_ivalues) { - if (inp.value().isTensor()) { - inputs.push_back(inp.value().toTensor()); - } - } - if (propagate) { - setInputTensorTypes(*retval, fmap(inputs)); - PropagateInputShapes(retval); - } - AT_ASSERT(retval->inputs().size() == inputs.size()); - for (size_t i = 0; i < retval->inputs().size(); ++i) { - auto scalar_type = inputs[i].scalar_type(); - auto sizes = inputs[i].sizes(); - auto type = - torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); - retval->inputs()[i]->setType(type); - } - at::ArrayRef output_values = retval->outputs(); - // patch this to still work if we are returning a tuple of multiple values - if (output_values.at(0)->type()->kind() == TupleType::Kind) { - AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct); - output_values = output_values.at(0)->node()->inputs(); - } - AT_ASSERT(output_values.size() == outputs.size()); - for (size_t i = 0; i < retval->outputs().size(); ++i) { - auto scalar_type = outputs[i].scalar_type(); - auto sizes = outputs[i].sizes(); - auto type = - torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); - output_values[i]->setType(type); - } - return retval; -} - + const SourceRange& loc, + Method& callee, + c10::optional self, + ArrayRef args, + ArrayRef kwargs, + std::stringstream& failure_messages, + // when callee uses no parameters (e.g. it is a function in a compilation + // unit, and not a method), then nullptr can be passed as caller. + Method* caller, + bool conv_tensors_to_nums); } // namespace script } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/script/schema_matching.cpp b/torch/csrc/jit/script/schema_matching.cpp index 0c5676c..30c273d 100644 --- a/torch/csrc/jit/script/schema_matching.cpp +++ b/torch/csrc/jit/script/schema_matching.cpp @@ -421,14 +421,16 @@ Value* emitBuiltinCall( return emitBuiltinNode(*matched_schema, loc, graph, name); } } - for (Function* method : builtin_functions) { - if (auto result = method->try_emit_call( + for (Method* method : builtin_functions) { + if (auto result = try_emit_call_to( graph, loc, + *method, self, inputs, attributes, failure_messages, + nullptr, allow_conversions)) { return result; } diff --git a/torch/csrc/jit/script/slot.h b/torch/csrc/jit/script/slot.h index 0304c0e..3e01731 100644 --- a/torch/csrc/jit/script/slot.h +++ b/torch/csrc/jit/script/slot.h @@ -33,7 +33,6 @@ private: c10::intrusive_ptr container_; size_t offset_; friend struct std::hash; - friend struct Module; }; }}} diff --git a/torch/csrc/jit/script/sugared_value.cpp b/torch/csrc/jit/script/sugared_value.cpp index 6410282..48c5337 100644 --- a/torch/csrc/jit/script/sugared_value.cpp +++ b/torch/csrc/jit/script/sugared_value.cpp @@ -16,7 +16,7 @@ struct NoneValue : SugaredValue { std::shared_ptr PrintValue::call( const SourceRange& loc, - Function& m, + Method& m, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) { @@ -58,7 +58,7 @@ builtin_cast_methods() { std::shared_ptr BuiltinFunction::call( const SourceRange& loc, - Function& m, + Method& m, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) { @@ -70,7 +70,7 @@ std::shared_ptr BuiltinFunction::call( // callable value that will resolve to foo(x, y, z) when called. std::shared_ptr SimpleValue::attr( const SourceRange& loc, - Function& m, + Method& m, const std::string& field) { // Allow method-style casts on Tensor types. e.g. x.int() if (value_->type()->isSubtypeOf(TensorType::get())) { @@ -116,7 +116,7 @@ std::shared_ptr SimpleValue::attr( if (auto classType = value_->type()->cast()) { // This is a class, emit the proper attribute lookup if (auto method = classType->getMethod(field)) { - return std::make_shared(getValue(), *method); + return std::make_shared(shared_from_this(), *method); } if (!classType->hasAttribute(field)) { @@ -135,7 +135,7 @@ std::shared_ptr SimpleValue::attr( std::vector> SimpleValue::asTuple( const SourceRange& loc, - Function& m, + Method& m, const c10::optional& size_hint) { static const auto make_simple_value = [](Value* v) -> std::shared_ptr { @@ -161,7 +161,7 @@ std::vector> SimpleValue::asTuple( void SimpleValue::setAttr( const SourceRange& loc, - Function& m, + Method& m, const std::string& field, Value* newValue) { const auto classType = value_->type()->cast(); @@ -217,7 +217,7 @@ void SimpleValue::setAttr( std::shared_ptr ClassValue::call( const SourceRange& loc, - Function& m, + Method& m, // note: names for args will be 'argument 0', 'argument 1', etc.. at::ArrayRef inputs, at::ArrayRef attributes, @@ -226,7 +226,8 @@ std::shared_ptr ClassValue::call( // Generate a new object of the right type, then call `__init__` on it auto& g = *m.graph(); - auto self = g.insertNode(g.createObject(type_))->output(); + auto createNode = g.insertNode(g.createObject(type_)); + auto self = std::make_shared(createNode->output()); auto initMethod = type_->getMethod("__init__"); AT_ASSERT(initMethod); @@ -234,12 +235,12 @@ std::shared_ptr ClassValue::call( // Call the init function MethodValue(self, *initMethod).call(loc, m, inputs, attributes, n_binders); - return std::make_shared(self); + return self; } std::shared_ptr ClassValue::attr( const SourceRange& loc, - Function& m, + Method& m, const std::string& field) { if (field != "__new__") { throw ErrorReport(loc) << "Tried to lookup unknown attribute on class"; diff --git a/torch/csrc/jit/script/sugared_value.h b/torch/csrc/jit/script/sugared_value.h index 6f4117d..e1fd725 100644 --- a/torch/csrc/jit/script/sugared_value.h +++ b/torch/csrc/jit/script/sugared_value.h @@ -28,14 +28,14 @@ struct SugaredValue : public std::enable_shared_from_this { // what can we do with this thing? // use it as a value e.g. `this + 4` - virtual Value* asValue(const SourceRange& loc, Function& m) { + virtual Value* asValue(const SourceRange& loc, Method& m) { throw ErrorReport(loc) << kind() << " cannot be used as a value"; } // select an attribute on it, e.g. `this.field` virtual std::shared_ptr attr( const SourceRange& loc, - Function& m, + Method& m, const std::string& field) { throw ErrorReport(loc) << "attribute lookup is not defined on " << kind(); } @@ -43,7 +43,7 @@ struct SugaredValue : public std::enable_shared_from_this { // assign an attribute on it, e.g. `this.field = newValue` virtual void setAttr( const SourceRange& loc, - Function& m, + Method& m, const std::string& field, Value* newValue) { throw ErrorReport(loc) << "attribute assignment is not defined on " @@ -57,7 +57,7 @@ struct SugaredValue : public std::enable_shared_from_this { // a method invocation virtual std::vector> asTuple( const SourceRange& loc, - Function& m, + Method& m, const c10::optional& size_hint = {}) { throw ErrorReport(loc) << kind() << " cannot be used as a tuple"; } @@ -65,7 +65,7 @@ struct SugaredValue : public std::enable_shared_from_this { // call it like a function, e.g. `outputs = this(inputs)` virtual std::shared_ptr call( const SourceRange& loc, - Function& m, + Method& m, // note: names for args will be 'argument 0', 'argument 1', etc.. at::ArrayRef inputs_, at::ArrayRef attributes, @@ -97,7 +97,7 @@ struct TORCH_API SimpleValue : public SugaredValue { std::string kind() const override { return "value"; } - Value* asValue(const SourceRange& range, Function& m) override { + Value* asValue(const SourceRange& range, Method& m) override { return value_; } NoneStatus isNone() override { @@ -110,16 +110,16 @@ struct TORCH_API SimpleValue : public SugaredValue { } std::vector> asTuple( const SourceRange& loc, - Function& m, + Method& m, const c10::optional& size_hint = {}) override; std::shared_ptr attr( const SourceRange& loc, - Function& m, + Method& m, const std::string& field) override; void setAttr( const SourceRange& loc, - Function& m, + Method& m, const std::string& field, Value* newValue) override; @@ -146,7 +146,7 @@ struct TORCH_API BuiltinFunction : public SugaredValue { } std::shared_ptr call( const SourceRange& loc, - Function& m, + Method& m, at::ArrayRef attributes, at::ArrayRef inputs, size_t n_binders) override; @@ -161,7 +161,7 @@ struct TORCH_API BuiltinModule : public SugaredValue { } std::shared_ptr attr( const SourceRange& loc, - Function& m, + Method& m, const std::string& field) override { return std::make_shared( Symbol::fromQualString(name + "::" + field), c10::nullopt); @@ -183,14 +183,14 @@ struct TORCH_API ClassValue : public SugaredValue { // n = Foo(constructor_arg) std::shared_ptr call( const SourceRange& loc, - Function& m, + Method& m, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override; std::shared_ptr attr( const SourceRange& loc, - Function& m, + Method& m, const std::string& field) override; std::string kind() const override { @@ -202,33 +202,34 @@ struct TORCH_API ClassValue : public SugaredValue { // defines how a method obtained from a module behaves in script struct MethodValue : public SugaredValue { - MethodValue(c10::optional self, Function& method) + MethodValue(std::shared_ptr self, Method& method) : self_(std::move(self)), method(method) {} std::string kind() const override { return "method"; } std::shared_ptr call( const SourceRange& loc, - Function& f, + Method& caller, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override { - Graph& graph = *f.graph(); - if (self_) { + if (auto classType = dynamic_cast(self_.get())) { + // If self_ is a class, then it will be expected as part of + // the schema. Add it to the front of the inputs. std::vector inputsWithSelf; - inputsWithSelf.emplace_back(loc, self_->value(graph)); + inputsWithSelf.emplace_back(loc, classType->getValue()); inputsWithSelf.insert(inputsWithSelf.end(), inputs.begin(), inputs.end()); return std::make_shared( - method.emit_call(graph, loc, inputsWithSelf, attributes)); + caller.emit_call_to(loc, method, inputsWithSelf, attributes)); } return std::make_shared( - method.emit_call(graph, loc, inputs, attributes)); + caller.emit_call_to(loc, method, inputs, attributes)); } private: - c10::optional self_; - Function& method; + std::shared_ptr self_; + Method& method; }; struct TORCH_API PrintValue : public SugaredValue { @@ -237,7 +238,7 @@ struct TORCH_API PrintValue : public SugaredValue { } std::shared_ptr call( const SourceRange& loc, - Function& m, + Method& m, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override; @@ -251,7 +252,7 @@ struct TORCH_API CastValue : public BuiltinFunction { : BuiltinFunction(method, c10::nullopt), type_(std::move(type)) {} std::shared_ptr call( const SourceRange& loc, - Function& m, + Method& m, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override { @@ -315,7 +316,7 @@ struct TORCH_API ClassNewMethod : public SugaredValue { std::shared_ptr createObject( const SourceRange& loc, - Function& m, + Method& m, const std::string& classname) { if (classname != type_->name()) { throw ErrorReport(loc) @@ -337,13 +338,6 @@ static inline std::vector toValues( return fmap(nvs, [&](const NamedValue& v) { return v.value(g); }); } -static inline Self simpleSelf(const TypePtr& typ) { - return [typ](Value* v) { - v->setType(typ); - return std::make_shared(v); - }; -} - } // namespace script } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp index 3433a9d..2a99f72 100644 --- a/torch/csrc/jit/symbolic_script.cpp +++ b/torch/csrc/jit/symbolic_script.cpp @@ -1303,8 +1303,8 @@ bool isHelperFunction(const std::string& method_name) { return method_name.compare(0, helper_prefix.length(), helper_prefix) == 0; } -void loadModule(const script::CompilationUnit& module) { - for (const auto& method : module.get_functions()) { +void loadModule(const std::shared_ptr& module) { + for (const auto& method : module->get_methods()) { if (isHelperFunction(method->name())) continue; @@ -1356,8 +1356,9 @@ void loadModule(const script::CompilationUnit& module) { void loadFunctions() { for (const std::string& str : functions) { - script::CompilationUnit cu; - cu.define(str, script::nativeResolver, nullptr); + auto cu = std::make_shared(); + script::defineMethodsInModule( + cu, str, script::nativeResolver, c10::nullopt); loadModule(cu); } } diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index b4d33ced..593f508 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -702,8 +702,14 @@ def _try_get_dispatched_fn(fn): return _jit_internal.boolean_dispatched.get(fn) -def _try_get_overloaded_fn(mod, field): - return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None +def _try_get_overloaded_fn(fn): + if not hasattr(fn, '__self__') or not isinstance(fn.__self__, ScriptModule): + # Only allow overloads for bound methods + return None + overloads = fn.__self__._overloads.get(fn.__name__, None) + if overloads is None: + return None + return [getattr(fn.__self__, overload) for overload in overloads] def _try_compile_weak_script(fn): @@ -732,20 +738,20 @@ def script(obj, optimize=True, _frames_up=0, _rcb=None): return obj if _rcb is None: _rcb = _jit_internal.createResolutionCallback(_frames_up + 1) + mod = ScriptModule() if inspect.isclass(obj): if not _is_new_style_class(obj): raise RuntimeError("TorchScript classes must be new-style classes. Please inherit from 'object'") ast = get_jit_class_def(obj) - _jit_script_class_compile(ast, _rcb) + _jit_script_class_compile(mod, ast, _rcb) _add_script_class(obj, obj.__name__) return obj else: - mod = ScriptModule() ast = get_jit_def(obj) _jit_script_compile(mod, ast, _rcb, get_default_args(obj)) - # Forward docstrings - mod.__doc__ = obj.__doc__ - return mod + # Forward docstrings + mod.__doc__ = obj.__doc__ + return mod ScriptMethodStub = namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))