From: David Riazati Date: Thu, 7 Mar 2019 18:41:13 +0000 (-0800) Subject: Add module attributes (#17309) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~944 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a2381fa346a08dd9be3ef7942efb048d789cbee1;p=platform%2Fupstream%2Fpytorch.git Add module attributes (#17309) Summary: Similar to `nn.Parameter`s, this PR lets you store any `IValue` on a module as an attribute on a `ScriptModule` (only from the Python front-end currently). To mark something as an attribute, it should wrapped in `jit.Attribute(value, type)` (ex. `self.table = torch.jit.Attribute(table, Dict[str, torch.Tensor])`) Followup Work: * (de)serializing for use in C++ * change `self.training` to be a `bool` attribute instead of a buffer * mutable attributes * string frontend support * documentation Pull Request resolved: https://github.com/pytorch/pytorch/pull/17309 Differential Revision: D14354316 Pulled By: driazati fbshipit-source-id: 67e08ab5229366b67fbc837e67b58831a4fb3318 --- diff --git a/test/custom_operator/test_custom_ops.cpp b/test/custom_operator/test_custom_ops.cpp index 3f783e0..4d6f7bf 100644 --- a/test/custom_operator/test_custom_ops.cpp +++ b/test/custom_operator/test_custom_ops.cpp @@ -15,7 +15,7 @@ void check_all_parameters( const torch::jit::script::Module& module, Predicate predicate) { for (const auto& parameter : module.get_parameters()) { - AT_ASSERT(predicate(*parameter->slot())); + AT_ASSERT(predicate(parameter->slot()->toTensor())); } for (const auto& child : module.get_modules()) { check_all_parameters(*child->module, predicate); diff --git a/test/test_jit.py b/test/test_jit.py index 9879eeb..83846e8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -278,8 +278,10 @@ class JitTestCase(TestCase): def emitModuleHook(self, module): def copy_structure_and_params(m): c = torch.jit.ScriptModule() - for name, v, buffer in m._get_parameters(): - c._register_parameter(name, v, buffer) + for name, v in m._get_parameters(): + c._register_parameter(name, v, False) + for name, the_type, v in m._get_attributes(): + c._register_attribute(name, the_type, v) for name, s in m._get_modules(): c._register_module(name, copy_structure_and_params(s)) return c @@ -2015,10 +2017,8 @@ class TestJit(JitTestCase): torch.nn.BatchNorm2d(100), torch.nn.BatchNorm2d(100, affine=False)]: getattr(clazz, mode)() - input = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \ torch.randn(20, 100, 35, 45) - traced = torch.jit.trace(clazz, (input,)) imported = self.getExportImportCopy(traced) x = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \ @@ -10570,6 +10570,24 @@ a") a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2} self.checkScript(fn, (a_dict, ('a', 'c'))) + def test_module_attrs(self): + class M(torch.jit.ScriptModule): + def __init__(self, table): + super(M, self).__init__() + self.table = torch.jit.Attribute(table, Dict[str, torch.Tensor]) + self.x = torch.nn.Parameter(torch.tensor([100.0])) + + @torch.jit.script_method + def forward(self, key): + # type: (str) -> Tensor + return self.table[key] + self.x + + with self.disableModuleHook(): + # TODO: re-enable module hook when Python printing of attributes is + # supported + m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"}) + self.assertEqual(m("c"), torch.tensor([103])) + def test_tensor_import_export(self): @torch.jit.script def foo(x): diff --git a/torch/csrc/api/src/serialize/input-archive.cpp b/torch/csrc/api/src/serialize/input-archive.cpp index 4cd934a..5f5b308 100644 --- a/torch/csrc/api/src/serialize/input-archive.cpp +++ b/torch/csrc/api/src/serialize/input-archive.cpp @@ -23,23 +23,30 @@ void InputArchive::read( const std::string& key, Tensor& tensor, bool is_buffer) { - auto* read_tensor = module_->find_parameter(key); - AT_CHECK(read_tensor != nullptr, "No such serialized tensor '", key, "'"); + auto param = module_->find_parameter(key); + auto buffer = module_->find_buffer(key); + AT_CHECK( + param != nullptr || buffer != nullptr, + "No such serialized tensor '", + key, + "'"); // clang-format off + auto read_param = is_buffer ? buffer : param; + auto read_tensor = read_param->slot()->toTensor(); AT_CHECK( - read_tensor->is_buffer == is_buffer, + bool(buffer) == is_buffer, "Expected deserialized tensor for key '", key, "' to ", is_buffer ? "not " : "", "be a buffer, but it was not"); // clang-format on if (tensor.defined()) { torch::NoGradGuard guard; - if (tensor.device() != read_tensor->slot()->device()) { - tensor.set_data(autograd::Variable(*read_tensor->slot()).data()); + if (tensor.device() != read_tensor.device()) { + tensor.set_data(autograd::Variable(read_tensor).data()); } else { - tensor.set_(*read_tensor->slot()); + tensor.set_(read_tensor); } } else { - tensor = std::move(*read_tensor->slot()); + tensor = std::move(read_tensor); } } diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index 39417ba..9590646 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -502,8 +502,9 @@ class ScriptModuleSerializer final { torch::ModuleDef* module_def); void convertParameter( - const script::NamedParameter& param, - torch::ParameterDef* param_def); + const script::NamedIValue& param, + torch::ParameterDef* param_def, + bool is_parameter); std::ofstream ofs_; caffe2::serialize::PyTorchStreamWriter writer_; @@ -646,7 +647,13 @@ void ScriptModuleSerializer::convertModule( module_def->set_optimize(module.is_optimized()); for (const auto& elem : module.get_parameters()) { torch::ParameterDef* param_def = module_def->add_parameters(); - convertParameter(elem.value(), param_def); + convertParameter(elem.value(), param_def, /*is_buffer=*/false); + } + for (const auto& elem : module.get_attributes()) { + if (elem.value().type->isSubtypeOf(TensorType::get())) { + torch::ParameterDef* param_def = module_def->add_parameters(); + convertParameter(elem.value(), param_def, /*is_buffer=*/true); + } } std::stringstream module_name; @@ -675,11 +682,12 @@ void ScriptModuleSerializer::convertModule( } void ScriptModuleSerializer::convertParameter( - const script::NamedParameter& param, - torch::ParameterDef* param_def) { - param_def->set_name(param.name); - param_def->set_is_buffer(param.is_buffer); - param_def->set_tensor_id(addTensor(*param.slot())); + const script::NamedIValue& param, + torch::ParameterDef* param_def, + bool is_parameter) { + param_def->set_name(param.name_); + param_def->set_is_buffer(is_parameter); + param_def->set_tensor_id(addTensor(param.slot()->toTensor())); } // Pretty printing for ONNX diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp index 1abcc8c..58ef1d1 100644 --- a/torch/csrc/jit/import.cpp +++ b/torch/csrc/jit/import.cpp @@ -44,7 +44,7 @@ class ScriptModuleDeserializer final { ScriptModuleDeserializer(std::istream* is); explicit ScriptModuleDeserializer(std::unique_ptr rai); void deserialize( - ModuleLookup module_lookup, + script::ModuleLookup module_lookup, c10::optional device, script::ExtraFilesMap& extra_files); @@ -60,7 +60,7 @@ class ScriptModuleDeserializer final { caffe2::serialize::PyTorchStreamReader reader_; // this is a hack to make sure the script module created in C++ is the // same as created in Python - ModuleLookup moduleLookup_; + script::ModuleLookup moduleLookup_; c10::optional device_; std::vector moduleStack_; @@ -80,7 +80,7 @@ ScriptModuleDeserializer::ScriptModuleDeserializer( : reader_(std::move(rai)) {} void ScriptModuleDeserializer::deserialize( - ModuleLookup module_lookup, + script::ModuleLookup module_lookup, c10::optional device, script::ExtraFilesMap& extra_files) { torch::ModelDef model_def; @@ -222,7 +222,11 @@ void ScriptModuleDeserializer::convertModule( for (int i = 0; i < module_def.parameters_size(); ++i) { const torch::ParameterDef& param_def = module_def.parameters(i); at::Tensor tensor = tensor_table_.at(param_def.tensor_id()); - module->register_parameter(param_def.name(), tensor, param_def.is_buffer()); + if (param_def.is_buffer()) { + module->register_buffer(param_def.name(), tensor); + } else { + module->register_parameter(param_def.name(), tensor, /*is_buffer=*/false); + } } if (module_def.has_torchscript_arena()) { at::DataPtr data; @@ -237,7 +241,7 @@ void ScriptModuleDeserializer::convertModule( } // namespace void import_ir_module( - ModuleLookup module_lookup, + script::ModuleLookup module_lookup, std::istream& in, c10::optional device, script::ExtraFilesMap& extra_files) { @@ -246,7 +250,7 @@ void import_ir_module( } void import_ir_module( - ModuleLookup module_lookup, + script::ModuleLookup module_lookup, const std::string& filename, c10::optional device, script::ExtraFilesMap& extra_files) { @@ -255,7 +259,7 @@ void import_ir_module( } void import_ir_module( - ModuleLookup module_lookup, + script::ModuleLookup module_lookup, std::unique_ptr rai, c10::optional device, script::ExtraFilesMap& extra_files) { diff --git a/torch/csrc/jit/import.h b/torch/csrc/jit/import.h index 0cac71d..a229c8c 100644 --- a/torch/csrc/jit/import.h +++ b/torch/csrc/jit/import.h @@ -14,25 +14,22 @@ class ReadAdapterInterface; namespace torch { namespace jit { -using ModuleLookup = std::function( - const std::vector&)>; - static script::ExtraFilesMap default_extra_files; TORCH_API void import_ir_module( - ModuleLookup module_lookup, + script::ModuleLookup module_lookup, const std::string& filename, c10::optional device = c10::nullopt, script::ExtraFilesMap& extra_files = default_extra_files); TORCH_API void import_ir_module( - ModuleLookup module_lookup, + script::ModuleLookup module_lookup, std::istream& in, c10::optional device = c10::nullopt, script::ExtraFilesMap& extra_files = default_extra_files); TORCH_API void import_ir_module( - ModuleLookup module_lookup, + script::ModuleLookup module_lookup, std::unique_ptr rai, c10::optional device = c10::nullopt, script::ExtraFilesMap& extra_files = default_extra_files); diff --git a/torch/csrc/jit/import_method.cpp b/torch/csrc/jit/import_method.cpp index 029988b..763f7a7 100644 --- a/torch/csrc/jit/import_method.cpp +++ b/torch/csrc/jit/import_method.cpp @@ -19,7 +19,10 @@ struct ModuleAccessorValue : public script::SugaredValue { const std::string& field) override { if (script::NamedModule* v = module->find_module(field)) { return std::make_shared(v->module); - } else if (script::NamedParameter* v = module->find_parameter(field)) { + } else if (script::NamedIValue* v = module->find_parameter(field)) { + return std::make_shared( + m.get_or_add_parameter(v->slot())); + } else if (script::NamedIValue* v = module->find_buffer(field)) { return std::make_shared( m.get_or_add_parameter(v->slot())); } else if (script::Method* m = module->find_method(field)) { diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index 22061fd..69d1cd0 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -130,10 +130,14 @@ struct QualifiedName : c10::intrusive_ptr_target { void createTensorToParameterNameMap( const script::Module& module, const QualifiedNamePtr& prefix, - std::unordered_map& result) { + std::unordered_map& result) { for (const auto& elem : module.get_parameters()) { - const script::NamedParameter& param = elem.value(); - result[param.slot()] = QualifiedName::create(prefix, param.name); + const script::NamedIValue& param = elem.value(); + result[param.slot()] = QualifiedName::create(prefix, param.name_); + } + for (const auto& elem : module.get_attributes()) { + const script::NamedIValue& param = elem.value(); + result[param.slot()] = QualifiedName::create(prefix, param.name_); } for (const auto& elem : module.get_modules()) { createTensorToParameterNameMap( @@ -1038,27 +1042,27 @@ struct PythonPrintPass { } } void printMethod(script::Method& method) { - std::unordered_map parameter_names; + std::unordered_map parameter_names; createTensorToParameterNameMap( method.owner(), QualifiedName::create("self"), parameter_names); printMethod(method, parameter_names); } void printMethod( script::Method& method, - const std::unordered_map& - parameter_names) { - std::vector param_names = fmap( - method.params(), - [&](at::Tensor* slot) { return parameter_names.at(slot)->str(); }); + const std::unordered_map& + extra_ivalue_names) { + std::vector ivalue_names = fmap( + method.initial_ivalues(), + [&](IValue* slot) { return extra_ivalue_names.at(slot)->str(); }); const std::string& name = method.name(); Graph& graph = *method.graph(); auto defaults = fmap( method.getSchema().arguments(), [](const Argument& arg) { return arg.default_value(); }); - printFunction(graph, name, defaults, param_names); + printFunction(graph, name, defaults, ivalue_names); } void printModule(script::Module& module) { - std::unordered_map parameter_names; + std::unordered_map parameter_names; createTensorToParameterNameMap( module, QualifiedName::create("self"), parameter_names); for (auto& method : module.get_methods()) { diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index f5a08cb..5ff0447 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -323,7 +323,6 @@ inline py::object toPyObject(IValue&& ivalue) { py_dict[toPyObject(IValue{pair.first})] = toPyObject(IValue{pair.second}); } return std::move(py_dict); - } else { AT_ERROR("Missing cases in 'toPyObject'! File a bug report."); } diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 9ddd5db..5425219 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -258,9 +258,7 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { const auto& param_list = module_->get_parameters().items(); for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) { auto& param = *it; - if (!param->is_buffer) { - params.push_back(caller.get_or_add_parameter(param->slot())); - } + params.push_back(caller.get_or_add_parameter(param->slot())); } auto list = caller.graph()->createList(TensorType::get(), params); caller.graph()->insertNode(list); @@ -271,6 +269,7 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { std::shared_ptr module_; }; + // defines how modules/methods behave inside the script subset. // for now this does not have any interaction with python. // in the future, we will add the ability to resolve `self.foo` to python @@ -294,14 +293,14 @@ struct ModuleValue : public SugaredValue { // it adds a buffer 'training' to the model if one doesn't exist // and then loads that parameter, casting it to bool if (field == "training") { - NamedParameter* v = module->find_parameter(field); + NamedIValue* v = module->find_buffer(field); if (!v) { py::object py_module = py::cast(module); bool training = py::cast(py::getattr(py_module, "training")); auto t = autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong)); - module->register_parameter("training", std::move(t), true); - v = module->find_parameter(field); + module->register_buffer("training", std::move(t)); + v = module->find_buffer(field); } Value* the_tensor = m.get_or_add_parameter(v->slot()); Value* the_bool = m.graph()->insert(prim::Bool, {the_tensor}); @@ -312,9 +311,13 @@ struct ModuleValue : public SugaredValue { return std::make_shared(v->module); } else if (Method* v = module->find_method(field)) { return std::make_shared(shared_from_this(), *v); - } else if (NamedParameter* v = module->find_parameter(field)) { + } else if (NamedIValue* v = module->find_parameter(field)) { return std::make_shared(m.get_or_add_parameter(v->slot())); + } else if (NamedIValue* v = module->find_attribute(field)) { + return std::make_shared( + m.get_or_add_attribute(v->type, v->slot())); } + // This can also be a call to a non-script module, or a plain // python method. If so return this as a python value. py::object py_module = py::cast(module); @@ -592,11 +595,16 @@ py::object unpackVariableTensorList(std::vector outputs) { } static void gatherParametersAndBuffers( - std::vector& values, + std::vector& values, const Module& m) { for (auto& param : m.get_parameters()) { values.push_back(param->slot()); } + for (auto& param : m.get_attributes()) { + if (param->type->isSubtypeOf(TensorType::get())) { + values.push_back(param->slot()); + } + } for (const auto& sub : m.get_modules()) { gatherParametersAndBuffers(values, *sub->module); } @@ -732,9 +740,16 @@ void initJitScriptBindings(PyObject* module) { }, py::return_value_policy::reference_internal) .def("_register_parameter", &Module::register_parameter) + .def( + "_register_attribute", + [](Module& self, std::string name, TypePtr type, py::object value) { + self.register_attribute(name, type, toIValue(value, type)); + }) .def("_register_module", &Module::register_module) + .def("_register_buffer", &Module::register_buffer) .def("_set_parameter", &Module::set_parameter) .def("_get_parameter", &Module::get_parameter) + .def("_get_buffer", &Module::get_buffer) .def("_get_module", &Module::get_module) .def( "_get_modules", @@ -754,27 +769,38 @@ void initJitScriptBindings(PyObject* module) { py::tuple result(parameters.size()); for (size_t i = 0; i < parameters.size(); ++i) { auto& p = parameters[i]; + py::tuple r(2); + result[i] = std::make_tuple( + p.key(), + autograd::as_variable_ref(p->slot()->toTensor())); + } + return result; + }) + .def( + "_get_attributes", + [](Module& self) -> py::tuple { + auto& attributes = self.get_attributes(); + py::tuple result(attributes.size()); + for (size_t i = 0; i < attributes.size(); ++i) { + auto& buffer = attributes[i]; py::tuple r(3); + IValue v = *buffer->slot(); result[i] = std::make_tuple( - p.key(), autograd::as_variable_ref(*p->slot()), p->is_buffer); + buffer.key(), + buffer->type, + toPyObject(std::move(v))); } return result; }) .def( "_has_parameter", - [](Module& self, const std::string& name) { - if (auto r = self.find_parameter(name)) { - return !r->is_buffer; - } - return false; + [](Module& self, const std::string& name) -> bool { + return self.find_parameter(name); }) .def( "_has_buffer", - [](Module& self, const std::string& name) { - if (auto r = self.find_parameter(name)) { - return r->is_buffer; - } - return false; + [](Module& self, const std::string& name) -> bool { + return self.find_buffer(name); }) .def( "_has_module", @@ -812,10 +838,10 @@ void initJitScriptBindings(PyObject* module) { bool force_outplace) { // prereq: Module's buffers and parameters are unique // this was ensured in python before calling this function - std::vector parameters; + std::vector parameters; gatherParametersAndBuffers(parameters, *self); Stack inputs = toStack(input_tuple); - for (at::Tensor* param : parameters) { + for (IValue* param : parameters) { inputs.emplace_back(*param); } auto graph = tracer::createGraphByTracing( @@ -904,16 +930,20 @@ void initJitScriptBindings(PyObject* module) { std::vector, std::string>> params, std::shared_ptr orig) { - std::vector member_inputs; + std::vector member_inputs; for (auto& p : params) { - NamedParameter* np = + NamedIValue* np = std::get<0>(p)->find_parameter(std::get<1>(p)); + if (np == nullptr) { + np = std::get<0>(p)->find_buffer(std::get<1>(p)); + } AT_ASSERT(np != nullptr); member_inputs.push_back(np->slot()); } Method* orig_method = orig->find_method(name); - m->create_method(name, orig_method->graph()->copy(), member_inputs); + m->create_method( + name, orig_method->graph()->copy(), member_inputs); }); py::class_(m, "ScriptMethod", py::dynamic_attr()) @@ -931,7 +961,7 @@ void initJitScriptBindings(PyObject* module) { .def( "propagate_and_assign_input_and_output_shapes", &Method::propagate_and_assign_input_and_output_shapes) - .def("params", &Method::params) + .def("initial_ivalues", &Method::initial_ivalues) .def( "graph_for", [](py::args args, py::kwargs kwargs) { diff --git a/torch/csrc/jit/script/module.cpp b/torch/csrc/jit/script/module.cpp index a735481..fe9196e 100644 --- a/torch/csrc/jit/script/module.cpp +++ b/torch/csrc/jit/script/module.cpp @@ -48,10 +48,11 @@ Value* try_emit_call_to( // parameters to callee method (which become parameters to _this_ method // if they were not already) - for (at::Tensor* member : callee.params()) { + for (auto member : callee.initial_ivalues()) { if (!caller) { throw ErrorReport(loc) - << " attempting to call a method with parameters from a raw graph. File a bug report"; + << " attempting to call a method with parameters/attributes" + " from a raw graph. File a bug report"; } matched_schema->inputs.push_back(caller->get_or_add_parameter(member)); } @@ -123,7 +124,7 @@ void Module::to_impl( // Then convert every of our parameters. for (auto& parameter : parameters) { // Need to access the `at::Tensor` as a `Variable` here. - autograd::Variable variable = *parameter->slot(); + autograd::Variable variable = parameter.value().slot()->toTensor(); at::Tensor data = variable.data(); // Use the data's original device or dtype if not supplied here. auto new_data = data.to( diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 592b096..a993687 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -49,30 +49,33 @@ using ExtraFilesMap = std::unordered_map; struct Module; +using ModuleLookup = std::function( + const std::vector&)>; + struct Method { Method( Module* owner, std::string name, bool optimize, std::shared_ptr graph, - std::vector initial_members, + std::vector initial_members, std::function method_creator) : owner_(owner), name_(std::move(name)), graph_(std::move(graph)), optimize(optimize), - member_inputs(std::move(initial_members)), + initial_ivalues_(std::move(initial_members)), method_creator(std::move(method_creator)) { - AT_ASSERT(graph_->inputs().size() >= member_inputs.size()); - int i = graph_->inputs().size() - member_inputs.size(); - for (at::Tensor* member : member_inputs) { - member_input_index[member] = i++; + AT_ASSERT(graph_->inputs().size() >= initial_ivalues_.size()); + int i = graph_->inputs().size() - initial_ivalues_.size(); + for (auto member : initial_ivalues_) { + initial_ivalue_index[member] = i++; } } void run(Stack& stack) { - for (at::Tensor* tp : member_inputs) { - stack.emplace_back(*tp); + for (auto input : initial_ivalues_) { + push(stack, *input); } get_executor().run(stack); } @@ -88,7 +91,7 @@ struct Method { } std::shared_ptr graph_for(Stack inputs) { - for (at::Tensor* tp : member_inputs) { + for (auto tp : initial_ivalues_) { inputs.emplace_back(*tp); } return get_executor().graphFor(inputs); @@ -115,17 +118,21 @@ struct Method { TORCH_API void ensure_defined(); size_t num_inputs() const { - return graph()->inputs().size() - member_inputs.size(); + return graph()->inputs().size() - initial_ivalues_.size(); + } + TORCH_API Value* get_or_add_parameter(IValue* slot) { + AT_ASSERT(slot->isTensor()); + return get_or_add_attribute(TensorType::get(), slot); } - TORCH_API Value* get_or_add_parameter(at::Tensor* slot) { - auto it = member_input_index.find(slot); - if (it != member_input_index.end()) { + + TORCH_API Value* get_or_add_attribute(TypePtr type, IValue* slot) { + auto it = initial_ivalue_index.find(slot); + if (it != initial_ivalue_index.end()) { return graph()->inputs().at(it->second); } - // add it as a new parameter - member_inputs.push_back(slot); - member_input_index[slot] = graph()->inputs().size(); - return graph()->addInput(); + initial_ivalues_.push_back(slot); + initial_ivalue_index[slot] = graph()->inputs().size(); + return graph()->addInput()->setType(type); } std::shared_ptr propagate_shapes( @@ -133,11 +140,11 @@ struct Method { bool with_grad = false) { auto retval = graph_->copy(); Stack stack; - stack.reserve(inputs.size() + member_inputs.size()); + stack.reserve(inputs.size() + initial_ivalues_.size()); for (at::Tensor& i : inputs) { stack.emplace_back(std::move(i)); } - for (at::Tensor* inp : member_inputs) { + for (IValue* inp : initial_ivalues_) { stack.push_back(*inp); } const auto size = stack.size(); @@ -152,8 +159,10 @@ struct Method { bool with_grad = false, bool propagate = true) { auto retval = graph_->copy(); - for (auto inp : member_inputs) { - inputs.push_back(*inp); + for (auto inp : initial_ivalues_) { + if (inp->isTensor()) { + inputs.push_back(inp->toTensor()); + } } if (propagate) { setInputTypes( @@ -186,8 +195,8 @@ struct Method { return retval; } - std::vector params() const { - return member_inputs; + const std::vector& initial_ivalues() const { + return initial_ivalues_; } Method& setSchema(FunctionSchema schema_) { @@ -310,16 +319,16 @@ struct Method { bool optimize; GraphExecutor executor; // for execution - // member_inputs are a list of additional arguments appended to graph that are - // inputs that come from the members of the Module or its submodules. + // initial_ivalues are a list of additional arguments appended to graph + // that are inputs that come from the members of the Module or its submodules. // each is a pointer to a slot in the module that owns this parameter // parameters and submodules can only be _added_ to script Modules to ensure // these pointers always stay valid - std::vector member_inputs; + std::vector initial_ivalues_; - // map from a at::Tensor* in member_inputs to the offset it appears at + // map from a IValue* in initial_ivalues to the offset it appears at // in graph. used to accelerate get_or_add_parameter - std::unordered_map member_input_index; + std::unordered_map initial_ivalue_index; // TODO: support that case where we allow _writes_ to parameters from // compiled functions. @@ -349,24 +358,18 @@ struct NamedModule { std::shared_ptr module; }; -struct NamedParameter { - NamedParameter(std::string name, at::Tensor tensor, bool is_buffer) - : name(std::move(name)), - is_buffer(is_buffer), - parameter(torch::make_unique(std::move(tensor))) {} +struct NamedIValue { + NamedIValue(std::string name, TypePtr type, IValue ivalue) + : name_(name), + type(type), + ivalue(torch::make_unique(std::move(ivalue))) {} - const std::string name; - bool is_buffer; // buffers are part of the module state but - // are not modified by optimizers during SGD - at::Tensor* slot() const { - return parameter.get(); + IValue* slot() const { + return ivalue.get(); } - - private: - // the extra level of indirection allows Methods to safely store pointers - // to the slots where parameters are kept while also allow parameters - // to be reassigned - std::unique_ptr parameter; + const std::string name_; + const TypePtr type; + std::unique_ptr ivalue; }; struct Module { @@ -374,6 +377,7 @@ struct Module { Module() : modules("Module"), parameters("Parameter"), + attributes("Attributes"), methods("Method"), optimize(true) {} @@ -391,16 +395,33 @@ struct Module { return get_method("forward")(std::move(inputs)); } + void register_buffer(const std::string& name, autograd::Variable v) { + if (auto b = attributes.find(name)) { + AT_ASSERT(b->type->isSubtypeOf(TensorType::get())); + *b->slot() = v; + return; + } + attributes.insert(name, NamedIValue(name, TensorType::get(), std::move(v))); + } void register_parameter( const std::string& name, autograd::Variable v, bool is_buffer) { + if (is_buffer) { + register_buffer(name, std::move(v)); + return; + } if (auto p = parameters.find(name)) { *p->slot() = v; - p->is_buffer = is_buffer; return; } - parameters.insert(name, NamedParameter(name, std::move(v), is_buffer)); + parameters.insert(name, NamedIValue(name, TensorType::get(), std::move(v))); + } + void register_attribute( + const std::string& name, + const TypePtr type, + IValue ivalue) { + attributes.insert(name, NamedIValue(name, type, ivalue)); } void register_module( const std::string& name, @@ -411,7 +432,7 @@ struct Module { Method& create_method( const std::string& name, std::shared_ptr graph, - std::vector member_inputs) { + std::vector member_inputs) { AT_ASSERT(graph); std::unique_ptr method(new Method( this, @@ -436,7 +457,7 @@ struct Module { return *methods.insert(name, std::move(method)); } - at::Tensor* parameter_slot(const std::string& name) const { + IValue* parameter_slot(const std::string& name) const { return parameters[name].slot(); } @@ -445,7 +466,10 @@ struct Module { } autograd::Variable get_parameter(const std::string& name) const { - return autograd::as_variable_ref(*parameter_slot(name)); + return autograd::as_variable_ref(parameter_slot(name)->toTensor()); + } + autograd::Variable get_buffer(const std::string& name) const { + return autograd::as_variable_ref(attributes.find(name)->slot()->toTensor()); } // each module owns its method. The reference returned here @@ -461,18 +485,32 @@ struct Module { const torch::OrderedDict& get_modules() const { return modules; } - const torch::OrderedDict& get_parameters() + const torch::OrderedDict& get_parameters() const { return parameters; } + const torch::OrderedDict& get_attributes() + const { + return attributes; + } const torch::OrderedDict>& get_methods() const { return methods; } - NamedParameter* find_parameter(const std::string& name) { + NamedIValue* find_parameter(const std::string& name) { return parameters.find(name); } + NamedIValue* find_attribute(const std::string& name) { + return attributes.find(name); + } + NamedIValue* find_buffer(const std::string& name) { + auto b = attributes.find(name); + if (b && b->type->isSubtypeOf(TensorType::get())) { + return b; + } + return nullptr; + } NamedModule* find_module(const std::string& name) { return modules.find(name); } @@ -493,7 +531,7 @@ struct Module { for (auto& submod : get_modules()) { submod->module->train(on); } - register_parameter("training", torch::tensor(on ? 1 : 0, at::kLong), /*is_buffer=*/true); + register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong)); } /// Calls train(false) to enable "eval" mode. /// Do not override this method, override `train()` instead. @@ -502,8 +540,8 @@ struct Module { } /// True if the module is in training mode. bool is_training() { - if (auto p = find_parameter("training")) { - return p->slot()->item() == 1; + if (auto p = find_buffer("training")) { + return p->slot()->toTensor().item() == 1; } // We are in training mode by default return true; @@ -563,18 +601,28 @@ struct Module { const ExtraFilesMap& extra_files = ExtraFilesMap()); void copy_into( - std::function(std::vector)> - module_lookup, + ModuleLookup module_lookup, // parameter_remap is needed when a parent module uses a parameter of a // submodule - std::unordered_map& parameter_remap, + std::unordered_map& parameter_remap, std::vector names = {}) const { auto curr = module_lookup(names); for (auto& kv : parameters) { curr->register_parameter( - kv.key(), *kv.value().slot(), kv.value().is_buffer); + kv.key(), + kv.value().slot()->toTensor(), + /*is_buffer=*/false); parameter_remap[kv.value().slot()] = curr->parameter_slot(kv.key()); } + for (auto& kv : attributes) { + if (!kv.value().type->isSubtypeOf(TensorType::get())) { + continue; + } + curr->register_buffer( + kv.key(), + kv.value().slot()->toTensor()); + parameter_remap[kv.value().slot()] = curr->find_buffer(kv.key())->slot(); + } for (auto& kv : modules) { names.push_back(kv.key()); // Submodules must be translated first, otherwise parameter_remap entries @@ -583,11 +631,12 @@ struct Module { names.pop_back(); } for (auto& kv : methods) { - std::vector params; - for (auto& p : kv.value()->params()) { - params.push_back(parameter_remap.at(p)); + std::vector initial_ivalues; + for (auto& p : kv.value()->initial_ivalues()) { + initial_ivalues.push_back(parameter_remap.at(p)); } - curr->create_method(kv.key(), kv.value()->graph()->copy(), params); + curr->create_method( + kv.key(), kv.value()->graph()->copy(), initial_ivalues); } } @@ -597,12 +646,13 @@ struct Module { const c10::optional& dtype, bool non_blocking); - // invariant: to ensure member_inputs of Methods stay valid, + // invariant: to ensure initial_ivalues of Methods stay valid, // it is only legal to _add_ new modules and parameters. - // removing them will allow member_inputs to point to invalid parameters + // removing them will allow initial_ivalues to point to invalid parameters // no such restriction exists for methods torch::OrderedDict modules; - torch::OrderedDict parameters; + torch::OrderedDict parameters; + torch::OrderedDict attributes; torch::OrderedDict> methods; bool optimize; }; diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 1983d9a..fd1cfec 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -899,9 +899,7 @@ class OrderedParameterDict(OrderedDictWrapper): super(OrderedParameterDict, self).__init__(module) def items(self): - return [(name, param) for name, param, is_buffer - in self.module._get_parameters() - if not is_buffer] + return [(name, param) for name, param in self.module._get_parameters()] def __setitem__(self, k, v): self.module._register_parameter(k, v, False) @@ -920,12 +918,11 @@ class OrderedBufferDict(OrderedDictWrapper): super(OrderedBufferDict, self).__init__(module) def items(self): - return [(name, param) for name, param, is_buffer - in self.module._get_parameters() - if is_buffer] + return [(name, param) for name, _, param in + self.module._get_attributes() if isinstance(param, torch.Tensor)] def __setitem__(self, k, v): - self.module._register_parameter(k, v, True) + self.module._register_buffer(k, v) def __contains__(self, k): return self.module._has_buffer(k) @@ -933,7 +930,7 @@ class OrderedBufferDict(OrderedDictWrapper): def __getitem__(self, k): if k not in self: raise KeyError(k) - return self.module._get_parameter(k) + return self.module._get_buffer(k) # base types that can be constants # in addition, tuples and lists of these base types are also considered constants @@ -1161,8 +1158,12 @@ if _enabled: if attr == 'training': if self._has_buffer('training'): self.__dict__['training'] = value - self._get_parameter('training').fill_(int(value)) + self._get_buffer('training').fill_(int(value)) return + if isinstance(value, Attribute): + the_type = torch.jit.annotations.ann_to_type(value.type) + self._register_attribute(attr, the_type, value.value) + return return super(ScriptModule, self).__setattr__(attr, value) if hasattr(self, attr): @@ -1552,5 +1553,12 @@ def annotate(the_type, the_value): # noop in python return the_value + +class Attribute(object): + def __init__(self, value, the_type): + self.value = value + self.type = the_type + + if not torch._C._jit_init(): raise RuntimeError("JIT initialization failed") diff --git a/torch/nn/parallel/replicate.py b/torch/nn/parallel/replicate.py index e4c2ae9..cc270a1 100644 --- a/torch/nn/parallel/replicate.py +++ b/torch/nn/parallel/replicate.py @@ -78,7 +78,7 @@ def _copy_scriptmodule_methods(modules, module_copies, module_indices): for method_name in module._method_names(): method = module._get_method(method_name) param_list = [] - for param in method.params(): + for param in method.initial_ivalues(): param_list.append(param_dict[param]) replica._copy_method(method_name, param_list, module) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 0dea267..b7b89c2 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -225,7 +225,7 @@ def _model_to_graph(model, args, f, verbose=False, training=False, graph = method.propagate_and_assign_input_and_output_shapes( args, example_outputs, False, propagate) # Erase number types to bring the graph to a pre-NumberType state - params = method.params() + params = method.initial_ivalues() except AttributeError: # TODO: just trace it raise RuntimeError('\'forward\' method must be a script method')