Add module attributes (#17309)
authorDavid Riazati <davidriazati@fb.com>
Thu, 7 Mar 2019 18:41:13 +0000 (10:41 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 7 Mar 2019 18:44:10 +0000 (10:44 -0800)
Summary:
Similar to `nn.Parameter`s, this PR lets you store any `IValue` on a module as an attribute on a `ScriptModule` (only from the Python front-end currently). To mark something as an attribute, it should wrapped in `jit.Attribute(value, type)` (ex. `self.table = torch.jit.Attribute(table, Dict[str, torch.Tensor])`)

Followup Work:
* (de)serializing for use in C++
* change `self.training` to be a `bool` attribute instead of a buffer
* mutable attributes
* string frontend support
* documentation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17309

Differential Revision: D14354316

Pulled By: driazati

fbshipit-source-id: 67e08ab5229366b67fbc837e67b58831a4fb3318

15 files changed:
test/custom_operator/test_custom_ops.cpp
test/test_jit.py
torch/csrc/api/src/serialize/input-archive.cpp
torch/csrc/jit/export.cpp
torch/csrc/jit/import.cpp
torch/csrc/jit/import.h
torch/csrc/jit/import_method.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/pybind_utils.h
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/module.cpp
torch/csrc/jit/script/module.h
torch/jit/__init__.py
torch/nn/parallel/replicate.py
torch/onnx/utils.py

index 3f783e0..4d6f7bf 100644 (file)
@@ -15,7 +15,7 @@ void check_all_parameters(
     const torch::jit::script::Module& module,
     Predicate predicate) {
   for (const auto& parameter : module.get_parameters()) {
-    AT_ASSERT(predicate(*parameter->slot()));
+    AT_ASSERT(predicate(parameter->slot()->toTensor()));
   }
   for (const auto& child : module.get_modules()) {
     check_all_parameters(*child->module, predicate);
index 9879eeb..83846e8 100644 (file)
@@ -278,8 +278,10 @@ class JitTestCase(TestCase):
     def emitModuleHook(self, module):
         def copy_structure_and_params(m):
             c = torch.jit.ScriptModule()
-            for name, v, buffer in m._get_parameters():
-                c._register_parameter(name, v, buffer)
+            for name, v in m._get_parameters():
+                c._register_parameter(name, v, False)
+            for name, the_type, v in m._get_attributes():
+                c._register_attribute(name, the_type, v)
             for name, s in m._get_modules():
                 c._register_module(name, copy_structure_and_params(s))
             return c
@@ -2015,10 +2017,8 @@ class TestJit(JitTestCase):
                     torch.nn.BatchNorm2d(100),
                     torch.nn.BatchNorm2d(100, affine=False)]:
                 getattr(clazz, mode)()
-
                 input = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
                     torch.randn(20, 100, 35, 45)
-
                 traced = torch.jit.trace(clazz, (input,))
                 imported = self.getExportImportCopy(traced)
                 x = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
@@ -10570,6 +10570,24 @@ a")
         a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2}
         self.checkScript(fn, (a_dict, ('a', 'c')))
 
+    def test_module_attrs(self):
+        class M(torch.jit.ScriptModule):
+            def __init__(self, table):
+                super(M, self).__init__()
+                self.table = torch.jit.Attribute(table, Dict[str, torch.Tensor])
+                self.x = torch.nn.Parameter(torch.tensor([100.0]))
+
+            @torch.jit.script_method
+            def forward(self, key):
+                # type: (str) -> Tensor
+                return self.table[key] + self.x
+
+        with self.disableModuleHook():
+            # TODO: re-enable module hook when Python printing of attributes is
+            # supported
+            m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"})
+            self.assertEqual(m("c"), torch.tensor([103]))
+
     def test_tensor_import_export(self):
         @torch.jit.script
         def foo(x):
index 4cd934a..5f5b308 100644 (file)
@@ -23,23 +23,30 @@ void InputArchive::read(
     const std::string& key,
     Tensor& tensor,
     bool is_buffer) {
-  auto* read_tensor = module_->find_parameter(key);
-  AT_CHECK(read_tensor != nullptr, "No such serialized tensor '", key, "'");
+  auto param = module_->find_parameter(key);
+  auto buffer = module_->find_buffer(key);
+  AT_CHECK(
+      param != nullptr || buffer != nullptr,
+      "No such serialized tensor '",
+      key,
+      "'");
   // clang-format off
+  auto read_param = is_buffer ? buffer : param;
+  auto read_tensor = read_param->slot()->toTensor();
   AT_CHECK(
-      read_tensor->is_buffer == is_buffer,
+      bool(buffer) == is_buffer,
       "Expected deserialized tensor for key '", key,
       "' to ", is_buffer ? "not " : "", "be a buffer, but it was not");
   // clang-format on
   if (tensor.defined()) {
     torch::NoGradGuard guard;
-    if (tensor.device() != read_tensor->slot()->device()) {
-      tensor.set_data(autograd::Variable(*read_tensor->slot()).data());
+    if (tensor.device() != read_tensor.device()) {
+      tensor.set_data(autograd::Variable(read_tensor).data());
     } else {
-      tensor.set_(*read_tensor->slot());
+      tensor.set_(read_tensor);
     }
   } else {
-    tensor = std::move(*read_tensor->slot());
+    tensor = std::move(read_tensor);
   }
 }
 
index 39417ba..9590646 100644 (file)
@@ -502,8 +502,9 @@ class ScriptModuleSerializer final {
       torch::ModuleDef* module_def);
 
   void convertParameter(
-      const script::NamedParameter& param,
-      torch::ParameterDef* param_def);
+      const script::NamedIValue& param,
+      torch::ParameterDef* param_def,
+      bool is_parameter);
 
   std::ofstream ofs_;
   caffe2::serialize::PyTorchStreamWriter writer_;
@@ -646,7 +647,13 @@ void ScriptModuleSerializer::convertModule(
   module_def->set_optimize(module.is_optimized());
   for (const auto& elem : module.get_parameters()) {
     torch::ParameterDef* param_def = module_def->add_parameters();
-    convertParameter(elem.value(), param_def);
+    convertParameter(elem.value(), param_def, /*is_buffer=*/false);
+  }
+  for (const auto& elem : module.get_attributes()) {
+    if (elem.value().type->isSubtypeOf(TensorType::get())) {
+      torch::ParameterDef* param_def = module_def->add_parameters();
+      convertParameter(elem.value(), param_def, /*is_buffer=*/true);
+    }
   }
 
   std::stringstream module_name;
@@ -675,11 +682,12 @@ void ScriptModuleSerializer::convertModule(
 }
 
 void ScriptModuleSerializer::convertParameter(
-    const script::NamedParameter& param,
-    torch::ParameterDef* param_def) {
-  param_def->set_name(param.name);
-  param_def->set_is_buffer(param.is_buffer);
-  param_def->set_tensor_id(addTensor(*param.slot()));
+    const script::NamedIValue& param,
+    torch::ParameterDef* param_def,
+    bool is_parameter) {
+  param_def->set_name(param.name_);
+  param_def->set_is_buffer(is_parameter);
+  param_def->set_tensor_id(addTensor(param.slot()->toTensor()));
 }
 
 // Pretty printing for ONNX
index 1abcc8c..58ef1d1 100644 (file)
@@ -44,7 +44,7 @@ class ScriptModuleDeserializer final {
   ScriptModuleDeserializer(std::istream* is);
   explicit ScriptModuleDeserializer(std::unique_ptr<ReadAdapterInterface> rai);
   void deserialize(
-      ModuleLookup module_lookup,
+      script::ModuleLookup module_lookup,
       c10::optional<at::Device> device,
       script::ExtraFilesMap& extra_files);
 
@@ -60,7 +60,7 @@ class ScriptModuleDeserializer final {
   caffe2::serialize::PyTorchStreamReader reader_;
   // this is a hack to make sure the script module created in C++ is the
   // same as created in Python
-  ModuleLookup moduleLookup_;
+  script::ModuleLookup moduleLookup_;
   c10::optional<at::Device> device_;
   std::vector<std::string> moduleStack_;
 
@@ -80,7 +80,7 @@ ScriptModuleDeserializer::ScriptModuleDeserializer(
     : reader_(std::move(rai)) {}
 
 void ScriptModuleDeserializer::deserialize(
-    ModuleLookup module_lookup,
+    script::ModuleLookup module_lookup,
     c10::optional<at::Device> device,
     script::ExtraFilesMap& extra_files) {
   torch::ModelDef model_def;
@@ -222,7 +222,11 @@ void ScriptModuleDeserializer::convertModule(
   for (int i = 0; i < module_def.parameters_size(); ++i) {
     const torch::ParameterDef& param_def = module_def.parameters(i);
     at::Tensor tensor = tensor_table_.at(param_def.tensor_id());
-    module->register_parameter(param_def.name(), tensor, param_def.is_buffer());
+    if (param_def.is_buffer()) {
+      module->register_buffer(param_def.name(), tensor);
+    } else {
+      module->register_parameter(param_def.name(), tensor, /*is_buffer=*/false);
+    }
   }
   if (module_def.has_torchscript_arena()) {
     at::DataPtr data;
@@ -237,7 +241,7 @@ void ScriptModuleDeserializer::convertModule(
 } // namespace
 
 void import_ir_module(
-    ModuleLookup module_lookup,
+    script::ModuleLookup module_lookup,
     std::istream& in,
     c10::optional<at::Device> device,
     script::ExtraFilesMap& extra_files) {
@@ -246,7 +250,7 @@ void import_ir_module(
 }
 
 void import_ir_module(
-    ModuleLookup module_lookup,
+    script::ModuleLookup module_lookup,
     const std::string& filename,
     c10::optional<at::Device> device,
     script::ExtraFilesMap& extra_files) {
@@ -255,7 +259,7 @@ void import_ir_module(
 }
 
 void import_ir_module(
-    ModuleLookup module_lookup,
+    script::ModuleLookup module_lookup,
     std::unique_ptr<ReadAdapterInterface> rai,
     c10::optional<at::Device> device,
     script::ExtraFilesMap& extra_files) {
index 0cac71d..a229c8c 100644 (file)
@@ -14,25 +14,22 @@ class ReadAdapterInterface;
 namespace torch {
 namespace jit {
 
-using ModuleLookup = std::function<std::shared_ptr<script::Module>(
-    const std::vector<std::string>&)>;
-
 static script::ExtraFilesMap default_extra_files;
 
 TORCH_API void import_ir_module(
-    ModuleLookup module_lookup,
+    script::ModuleLookup module_lookup,
     const std::string& filename,
     c10::optional<c10::Device> device = c10::nullopt,
     script::ExtraFilesMap& extra_files = default_extra_files);
 
 TORCH_API void import_ir_module(
-    ModuleLookup module_lookup,
+    script::ModuleLookup module_lookup,
     std::istream& in,
     c10::optional<c10::Device> device = c10::nullopt,
     script::ExtraFilesMap& extra_files = default_extra_files);
 
 TORCH_API void import_ir_module(
-    ModuleLookup module_lookup,
+    script::ModuleLookup module_lookup,
     std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
     c10::optional<c10::Device> device = c10::nullopt,
     script::ExtraFilesMap& extra_files = default_extra_files);
index 029988b..763f7a7 100644 (file)
@@ -19,7 +19,10 @@ struct ModuleAccessorValue : public script::SugaredValue {
       const std::string& field) override {
     if (script::NamedModule* v = module->find_module(field)) {
       return std::make_shared<ModuleAccessorValue>(v->module);
-    } else if (script::NamedParameter* v = module->find_parameter(field)) {
+    } else if (script::NamedIValue* v = module->find_parameter(field)) {
+      return std::make_shared<script::SimpleValue>(
+          m.get_or_add_parameter(v->slot()));
+    } else if (script::NamedIValue* v = module->find_buffer(field)) {
       return std::make_shared<script::SimpleValue>(
           m.get_or_add_parameter(v->slot()));
     } else if (script::Method* m = module->find_method(field)) {
index 22061fd..69d1cd0 100644 (file)
@@ -130,10 +130,14 @@ struct QualifiedName : c10::intrusive_ptr_target {
 void createTensorToParameterNameMap(
     const script::Module& module,
     const QualifiedNamePtr& prefix,
-    std::unordered_map<at::Tensor*, QualifiedNamePtr>& result) {
+    std::unordered_map<IValue*, QualifiedNamePtr>& result) {
   for (const auto& elem : module.get_parameters()) {
-    const script::NamedParameter& param = elem.value();
-    result[param.slot()] = QualifiedName::create(prefix, param.name);
+    const script::NamedIValue& param = elem.value();
+    result[param.slot()] = QualifiedName::create(prefix, param.name_);
+  }
+  for (const auto& elem : module.get_attributes()) {
+    const script::NamedIValue& param = elem.value();
+    result[param.slot()] = QualifiedName::create(prefix, param.name_);
   }
   for (const auto& elem : module.get_modules()) {
     createTensorToParameterNameMap(
@@ -1038,27 +1042,27 @@ struct PythonPrintPass {
     }
   }
   void printMethod(script::Method& method) {
-    std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;
+    std::unordered_map<IValue*, QualifiedNamePtr> parameter_names;
     createTensorToParameterNameMap(
         method.owner(), QualifiedName::create("self"), parameter_names);
     printMethod(method, parameter_names);
   }
   void printMethod(
       script::Method& method,
-      const std::unordered_map<at::Tensor*, QualifiedNamePtr>&
-          parameter_names) {
-    std::vector<std::string> param_names = fmap(
-        method.params(),
-        [&](at::Tensor* slot) { return parameter_names.at(slot)->str(); });
+      const std::unordered_map<IValue*, QualifiedNamePtr>&
+          extra_ivalue_names) {
+    std::vector<std::string> ivalue_names = fmap(
+        method.initial_ivalues(),
+        [&](IValue* slot) { return extra_ivalue_names.at(slot)->str(); });
     const std::string& name = method.name();
     Graph& graph = *method.graph();
     auto defaults = fmap(
         method.getSchema().arguments(),
         [](const Argument& arg) { return arg.default_value(); });
-    printFunction(graph, name, defaults, param_names);
+    printFunction(graph, name, defaults, ivalue_names);
   }
   void printModule(script::Module& module) {
-    std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;
+    std::unordered_map<IValue*, QualifiedNamePtr> parameter_names;
     createTensorToParameterNameMap(
         module, QualifiedName::create("self"), parameter_names);
     for (auto& method : module.get_methods()) {
index f5a08cb..5ff0447 100644 (file)
@@ -323,7 +323,6 @@ inline py::object toPyObject(IValue&& ivalue) {
       py_dict[toPyObject(IValue{pair.first})] = toPyObject(IValue{pair.second});
     }
     return std::move(py_dict);
-
   } else {
     AT_ERROR("Missing cases in 'toPyObject'! File a bug report.");
   }
index 9ddd5db..5425219 100644 (file)
@@ -258,9 +258,7 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
     const auto& param_list = module_->get_parameters().items();
     for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) {
       auto& param = *it;
-      if (!param->is_buffer) {
-        params.push_back(caller.get_or_add_parameter(param->slot()));
-      }
+      params.push_back(caller.get_or_add_parameter(param->slot()));
     }
     auto list = caller.graph()->createList(TensorType::get(), params);
     caller.graph()->insertNode(list);
@@ -271,6 +269,7 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
   std::shared_ptr<Module> module_;
 };
 
+
 // defines how modules/methods behave inside the script subset.
 // for now this does not have any interaction with python.
 // in the future, we will add the ability to resolve `self.foo` to python
@@ -294,14 +293,14 @@ struct ModuleValue : public SugaredValue {
     // it adds a buffer 'training' to the model if one doesn't exist
     // and then loads that parameter, casting it to bool
     if (field == "training") {
-      NamedParameter* v = module->find_parameter(field);
+      NamedIValue* v = module->find_buffer(field);
       if (!v) {
         py::object py_module = py::cast(module);
         bool training = py::cast<bool>(py::getattr(py_module, "training"));
         auto t =
             autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong));
-        module->register_parameter("training", std::move(t), true);
-        v = module->find_parameter(field);
+        module->register_buffer("training", std::move(t));
+        v = module->find_buffer(field);
       }
       Value* the_tensor = m.get_or_add_parameter(v->slot());
       Value* the_bool = m.graph()->insert(prim::Bool, {the_tensor});
@@ -312,9 +311,13 @@ struct ModuleValue : public SugaredValue {
       return std::make_shared<ModuleValue>(v->module);
     } else if (Method* v = module->find_method(field)) {
       return std::make_shared<MethodValue>(shared_from_this(), *v);
-    } else if (NamedParameter* v = module->find_parameter(field)) {
+    } else if (NamedIValue* v = module->find_parameter(field)) {
       return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
+    } else if (NamedIValue* v = module->find_attribute(field)) {
+      return std::make_shared<SimpleValue>(
+          m.get_or_add_attribute(v->type, v->slot()));
     }
+
     // This can also be a call to a non-script module, or a plain
     // python method. If so return this as a python value.
     py::object py_module = py::cast(module);
@@ -592,11 +595,16 @@ py::object unpackVariableTensorList(std::vector<at::Tensor> outputs) {
 }
 
 static void gatherParametersAndBuffers(
-    std::vector<at::Tensor*>& values,
+    std::vector<IValue*>& values,
     const Module& m) {
   for (auto& param : m.get_parameters()) {
     values.push_back(param->slot());
   }
+  for (auto& param : m.get_attributes()) {
+    if (param->type->isSubtypeOf(TensorType::get())) {
+      values.push_back(param->slot());
+    }
+  }
   for (const auto& sub : m.get_modules()) {
     gatherParametersAndBuffers(values, *sub->module);
   }
@@ -732,9 +740,16 @@ void initJitScriptBindings(PyObject* module) {
           },
           py::return_value_policy::reference_internal)
       .def("_register_parameter", &Module::register_parameter)
+      .def(
+          "_register_attribute",
+          [](Module& self, std::string name, TypePtr type, py::object value) {
+            self.register_attribute(name, type, toIValue(value, type));
+          })
       .def("_register_module", &Module::register_module)
+      .def("_register_buffer", &Module::register_buffer)
       .def("_set_parameter", &Module::set_parameter)
       .def("_get_parameter", &Module::get_parameter)
+      .def("_get_buffer", &Module::get_buffer)
       .def("_get_module", &Module::get_module)
       .def(
           "_get_modules",
@@ -754,27 +769,38 @@ void initJitScriptBindings(PyObject* module) {
             py::tuple result(parameters.size());
             for (size_t i = 0; i < parameters.size(); ++i) {
               auto& p = parameters[i];
+              py::tuple r(2);
+              result[i] = std::make_tuple(
+                  p.key(),
+                  autograd::as_variable_ref(p->slot()->toTensor()));
+            }
+            return result;
+          })
+      .def(
+          "_get_attributes",
+          [](Module& self) -> py::tuple {
+            auto& attributes = self.get_attributes();
+            py::tuple result(attributes.size());
+            for (size_t i = 0; i < attributes.size(); ++i) {
+              auto& buffer = attributes[i];
               py::tuple r(3);
+              IValue v = *buffer->slot();
               result[i] = std::make_tuple(
-                  p.key(), autograd::as_variable_ref(*p->slot()), p->is_buffer);
+                  buffer.key(),
+                  buffer->type,
+                  toPyObject(std::move(v)));
             }
             return result;
           })
       .def(
           "_has_parameter",
-          [](Module& self, const std::string& name) {
-            if (auto r = self.find_parameter(name)) {
-              return !r->is_buffer;
-            }
-            return false;
+          [](Module& self, const std::string& name) -> bool {
+            return self.find_parameter(name);
           })
       .def(
           "_has_buffer",
-          [](Module& self, const std::string& name) {
-            if (auto r = self.find_parameter(name)) {
-              return r->is_buffer;
-            }
-            return false;
+          [](Module& self, const std::string& name) -> bool {
+            return self.find_buffer(name);
           })
       .def(
           "_has_module",
@@ -812,10 +838,10 @@ void initJitScriptBindings(PyObject* module) {
              bool force_outplace) {
             // prereq: Module's buffers and parameters are unique
             // this was ensured in python before calling this function
-            std::vector<at::Tensor*> parameters;
+            std::vector<IValue*> parameters;
             gatherParametersAndBuffers(parameters, *self);
             Stack inputs = toStack(input_tuple);
-            for (at::Tensor* param : parameters) {
+            for (IValue* param : parameters) {
               inputs.emplace_back(*param);
             }
             auto graph = tracer::createGraphByTracing(
@@ -904,16 +930,20 @@ void initJitScriptBindings(PyObject* module) {
              std::vector<std::tuple<std::shared_ptr<Module>, std::string>>
                  params,
              std::shared_ptr<Module> orig) {
-            std::vector<at::Tensor*> member_inputs;
+            std::vector<IValue*> member_inputs;
             for (auto& p : params) {
-              NamedParameter* np =
+              NamedIValue* np =
                   std::get<0>(p)->find_parameter(std::get<1>(p));
+              if (np == nullptr) {
+                np = std::get<0>(p)->find_buffer(std::get<1>(p));
+              }
               AT_ASSERT(np != nullptr);
               member_inputs.push_back(np->slot());
             }
 
             Method* orig_method = orig->find_method(name);
-            m->create_method(name, orig_method->graph()->copy(), member_inputs);
+            m->create_method(
+                name, orig_method->graph()->copy(), member_inputs);
           });
 
   py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
@@ -931,7 +961,7 @@ void initJitScriptBindings(PyObject* module) {
       .def(
           "propagate_and_assign_input_and_output_shapes",
           &Method::propagate_and_assign_input_and_output_shapes)
-      .def("params", &Method::params)
+      .def("initial_ivalues", &Method::initial_ivalues)
       .def(
           "graph_for",
           [](py::args args, py::kwargs kwargs) {
index a735481..fe9196e 100644 (file)
@@ -48,10 +48,11 @@ Value* try_emit_call_to(
 
   // parameters to callee method (which become parameters to _this_ method
   // if they were not already)
-  for (at::Tensor* member : callee.params()) {
+  for (auto member : callee.initial_ivalues()) {
     if (!caller) {
       throw ErrorReport(loc)
-          << " attempting to call a method with parameters from a raw graph. File a bug report";
+          << " attempting to call a method with parameters/attributes"
+             " from a raw graph. File a bug report";
     }
     matched_schema->inputs.push_back(caller->get_or_add_parameter(member));
   }
@@ -123,7 +124,7 @@ void Module::to_impl(
   // Then convert every of our parameters.
   for (auto& parameter : parameters) {
     // Need to access the `at::Tensor` as a `Variable` here.
-    autograd::Variable variable = *parameter->slot();
+    autograd::Variable variable = parameter.value().slot()->toTensor();
     at::Tensor data = variable.data();
     // Use the data's original device or dtype if not supplied here.
     auto new_data = data.to(
index 592b096..a993687 100644 (file)
@@ -49,30 +49,33 @@ using ExtraFilesMap = std::unordered_map<std::string, std::string>;
 
 struct Module;
 
+using ModuleLookup = std::function<std::shared_ptr<Module>(
+    const std::vector<std::string>&)>;
+
 struct Method {
   Method(
       Module* owner,
       std::string name,
       bool optimize,
       std::shared_ptr<Graph> graph,
-      std::vector<at::Tensor*> initial_members,
+      std::vector<IValue*> initial_members,
       std::function<void(Method&)> method_creator)
       : owner_(owner),
         name_(std::move(name)),
         graph_(std::move(graph)),
         optimize(optimize),
-        member_inputs(std::move(initial_members)),
+        initial_ivalues_(std::move(initial_members)),
         method_creator(std::move(method_creator)) {
-    AT_ASSERT(graph_->inputs().size() >= member_inputs.size());
-    int i = graph_->inputs().size() - member_inputs.size();
-    for (at::Tensor* member : member_inputs) {
-      member_input_index[member] = i++;
+    AT_ASSERT(graph_->inputs().size() >= initial_ivalues_.size());
+    int i = graph_->inputs().size() - initial_ivalues_.size();
+    for (auto member : initial_ivalues_) {
+      initial_ivalue_index[member] = i++;
     }
   }
 
   void run(Stack& stack) {
-    for (at::Tensor* tp : member_inputs) {
-      stack.emplace_back(*tp);
+    for (auto input : initial_ivalues_) {
+      push(stack, *input);
     }
     get_executor().run(stack);
   }
@@ -88,7 +91,7 @@ struct Method {
   }
 
   std::shared_ptr<Graph> graph_for(Stack inputs) {
-    for (at::Tensor* tp : member_inputs) {
+    for (auto tp : initial_ivalues_) {
       inputs.emplace_back(*tp);
     }
     return get_executor().graphFor(inputs);
@@ -115,17 +118,21 @@ struct Method {
   TORCH_API void ensure_defined();
 
   size_t num_inputs() const {
-    return graph()->inputs().size() - member_inputs.size();
+    return graph()->inputs().size() - initial_ivalues_.size();
+  }
+  TORCH_API Value* get_or_add_parameter(IValue* slot) {
+    AT_ASSERT(slot->isTensor());
+    return get_or_add_attribute(TensorType::get(), slot);
   }
-  TORCH_API Value* get_or_add_parameter(at::Tensor* slot) {
-    auto it = member_input_index.find(slot);
-    if (it != member_input_index.end()) {
+
+  TORCH_API Value* get_or_add_attribute(TypePtr type, IValue* slot) {
+    auto it = initial_ivalue_index.find(slot);
+    if (it != initial_ivalue_index.end()) {
       return graph()->inputs().at(it->second);
     }
-    // add it as a new parameter
-    member_inputs.push_back(slot);
-    member_input_index[slot] = graph()->inputs().size();
-    return graph()->addInput();
+    initial_ivalues_.push_back(slot);
+    initial_ivalue_index[slot] = graph()->inputs().size();
+    return graph()->addInput()->setType(type);
   }
 
   std::shared_ptr<Graph> propagate_shapes(
@@ -133,11 +140,11 @@ struct Method {
       bool with_grad = false) {
     auto retval = graph_->copy();
     Stack stack;
-    stack.reserve(inputs.size() + member_inputs.size());
+    stack.reserve(inputs.size() + initial_ivalues_.size());
     for (at::Tensor& i : inputs) {
       stack.emplace_back(std::move(i));
     }
-    for (at::Tensor* inp : member_inputs) {
+    for (IValue* inp : initial_ivalues_) {
       stack.push_back(*inp);
     }
     const auto size = stack.size();
@@ -152,8 +159,10 @@ struct Method {
       bool with_grad = false,
       bool propagate = true) {
     auto retval = graph_->copy();
-    for (auto inp : member_inputs) {
-      inputs.push_back(*inp);
+    for (auto inp : initial_ivalues_) {
+      if (inp->isTensor()) {
+        inputs.push_back(inp->toTensor());
+      }
     }
     if (propagate) {
       setInputTypes(
@@ -186,8 +195,8 @@ struct Method {
     return retval;
   }
 
-  std::vector<at::Tensor*> params() const {
-    return member_inputs;
+  const std::vector<IValue*>& initial_ivalues() const {
+    return initial_ivalues_;
   }
 
   Method& setSchema(FunctionSchema schema_) {
@@ -310,16 +319,16 @@ struct Method {
   bool optimize;
 
   GraphExecutor executor; // for execution
-  // member_inputs are a list of additional arguments appended to graph that are
-  // inputs that come from the members of the Module or its submodules.
+  // initial_ivalues are a list of additional arguments appended to graph
+  // that are inputs that come from the members of the Module or its submodules.
   // each is a pointer to a slot in the module that owns this parameter
   // parameters and submodules can only be _added_ to script Modules to ensure
   // these pointers always stay valid
-  std::vector<at::Tensor*> member_inputs;
+  std::vector<IValue*> initial_ivalues_;
 
-  // map from a at::Tensor* in member_inputs to the offset it appears at
+  // map from a IValue* in initial_ivalues to the offset it appears at
   // in graph. used to accelerate get_or_add_parameter
-  std::unordered_map<at::Tensor*, size_t> member_input_index;
+  std::unordered_map<IValue*, size_t> initial_ivalue_index;
 
   // TODO: support that case where we allow _writes_ to parameters from
   // compiled functions.
@@ -349,24 +358,18 @@ struct NamedModule {
   std::shared_ptr<Module> module;
 };
 
-struct NamedParameter {
-  NamedParameter(std::string name, at::Tensor tensor, bool is_buffer)
-      : name(std::move(name)),
-        is_buffer(is_buffer),
-        parameter(torch::make_unique<at::Tensor>(std::move(tensor))) {}
+struct NamedIValue {
+  NamedIValue(std::string name, TypePtr type, IValue ivalue)
+      : name_(name),
+        type(type),
+        ivalue(torch::make_unique<IValue>(std::move(ivalue))) {}
 
-  const std::string name;
-  bool is_buffer; // buffers are part of the module state but
-                  // are not modified by optimizers during SGD
-  at::Tensor* slot() const {
-    return parameter.get();
+  IValue* slot() const {
+    return ivalue.get();
   }
-
- private:
-  // the extra level of indirection allows Methods to safely store pointers
-  // to the slots where parameters are kept while also allow parameters
-  // to be reassigned
-  std::unique_ptr<at::Tensor> parameter;
+  const std::string name_;
+  const TypePtr type;
+  std::unique_ptr<IValue> ivalue;
 };
 
 struct Module {
@@ -374,6 +377,7 @@ struct Module {
   Module()
       : modules("Module"),
         parameters("Parameter"),
+        attributes("Attributes"),
         methods("Method"),
         optimize(true) {}
 
@@ -391,16 +395,33 @@ struct Module {
     return get_method("forward")(std::move(inputs));
   }
 
+  void register_buffer(const std::string& name, autograd::Variable v) {
+    if (auto b = attributes.find(name)) {
+      AT_ASSERT(b->type->isSubtypeOf(TensorType::get()));
+      *b->slot() = v;
+      return;
+    }
+    attributes.insert(name, NamedIValue(name, TensorType::get(), std::move(v)));
+  }
   void register_parameter(
       const std::string& name,
       autograd::Variable v,
       bool is_buffer) {
+    if (is_buffer) {
+      register_buffer(name, std::move(v));
+      return;
+    }
     if (auto p = parameters.find(name)) {
       *p->slot() = v;
-      p->is_buffer = is_buffer;
       return;
     }
-    parameters.insert(name, NamedParameter(name, std::move(v), is_buffer));
+    parameters.insert(name, NamedIValue(name, TensorType::get(), std::move(v)));
+  }
+  void register_attribute(
+      const std::string& name,
+      const TypePtr type,
+      IValue ivalue) {
+    attributes.insert(name, NamedIValue(name, type, ivalue));
   }
   void register_module(
       const std::string& name,
@@ -411,7 +432,7 @@ struct Module {
   Method& create_method(
       const std::string& name,
       std::shared_ptr<Graph> graph,
-      std::vector<at::Tensor*> member_inputs) {
+      std::vector<IValue*> member_inputs) {
     AT_ASSERT(graph);
     std::unique_ptr<Method> method(new Method(
         this,
@@ -436,7 +457,7 @@ struct Module {
     return *methods.insert(name, std::move(method));
   }
 
-  at::Tensor* parameter_slot(const std::string& name) const {
+  IValue* parameter_slot(const std::string& name) const {
     return parameters[name].slot();
   }
 
@@ -445,7 +466,10 @@ struct Module {
   }
 
   autograd::Variable get_parameter(const std::string& name) const {
-    return autograd::as_variable_ref(*parameter_slot(name));
+    return autograd::as_variable_ref(parameter_slot(name)->toTensor());
+  }
+  autograd::Variable get_buffer(const std::string& name) const {
+    return autograd::as_variable_ref(attributes.find(name)->slot()->toTensor());
   }
 
   // each module owns its method. The reference returned here
@@ -461,18 +485,32 @@ struct Module {
   const torch::OrderedDict<std::string, NamedModule>& get_modules() const {
     return modules;
   }
-  const torch::OrderedDict<std::string, NamedParameter>& get_parameters()
+  const torch::OrderedDict<std::string, NamedIValue>& get_parameters()
       const {
     return parameters;
   }
+  const torch::OrderedDict<std::string, NamedIValue>& get_attributes()
+      const {
+    return attributes;
+  }
   const torch::OrderedDict<std::string, std::unique_ptr<Method>>& get_methods()
       const {
     return methods;
   }
 
-  NamedParameter* find_parameter(const std::string& name) {
+  NamedIValue* find_parameter(const std::string& name) {
     return parameters.find(name);
   }
+  NamedIValue* find_attribute(const std::string& name) {
+    return attributes.find(name);
+  }
+  NamedIValue* find_buffer(const std::string& name) {
+    auto b = attributes.find(name);
+    if (b && b->type->isSubtypeOf(TensorType::get())) {
+      return b;
+    }
+    return nullptr;
+  }
   NamedModule* find_module(const std::string& name) {
     return modules.find(name);
   }
@@ -493,7 +531,7 @@ struct Module {
     for (auto& submod : get_modules()) {
       submod->module->train(on);
     }
-    register_parameter("training", torch::tensor(on ? 1 : 0, at::kLong), /*is_buffer=*/true);
+    register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
   }
   /// Calls train(false) to enable "eval" mode.
   /// Do not override this method, override `train()` instead.
@@ -502,8 +540,8 @@ struct Module {
   }
   /// True if the module is in training mode.
   bool is_training() {
-    if (auto p = find_parameter("training")) {
-      return p->slot()->item<int64_t>() == 1;
+    if (auto p = find_buffer("training")) {
+      return p->slot()->toTensor().item<int64_t>() == 1;
     }
     // We are in training mode by default
     return true;
@@ -563,18 +601,28 @@ struct Module {
       const ExtraFilesMap& extra_files = ExtraFilesMap());
 
   void copy_into(
-      std::function<std::shared_ptr<Module>(std::vector<std::string>)>
-          module_lookup,
+      ModuleLookup module_lookup,
       // parameter_remap is needed when a parent module uses a parameter of a
       // submodule
-      std::unordered_map<at::Tensor*, at::Tensor*>& parameter_remap,
+      std::unordered_map<IValue*, IValue*>& parameter_remap,
       std::vector<std::string> names = {}) const {
     auto curr = module_lookup(names);
     for (auto& kv : parameters) {
       curr->register_parameter(
-          kv.key(), *kv.value().slot(), kv.value().is_buffer);
+          kv.key(),
+          kv.value().slot()->toTensor(),
+          /*is_buffer=*/false);
       parameter_remap[kv.value().slot()] = curr->parameter_slot(kv.key());
     }
+    for (auto& kv : attributes) {
+      if (!kv.value().type->isSubtypeOf(TensorType::get())) {
+        continue;
+      }
+      curr->register_buffer(
+          kv.key(),
+          kv.value().slot()->toTensor());
+      parameter_remap[kv.value().slot()] = curr->find_buffer(kv.key())->slot();
+    }
     for (auto& kv : modules) {
       names.push_back(kv.key());
       // Submodules must be translated first, otherwise parameter_remap entries
@@ -583,11 +631,12 @@ struct Module {
       names.pop_back();
     }
     for (auto& kv : methods) {
-      std::vector<at::Tensor*> params;
-      for (auto& p : kv.value()->params()) {
-        params.push_back(parameter_remap.at(p));
+      std::vector<IValue*> initial_ivalues;
+      for (auto& p : kv.value()->initial_ivalues()) {
+        initial_ivalues.push_back(parameter_remap.at(p));
       }
-      curr->create_method(kv.key(), kv.value()->graph()->copy(), params);
+      curr->create_method(
+          kv.key(), kv.value()->graph()->copy(), initial_ivalues);
     }
   }
 
@@ -597,12 +646,13 @@ struct Module {
       const c10::optional<at::ScalarType>& dtype,
       bool non_blocking);
 
-  // invariant: to ensure member_inputs of Methods stay valid,
+  // invariant: to ensure initial_ivalues of Methods stay valid,
   // it is only legal to _add_ new modules and parameters.
-  // removing them will allow member_inputs to point to invalid parameters
+  // removing them will allow initial_ivalues to point to invalid parameters
   // no such restriction exists for methods
   torch::OrderedDict<std::string, NamedModule> modules;
-  torch::OrderedDict<std::string, NamedParameter> parameters;
+  torch::OrderedDict<std::string, NamedIValue> parameters;
+  torch::OrderedDict<std::string, NamedIValue> attributes;
   torch::OrderedDict<std::string, std::unique_ptr<Method>> methods;
   bool optimize;
 };
index 1983d9a..fd1cfec 100644 (file)
@@ -899,9 +899,7 @@ class OrderedParameterDict(OrderedDictWrapper):
         super(OrderedParameterDict, self).__init__(module)
 
     def items(self):
-        return [(name, param) for name, param, is_buffer
-                in self.module._get_parameters()
-                if not is_buffer]
+        return [(name, param) for name, param in self.module._get_parameters()]
 
     def __setitem__(self, k, v):
         self.module._register_parameter(k, v, False)
@@ -920,12 +918,11 @@ class OrderedBufferDict(OrderedDictWrapper):
         super(OrderedBufferDict, self).__init__(module)
 
     def items(self):
-        return [(name, param) for name, param, is_buffer
-                in self.module._get_parameters()
-                if is_buffer]
+        return [(name, param) for name, _, param in
+                self.module._get_attributes() if isinstance(param, torch.Tensor)]
 
     def __setitem__(self, k, v):
-        self.module._register_parameter(k, v, True)
+        self.module._register_buffer(k, v)
 
     def __contains__(self, k):
         return self.module._has_buffer(k)
@@ -933,7 +930,7 @@ class OrderedBufferDict(OrderedDictWrapper):
     def __getitem__(self, k):
         if k not in self:
             raise KeyError(k)
-        return self.module._get_parameter(k)
+        return self.module._get_buffer(k)
 
 # base types that can be constants
 # in addition, tuples and lists of these base types are also considered constants
@@ -1161,8 +1158,12 @@ if _enabled:
                 if attr == 'training':
                     if self._has_buffer('training'):
                         self.__dict__['training'] = value
-                        self._get_parameter('training').fill_(int(value))
+                        self._get_buffer('training').fill_(int(value))
                         return
+                if isinstance(value, Attribute):
+                    the_type = torch.jit.annotations.ann_to_type(value.type)
+                    self._register_attribute(attr, the_type, value.value)
+                    return
                 return super(ScriptModule, self).__setattr__(attr, value)
 
             if hasattr(self, attr):
@@ -1552,5 +1553,12 @@ def annotate(the_type, the_value):
     # noop in python
     return the_value
 
+
+class Attribute(object):
+    def __init__(self, value, the_type):
+        self.value = value
+        self.type = the_type
+
+
 if not torch._C._jit_init():
     raise RuntimeError("JIT initialization failed")
index e4c2ae9..cc270a1 100644 (file)
@@ -78,7 +78,7 @@ def _copy_scriptmodule_methods(modules, module_copies, module_indices):
         for method_name in module._method_names():
             method = module._get_method(method_name)
             param_list = []
-            for param in method.params():
+            for param in method.initial_ivalues():
                 param_list.append(param_dict[param])
             replica._copy_method(method_name, param_list, module)
 
index 0dea267..b7b89c2 100644 (file)
@@ -225,7 +225,7 @@ def _model_to_graph(model, args, f, verbose=False, training=False,
             graph = method.propagate_and_assign_input_and_output_shapes(
                 args, example_outputs, False, propagate)
             # Erase number types to bring the graph to a pre-NumberType state
-            params = method.params()
+            params = method.initial_ivalues()
         except AttributeError:
             # TODO: just trace it
             raise RuntimeError('\'forward\' method must be a script method')