Revert D14603722: Enforce single parent for script submodules
authorZachary DeVito <zdevito@fb.com>
Thu, 4 Apr 2019 17:22:27 +0000 (10:22 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 4 Apr 2019 17:32:36 +0000 (10:32 -0700)
Differential Revision:
D14603722

Original commit changeset: 63ab5d0cccf7

fbshipit-source-id: 2c4174def102eda4589e08c4dbd67ce8af975199

test/custom_operator/test_custom_ops.cpp
test/test_jit.py
torch/csrc/api/src/serialize/input-archive.cpp
torch/csrc/jit/export.cpp
torch/csrc/jit/import_source.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/module.cpp
torch/csrc/jit/script/module.h

index 1b64171..a0387f4 100644 (file)
@@ -18,7 +18,7 @@ void check_all_parameters(
     AT_ASSERT(predicate(parameter.slot()->toTensor()));
   }
   for (const auto& child : module.get_modules()) {
-    check_all_parameters(module, predicate);
+    check_all_parameters(*child.module, predicate);
   }
 }
 } // namespace helpers
index 9348447..931f6e8 100644 (file)
@@ -9956,7 +9956,7 @@ a")
             def __init__(self):
                 super(OtherStrong, self).__init__()
                 self.weak = weak
-                self.weak2 = Weak()
+                self.weak2 = weak
 
             @torch.jit.script_method
             def forward(self, x):
@@ -9973,7 +9973,7 @@ a")
 
         other_strong_mod = OtherStrong()
 
-        self.assertIsNot(other_strong_mod.weak, other_strong_mod.weak2)
+        self.assertIs(other_strong_mod.weak, other_strong_mod.weak2)
 
         with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with param"):
             strong_mod = Strong()
