Add named submodule support to nn::Sequential (#17552)
authorWill Feng <willfeng@fb.com>
Fri, 29 Mar 2019 19:59:29 +0000 (12:59 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 29 Mar 2019 20:06:29 +0000 (13:06 -0700)
Summary:
Previously, we were not able to assign names to `nn::Sequential`'s submodules. This PR adds this feature to match the Python API. Example use:
```cpp
Sequential sequential(named_submodule({
      {"linear", Linear(10, 3)},
      {"conv2d", Conv2d(1, 2, 3)},
      {"dropout", Dropout(0.5)},
      {"batchnorm", BatchNorm(5)},
      {"embedding", Embedding(4, 10)},
      {"lstm", LSTM(4, 5)}
}));
```

It also enables loading parameters of Python `nn.Sequential` module with custom submodules names into C++ frontend, unblocking https://github.com/pytorch/vision/pull/728#issuecomment-466661344.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17552

Differential Revision: D14246834

Pulled By: yf225

fbshipit-source-id: 3030b5c5d68f6dd5d3e37ac4b4f98dc6d6d9ba72

docs/cpp/source/check-doxygen.sh
test/cpp/api/sequential.cpp
torch/CMakeLists.txt
torch/csrc/api/include/torch/nn/modules/any.h
torch/csrc/api/include/torch/nn/modules/named_any.h [new file with mode: 0644]
torch/csrc/api/include/torch/nn/modules/sequential.h
torch/csrc/api/src/nn/modules/named_any.cpp [new file with mode: 0644]

index 18863e2..9959a3f 100755 (executable)
@@ -39,6 +39,7 @@ cat original-doxygen-log.txt
 
 # Filter out some warnings.
 ignore_warning "warning: no uniquely matching class member found for"
+ignore_warning "warning: explicit link request to 'Item' could not be resolved"
 
 # Count the number of remaining warnings.
 warnings="$(grep 'warning:' doxygen-log.txt | wc -l)"
index 1f7cf55..747bbf1 100644 (file)
@@ -5,6 +5,7 @@
 #include <torch/nn/modules/conv.h>
 #include <torch/nn/modules/dropout.h>
 #include <torch/nn/modules/linear.h>
+#include <torch/nn/modules/named_any.h>
 #include <torch/nn/modules/rnn.h>
 #include <torch/nn/modules/sequential.h>
 #include <torch/types.h>
@@ -32,20 +33,47 @@ TEST_F(SequentialTest, ConstructsFromSharedPointer) {
   Sequential sequential(
       std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3));
   ASSERT_EQ(sequential->size(), 3);
+
+  Sequential sequential_named(modules_ordered_dict({
+    {"m1", std::make_shared<M>(1)},
+    {std::string("m2"), std::make_shared<M>(2)},
+    {"m3", std::make_shared<M>(3)}
+  }));
+  ASSERT_EQ(sequential->size(), 3);
 }
 
 TEST_F(SequentialTest, ConstructsFromConcreteType) {
+  static int copy_count;
+
   struct M : torch::nn::Module {
     explicit M(int value_) : value(value_) {}
+    M(const M& other) : torch::nn::Module(other) {
+      copy_count++;
+    }
     int value;
     int forward() {
       return value;
     }
   };
 
+  copy_count = 0;
   Sequential sequential(M(1), M(2), M(3));
   ASSERT_EQ(sequential->size(), 3);
+  // NOTE: The current implementation expects each module to be copied exactly once,
+  // which happens when the module is passed into `std::make_shared<T>()`.
+  // TODO: Find a way to avoid copying, and then delete the copy constructor of `M`.
+  ASSERT_EQ(copy_count, 3);
+
+  copy_count = 0;
+  Sequential sequential_named(modules_ordered_dict({
+    {"m1", M(1)},
+    {std::string("m2"), M(2)},
+    {"m3", M(3)}
+  }));
+  ASSERT_EQ(sequential->size(), 3);
+  ASSERT_EQ(copy_count, 3);
 }
+
 TEST_F(SequentialTest, ConstructsFromModuleHolder) {
   struct MImpl : torch::nn::Module {
     explicit MImpl(int value_) : value(value_) {}
@@ -62,6 +90,13 @@ TEST_F(SequentialTest, ConstructsFromModuleHolder) {
 
   Sequential sequential(M(1), M(2), M(3));
   ASSERT_EQ(sequential->size(), 3);
+
+  Sequential sequential_named(modules_ordered_dict({
+    {"m1", M(1)},
+    {std::string("m2"), M(2)},
+    {"m3", M(3)}
+  }));
+  ASSERT_EQ(sequential->size(), 3);
 }
 
 TEST_F(SequentialTest, PushBackAddsAnElement) {
@@ -72,6 +107,8 @@ TEST_F(SequentialTest, PushBackAddsAnElement) {
     }
     int value;
   };
+
+  // Test unnamed submodules
   Sequential sequential;
   ASSERT_EQ(sequential->size(), 0);
   ASSERT_TRUE(sequential->is_empty());
@@ -81,6 +118,32 @@ TEST_F(SequentialTest, PushBackAddsAnElement) {
   ASSERT_EQ(sequential->size(), 2);
   sequential->push_back(M(2));
   ASSERT_EQ(sequential->size(), 3);
+
+  // Mix named and unnamed submodules
+  Sequential sequential_named;
+  ASSERT_EQ(sequential_named->size(), 0);
+  ASSERT_TRUE(sequential_named->is_empty());
+
+  sequential_named->push_back(Linear(3, 4));
+  ASSERT_EQ(sequential_named->size(), 1);
+  ASSERT_EQ(sequential_named->named_children()[0].key(), "0");
+  sequential_named->push_back(std::string("linear2"), Linear(3, 4));
+  ASSERT_EQ(sequential_named->size(), 2);
+  ASSERT_EQ(sequential_named->named_children()[1].key(), "linear2");
+
+  sequential_named->push_back("shared_m1", std::make_shared<M>(1));
+  ASSERT_EQ(sequential_named->size(), 3);
+  ASSERT_EQ(sequential_named->named_children()[2].key(), "shared_m1");
+  sequential_named->push_back(std::make_shared<M>(1));
+  ASSERT_EQ(sequential_named->size(), 4);
+  ASSERT_EQ(sequential_named->named_children()[3].key(), "3");
+
+  sequential_named->push_back(M(1));
+  ASSERT_EQ(sequential_named->size(), 5);
+  ASSERT_EQ(sequential_named->named_children()[4].key(), "4");
+  sequential_named->push_back(std::string("m2"), M(1));
+  ASSERT_EQ(sequential_named->size(), 6);
+  ASSERT_EQ(sequential_named->named_children()[5].key(), "m2");
 }
 
 TEST_F(SequentialTest, AccessWithAt) {
@@ -342,4 +405,23 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
       "  (4): torch::nn::Embedding(count=4, dimension=10)\n"
       "  (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
       ")");
+
+  Sequential sequential_named(modules_ordered_dict({
+      {"linear", Linear(10, 3)},
+      {"conv2d", Conv2d(1, 2, 3)},
+      {"dropout", Dropout(0.5)},
+      {"batchnorm", BatchNorm(5)},
+      {"embedding", Embedding(4, 10)},
+      {"lstm", LSTM(4, 5)}
+  }));
+  ASSERT_EQ(
+      c10::str(sequential_named),
+      "torch::nn::Sequential(\n"
+      "  (linear): torch::nn::Linear(in=10, out=3, with_bias=true)\n"
+      "  (conv2d): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
+      "  (dropout): torch::nn::Dropout(rate=0.5)\n"
+      "  (batchnorm): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"
+      "  (embedding): torch::nn::Embedding(count=4, dimension=10)\n"
+      "  (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
+      ")");
 }
index 4e6fe9a..deff903 100644 (file)
@@ -243,6 +243,7 @@ if (NOT NO_API)
     ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/embedding.cpp
     ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/functional.cpp
     ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/linear.cpp
+    ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/named_any.cpp
     ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/rnn.cpp
     ${TORCH_SRC_DIR}/csrc/api/src/optim/adagrad.cpp
     ${TORCH_SRC_DIR}/csrc/api/src/optim/adam.cpp
index 0715fae..08deb0b 100644 (file)
@@ -419,7 +419,20 @@ template <typename ModuleType>
 AnyModule::AnyModule(std::shared_ptr<ModuleType> module)
     : content_(make_holder(
           std::move(module),
-          &std::remove_reference<ModuleType>::type::forward)) {}
+          &std::remove_reference<ModuleType>::type::forward)) {
+  // `AnyModule` can only store an `nn::Module` subclass object that provides
+  // a `forward()` method that has a non-templatized return type.
+  // (e.g. `AnyModule` cannot store `nn::Sequential`, because `nn::Sequential`'s
+  // `forward()` method has a templatized return type.)
+  static_assert(
+      torch::detail::is_module<ModuleType>::value,
+      "Can only store object derived from nn::Module into AnyModule");
+  static_assert(
+      torch::detail::has_forward<ModuleType>::value,
+      "Can only store module with a forward() method that has a non-templatized"
+      "return type into AnyModule (e.g. we cannot store nn::Sequential"
+      "into AnyModule, because its forward() method's return type is templatized)");
+}
 
 template <typename ModuleType, typename>
 AnyModule::AnyModule(ModuleType&& module)
