Enforce single parent for script submodules (#18379)
authorZachary DeVito <zdevito@fb.com>
Thu, 4 Apr 2019 03:21:27 +0000 (20:21 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 4 Apr 2019 03:26:58 +0000 (20:26 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18379
ghimport-source-id: 9895ecc1ff7897e98853dc00675341f36726e7c7

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18379 Enforce single parent for script submodules**
* #18378 Unify namespace of script::Module
* #18314 Add ability to specialize class types to ArgumentSpec
* #18226 Add Slot type to abstract the raw pointers being used for slots.

The assumption that a ScriptModule has a single parent is present in
our serialization format, and likely a few other places. It is not
enforced on creation of script module hierarchies though, meaning that
problems associated with (e.g. replicating a module twice in the output
format) will not be caught until much later in the development cycle.

This patch enforces the property when a submodule is registered.
It also removes NamedModule since it is no longer necessary in this regime.
This will also allow the easy discover of a modules fully-qualified name
without needing to traverse the Module hierarchy.

Differential Revision: D14603722

fbshipit-source-id: 63ab5d0cccf7d66c7833e0adf9023024ca9607cb

test/custom_operator/test_custom_ops.cpp
test/test_jit.py
torch/csrc/api/src/serialize/input-archive.cpp
torch/csrc/jit/export.cpp
torch/csrc/jit/import_source.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/module.cpp
torch/csrc/jit/script/module.h

index a0387f4..1b64171 100644 (file)
@@ -18,7 +18,7 @@ void check_all_parameters(
     AT_ASSERT(predicate(parameter.slot()->toTensor()));
   }
   for (const auto& child : module.get_modules()) {
-    check_all_parameters(*child.module, predicate);
+    check_all_parameters(module, predicate);
   }
 }
 } // namespace helpers
index 18bf956..c2b6b6d 100644 (file)
@@ -340,6 +340,7 @@ class JitTestCase(TestCase):
 
             self.assertMultiLineEqual(main_module_code, main_module_2_code)
 
+
     def getExportImportCopy(self, m, also_test_file=True, map_location=None):
         buffer = io.BytesIO()
         torch.jit.save(m, buffer)
@@ -9929,7 +9930,7 @@ a")
             def __init__(self):
                 super(OtherStrong, self).__init__()
                 self.weak = weak
-                self.weak2 = weak
+                self.weak2 = Weak()
 
             @torch.jit.script_method
             def forward(self, x):
@@ -9946,7 +9947,7 @@ a")
 
         other_strong_mod = OtherStrong()
 
-        self.assertIs(other_strong_mod.weak, other_strong_mod.weak2)
+        self.assertIsNot(other_strong_mod.weak, other_strong_mod.weak2)
 
         with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with param"):
             strong_mod = Strong()
index 5f5b308..e8243cd 100644 (file)
@@ -51,9 +51,8 @@ void InputArchive::read(
 }
 
 void InputArchive::read(const std::string& key, InputArchive& archive) {
-  if (auto* named_module = module_->find_module(key)) {
-    AT_ASSERT(named_module->module != nullptr);
-    archive.module_ = std::move(named_module->module);
+  if (auto named_module = module_->find_module(key)) {
+    archive.module_ = std::move(named_module);
   } else {
     AT_ERROR("No such serialized submodule: '", key, "'");
   }
index 84ac515..4f27838 100644 (file)
@@ -769,7 +769,7 @@ void ScriptModuleSerializer::convertModule(
 
   for (const auto& elem : module.get_modules()) {
     torch::ModuleDef* sub_def = module_def->add_submodules();
-    convertModule(*elem.module, module_name.str(), elem.name, sub_def);
+    convertModule(*elem, module_name.str(), elem->name(), sub_def);
   }
 }
 
index 6364947..ed305c7 100644 (file)
@@ -19,8 +19,8 @@ struct ModuleAccessorValue : public SugaredValue {
       const SourceRange& loc,
       Method& m,
       const std::string& field) override {
-    if (NamedModule* v = module->find_module(field)) {
-      return std::make_shared<ModuleAccessorValue>(v->module);
+    if (std::shared_ptr<Module> v = module->find_module(field)) {
+      return std::make_shared<ModuleAccessorValue>(std::move(v));
     } else if (NamedIValue* v = module->find_parameter(field)) {
       return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
     } else if (NamedIValue* v = module->find_buffer(field)) {
index 0844a96..20fce2b 100644 (file)
@@ -139,7 +139,7 @@ void createTensorToParameterNameMap(
   }
   for (const auto& elem : module.get_modules()) {
     createTensorToParameterNameMap(
-        *elem.module, QualifiedName::create(prefix, elem.name), result);
+        *elem, QualifiedName::create(prefix, elem->name()), result);
   }
 }
 
index 54dbb98..ea69b11 100644 (file)
@@ -324,8 +324,8 @@ struct ModuleValue : public SugaredValue {
       return std::make_shared<SimpleValue>(the_bool);
     }
 
-    if (NamedModule* v = module->find_module(field)) {
-      return std::make_shared<ModuleValue>(v->module);
+    if (std::shared_ptr<Module> v = module->find_module(field)) {
+      return std::make_shared<ModuleValue>(v);
     } else if (Method* v = module->find_method(field)) {
       return std::make_shared<MethodValue>(shared_from_this(), *v);
     } else if (NamedIValue* v = module->find_parameter(field)) {
@@ -614,7 +614,7 @@ static void gatherParametersAndBuffers(
     }
   }
   for (const auto& sub : m.get_modules()) {
-    gatherParametersAndBuffers(values, *sub.module);
+    gatherParametersAndBuffers(values, *sub);
   }
 }
 
@@ -771,7 +771,7 @@ void initJitScriptBindings(PyObject* module) {
             py::tuple result(modules.size());
             for (size_t i = 0; i < modules.size(); ++i) {
               auto& item = modules[i];
-              result[i] = std::make_pair(item.name, item.module);
+              result[i] = std::make_pair(item->name(), item);
             }
             return result;
           })
index 933cf40..cb9cdc7 100644 (file)
@@ -123,7 +123,7 @@ void Module::to_impl(
     bool non_blocking) {
   // First call `to()` on every child module.
   for (auto& child : get_modules()) {
-    child.module->to_impl(device, dtype, non_blocking);
+    child->to_impl(device, dtype, non_blocking);
   }
   // Then convert every of our parameters.
   for (auto& parameter : get_parameters()) {
index 304857d..4d88521 100644 (file)
@@ -359,11 +359,6 @@ struct Method {
 
 struct Module;
 
-struct NamedModule {
-  std::string name;
-  std::shared_ptr<Module> module;
-};
-
 struct NamedIValue {
   NamedIValue(std::string name, TypePtr type, IValue ivalue)
       : name_(name),
@@ -388,16 +383,20 @@ struct NamedIValue {
 
 struct Module {
   TH_DISALLOW_COPY_AND_ASSIGN(Module);
-  Module() : optimize(true) {}
+  Module() : name_("__main__"), optimize_(true) {}
+
+  const std::string& name() const {
+    return name_;
+  }
 
   // note this doesn't change the flags of existing methods just ones
   // added afterward.
   void set_optimized(bool o) {
-    optimize = o;
+    optimize_ = o;
   }
 
   bool is_optimized() const {
-    return optimize;
+    return optimize_;
   }
 
   IValue forward(std::vector<IValue> inputs) {
@@ -447,7 +446,15 @@ struct Module {
   void register_module(
       const std::string& name,
       std::shared_ptr<Module> module) {
-    insert(name, modules_, EntityType::MODULE, {name, std::move(module)});
+    if (module->parent_) {
+      AT_ERROR(
+          "Attempting to assign submodule '",
+          name,
+          "' but it is already a submodule of another ScriptModule '", module->parent_->name(), "'");
+    }
+    module->parent_ = this;
+    module->name_ = name;
+    insert(name, modules_, EntityType::MODULE, std::move(module));
   }
 
   Method& create_method(
@@ -458,7 +465,7 @@ struct Module {
     std::unique_ptr<Method> method(new Method(
         this,
         name,
-        optimize,
+        optimize_,
         std::move(graph),
         std::move(member_inputs),
         nullptr));
@@ -471,7 +478,7 @@ struct Module {
     std::unique_ptr<Method> method(new Method(
         this,
         name,
-        optimize,
+        optimize_,
         std::make_shared<Graph>(),
         {},
         std::move(creator)));
@@ -505,10 +512,10 @@ struct Module {
   }
 
   std::shared_ptr<Module> get_module(const std::string& name) const {
-    return modules_[get_offset(name, EntityType::MODULE)].module;
+    return modules_[get_offset(name, EntityType::MODULE)];
   }
 
-  c10::ArrayRef<NamedModule> get_modules() const {
+  c10::ArrayRef<std::shared_ptr<Module>> get_modules() const {
     return modules_;
   }
   c10::ArrayRef<NamedIValue> get_parameters() const {
@@ -536,9 +543,9 @@ struct Module {
     }
     return nullptr;
   }
-  NamedModule* find_module(const std::string& name) {
+  std::shared_ptr<Module> find_module(const std::string& name) {
     auto offset = find_offset(name, EntityType::MODULE);
-    return offset ? &modules_[*offset] : nullptr;
+    return offset ? modules_[*offset] : nullptr;
   }
   Method* find_method(const std::string& name) {
     auto offset = find_offset(name, EntityType::METHOD);
@@ -546,14 +553,14 @@ struct Module {
   }
   void apply(std::function<void(Module&)> fn) {
     for (auto& submod : get_modules()) {
-      submod.module->apply(fn);
+      submod->apply(fn);
     }
     fn(*this);
   }
   /// Enables "training" mode.
   void train(bool on = true) {
     for (auto& submod : get_modules()) {
-      submod.module->train(on);
+      submod->train(on);
     }
     register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
   }
@@ -646,10 +653,10 @@ struct Module {
       parameter_remap[attr.slot()] = curr->find_buffer(attr.name())->slot();
     }
     for (auto& mod : get_modules()) {
-      names.push_back(mod.name);
+      names.push_back(mod->name());
       // Submodules must be translated first, otherwise parameter_remap entries
       // will not be filled in for methods of this module.
-      mod.module->copy_into(module_lookup, parameter_remap, names);
+      mod->copy_into(module_lookup, parameter_remap, names);
       names.pop_back();
     }
     for (auto& method : get_methods()) {
@@ -677,20 +684,6 @@ struct Module {
       const c10::optional<at::ScalarType>& dtype,
       bool non_blocking);
 
-  // modules have a single namespace, but spread over 4 different concepts:
-  // parameters, attributes, methods, and sub-modules
-  // we store individual lists of each concept, and a single map to
-  // unify the namespace and ensure fast lookup
-
-  // invariant: to ensure initial_ivalues of Methods stay valid,
-  // it is only legal to _add_ new modules and parameters.
-  // removing them will allow initial_ivalues to point to invalid parameters
-  // no such restriction exists for methods
-  std::vector<NamedModule> modules_;
-  std::vector<NamedIValue> parameters_;
-  std::vector<NamedIValue> attributes_;
-  std::vector<std::unique_ptr<Method>> methods_;
-
   static const char* toString(EntityType t) {
     switch (t) {
       case EntityType::MODULE:
@@ -762,9 +755,26 @@ struct Module {
     return list.back();
   }
 
+  // modules have a single namespace, but spread over 4 different concepts:
+  // parameters, attributes, methods, and sub-modules
+  // we store individual lists of each concept, and a single map to
+  // unify the namespace and ensure fast lookup
+
+  // invariant: to ensure initial_ivalues of Methods stay valid,
+  // it is only legal to _add_ new modules and parameters.
+  // removing them will allow initial_ivalues to point to invalid parameters
+  // no such restriction exists for methods
+  std::vector<std::shared_ptr<Module>> modules_;
+  std::vector<NamedIValue> parameters_;
+  std::vector<NamedIValue> attributes_;
+  std::vector<std::unique_ptr<Method>> methods_;
+
   std::unordered_map<std::string, Entry> dict_;
+  std::string name_;
 
-  bool optimize;
+  // back reference to parent of this Module if present
+  Module* parent_ = nullptr;
+  bool optimize_;
 };
 
 // returns nullptr and fills in failure_messages if the callee does not