From: Zachary DeVito Date: Thu, 4 Apr 2019 17:22:27 +0000 (-0700) Subject: Revert D14603722: Enforce single parent for script submodules X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~418 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f97eb8d9e455911bbd2010527bed29d2ab47e0a8;p=platform%2Fupstream%2Fpytorch.git Revert D14603722: Enforce single parent for script submodules Differential Revision: D14603722 Original commit changeset: 63ab5d0cccf7 fbshipit-source-id: 2c4174def102eda4589e08c4dbd67ce8af975199 --- diff --git a/test/custom_operator/test_custom_ops.cpp b/test/custom_operator/test_custom_ops.cpp index 1b64171..a0387f4 100644 --- a/test/custom_operator/test_custom_ops.cpp +++ b/test/custom_operator/test_custom_ops.cpp @@ -18,7 +18,7 @@ void check_all_parameters( AT_ASSERT(predicate(parameter.slot()->toTensor())); } for (const auto& child : module.get_modules()) { - check_all_parameters(module, predicate); + check_all_parameters(*child.module, predicate); } } } // namespace helpers diff --git a/test/test_jit.py b/test/test_jit.py index 9348447..931f6e8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9956,7 +9956,7 @@ a") def __init__(self): super(OtherStrong, self).__init__() self.weak = weak - self.weak2 = Weak() + self.weak2 = weak @torch.jit.script_method def forward(self, x): @@ -9973,7 +9973,7 @@ a") other_strong_mod = OtherStrong() - self.assertIsNot(other_strong_mod.weak, other_strong_mod.weak2) + self.assertIs(other_strong_mod.weak, other_strong_mod.weak2) with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with param"): strong_mod = Strong() diff --git a/torch/csrc/api/src/serialize/input-archive.cpp b/torch/csrc/api/src/serialize/input-archive.cpp index e8243cd..5f5b308 100644 --- a/torch/csrc/api/src/serialize/input-archive.cpp +++ b/torch/csrc/api/src/serialize/input-archive.cpp @@ -51,8 +51,9 @@ void InputArchive::read( } void InputArchive::read(const std::string& key, InputArchive& archive) { - if (auto named_module = module_->find_module(key)) { - archive.module_ = std::move(named_module); + if (auto* named_module = module_->find_module(key)) { + AT_ASSERT(named_module->module != nullptr); + archive.module_ = std::move(named_module->module); } else { AT_ERROR("No such serialized submodule: '", key, "'"); } diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index 4f27838..84ac515 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -769,7 +769,7 @@ void ScriptModuleSerializer::convertModule( for (const auto& elem : module.get_modules()) { torch::ModuleDef* sub_def = module_def->add_submodules(); - convertModule(*elem, module_name.str(), elem->name(), sub_def); + convertModule(*elem.module, module_name.str(), elem.name, sub_def); } } diff --git a/torch/csrc/jit/import_source.cpp b/torch/csrc/jit/import_source.cpp index ed305c7..6364947 100644 --- a/torch/csrc/jit/import_source.cpp +++ b/torch/csrc/jit/import_source.cpp @@ -19,8 +19,8 @@ struct ModuleAccessorValue : public SugaredValue { const SourceRange& loc, Method& m, const std::string& field) override { - if (std::shared_ptr v = module->find_module(field)) { - return std::make_shared(std::move(v)); + if (NamedModule* v = module->find_module(field)) { + return std::make_shared(v->module); } else if (NamedIValue* v = module->find_parameter(field)) { return std::make_shared(m.get_or_add_parameter(v->slot())); } else if (NamedIValue* v = module->find_buffer(field)) { diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index 20fce2b..0844a96 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -139,7 +139,7 @@ void createTensorToParameterNameMap( } for (const auto& elem : module.get_modules()) { createTensorToParameterNameMap( - *elem, QualifiedName::create(prefix, elem->name()), result); + *elem.module, QualifiedName::create(prefix, elem.name), result); } } diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index ea69b11..54dbb98 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -324,8 +324,8 @@ struct ModuleValue : public SugaredValue { return std::make_shared(the_bool); } - if (std::shared_ptr v = module->find_module(field)) { - return std::make_shared(v); + if (NamedModule* v = module->find_module(field)) { + return std::make_shared(v->module); } else if (Method* v = module->find_method(field)) { return std::make_shared(shared_from_this(), *v); } else if (NamedIValue* v = module->find_parameter(field)) { @@ -614,7 +614,7 @@ static void gatherParametersAndBuffers( } } for (const auto& sub : m.get_modules()) { - gatherParametersAndBuffers(values, *sub); + gatherParametersAndBuffers(values, *sub.module); } } @@ -771,7 +771,7 @@ void initJitScriptBindings(PyObject* module) { py::tuple result(modules.size()); for (size_t i = 0; i < modules.size(); ++i) { auto& item = modules[i]; - result[i] = std::make_pair(item->name(), item); + result[i] = std::make_pair(item.name, item.module); } return result; }) diff --git a/torch/csrc/jit/script/module.cpp b/torch/csrc/jit/script/module.cpp index cb9cdc7..933cf40 100644 --- a/torch/csrc/jit/script/module.cpp +++ b/torch/csrc/jit/script/module.cpp @@ -123,7 +123,7 @@ void Module::to_impl( bool non_blocking) { // First call `to()` on every child module. for (auto& child : get_modules()) { - child->to_impl(device, dtype, non_blocking); + child.module->to_impl(device, dtype, non_blocking); } // Then convert every of our parameters. for (auto& parameter : get_parameters()) { diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 4d88521..304857d 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -359,6 +359,11 @@ struct Method { struct Module; +struct NamedModule { + std::string name; + std::shared_ptr module; +}; + struct NamedIValue { NamedIValue(std::string name, TypePtr type, IValue ivalue) : name_(name), @@ -383,20 +388,16 @@ struct NamedIValue { struct Module { TH_DISALLOW_COPY_AND_ASSIGN(Module); - Module() : name_("__main__"), optimize_(true) {} - - const std::string& name() const { - return name_; - } + Module() : optimize(true) {} // note this doesn't change the flags of existing methods just ones // added afterward. void set_optimized(bool o) { - optimize_ = o; + optimize = o; } bool is_optimized() const { - return optimize_; + return optimize; } IValue forward(std::vector inputs) { @@ -446,15 +447,7 @@ struct Module { void register_module( const std::string& name, std::shared_ptr module) { - if (module->parent_) { - AT_ERROR( - "Attempting to assign submodule '", - name, - "' but it is already a submodule of another ScriptModule '", module->parent_->name(), "'"); - } - module->parent_ = this; - module->name_ = name; - insert(name, modules_, EntityType::MODULE, std::move(module)); + insert(name, modules_, EntityType::MODULE, {name, std::move(module)}); } Method& create_method( @@ -465,7 +458,7 @@ struct Module { std::unique_ptr method(new Method( this, name, - optimize_, + optimize, std::move(graph), std::move(member_inputs), nullptr)); @@ -478,7 +471,7 @@ struct Module { std::unique_ptr method(new Method( this, name, - optimize_, + optimize, std::make_shared(), {}, std::move(creator))); @@ -512,10 +505,10 @@ struct Module { } std::shared_ptr get_module(const std::string& name) const { - return modules_[get_offset(name, EntityType::MODULE)]; + return modules_[get_offset(name, EntityType::MODULE)].module; } - c10::ArrayRef> get_modules() const { + c10::ArrayRef get_modules() const { return modules_; } c10::ArrayRef get_parameters() const { @@ -543,9 +536,9 @@ struct Module { } return nullptr; } - std::shared_ptr find_module(const std::string& name) { + NamedModule* find_module(const std::string& name) { auto offset = find_offset(name, EntityType::MODULE); - return offset ? modules_[*offset] : nullptr; + return offset ? &modules_[*offset] : nullptr; } Method* find_method(const std::string& name) { auto offset = find_offset(name, EntityType::METHOD); @@ -553,14 +546,14 @@ struct Module { } void apply(std::function fn) { for (auto& submod : get_modules()) { - submod->apply(fn); + submod.module->apply(fn); } fn(*this); } /// Enables "training" mode. void train(bool on = true) { for (auto& submod : get_modules()) { - submod->train(on); + submod.module->train(on); } register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong)); } @@ -653,10 +646,10 @@ struct Module { parameter_remap[attr.slot()] = curr->find_buffer(attr.name())->slot(); } for (auto& mod : get_modules()) { - names.push_back(mod->name()); + names.push_back(mod.name); // Submodules must be translated first, otherwise parameter_remap entries // will not be filled in for methods of this module. - mod->copy_into(module_lookup, parameter_remap, names); + mod.module->copy_into(module_lookup, parameter_remap, names); names.pop_back(); } for (auto& method : get_methods()) { @@ -684,6 +677,20 @@ struct Module { const c10::optional& dtype, bool non_blocking); + // modules have a single namespace, but spread over 4 different concepts: + // parameters, attributes, methods, and sub-modules + // we store individual lists of each concept, and a single map to + // unify the namespace and ensure fast lookup + + // invariant: to ensure initial_ivalues of Methods stay valid, + // it is only legal to _add_ new modules and parameters. + // removing them will allow initial_ivalues to point to invalid parameters + // no such restriction exists for methods + std::vector modules_; + std::vector parameters_; + std::vector attributes_; + std::vector> methods_; + static const char* toString(EntityType t) { switch (t) { case EntityType::MODULE: @@ -755,26 +762,9 @@ struct Module { return list.back(); } - // modules have a single namespace, but spread over 4 different concepts: - // parameters, attributes, methods, and sub-modules - // we store individual lists of each concept, and a single map to - // unify the namespace and ensure fast lookup - - // invariant: to ensure initial_ivalues of Methods stay valid, - // it is only legal to _add_ new modules and parameters. - // removing them will allow initial_ivalues to point to invalid parameters - // no such restriction exists for methods - std::vector> modules_; - std::vector parameters_; - std::vector attributes_; - std::vector> methods_; - std::unordered_map dict_; - std::string name_; - // back reference to parent of this Module if present - Module* parent_ = nullptr; - bool optimize_; + bool optimize; }; // returns nullptr and fills in failure_messages if the callee does not