Unify namespace of script::Module (#18378)
authorZachary DeVito <zdevito@fb.com>
Wed, 3 Apr 2019 22:58:08 +0000 (15:58 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 23:04:17 +0000 (16:04 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18378
ghimport-source-id: 55c29bb436a2153d29ff2f4488d99d8863c187b1

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18379 Enforce single parent for script submodules
* **#18378 Unify namespace of script::Module**
* #18314 Add ability to specialize class types to ArgumentSpec
* #18226 Add Slot type to abstract the raw pointers being used for slots.

This removes individual OrderedDicts in favor of a single unified
namespace for all things in a script::Module. This removes a whole
class of bugs where both a method and an parameter could get the
same name, for instance.

Since we no longer have to expose OrderedDict::Item objects, a lot of
downstream code can be simplified.

We no longer now double-store names (both in the key of the dictionary,
and in the object itself).

Differential Revision: D14603723

fbshipit-source-id: b5f7551b3074679623edd6ea70269830353b4d4c

test/custom_operator/test_custom_ops.cpp
test/test_jit.py
torch/csrc/jit/export.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/script/builtin_functions.cpp
torch/csrc/jit/script/class_type.cpp
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/module.cpp
torch/csrc/jit/script/module.h
torch/csrc/jit/symbolic_script.cpp

index 4d6f7bf..a0387f4 100644 (file)
@@ -15,10 +15,10 @@ void check_all_parameters(
     const torch::jit::script::Module& module,
     Predicate predicate) {
   for (const auto& parameter : module.get_parameters()) {
-    AT_ASSERT(predicate(parameter->slot()->toTensor()));
+    AT_ASSERT(predicate(parameter.slot()->toTensor()));
   }
   for (const auto& child : module.get_modules()) {
-    check_all_parameters(*child->module, predicate);
+    check_all_parameters(*child.module, predicate);
   }
 }
 } // namespace helpers
index 61331ca..70d01df 100644 (file)
@@ -8330,7 +8330,7 @@ a")
             ''')
 
     def test_duplicate(self):
-        with self.assertRaisesRegex(RuntimeError, 'Method \'test\' already defined'):
+        with self.assertRaisesRegex(RuntimeError, 'method \'test\' already defined'):
             cu = torch.jit.CompilationUnit('''
             def test():
                 return 1
index 7e2a63f..84ac515 100644 (file)
@@ -454,8 +454,7 @@ void GraphEncoder::EncodeTensor(
   } else {
     AT_ASSERT(t.is_contiguous());
     tensor_proto->set_raw_data(std::string(
-        static_cast<char*>(t.data_ptr()),
-        t.element_size() * t.numel()));
+        static_cast<char*>(t.data_ptr()), t.element_size() * t.numel()));
   }
 }
 
@@ -665,8 +664,7 @@ void ScriptModuleSerializer::convertAndWriteTensor(
 
   tensor_proto->set_requires_grad(tensor.requires_grad());
 
-  uint64_t record_size =
-      tensor.element_size() * tensor.storage().size();
+  uint64_t record_size = tensor.element_size() * tensor.storage().size();
   auto* key = tensor.storage().unsafeGetStorageImpl();
 
   auto storage_it = storageMap.find(key);
@@ -686,8 +684,7 @@ void ScriptModuleSerializer::convertAndWriteTensor(
                                /* stride = */ {1})
                            .cpu();
       AT_ASSERT(
-          storage_tensor.element_size() *
-              storage_tensor.storage().size() ==
+          storage_tensor.element_size() * storage_tensor.storage().size() ==
           record_size);
     }
     std::string name = "tensors/" + std::to_string(tensor_id);
