From: Will Feng Date: Fri, 29 Mar 2019 19:59:29 +0000 (-0700) Subject: Add named submodule support to nn::Sequential (#17552) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~553 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6ebfbdf4c66586d0c3237bc29b1267a9fe0a13cd;p=platform%2Fupstream%2Fpytorch.git Add named submodule support to nn::Sequential (#17552) Summary: Previously, we were not able to assign names to `nn::Sequential`'s submodules. This PR adds this feature to match the Python API. Example use: ```cpp Sequential sequential(named_submodule({ {"linear", Linear(10, 3)}, {"conv2d", Conv2d(1, 2, 3)}, {"dropout", Dropout(0.5)}, {"batchnorm", BatchNorm(5)}, {"embedding", Embedding(4, 10)}, {"lstm", LSTM(4, 5)} })); ``` It also enables loading parameters of Python `nn.Sequential` module with custom submodules names into C++ frontend, unblocking https://github.com/pytorch/vision/pull/728#issuecomment-466661344. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17552 Differential Revision: D14246834 Pulled By: yf225 fbshipit-source-id: 3030b5c5d68f6dd5d3e37ac4b4f98dc6d6d9ba72 --- diff --git a/docs/cpp/source/check-doxygen.sh b/docs/cpp/source/check-doxygen.sh index 18863e2..9959a3f 100755 --- a/docs/cpp/source/check-doxygen.sh +++ b/docs/cpp/source/check-doxygen.sh @@ -39,6 +39,7 @@ cat original-doxygen-log.txt # Filter out some warnings. ignore_warning "warning: no uniquely matching class member found for" +ignore_warning "warning: explicit link request to 'Item' could not be resolved" # Count the number of remaining warnings. warnings="$(grep 'warning:' doxygen-log.txt | wc -l)" diff --git a/test/cpp/api/sequential.cpp b/test/cpp/api/sequential.cpp index 1f7cf55..747bbf1 100644 --- a/test/cpp/api/sequential.cpp +++ b/test/cpp/api/sequential.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -32,20 +33,47 @@ TEST_F(SequentialTest, ConstructsFromSharedPointer) { Sequential sequential( std::make_shared(1), std::make_shared(2), std::make_shared(3)); ASSERT_EQ(sequential->size(), 3); + + Sequential sequential_named(modules_ordered_dict({ + {"m1", std::make_shared(1)}, + {std::string("m2"), std::make_shared(2)}, + {"m3", std::make_shared(3)} + })); + ASSERT_EQ(sequential->size(), 3); } TEST_F(SequentialTest, ConstructsFromConcreteType) { + static int copy_count; + struct M : torch::nn::Module { explicit M(int value_) : value(value_) {} + M(const M& other) : torch::nn::Module(other) { + copy_count++; + } int value; int forward() { return value; } }; + copy_count = 0; Sequential sequential(M(1), M(2), M(3)); ASSERT_EQ(sequential->size(), 3); + // NOTE: The current implementation expects each module to be copied exactly once, + // which happens when the module is passed into `std::make_shared()`. + // TODO: Find a way to avoid copying, and then delete the copy constructor of `M`. + ASSERT_EQ(copy_count, 3); + + copy_count = 0; + Sequential sequential_named(modules_ordered_dict({ + {"m1", M(1)}, + {std::string("m2"), M(2)}, + {"m3", M(3)} + })); + ASSERT_EQ(sequential->size(), 3); + ASSERT_EQ(copy_count, 3); } + TEST_F(SequentialTest, ConstructsFromModuleHolder) { struct MImpl : torch::nn::Module { explicit MImpl(int value_) : value(value_) {} @@ -62,6 +90,13 @@ TEST_F(SequentialTest, ConstructsFromModuleHolder) { Sequential sequential(M(1), M(2), M(3)); ASSERT_EQ(sequential->size(), 3); + + Sequential sequential_named(modules_ordered_dict({ + {"m1", M(1)}, + {std::string("m2"), M(2)}, + {"m3", M(3)} + })); + ASSERT_EQ(sequential->size(), 3); } TEST_F(SequentialTest, PushBackAddsAnElement) { @@ -72,6 +107,8 @@ TEST_F(SequentialTest, PushBackAddsAnElement) { } int value; }; + + // Test unnamed submodules Sequential sequential; ASSERT_EQ(sequential->size(), 0); ASSERT_TRUE(sequential->is_empty()); @@ -81,6 +118,32 @@ TEST_F(SequentialTest, PushBackAddsAnElement) { ASSERT_EQ(sequential->size(), 2); sequential->push_back(M(2)); ASSERT_EQ(sequential->size(), 3); + + // Mix named and unnamed submodules + Sequential sequential_named; + ASSERT_EQ(sequential_named->size(), 0); + ASSERT_TRUE(sequential_named->is_empty()); + + sequential_named->push_back(Linear(3, 4)); + ASSERT_EQ(sequential_named->size(), 1); + ASSERT_EQ(sequential_named->named_children()[0].key(), "0"); + sequential_named->push_back(std::string("linear2"), Linear(3, 4)); + ASSERT_EQ(sequential_named->size(), 2); + ASSERT_EQ(sequential_named->named_children()[1].key(), "linear2"); + + sequential_named->push_back("shared_m1", std::make_shared(1)); + ASSERT_EQ(sequential_named->size(), 3); + ASSERT_EQ(sequential_named->named_children()[2].key(), "shared_m1"); + sequential_named->push_back(std::make_shared(1)); + ASSERT_EQ(sequential_named->size(), 4); + ASSERT_EQ(sequential_named->named_children()[3].key(), "3"); + + sequential_named->push_back(M(1)); + ASSERT_EQ(sequential_named->size(), 5); + ASSERT_EQ(sequential_named->named_children()[4].key(), "4"); + sequential_named->push_back(std::string("m2"), M(1)); + ASSERT_EQ(sequential_named->size(), 6); + ASSERT_EQ(sequential_named->named_children()[5].key(), "m2"); } TEST_F(SequentialTest, AccessWithAt) { @@ -342,4 +405,23 @@ TEST_F(SequentialTest, PrettyPrintSequential) { " (4): torch::nn::Embedding(count=4, dimension=10)\n" " (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n" ")"); + + Sequential sequential_named(modules_ordered_dict({ + {"linear", Linear(10, 3)}, + {"conv2d", Conv2d(1, 2, 3)}, + {"dropout", Dropout(0.5)}, + {"batchnorm", BatchNorm(5)}, + {"embedding", Embedding(4, 10)}, + {"lstm", LSTM(4, 5)} + })); + ASSERT_EQ( + c10::str(sequential_named), + "torch::nn::Sequential(\n" + " (linear): torch::nn::Linear(in=10, out=3, with_bias=true)\n" + " (conv2d): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n" + " (dropout): torch::nn::Dropout(rate=0.5)\n" + " (batchnorm): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n" + " (embedding): torch::nn::Embedding(count=4, dimension=10)\n" + " (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n" + ")"); } diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 4e6fe9a..deff903 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -243,6 +243,7 @@ if (NOT NO_API) ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/embedding.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/functional.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/linear.cpp + ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/named_any.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/rnn.cpp ${TORCH_SRC_DIR}/csrc/api/src/optim/adagrad.cpp ${TORCH_SRC_DIR}/csrc/api/src/optim/adam.cpp diff --git a/torch/csrc/api/include/torch/nn/modules/any.h b/torch/csrc/api/include/torch/nn/modules/any.h index 0715fae..08deb0b 100644 --- a/torch/csrc/api/include/torch/nn/modules/any.h +++ b/torch/csrc/api/include/torch/nn/modules/any.h @@ -419,7 +419,20 @@ template AnyModule::AnyModule(std::shared_ptr module) : content_(make_holder( std::move(module), - &std::remove_reference::type::forward)) {} + &std::remove_reference::type::forward)) { + // `AnyModule` can only store an `nn::Module` subclass object that provides + // a `forward()` method that has a non-templatized return type. + // (e.g. `AnyModule` cannot store `nn::Sequential`, because `nn::Sequential`'s + // `forward()` method has a templatized return type.) + static_assert( + torch::detail::is_module::value, + "Can only store object derived from nn::Module into AnyModule"); + static_assert( + torch::detail::has_forward::value, + "Can only store module with a forward() method that has a non-templatized" + "return type into AnyModule (e.g. we cannot store nn::Sequential" + "into AnyModule, because its forward() method's return type is templatized)"); +} template AnyModule::AnyModule(ModuleType&& module) diff --git a/torch/csrc/api/include/torch/nn/modules/named_any.h b/torch/csrc/api/include/torch/nn/modules/named_any.h new file mode 100644 index 0000000..d4fa54b --- /dev/null +++ b/torch/csrc/api/include/torch/nn/modules/named_any.h @@ -0,0 +1,130 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace nn { + +/// Stores a type erased `Module` with name. +/// +/// The `NamedAnyModule` class and the `modules_ordered_dict(...)` function enables +/// the following API for constructing `nn::Sequential` with named submodules: +/// \rst +/// .. code-block:: cpp +/// +/// struct M : torch::nn::Module { +/// explicit M(int value_) : value(value_) {} +/// int value; +/// int forward() { +/// return value; +/// } +/// }; +/// +/// Sequential sequential(modules_ordered_dict({ +/// {"m1", std::make_shared(1)}, // shared pointer to `Module` is supported +/// {std::string("m2"), M(2)}, // `Module` is supported +/// {"linear1", Linear(10, 3)} // `ModuleHolder` is supported +/// })); +/// \endrst +/// +/// Specifically, we design the signature of `modules_ordered_dict(...)` to be +/// `modules_ordered_dict(std::initializer_list named_modules)`, as +/// a result of evaluating the following possible approaches: +/// +/// Approach 1: +/// `modules_ordered_dict(std::initializer_list< +/// torch::OrderedDict::Item> named_modules)` +/// +/// Why it doens't work: +/// When we pass in a braced-init list such as +/// `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`, at the template argument +/// deduction step the compiler is not able to deduce the type of `ModuleType` to +/// the type of `M(1)` or `M(2)`, since the compiler doesn't actually look into the +/// braced-init list `{"m1", M(1)}` and figure out what the types of its elements are. +/// +/// Approach 2: +/// `modules_ordered_dict(std::initializer_list< +/// std::pair named_modules)` +/// +/// Why it doens't work: +/// When we pass in a braced-init list such as +/// `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`, the compiler is not able to +/// match `std::initializer_list>` to the nested +/// braced-init list `{{"m1", M(1)}, {"m2", M(2)}}`, and results in a "could not +/// convert" error. +/// +/// Approach 3: +/// `modules_ordered_dict(std::initializer_list named_modules)` +/// +/// Why it works: +/// When we pass in a braced-init list such as +/// `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`, the compiler is passing the +/// braced-init lists {"m1", M(1)} and {"m2", M(2)} to the `NamedAnyModule` +/// constructors, and the constructors are able to figure out the types of the +/// braced-init lists' elements and match to the correct module type. + +class NamedAnyModule { + public: + /// Creates a `NamedAnyModule` from a (boxed) `Module`. + template + NamedAnyModule(std::string name, std::shared_ptr module_ptr) + : NamedAnyModule(std::move(name), AnyModule(std::move(module_ptr))) {} + + /// Creates a `NamedAnyModule` from a `Module`, moving or copying it + /// into a `shared_ptr` internally. + // NOTE: We need to use `std::remove_reference::type` to get rid of + // any reference components for make_unique. + template > + NamedAnyModule(std::string name, M&& module) + : NamedAnyModule( + std::move(name), + std::make_shared::type>( + std::forward(module))) {} + + /// Creates a `NamedAnyModule` from a `Module` that is unwrapped from + /// a `ModuleHolder`. + template + NamedAnyModule(std::string name, const ModuleHolder& module_holder) + : NamedAnyModule(std::move(name), module_holder.ptr()) {} + + /// Returns a reference to the name. + const std::string& name() const noexcept { + return name_; + } + + /// Returns a reference to the module. + AnyModule& module() noexcept { + return module_; + } + + private: + /// Creates a `NamedAnyModule` from a type-erased `AnyModule`. + NamedAnyModule(std::string name, AnyModule any_module) + : name_(std::move(name)), module_(std::move(any_module)) {} + + std::string name_; + AnyModule module_; +}; + +TORCH_API torch::OrderedDict modules_ordered_dict( + std::initializer_list named_modules); + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/modules/sequential.h b/torch/csrc/api/include/torch/nn/modules/sequential.h index 976e0e3..7a4818c 100644 --- a/torch/csrc/api/include/torch/nn/modules/sequential.h +++ b/torch/csrc/api/include/torch/nn/modules/sequential.h @@ -102,6 +102,16 @@ class SequentialImpl : public Cloneable { push_back(std::forward(modules)...); } + /// Constructs the `Sequential` from an `OrderedDict` of named `AnyModule`s. + /// Combining with `modules_ordered_dict()`, it enables the following use case: + /// `Sequential sequential(modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}}))` + explicit SequentialImpl(torch::OrderedDict&& ordered_dict) { + modules_.reserve(ordered_dict.size()); + for (auto& item : ordered_dict) { + push_back(std::move(item.key()), std::move(item.value())); + } + } + /// Special cloning function for `Sequential` because it does not use /// `reset()`. std::shared_ptr clone( @@ -175,18 +185,13 @@ class SequentialImpl : public Cloneable { /// Adds a new (boxed) `Module` to the `Sequential` container. template void push_back(std::shared_ptr module_ptr) { - // Nesting Sequential doesn't work because `forward()`'s return type is - // templatized, so it'll give a nasty compiler error. - static_assert( - !std::is_same::value, - "Sequential is not nestable"); - static_assert( - torch::detail::is_module::value, - "Can only add objects derived from nn::Module to Sequential"); - static_assert( - torch::detail::has_forward::value, - "Can only add modules with a forward() method to Sequential"); - push_back(AnyModule(std::move(module_ptr))); + push_back(std::to_string(modules_.size()), std::move(module_ptr)); + } + + /// Adds a new named (boxed) `Module` to the `Sequential` container. + template + void push_back(std::string name, std::shared_ptr module_ptr) { + push_back(std::move(name), AnyModule(std::move(module_ptr))); } /// Adds a new `Module` to the `Sequential` container, moving or copying it @@ -196,17 +201,30 @@ class SequentialImpl : public Cloneable { /// `Sequential(std::make_shared(3, 4))`. template > void push_back(M&& module) { - // Need to get rid of any reference components for make_unique. + push_back(std::to_string(modules_.size()), std::forward(module)); + } + + /// Adds a new named `Module` to the `Sequential` container, moving or copying it + /// into a `shared_ptr` internally. This method allows passing value types, + /// and letting the container deal with the boxing. + template > + void push_back(std::string name, M&& module) { using Type = typename std::remove_reference::type; - // Here we move (or copy) the module into a new shared_ptr. - push_back(std::make_shared(std::forward(module))); + push_back(std::move(name), std::make_shared(std::forward(module))); } /// Unwraps the contained module of a `ModuleHolder` and adds it to the /// `Sequential`. template void push_back(const ModuleHolder& module_holder) { - push_back(module_holder.ptr()); + push_back(std::to_string(modules_.size()), module_holder); + } + + /// Unwraps the contained named module of a `ModuleHolder` and adds it to the + /// `Sequential`. + template + void push_back(std::string name, const ModuleHolder& module_holder) { + push_back(std::move(name), module_holder.ptr()); } /// Iterates over the container and calls `push_back()` on each value. @@ -302,7 +320,12 @@ class SequentialImpl : public Cloneable { /// pack has only one type, in which case the template would be preferred, /// even if the other `push_back` functions are better fits (e.g. `unique_ptr` /// -> `shared_ptr` overload). - template + /// NOTE: We explicitly avoid matching this template with `push_back(std::string("name"), module)` + /// or `push_back("name", module)`, since they should be handled by their respective + /// `push_back` functions. + template ::value || + std::is_same::type, std::decay::type>::value>> void push_back(First&& first, Second&& second, Rest&&... rest) { push_back(std::forward(first)); // Recursively calls this method, until the parameter pack only thas this @@ -312,9 +335,13 @@ class SequentialImpl : public Cloneable { /// Adds a type-erased `AnyModule` to the `Sequential`. void push_back(AnyModule any_module) { + push_back(std::to_string(modules_.size()), std::move(any_module)); + } + + void push_back(std::string name, AnyModule any_module) { modules_.push_back(std::move(any_module)); const auto index = modules_.size() - 1; - register_module(std::to_string(index), modules_[index].ptr()); + register_module(std::move(name), modules_[index].ptr()); } /// The base case, when the list of modules is empty. diff --git a/torch/csrc/api/src/nn/modules/named_any.cpp b/torch/csrc/api/src/nn/modules/named_any.cpp new file mode 100644 index 0000000..85c7656 --- /dev/null +++ b/torch/csrc/api/src/nn/modules/named_any.cpp @@ -0,0 +1,16 @@ +#include + +namespace torch { +namespace nn { + +torch::OrderedDict modules_ordered_dict( + std::initializer_list named_modules) { + torch::OrderedDict dict; + for (auto named_module : named_modules) { + dict.insert(named_module.name(), std::move(named_module.module())); + } + return dict; +} + +} // namespace nn +} // namespace torch