index e8243cd..5f5b308 100644 (file)
@@ -51,8 +51,9 @@ void InputArchive::read(
 }
 
 void InputArchive::read(const std::string& key, InputArchive& archive) {
-  if (auto named_module = module_->find_module(key)) {
-    archive.module_ = std::move(named_module);
+  if (auto* named_module = module_->find_module(key)) {
+    AT_ASSERT(named_module->module != nullptr);
+    archive.module_ = std::move(named_module->module);
   } else {
     AT_ERROR("No such serialized submodule: '", key, "'");
   }
index 4f27838..84ac515 100644 (file)
@@ -769,7 +769,7 @@ void ScriptModuleSerializer::convertModule(
 
   for (const auto& elem : module.get_modules()) {
     torch::ModuleDef* sub_def = module_def->add_submodules();
-    convertModule(*elem, module_name.str(), elem->name(), sub_def);
+    convertModule(*elem.module, module_name.str(), elem.name, sub_def);
   }
 }
 
index ed305c7..6364947 100644 (file)
@@ -19,8 +19,8 @@ struct ModuleAccessorValue : public SugaredValue {
       const SourceRange& loc,
       Method& m,
       const std::string& field) override {
-    if (std::shared_ptr<Module> v = module->find_module(field)) {
-      return std::make_shared<ModuleAccessorValue>(std::move(v));
+    if (NamedModule* v = module->find_module(field)) {
+      return std::make_shared<ModuleAccessorValue>(v->module);
     } else if (NamedIValue* v = module->find_parameter(field)) {
       return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
     } else if (NamedIValue* v = module->find_buffer(field)) {
index 20fce2b..0844a96 100644 (file)
@@ -139,7 +139,7 @@ void createTensorToParameterNameMap(
   }
   for (const auto& elem : module.get_modules()) {
     createTensorToParameterNameMap(
-        *elem, QualifiedName::create(prefix, elem->name()), result);
+        *elem.module, QualifiedName::create(prefix, elem.name), result);
   }
 }
 
index ea69b11..54dbb98 100644 (file)
@@ -324,8 +324,8 @@ struct ModuleValue : public SugaredValue {
       return std::make_shared<SimpleValue>(the_bool);
     }
 
-    if (std::shared_ptr<Module> v = module->find_module(field)) {
-      return std::make_shared<ModuleValue>(v);
+    if (NamedModule* v = module->find_module(field)) {
+      return std::make_shared<ModuleValue>(v->module);
     } else if (Method* v = module->find_method(field)) {
       return std::make_shared<MethodValue>(shared_from_this(), *v);
     } else if (NamedIValue* v = module->find_parameter(field)) {
@@ -614,7 +614,7 @@ static void gatherParametersAndBuffers(
     }
   }
   for (const auto& sub : m.get_modules()) {
-    gatherParametersAndBuffers(values, *sub);
+    gatherParametersAndBuffers(values, *sub.module);
   }
 }
 
@@ -771,7 +771,7 @@ void initJitScriptBindings(PyObject* module) {
             py::tuple result(modules.size());
             for (size_t i = 0; i < modules.size(); ++i) {
               auto& item = modules[i];
-              result[i] = std::make_pair(item->name(), item);
+              result[i] = std::make_pair(item.name, item.module);
             }
             return result;
           })
index cb9cdc7..933cf40 100644 (file)
@@ -123,7 +123,7 @@ void Module::to_impl(
     bool non_blocking) {
   // First call `to()` on every child module.
   for (auto& child : get_modules()) {
-    child->to_impl(device, dtype, non_blocking);
+    child.module->to_impl(device, dtype, non_blocking);
   }
   // Then convert every of our parameters.
   for (auto& parameter : get_parameters()) {
index 4d88521..304857d 100644 (file)
@@ -359,6 +359,11 @@ struct Method {
 
 struct Module;
 
+struct NamedModule {
+  std::string name;
+  std::shared_ptr<Module> module;
+};
+
 struct NamedIValue {
   NamedIValue(std::string name, TypePtr type, IValue ivalue)
       : name_(name),
@@ -383,20 +388,16 @@ struct NamedIValue {
 
 struct Module {
   TH_DISALLOW_COPY_AND_ASSIGN(Module);
-  Module() : name_("__main__"), optimize_(true) {}
-
-  const std::string& name() const {
-    return name_;
-  }
+  Module() : optimize(true) {}
 
   // note this doesn't change the flags of existing methods just ones
   // added afterward.
   void set_optimized(bool o) {
-    optimize_ = o;
+    optimize = o;
   }
 
   bool is_optimized() const {
-    return optimize_;
+    return optimize;
   }
 
   IValue forward(std::vector<IValue> inputs) {
@@ -446,15 +447,7 @@ struct Module {
   void register_module(
       const std::string& name,
       std::shared_ptr<Module> module) {
-    if (module->parent_) {
-      AT_ERROR(
-          "Attempting to assign submodule '",
-          name,
-          "' but it is already a submodule of another ScriptModule '", module->parent_->name(), "'");
-    }
-    module->parent_ = this;
-    module->name_ = name;
-    insert(name, modules_, EntityType::MODULE, std::move(module));
+    insert(name, modules_, EntityType::MODULE, {name, std::move(module)});
   }
 
   Method& create_method(
@@ -465,7 +458,7 @@ struct Module {
     std::unique_ptr<Method> method(new Method(
         this,
         name,
-        optimize_,
+        optimize,
         std::move(graph),
         std::move(member_inputs),
         nullptr));
@@ -478,7 +471,7 @@ struct Module {
     std::unique_ptr<Method> method(new Method(
         this,
         name,
-        optimize_,
+        optimize,
         std::make_shared<Graph>(),
         {},
         std::move(creator)));
@@ -512,10 +505,10 @@ struct Module {
   }
 
   std::shared_ptr<Module> get_module(const std::string& name) const {
-    return modules_[get_offset(name, EntityType::MODULE)];
+    return modules_[get_offset(name, EntityType::MODULE)].module;
   }
 
-  c10::ArrayRef<std::shared_ptr<Module>> get_modules() const {
+  c10::ArrayRef<NamedModule> get_modules() const {
     return modules_;
   }
   c10::ArrayRef<NamedIValue> get_parameters() const {
@@ -543,9 +536,9 @@ struct Module {
     }
     return nullptr;
   }
-  std::shared_ptr<Module> find_module(const std::string& name) {
+  NamedModule* find_module(const std::string& name) {
     auto offset = find_offset(name, EntityType::MODULE);
-    return offset ? modules_[*offset] : nullptr;
+    return offset ? &modules_[*offset] : nullptr;
   }
   Method* find_method(const std::string& name) {
     auto offset = find_offset(name, EntityType::METHOD);
@@ -553,14 +546,14 @@ struct Module {
   }
   void apply(std::function<void(Module&)> fn) {
     for (auto& submod : get_modules()) {
-      submod->apply(fn);
+      submod.module->apply(fn);
     }
     fn(*this);
   }
   /// Enables "training" mode.
   void train(bool on = true) {
     for (auto& submod : get_modules()) {
-      submod->train(on);
+      submod.module->train(on);
     }
     register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
   }
@@ -653,10 +646,10 @@ struct Module {
       parameter_remap[attr.slot()] = curr->find_buffer(attr.name())->slot();
     }
     for (auto& mod : get_modules()) {
-      names.push_back(mod->name());
+      names.push_back(mod.name);
       // Submodules must be translated first, otherwise parameter_remap entries
       // will not be filled in for methods of this module.
-      mod->copy_into(module_lookup, parameter_remap, names);
+      mod.module->copy_into(module_lookup, parameter_remap, names);
       names.pop_back();
     }
     for (auto& method : get_methods()) {
@@ -684,6 +677,20 @@ struct Module {
       const c10::optional<at::ScalarType>& dtype,
       bool non_blocking);
 
+  // modules have a single namespace, but spread over 4 different concepts:
+  // parameters, attributes, methods, and sub-modules
+  // we store individual lists of each concept, and a single map to
+  // unify the namespace and ensure fast lookup
+
+  // invariant: to ensure initial_ivalues of Methods stay valid,
+  // it is only legal to _add_ new modules and parameters.
+  // removing them will allow initial_ivalues to point to invalid parameters
+  // no such restriction exists for methods
+  std::vector<NamedModule> modules_;
+  std::vector<NamedIValue> parameters_;
+  std::vector<NamedIValue> attributes_;
+  std::vector<std::unique_ptr<Method>> methods_;
+
   static const char* toString(EntityType t) {
     switch (t) {
       case EntityType::MODULE:
@@ -755,26 +762,9 @@ struct Module {
     return list.back();
   }
 
-  // modules have a single namespace, but spread over 4 different concepts:
-  // parameters, attributes, methods, and sub-modules
-  // we store individual lists of each concept, and a single map to
-  // unify the namespace and ensure fast lookup
-
-  // invariant: to ensure initial_ivalues of Methods stay valid,
-  // it is only legal to _add_ new modules and parameters.
-  // removing them will allow initial_ivalues to point to invalid parameters
-  // no such restriction exists for methods
-  std::vector<std::shared_ptr<Module>> modules_;
-  std::vector<NamedIValue> parameters_;
-  std::vector<NamedIValue> attributes_;
-  std::vector<std::unique_ptr<Method>> methods_;
-
   std::unordered_map<std::string, Entry> dict_;
-  std::string name_;
 
-  // back reference to parent of this Module if present
-  Module* parent_ = nullptr;
-  bool optimize_;
+  bool optimize;
 };
 
 // returns nullptr and fills in failure_messages if the callee does not