@@ -733,11 +730,10 @@ void ScriptModuleSerializer::convertModule(
   module_def->set_optimize(module.is_optimized());
   for (const auto& elem : module.get_parameters()) {
     torch::ParameterDef* param_def = module_def->add_parameters();
-    convertParameter(elem.value(), param_def, /*is_buffer=*/false);
+    convertParameter(elem, param_def, /*is_buffer=*/false);
   }
 
-  for (const auto& item : module.get_attributes()) {
-    auto& attribute = item.value();
+  for (const auto& attribute : module.get_attributes()) {
     // Add attribute to ModuleDef
     torch::AttributeDef* attribute_def = module_def->add_attributes();
     attribute_def->set_name(attribute.name());
@@ -773,7 +769,7 @@ void ScriptModuleSerializer::convertModule(
 
   for (const auto& elem : module.get_modules()) {
     torch::ModuleDef* sub_def = module_def->add_submodules();
-    convertModule(*elem->module, module_name.str(), elem.key(), sub_def);
+    convertModule(*elem.module, module_name.str(), elem.name, sub_def);
   }
 }
 
index d515014..0844a96 100644 (file)
@@ -1,9 +1,9 @@
+#include <torch/csrc/jit/passes/python_print.h>
 #include <c10/util/Exception.h>
 #include <torch/csrc/jit/attributes.h>
 #include <torch/csrc/jit/export.h>
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/ir_views.h>
-#include <torch/csrc/jit/passes/python_print.h>
 #include <torch/csrc/jit/resource_guard.h>
 #include <torch/csrc/jit/script/error_report.h>
 #include <torch/csrc/jit/script/module.h>
@@ -131,17 +131,15 @@ void createTensorToParameterNameMap(
     const script::Module& module,
     const QualifiedNamePtr& prefix,
     std::unordered_map<script::Slot, QualifiedNamePtr>& result) {
-  for (const auto& elem : module.get_parameters()) {
-    const script::NamedIValue& param = elem.value();
+  for (const auto& param : module.get_parameters()) {
     result[param.slot()] = QualifiedName::create(prefix, param.name());
   }
-  for (const auto& elem : module.get_attributes()) {
-    const script::NamedIValue& param = elem.value();
+  for (const auto& param : module.get_attributes()) {
     result[param.slot()] = QualifiedName::create(prefix, param.name());
   }
   for (const auto& elem : module.get_modules()) {
     createTensorToParameterNameMap(
-        *elem->module, QualifiedName::create(prefix, elem.key()), result);
+        *elem.module, QualifiedName::create(prefix, elem.name), result);
   }
 }
 
@@ -1122,10 +1120,12 @@ struct PythonPrintPass {
   void printMethod(
       script::Method& method,
       bool is_class,
-      const std::unordered_map<script::Slot, QualifiedNamePtr>& extra_ivalue_names) {
-    std::vector<std::string> ivalue_names = fmap(
-        method.initial_ivalues(),
-        [&](const script::Slot& slot) { return extra_ivalue_names.at(slot)->str(); });
+      const std::unordered_map<script::Slot, QualifiedNamePtr>&
+          extra_ivalue_names) {
+    std::vector<std::string> ivalue_names =
+        fmap(method.initial_ivalues(), [&](const script::Slot& slot) {
+          return extra_ivalue_names.at(slot)->str();
+        });
     const std::string& name = method.name();
     Graph& graph = *method.graph();
     auto defaults = fmap(
@@ -1138,14 +1138,14 @@ struct PythonPrintPass {
     createTensorToParameterNameMap(
         module, QualifiedName::create("self"), extra_ivalue_names);
     for (auto& method : module.get_methods()) {
-      const std::string& name = method.value()->name();
+      const std::string& name = method->name();
       // we skip __forked_functions because they actually get inlined into their
       // callers, exporting them again will lead to more code generated on each
       // export
       if (name.find("__forked_function") == 0) {
         continue;
       }
-      printMethod(*method.value(), /*is_class=*/false, extra_ivalue_names);
+      printMethod(*method, /*is_class=*/false, extra_ivalue_names);
     }
   }
 
index 5a04545..a1ed46c 100644 (file)
@@ -1,6 +1,6 @@
+#include <torch/csrc/jit/script/builtin_functions.h>
 #include <torch/csrc/api/include/torch/jit.h>
 #include <torch/csrc/jit/code_template.h>
-#include <torch/csrc/jit/script/builtin_functions.h>
 
 namespace torch {
 namespace jit {
@@ -67,8 +67,8 @@ struct BuiltinFunctionRegistry {
         module, source, script::nativeResolver, /*self=*/c10::nullopt);
     modules.push_back(module);
     for (auto& method : module->get_methods()) {
-      builtins_by_name[Symbol::fromQualString("aten::" + method.key())]
-          .push_back(method->get());
+      builtins_by_name[Symbol::fromQualString("aten::" + method->name())]
+          .push_back(method.get());
     }
   }
   void loadBuiltinFunctions() {
index 2258b99..be1919c 100644 (file)
@@ -10,10 +10,9 @@ Method* ClassType::getMethod(const std::string& name) const {
 }
 
 std::vector<Method*> ClassType::methods() const {
-  const auto& methods = module_->get_methods();
   std::vector<Method*> ret;
-  for (const auto& pr : methods.items()) {
-    ret.push_back(pr.value().get());
+  for (const auto& pr : module_->get_methods()) {
+    ret.push_back(pr.get());
   }
   return ret;
 }
index 1221cdd..54dbb98 100644 (file)
@@ -273,10 +273,10 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
       size_t n_binders) override {
     // Add all module parameters as inputs to the graph
     std::vector<Value*> params;
-    const auto& param_list = module_->get_parameters().items();
+    const auto& param_list = module_->get_parameters();
     for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) {
       auto& param = *it;
-      params.push_back(caller.get_or_add_parameter(param->slot()));
+      params.push_back(caller.get_or_add_parameter(param.slot()));
     }
     auto list = caller.graph()->createList(TensorType::get(), params);
     caller.graph()->insertNode(list);
@@ -606,15 +606,15 @@ static void gatherParametersAndBuffers(
     std::vector<Slot>& values,
     const Module& m) {
   for (auto& param : m.get_parameters()) {
-    values.push_back(param->slot());
+    values.push_back(param.slot());
   }
   for (auto& param : m.get_attributes()) {
-    if (param->type()->isSubtypeOf(TensorType::get())) {
-      values.push_back(param->slot());
+    if (param.type()->isSubtypeOf(TensorType::get())) {
+      values.push_back(param.slot());
     }
   }
   for (const auto& sub : m.get_modules()) {
-    gatherParametersAndBuffers(values, *sub->module);
+    gatherParametersAndBuffers(values, *sub.module);
   }
 }
 
@@ -767,38 +767,38 @@ void initJitScriptBindings(PyObject* module) {
       .def(
           "_get_modules",
           [](Module& self) -> py::tuple {
-            auto& modules = self.get_modules();
+            auto modules = self.get_modules();
             py::tuple result(modules.size());
             for (size_t i = 0; i < modules.size(); ++i) {
               auto& item = modules[i];
-              result[i] = std::make_pair(item.key(), item.value().module);
+              result[i] = std::make_pair(item.name, item.module);
             }
             return result;
           })
       .def(
           "_get_parameters",
           [](Module& self) -> py::tuple {
-            auto& parameters = self.get_parameters();
+            auto parameters = self.get_parameters();
             py::tuple result(parameters.size());
             for (size_t i = 0; i < parameters.size(); ++i) {
               auto& p = parameters[i];
               py::tuple r(2);
               result[i] = std::make_tuple(
-                  p.key(), autograd::as_variable_ref(p->slot()->toTensor()));
+                  p.name(), autograd::as_variable_ref(p.slot()->toTensor()));
             }
             return result;
           })
       .def(
           "_get_attributes",
           [](Module& self) -> py::tuple {
-            auto& attributes = self.get_attributes();
+            auto attributes = self.get_attributes();
             py::tuple result(attributes.size());
             for (size_t i = 0; i < attributes.size(); ++i) {
               auto& buffer = attributes[i];
               py::tuple r(3);
-              IValue v = *buffer->slot();
+              IValue v = *buffer.slot();
               result[i] = std::make_tuple(
-                  buffer.key(), buffer->type(), toPyObject(std::move(v)));
+                  buffer.name(), buffer.type(), toPyObject(std::move(v)));
             }
             return result;
           })
@@ -830,11 +830,10 @@ void initJitScriptBindings(PyObject* module) {
       .def(
           "_method_names",
           [](Module& self) {
-            using Item =
-                torch::OrderedDict<std::string, std::unique_ptr<Method>>::Item;
-            return fmap(self.get_methods(), [](const Item& item) {
-              return (*item)->name();
-            });
+            return fmap(
+                self.get_methods(), [](const std::unique_ptr<Method>& method) {
+                  return method->name();
+                });
           })
       .def(
           "_create_method_from_graph",
@@ -976,13 +975,15 @@ void initJitScriptBindings(PyObject* module) {
       .def(
           "propagate_and_assign_input_and_output_shapes",
           &Method::propagate_and_assign_input_and_output_shapes)
-      .def("initial_ivalues",[](Method& m) {
-        std::vector<at::Tensor> tensors;
-        for (auto& t : m.initial_ivalues()) {
-          tensors.push_back(t->toTensor());
-        }
-        return tensors;
-      })
+      .def(
+          "initial_ivalues",
+          [](Method& m) {
+            std::vector<at::Tensor> tensors;
+            for (auto& t : m.initial_ivalues()) {
+              tensors.push_back(t->toTensor());
+            }
+            return tensors;
+          })
       .def(
           "graph_for",
           [](py::args args, py::kwargs kwargs) {
@@ -996,22 +997,22 @@ void initJitScriptBindings(PyObject* module) {
           &Method::debugDisableAutodiffSubgraphInlining)
       .def("schema", &Method::getSchema)
       .def("pretty_print_schema", &Method::pretty_print_schema)
-      .def("python_print", [](Method& m) {
-        std::ostringstream oss;
-        std::vector<at::Tensor> constants;
-        std::vector<ClassTypePtr> classes;
-        PythonPrint(oss, m, constants, classes, true);
-        return std::make_pair(oss.str(), std::move(constants));
-      })
-      .def_property_readonly(
-          "code",
-          [](Method& self) {
-            std::ostringstream ss;
-            std::vector<at::Tensor> tensors;
+      .def(
+          "python_print",
+          [](Method& m) {
+            std::ostringstream oss;
+            std::vector<at::Tensor> constants;
             std::vector<ClassTypePtr> classes;
-            PythonPrint(ss, self, tensors, classes, false);
-            return ss.str();
-          });
+            PythonPrint(oss, m, constants, classes, true);
+            return std::make_pair(oss.str(), std::move(constants));
+          })
+      .def_property_readonly("code", [](Method& self) {
+        std::ostringstream ss;
+        std::vector<at::Tensor> tensors;
+        std::vector<ClassTypePtr> classes;
+        PythonPrint(ss, self, tensors, classes, false);
+        return ss.str();
+      });
 
   m.def(
       "_jit_script_compile",
@@ -1127,9 +1128,10 @@ void initJitScriptBindings(PyObject* module) {
           py::arg("checks_file"),
           py::arg("graph"));
 
-  m.def("_logging_set_logger", [](logging::LoggerBase* logger) {
-    return logging::setLogger(logger);
-  }, py::return_value_policy::reference);
+  m.def(
+      "_logging_set_logger",
+      [](logging::LoggerBase* logger) { return logging::setLogger(logger); },
+      py::return_value_policy::reference);
   py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>(
       m, "LoggerBase");
   py::enum_<logging::LockingLogger::AggregationType>(m, "AggregationType")
@@ -1148,7 +1150,6 @@ void initJitScriptBindings(PyObject* module) {
       logging::LoggerBase,
       std::shared_ptr<logging::NoopLogger>>(m, "NoopLogger")
       .def(py::init<>());
-
 }
 } // namespace script
 } // namespace jit
index 94b1068..933cf40 100644 (file)
@@ -122,13 +122,13 @@ void Module::to_impl(
     const c10::optional<at::ScalarType>& dtype,
     bool non_blocking) {
   // First call `to()` on every child module.
-  for (auto& child : modules) {
-    child->module->to_impl(device, dtype, non_blocking);
+  for (auto& child : get_modules()) {
+    child.module->to_impl(device, dtype, non_blocking);
   }
   // Then convert every of our parameters.
-  for (auto& parameter : parameters) {
+  for (auto& parameter : get_parameters()) {
     // Need to access the `at::Tensor` as a `Variable` here.
-    autograd::Variable variable = parameter.value().slot()->toTensor();
+    autograd::Variable variable = parameter.slot()->toTensor();
     at::Tensor data = variable.data();
     // Use the data's original device or dtype if not supplied here.
     auto new_data = data.to(
index ba99b2c..304857d 100644 (file)
@@ -388,12 +388,7 @@ struct NamedIValue {
 
 struct Module {
   TH_DISALLOW_COPY_AND_ASSIGN(Module);
-  Module()
-      : modules("Module"),
-        parameters("Parameter"),
-        attributes("Attributes"),
-        methods("Method"),
-        optimize(true) {}
+  Module() : optimize(true) {}
 
   // note this doesn't change the flags of existing methods just ones
   // added afterward.
@@ -410,12 +405,16 @@ struct Module {
   }
 
   void register_buffer(const std::string& name, autograd::Variable v) {
-    if (auto b = attributes.find(name)) {
+    if (auto b = find_attribute(name)) {
       AT_ASSERT(b->type()->isSubtypeOf(TensorType::get()));
       *b->slot() = v;
       return;
     }
-    attributes.insert(name, NamedIValue(name, TensorType::get(), std::move(v)));
+    insert(
+        name,
+        attributes_,
+        EntityType::ATTRIBUTE,
+        NamedIValue(name, TensorType::get(), std::move(v)));
   }
   void register_parameter(
       const std::string& name,
@@ -425,22 +424,30 @@ struct Module {
       register_buffer(name, std::move(v));
       return;
     }
-    if (auto p = parameters.find(name)) {
+    if (auto p = find_parameter(name)) {
       *p->slot() = v;
       return;
     }
-    parameters.insert(name, NamedIValue(name, TensorType::get(), std::move(v)));
+    insert(
+        name,
+        parameters_,
+        EntityType::PARAMETER,
+        NamedIValue(name, TensorType::get(), std::move(v)));
   }
   void register_attribute(
       const std::string& name,
       const TypePtr type,
       IValue ivalue) {
-    attributes.insert(name, NamedIValue(name, type, ivalue));
+    insert(
+        name,
+        attributes_,
+        EntityType::ATTRIBUTE,
+        NamedIValue(name, type, ivalue));
   }
   void register_module(
       const std::string& name,
       std::shared_ptr<Module> module) {
-    modules.insert(name, {name, std::move(module)});
+    insert(name, modules_, EntityType::MODULE, {name, std::move(module)});
   }
 
   Method& create_method(
@@ -455,7 +462,7 @@ struct Module {
         std::move(graph),
         std::move(member_inputs),
         nullptr));
-    return *methods.insert(name, std::move(method));
+    return *insert(name, methods_, EntityType::METHOD, std::move(method));
   }
 
   Method& create_method(
@@ -468,11 +475,11 @@ struct Module {
         std::make_shared<Graph>(),
         {},
         std::move(creator)));
-    return *methods.insert(name, std::move(method));
+    return *insert(name, methods_, EntityType::METHOD, std::move(method));
   }
 
   Slot parameter_slot(const std::string& name) const {
-    return parameters[name].slot();
+    return parameters_[get_offset(name, EntityType::PARAMETER)].slot();
   }
 
   void set_parameter(const std::string& name, at::Tensor v) {
@@ -482,69 +489,71 @@ struct Module {
   autograd::Variable get_parameter(const std::string& name) const {
     return autograd::as_variable_ref(parameter_slot(name)->toTensor());
   }
-  autograd::Variable get_buffer(const std::string& name) const {
-    return autograd::as_variable_ref(attributes.find(name)->slot()->toTensor());
-  }
+
   IValue get_attribute(const std::string& name) const {
-    return *attributes.find(name)->slot();
+    return *attributes_[get_offset(name, EntityType::ATTRIBUTE)].slot();
+  }
+
+  autograd::Variable get_buffer(const std::string& name) const {
+    return autograd::as_variable_ref(get_attribute(name).toTensor());
   }
 
   // each module owns its method. The reference returned here
   // is guarenteed to stay valid until this module has been destroyed
   Method& get_method(const std::string& name) const {
-    return *methods[name];
+    return *methods_[get_offset(name, EntityType::METHOD)];
   }
 
   std::shared_ptr<Module> get_module(const std::string& name) const {
-    return modules[name].module;
+    return modules_[get_offset(name, EntityType::MODULE)].module;
   }
 
-  const torch::OrderedDict<std::string, NamedModule>& get_modules() const {
-    return modules;
+  c10::ArrayRef<NamedModule> get_modules() const {
+    return modules_;
   }
-  const torch::OrderedDict<std::string, NamedIValue>& get_parameters() const {
-    return parameters;
+  c10::ArrayRef<NamedIValue> get_parameters() const {
+    return parameters_;
   }
-  const torch::OrderedDict<std::string, NamedIValue>& get_attributes() const {
-    return attributes;
+  c10::ArrayRef<NamedIValue> get_attributes() const {
+    return attributes_;
   }
-  const torch::OrderedDict<std::string, std::unique_ptr<Method>>& get_methods()
-      const {
-    return methods;
+  c10::ArrayRef<std::unique_ptr<Method>> get_methods() const {
+    return methods_;
   }
 
   NamedIValue* find_parameter(const std::string& name) {
-    return parameters.find(name);
+    auto offset = find_offset(name, EntityType::PARAMETER);
+    return offset ? &parameters_[*offset] : nullptr;
   }
   NamedIValue* find_attribute(const std::string& name) {
-    return attributes.find(name);
+    auto offset = find_offset(name, EntityType::ATTRIBUTE);
+    return offset ? &attributes_[*offset] : nullptr;
   }
   NamedIValue* find_buffer(const std::string& name) {
-    auto b = attributes.find(name);
-    if (b && b->type()->isSubtypeOf(TensorType::get())) {
-      return b;
+    auto iv = find_attribute(name);
+    if (iv && iv->type()->isSubtypeOf(TensorType::get())) {
+      return iv;
     }
     return nullptr;
   }
   NamedModule* find_module(const std::string& name) {
-    return modules.find(name);
+    auto offset = find_offset(name, EntityType::MODULE);
+    return offset ? &modules_[*offset] : nullptr;
   }
   Method* find_method(const std::string& name) {
-    if (auto* pm = methods.find(name)) {
-      return pm->get();
-    }
-    return nullptr;
+    auto offset = find_offset(name, EntityType::METHOD);
+    return offset ? methods_[*offset].get() : nullptr;
   }
   void apply(std::function<void(Module&)> fn) {
     for (auto& submod : get_modules()) {
-      submod.value().module->apply(fn);
+      submod.module->apply(fn);
     }
     fn(*this);
   }
   /// Enables "training" mode.
   void train(bool on = true) {
     for (auto& submod : get_modules()) {
-      submod->module->train(on);
+      submod.module->train(on);
     }
     register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
   }
@@ -622,51 +631,139 @@ struct Module {
       std::unordered_map<Slot, Slot>& parameter_remap,
       std::vector<std::string> names = {}) const {
     auto curr = module_lookup(names);
-    for (auto& kv : parameters) {
+    for (auto& param : get_parameters()) {
       curr->register_parameter(
-          kv.key(),
-          kv.value().slot()->toTensor(),
+          param.name(),
+          param.slot()->toTensor(),
           /*is_buffer=*/false);
-      parameter_remap[kv.value().slot()] = curr->parameter_slot(kv.key());
+      parameter_remap[param.slot()] = curr->parameter_slot(param.name());
     }
-    for (auto& kv : attributes) {
-      if (!kv.value().type()->isSubtypeOf(TensorType::get())) {
+    for (auto& attr : get_attributes()) {
+      if (!attr.type()->isSubtypeOf(TensorType::get())) {
         continue;
       }
-      curr->register_buffer(kv.key(), kv.value().slot()->toTensor());
-      parameter_remap[kv.value().slot()] = curr->find_buffer(kv.key())->slot();
+      curr->register_buffer(attr.name(), attr.slot()->toTensor());
+      parameter_remap[attr.slot()] = curr->find_buffer(attr.name())->slot();
     }
-    for (auto& kv : modules) {
-      names.push_back(kv.key());
+    for (auto& mod : get_modules()) {
+      names.push_back(mod.name);
       // Submodules must be translated first, otherwise parameter_remap entries
       // will not be filled in for methods of this module.
-      kv.value().module->copy_into(module_lookup, parameter_remap, names);
+      mod.module->copy_into(module_lookup, parameter_remap, names);
       names.pop_back();
     }
-    for (auto& kv : methods) {
+    for (auto& method : get_methods()) {
       std::vector<Slot> initial_ivalues;
-      for (auto& p : kv.value()->initial_ivalues()) {
+      for (auto& p : method->initial_ivalues()) {
         initial_ivalues.push_back(parameter_remap.at(p));
       }
       curr->create_method(
-          kv.key(), kv.value()->graph()->copy(), initial_ivalues);
+          method->name(), method->graph()->copy(), initial_ivalues);
     }
   }
 
+  enum class EntityType { MODULE, PARAMETER, ATTRIBUTE, METHOD };
+
+  at::optional<EntityType> kind_of(const std::string& name) const {
+    auto it = dict_.find(name);
+    if (it == dict_.end())
+      return at::nullopt;
+    return it->second.type;
+  }
+
  private:
   void to_impl(
       const c10::optional<at::Device>& device,
       const c10::optional<at::ScalarType>& dtype,
       bool non_blocking);
 
+  // modules have a single namespace, but spread over 4 different concepts:
+  // parameters, attributes, methods, and sub-modules
+  // we store individual lists of each concept, and a single map to
+  // unify the namespace and ensure fast lookup
+
   // invariant: to ensure initial_ivalues of Methods stay valid,
   // it is only legal to _add_ new modules and parameters.
   // removing them will allow initial_ivalues to point to invalid parameters
   // no such restriction exists for methods
-  torch::OrderedDict<std::string, NamedModule> modules;
-  torch::OrderedDict<std::string, NamedIValue> parameters;
-  torch::OrderedDict<std::string, NamedIValue> attributes;
-  torch::OrderedDict<std::string, std::unique_ptr<Method>> methods;
+  std::vector<NamedModule> modules_;
+  std::vector<NamedIValue> parameters_;
+  std::vector<NamedIValue> attributes_;
+  std::vector<std::unique_ptr<Method>> methods_;
+
+  static const char* toString(EntityType t) {
+    switch (t) {
+      case EntityType::MODULE:
+        return "module";
+      case EntityType::PARAMETER:
+        return "parameter";
+      case EntityType::ATTRIBUTE:
+        return "attribute";
+      case EntityType::METHOD:
+        return "method";
+    }
+    return nullptr;
+  }
+
+  struct Entry {
+    EntityType type;
+    size_t offset;
+  };
+
+  size_t get_offset(const std::string& name, EntityType expected_type) const {
+    auto it = dict_.find(name);
+    if (it == dict_.end()) {
+      AT_ERROR(toString(expected_type), " '", name, "' is not defined.");
+    }
+    if (it->second.type != expected_type) {
+      AT_ERROR(
+          "The field '",
+          name,
+          "' is a ",
+          toString(it->second.type),
+          " but this call is"
+          " trying to use it as a ",
+          toString(expected_type));
+    }
+    return it->second.offset;
+  }
+  at::optional<size_t> find_offset(
+      const std::string& name,
+      EntityType expected_type) const {
+    auto it = dict_.find(name);
+    if (it == dict_.end() || it->second.type != expected_type) {
+      return at::nullopt;
+    }
+    return it->second.offset;
+  }
+
+  template <typename T>
+  T& insert(
+      const std::string& name,
+      std::vector<T>& list,
+      EntityType type,
+      T value) {
+    auto it = dict_.find(name);
+    if (it != dict_.end()) {
+      if (type != it->second.type) {
+        AT_ERROR(
+            "attempting to add ",
+            toString(type),
+            " '",
+            name,
+            "' but it already exists as a ",
+            toString(it->second.type));
+      } else {
+        AT_ERROR(toString(type), " '", name, "' already defined.");
+      }
+    }
+    dict_[name] = Entry{type, list.size()};
+    list.emplace_back(std::move(value));
+    return list.back();
+  }
+
+  std::unordered_map<std::string, Entry> dict_;
+
   bool optimize;
 };
 
index 8b697e7..c897d73 100644 (file)
@@ -929,11 +929,10 @@ bool isHelperFunction(const std::string& method_name) {
 }
 
 void loadModule(const std::shared_ptr<script::Module>& module) {
-  for (const auto& method_ : module->get_methods()) {
-    if (isHelperFunction(method_.key()))
+  for (const auto& method : module->get_methods()) {
+    if (isHelperFunction(method->name()))
       continue;
 
-    const auto& method = method_.value();
     GradientPair pair;
     pair.forward = method->graph();