diff --git a/torch/csrc/api/include/torch/nn/modules/named_any.h b/torch/csrc/api/include/torch/nn/modules/named_any.h
new file mode 100644 (file)
index 0000000..d4fa54b
--- /dev/null
@@ -0,0 +1,130 @@
+#pragma once
+
+#include <torch/detail/static.h>
+#include <torch/nn/module.h>
+#include <torch/nn/modules/any.h>
+#include <torch/nn/pimpl.h>
+#include <torch/types.h>
+
+#include <torch/csrc/autograd/variable.h>
+#include <torch/csrc/utils/memory.h>
+#include <torch/csrc/utils/variadic.h>
+
+#include <ATen/Device.h>
+
+#include <initializer_list>
+#include <memory>
+#include <type_traits>
+#include <typeinfo>
+#include <utility>
+#include <vector>
+
+namespace torch {
+namespace nn {
+
+/// Stores a type erased `Module` with name.
+///
+/// The `NamedAnyModule` class and the `modules_ordered_dict(...)` function enables
+/// the following API for constructing `nn::Sequential` with named submodules:
+/// \rst
+/// .. code-block:: cpp
+///
+///   struct M : torch::nn::Module {
+///     explicit M(int value_) : value(value_) {}
+///     int value;
+///     int forward() {
+///       return value;
+///     }
+///   };
+///
+///   Sequential sequential(modules_ordered_dict({
+///     {"m1", std::make_shared<M>(1)},  // shared pointer to `Module` is supported
+///     {std::string("m2"), M(2)},  // `Module` is supported
+///     {"linear1", Linear(10, 3)}  // `ModuleHolder` is supported
+///   }));
+/// \endrst
+///
+/// Specifically, we design the signature of `modules_ordered_dict(...)` to be
+/// `modules_ordered_dict(std::initializer_list<NamedAnyModule> named_modules)`, as
+/// a result of evaluating the following possible approaches:
+///
+/// Approach 1:
+/// `modules_ordered_dict(std::initializer_list<
+///   torch::OrderedDict<std::string, ModuleType>::Item> named_modules)`
+///
+/// Why it doens't work:
+/// When we pass in a braced-init list such as
+/// `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`, at the template argument
+/// deduction step the compiler is not able to deduce the type of `ModuleType` to
+/// the type of `M(1)` or `M(2)`, since the compiler doesn't actually look into the
+/// braced-init list `{"m1", M(1)}` and figure out what the types of its elements are.
+///
+/// Approach 2:
+/// `modules_ordered_dict(std::initializer_list<
+///   std::pair<std::string, AnyModule> named_modules)`
+///
+/// Why it doens't work:
+/// When we pass in a braced-init list such as
+/// `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`, the compiler is not able to
+/// match `std::initializer_list<std::pair<std::string, AnyModule>>` to the nested
+/// braced-init list `{{"m1", M(1)}, {"m2", M(2)}}`, and results in a "could not
+/// convert" error.
+///
+/// Approach 3:
+/// `modules_ordered_dict(std::initializer_list<NamedAnyModule> named_modules)`
+///
+/// Why it works:
+/// When we pass in a braced-init list such as
+/// `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`, the compiler is passing the
+/// braced-init lists {"m1", M(1)} and {"m2", M(2)} to the `NamedAnyModule`
+/// constructors, and the constructors are able to figure out the types of the
+/// braced-init lists' elements and match to the correct module type.
+
+class NamedAnyModule {
+ public:
+  /// Creates a `NamedAnyModule` from a (boxed) `Module`.
+  template <typename ModuleType>
+  NamedAnyModule(std::string name, std::shared_ptr<ModuleType> module_ptr)
+      : NamedAnyModule(std::move(name), AnyModule(std::move(module_ptr))) {}
+
+  /// Creates a `NamedAnyModule` from a `Module`, moving or copying it
+  /// into a `shared_ptr` internally.
+  // NOTE: We need to use `std::remove_reference<M>::type` to get rid of
+  // any reference components for make_unique.
+  template <typename M, typename = torch::detail::enable_if_module_t<M>>
+  NamedAnyModule(std::string name, M&& module)
+      : NamedAnyModule(
+          std::move(name),
+          std::make_shared<typename std::remove_reference<M>::type>(
+            std::forward<M>(module))) {}
+
+  /// Creates a `NamedAnyModule` from a `Module` that is unwrapped from
+  /// a `ModuleHolder`.
+  template <typename M>
+  NamedAnyModule(std::string name, const ModuleHolder<M>& module_holder)
+      : NamedAnyModule(std::move(name), module_holder.ptr()) {}
+
+  /// Returns a reference to the name.
+  const std::string& name() const noexcept {
+    return name_;
+  }
+
+  /// Returns a reference to the module.
+  AnyModule& module() noexcept {
+    return module_;
+  }
+
+ private:
+  /// Creates a `NamedAnyModule` from a type-erased `AnyModule`.
+  NamedAnyModule(std::string name, AnyModule any_module)
+    : name_(std::move(name)), module_(std::move(any_module)) {}
+
+  std::string name_;
+  AnyModule module_;
+};
+
+TORCH_API torch::OrderedDict<std::string, AnyModule> modules_ordered_dict(
+  std::initializer_list<NamedAnyModule> named_modules);
+
+} // namespace nn
+} // namespace torch
index 976e0e3..7a4818c 100644 (file)
@@ -102,6 +102,16 @@ class SequentialImpl : public Cloneable<SequentialImpl> {
     push_back(std::forward<Modules>(modules)...);
   }
 
+  /// Constructs the `Sequential` from an `OrderedDict` of named `AnyModule`s.
+  /// Combining with `modules_ordered_dict()`, it enables the following use case:
+  /// `Sequential sequential(modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}}))`
+  explicit SequentialImpl(torch::OrderedDict<std::string, AnyModule>&& ordered_dict) {
+    modules_.reserve(ordered_dict.size());
+    for (auto& item : ordered_dict) {
+      push_back(std::move(item.key()), std::move(item.value()));
+    }
+  }
+
   /// Special cloning function for `Sequential` because it does not use
   /// `reset()`.
   std::shared_ptr<Module> clone(
@@ -175,18 +185,13 @@ class SequentialImpl : public Cloneable<SequentialImpl> {
   /// Adds a new (boxed) `Module` to the `Sequential` container.
   template <typename ModuleType>
   void push_back(std::shared_ptr<ModuleType> module_ptr) {
-    // Nesting Sequential doesn't work because `forward()`'s return type is
-    // templatized, so it'll give a nasty compiler error.
-    static_assert(
-        !std::is_same<SequentialImpl, ModuleType>::value,
-        "Sequential is not nestable");
-    static_assert(
-        torch::detail::is_module<ModuleType>::value,
-        "Can only add objects derived from nn::Module to Sequential");
-    static_assert(
-        torch::detail::has_forward<ModuleType>::value,
-        "Can only add modules with a forward() method to Sequential");
-    push_back(AnyModule(std::move(module_ptr)));
+    push_back(std::to_string(modules_.size()), std::move(module_ptr));
+  }
+
+  /// Adds a new named (boxed) `Module` to the `Sequential` container.
+  template <typename ModuleType>
+  void push_back(std::string name, std::shared_ptr<ModuleType> module_ptr) {
+    push_back(std::move(name), AnyModule(std::move(module_ptr)));
   }
 
   /// Adds a new `Module` to the `Sequential` container, moving or copying it
@@ -196,17 +201,30 @@ class SequentialImpl : public Cloneable<SequentialImpl> {
   /// `Sequential(std::make_shared<Module>(3, 4))`.
   template <typename M, typename = torch::detail::enable_if_module_t<M>>
   void push_back(M&& module) {
-    // Need to get rid of any reference components for make_unique.
+    push_back(std::to_string(modules_.size()), std::forward<M>(module));
+  }
+
+  /// Adds a new named `Module` to the `Sequential` container, moving or copying it
+  /// into a `shared_ptr` internally. This method allows passing value types,
+  /// and letting the container deal with the boxing.
+  template <typename M, typename = torch::detail::enable_if_module_t<M>>
+  void push_back(std::string name, M&& module) {
     using Type = typename std::remove_reference<M>::type;
-    // Here we move (or copy) the module into a new shared_ptr.
-    push_back(std::make_shared<Type>(std::forward<M>(module)));
+    push_back(std::move(name), std::make_shared<Type>(std::forward<M>(module)));
   }
 
   /// Unwraps the contained module of a `ModuleHolder` and adds it to the
   /// `Sequential`.
   template <typename M>
   void push_back(const ModuleHolder<M>& module_holder) {
-    push_back(module_holder.ptr());
+    push_back(std::to_string(modules_.size()), module_holder);
+  }
+
+  /// Unwraps the contained named module of a `ModuleHolder` and adds it to the
+  /// `Sequential`.
+  template <typename M>
+  void push_back(std::string name, const ModuleHolder<M>& module_holder) {
+    push_back(std::move(name), module_holder.ptr());
   }
 
   /// Iterates over the container and calls `push_back()` on each value.
@@ -302,7 +320,12 @@ class SequentialImpl : public Cloneable<SequentialImpl> {
   /// pack has only one type, in which case the template would be preferred,
   /// even if the other `push_back` functions are better fits (e.g. `unique_ptr`
   /// -> `shared_ptr` overload).
-  template <typename First, typename Second, typename... Rest>
+  /// NOTE: We explicitly avoid matching this template with `push_back(std::string("name"), module)`
+  /// or `push_back("name", module)`, since they should be handled by their respective
+  /// `push_back` functions.
+  template <typename First, typename Second, typename... Rest,
+    typename = torch::disable_if_t<std::is_same<First, std::string>::value ||
+      std::is_same<typename std::decay<First>::type, std::decay<const char (&)[]>::type>::value>>
   void push_back(First&& first, Second&& second, Rest&&... rest) {
     push_back(std::forward<First>(first));
     // Recursively calls this method, until the parameter pack only thas this
@@ -312,9 +335,13 @@ class SequentialImpl : public Cloneable<SequentialImpl> {
 
   /// Adds a type-erased `AnyModule` to the `Sequential`.
   void push_back(AnyModule any_module) {
+    push_back(std::to_string(modules_.size()), std::move(any_module));
+  }
+
+  void push_back(std::string name, AnyModule any_module) {
     modules_.push_back(std::move(any_module));
     const auto index = modules_.size() - 1;
-    register_module(std::to_string(index), modules_[index].ptr());
+    register_module(std::move(name), modules_[index].ptr());
   }
 
   /// The base case, when the list of modules is empty.
diff --git a/torch/csrc/api/src/nn/modules/named_any.cpp b/torch/csrc/api/src/nn/modules/named_any.cpp
new file mode 100644 (file)
index 0000000..85c7656
--- /dev/null
@@ -0,0 +1,16 @@
+#include <torch/nn/modules/named_any.h>
+
+namespace torch {
+namespace nn {
+
+torch::OrderedDict<std::string, AnyModule> modules_ordered_dict(
+  std::initializer_list<NamedAnyModule> named_modules) {
+  torch::OrderedDict<std::string, AnyModule> dict;
+  for (auto named_module : named_modules) {
+    dict.insert(named_module.name(), std::move(named_module.module()));
+  }
+  return dict;
+}
+
+} // namespace nn
+} // namespace torch