# Filter out some warnings.
ignore_warning "warning: no uniquely matching class member found for"
+ignore_warning "warning: explicit link request to 'Item' could not be resolved"
# Count the number of remaining warnings.
warnings="$(grep 'warning:' doxygen-log.txt | wc -l)"
#include <torch/nn/modules/conv.h>
#include <torch/nn/modules/dropout.h>
#include <torch/nn/modules/linear.h>
+#include <torch/nn/modules/named_any.h>
#include <torch/nn/modules/rnn.h>
#include <torch/nn/modules/sequential.h>
#include <torch/types.h>
Sequential sequential(
std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3));
ASSERT_EQ(sequential->size(), 3);
+
+ Sequential sequential_named(modules_ordered_dict({
+ {"m1", std::make_shared<M>(1)},
+ {std::string("m2"), std::make_shared<M>(2)},
+ {"m3", std::make_shared<M>(3)}
+ }));
+ ASSERT_EQ(sequential->size(), 3);
}
TEST_F(SequentialTest, ConstructsFromConcreteType) {
+ static int copy_count;
+
struct M : torch::nn::Module {
explicit M(int value_) : value(value_) {}
+ M(const M& other) : torch::nn::Module(other) {
+ copy_count++;
+ }
int value;
int forward() {
return value;
}
};
+ copy_count = 0;
Sequential sequential(M(1), M(2), M(3));
ASSERT_EQ(sequential->size(), 3);
+ // NOTE: The current implementation expects each module to be copied exactly once,
+ // which happens when the module is passed into `std::make_shared<T>()`.
+ // TODO: Find a way to avoid copying, and then delete the copy constructor of `M`.
+ ASSERT_EQ(copy_count, 3);
+
+ copy_count = 0;
+ Sequential sequential_named(modules_ordered_dict({
+ {"m1", M(1)},
+ {std::string("m2"), M(2)},
+ {"m3", M(3)}
+ }));
+ ASSERT_EQ(sequential->size(), 3);
+ ASSERT_EQ(copy_count, 3);
}
+
TEST_F(SequentialTest, ConstructsFromModuleHolder) {
struct MImpl : torch::nn::Module {
explicit MImpl(int value_) : value(value_) {}
Sequential sequential(M(1), M(2), M(3));
ASSERT_EQ(sequential->size(), 3);
+
+ Sequential sequential_named(modules_ordered_dict({
+ {"m1", M(1)},
+ {std::string("m2"), M(2)},
+ {"m3", M(3)}
+ }));
+ ASSERT_EQ(sequential->size(), 3);
}
TEST_F(SequentialTest, PushBackAddsAnElement) {
}
int value;
};
+
+ // Test unnamed submodules
Sequential sequential;
ASSERT_EQ(sequential->size(), 0);
ASSERT_TRUE(sequential->is_empty());
ASSERT_EQ(sequential->size(), 2);
sequential->push_back(M(2));
ASSERT_EQ(sequential->size(), 3);
+
+ // Mix named and unnamed submodules
+ Sequential sequential_named;
+ ASSERT_EQ(sequential_named->size(), 0);
+ ASSERT_TRUE(sequential_named->is_empty());
+
+ sequential_named->push_back(Linear(3, 4));
+ ASSERT_EQ(sequential_named->size(), 1);
+ ASSERT_EQ(sequential_named->named_children()[0].key(), "0");
+ sequential_named->push_back(std::string("linear2"), Linear(3, 4));
+ ASSERT_EQ(sequential_named->size(), 2);
+ ASSERT_EQ(sequential_named->named_children()[1].key(), "linear2");
+
+ sequential_named->push_back("shared_m1", std::make_shared<M>(1));
+ ASSERT_EQ(sequential_named->size(), 3);
+ ASSERT_EQ(sequential_named->named_children()[2].key(), "shared_m1");
+ sequential_named->push_back(std::make_shared<M>(1));
+ ASSERT_EQ(sequential_named->size(), 4);
+ ASSERT_EQ(sequential_named->named_children()[3].key(), "3");
+
+ sequential_named->push_back(M(1));
+ ASSERT_EQ(sequential_named->size(), 5);
+ ASSERT_EQ(sequential_named->named_children()[4].key(), "4");
+ sequential_named->push_back(std::string("m2"), M(1));
+ ASSERT_EQ(sequential_named->size(), 6);
+ ASSERT_EQ(sequential_named->named_children()[5].key(), "m2");
}
TEST_F(SequentialTest, AccessWithAt) {
" (4): torch::nn::Embedding(count=4, dimension=10)\n"
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
")");
+
+ Sequential sequential_named(modules_ordered_dict({
+ {"linear", Linear(10, 3)},
+ {"conv2d", Conv2d(1, 2, 3)},
+ {"dropout", Dropout(0.5)},
+ {"batchnorm", BatchNorm(5)},
+ {"embedding", Embedding(4, 10)},
+ {"lstm", LSTM(4, 5)}
+ }));
+ ASSERT_EQ(
+ c10::str(sequential_named),
+ "torch::nn::Sequential(\n"
+ " (linear): torch::nn::Linear(in=10, out=3, with_bias=true)\n"
+ " (conv2d): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
+ " (dropout): torch::nn::Dropout(rate=0.5)\n"
+ " (batchnorm): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"
+ " (embedding): torch::nn::Embedding(count=4, dimension=10)\n"
+ " (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
+ ")");
}
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/embedding.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/functional.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/linear.cpp
+ ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/named_any.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/rnn.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/adagrad.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/adam.cpp
AnyModule::AnyModule(std::shared_ptr<ModuleType> module)
: content_(make_holder(
std::move(module),
- &std::remove_reference<ModuleType>::type::forward)) {}
+ &std::remove_reference<ModuleType>::type::forward)) {
+ // `AnyModule` can only store an `nn::Module` subclass object that provides
+ // a `forward()` method that has a non-templatized return type.
+ // (e.g. `AnyModule` cannot store `nn::Sequential`, because `nn::Sequential`'s
+ // `forward()` method has a templatized return type.)
+ static_assert(
+ torch::detail::is_module<ModuleType>::value,
+ "Can only store object derived from nn::Module into AnyModule");
+ static_assert(
+ torch::detail::has_forward<ModuleType>::value,
+ "Can only store module with a forward() method that has a non-templatized"
+ "return type into AnyModule (e.g. we cannot store nn::Sequential"
+ "into AnyModule, because its forward() method's return type is templatized)");
+}
template <typename ModuleType, typename>
AnyModule::AnyModule(ModuleType&& module)
--- /dev/null
+#pragma once
+
+#include <torch/detail/static.h>
+#include <torch/nn/module.h>
+#include <torch/nn/modules/any.h>
+#include <torch/nn/pimpl.h>
+#include <torch/types.h>
+
+#include <torch/csrc/autograd/variable.h>
+#include <torch/csrc/utils/memory.h>
+#include <torch/csrc/utils/variadic.h>
+
+#include <ATen/Device.h>
+
+#include <initializer_list>
+#include <memory>
+#include <type_traits>
+#include <typeinfo>
+#include <utility>
+#include <vector>
+
+namespace torch {
+namespace nn {
+
+/// Stores a type erased `Module` with name.
+///
+/// The `NamedAnyModule` class and the `modules_ordered_dict(...)` function enables
+/// the following API for constructing `nn::Sequential` with named submodules:
+/// \rst
+/// .. code-block:: cpp
+///
+/// struct M : torch::nn::Module {
+/// explicit M(int value_) : value(value_) {}
+/// int value;
+/// int forward() {
+/// return value;
+/// }
+/// };
+///
+/// Sequential sequential(modules_ordered_dict({
+/// {"m1", std::make_shared<M>(1)}, // shared pointer to `Module` is supported
+/// {std::string("m2"), M(2)}, // `Module` is supported
+/// {"linear1", Linear(10, 3)} // `ModuleHolder` is supported
+/// }));
+/// \endrst
+///
+/// Specifically, we design the signature of `modules_ordered_dict(...)` to be
+/// `modules_ordered_dict(std::initializer_list<NamedAnyModule> named_modules)`, as
+/// a result of evaluating the following possible approaches:
+///
+/// Approach 1:
+/// `modules_ordered_dict(std::initializer_list<
+/// torch::OrderedDict<std::string, ModuleType>::Item> named_modules)`
+///
+/// Why it doens't work:
+/// When we pass in a braced-init list such as
+/// `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`, at the template argument
+/// deduction step the compiler is not able to deduce the type of `ModuleType` to
+/// the type of `M(1)` or `M(2)`, since the compiler doesn't actually look into the
+/// braced-init list `{"m1", M(1)}` and figure out what the types of its elements are.
+///
+/// Approach 2:
+/// `modules_ordered_dict(std::initializer_list<
+/// std::pair<std::string, AnyModule> named_modules)`
+///
+/// Why it doens't work:
+/// When we pass in a braced-init list such as
+/// `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`, the compiler is not able to
+/// match `std::initializer_list<std::pair<std::string, AnyModule>>` to the nested
+/// braced-init list `{{"m1", M(1)}, {"m2", M(2)}}`, and results in a "could not
+/// convert" error.
+///
+/// Approach 3:
+/// `modules_ordered_dict(std::initializer_list<NamedAnyModule> named_modules)`
+///
+/// Why it works:
+/// When we pass in a braced-init list such as
+/// `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`, the compiler is passing the
+/// braced-init lists {"m1", M(1)} and {"m2", M(2)} to the `NamedAnyModule`
+/// constructors, and the constructors are able to figure out the types of the
+/// braced-init lists' elements and match to the correct module type.
+
+class NamedAnyModule {
+ public:
+ /// Creates a `NamedAnyModule` from a (boxed) `Module`.
+ template <typename ModuleType>
+ NamedAnyModule(std::string name, std::shared_ptr<ModuleType> module_ptr)
+ : NamedAnyModule(std::move(name), AnyModule(std::move(module_ptr))) {}
+
+ /// Creates a `NamedAnyModule` from a `Module`, moving or copying it
+ /// into a `shared_ptr` internally.
+ // NOTE: We need to use `std::remove_reference<M>::type` to get rid of
+ // any reference components for make_unique.
+ template <typename M, typename = torch::detail::enable_if_module_t<M>>
+ NamedAnyModule(std::string name, M&& module)
+ : NamedAnyModule(
+ std::move(name),
+ std::make_shared<typename std::remove_reference<M>::type>(
+ std::forward<M>(module))) {}
+
+ /// Creates a `NamedAnyModule` from a `Module` that is unwrapped from
+ /// a `ModuleHolder`.
+ template <typename M>
+ NamedAnyModule(std::string name, const ModuleHolder<M>& module_holder)
+ : NamedAnyModule(std::move(name), module_holder.ptr()) {}
+
+ /// Returns a reference to the name.
+ const std::string& name() const noexcept {
+ return name_;
+ }
+
+ /// Returns a reference to the module.
+ AnyModule& module() noexcept {
+ return module_;
+ }
+
+ private:
+ /// Creates a `NamedAnyModule` from a type-erased `AnyModule`.
+ NamedAnyModule(std::string name, AnyModule any_module)
+ : name_(std::move(name)), module_(std::move(any_module)) {}
+
+ std::string name_;
+ AnyModule module_;
+};
+
+TORCH_API torch::OrderedDict<std::string, AnyModule> modules_ordered_dict(
+ std::initializer_list<NamedAnyModule> named_modules);
+
+} // namespace nn
+} // namespace torch
push_back(std::forward<Modules>(modules)...);
}
+ /// Constructs the `Sequential` from an `OrderedDict` of named `AnyModule`s.
+ /// Combining with `modules_ordered_dict()`, it enables the following use case:
+ /// `Sequential sequential(modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}}))`
+ explicit SequentialImpl(torch::OrderedDict<std::string, AnyModule>&& ordered_dict) {
+ modules_.reserve(ordered_dict.size());
+ for (auto& item : ordered_dict) {
+ push_back(std::move(item.key()), std::move(item.value()));
+ }
+ }
+
/// Special cloning function for `Sequential` because it does not use
/// `reset()`.
std::shared_ptr<Module> clone(
/// Adds a new (boxed) `Module` to the `Sequential` container.
template <typename ModuleType>
void push_back(std::shared_ptr<ModuleType> module_ptr) {
- // Nesting Sequential doesn't work because `forward()`'s return type is
- // templatized, so it'll give a nasty compiler error.
- static_assert(
- !std::is_same<SequentialImpl, ModuleType>::value,
- "Sequential is not nestable");
- static_assert(
- torch::detail::is_module<ModuleType>::value,
- "Can only add objects derived from nn::Module to Sequential");
- static_assert(
- torch::detail::has_forward<ModuleType>::value,
- "Can only add modules with a forward() method to Sequential");
- push_back(AnyModule(std::move(module_ptr)));
+ push_back(std::to_string(modules_.size()), std::move(module_ptr));
+ }
+
+ /// Adds a new named (boxed) `Module` to the `Sequential` container.
+ template <typename ModuleType>
+ void push_back(std::string name, std::shared_ptr<ModuleType> module_ptr) {
+ push_back(std::move(name), AnyModule(std::move(module_ptr)));
}
/// Adds a new `Module` to the `Sequential` container, moving or copying it
/// `Sequential(std::make_shared<Module>(3, 4))`.
template <typename M, typename = torch::detail::enable_if_module_t<M>>
void push_back(M&& module) {
- // Need to get rid of any reference components for make_unique.
+ push_back(std::to_string(modules_.size()), std::forward<M>(module));
+ }
+
+ /// Adds a new named `Module` to the `Sequential` container, moving or copying it
+ /// into a `shared_ptr` internally. This method allows passing value types,
+ /// and letting the container deal with the boxing.
+ template <typename M, typename = torch::detail::enable_if_module_t<M>>
+ void push_back(std::string name, M&& module) {
using Type = typename std::remove_reference<M>::type;
- // Here we move (or copy) the module into a new shared_ptr.
- push_back(std::make_shared<Type>(std::forward<M>(module)));
+ push_back(std::move(name), std::make_shared<Type>(std::forward<M>(module)));
}
/// Unwraps the contained module of a `ModuleHolder` and adds it to the
/// `Sequential`.
template <typename M>
void push_back(const ModuleHolder<M>& module_holder) {
- push_back(module_holder.ptr());
+ push_back(std::to_string(modules_.size()), module_holder);
+ }
+
+ /// Unwraps the contained named module of a `ModuleHolder` and adds it to the
+ /// `Sequential`.
+ template <typename M>
+ void push_back(std::string name, const ModuleHolder<M>& module_holder) {
+ push_back(std::move(name), module_holder.ptr());
}
/// Iterates over the container and calls `push_back()` on each value.
/// pack has only one type, in which case the template would be preferred,
/// even if the other `push_back` functions are better fits (e.g. `unique_ptr`
/// -> `shared_ptr` overload).
- template <typename First, typename Second, typename... Rest>
+ /// NOTE: We explicitly avoid matching this template with `push_back(std::string("name"), module)`
+ /// or `push_back("name", module)`, since they should be handled by their respective
+ /// `push_back` functions.
+ template <typename First, typename Second, typename... Rest,
+ typename = torch::disable_if_t<std::is_same<First, std::string>::value ||
+ std::is_same<typename std::decay<First>::type, std::decay<const char (&)[]>::type>::value>>
void push_back(First&& first, Second&& second, Rest&&... rest) {
push_back(std::forward<First>(first));
// Recursively calls this method, until the parameter pack only thas this
/// Adds a type-erased `AnyModule` to the `Sequential`.
void push_back(AnyModule any_module) {
+ push_back(std::to_string(modules_.size()), std::move(any_module));
+ }
+
+ void push_back(std::string name, AnyModule any_module) {
modules_.push_back(std::move(any_module));
const auto index = modules_.size() - 1;
- register_module(std::to_string(index), modules_[index].ptr());
+ register_module(std::move(name), modules_[index].ptr());
}
/// The base case, when the list of modules is empty.
--- /dev/null
+#include <torch/nn/modules/named_any.h>
+
+namespace torch {
+namespace nn {
+
+torch::OrderedDict<std::string, AnyModule> modules_ordered_dict(
+ std::initializer_list<NamedAnyModule> named_modules) {
+ torch::OrderedDict<std::string, AnyModule> dict;
+ for (auto named_module : named_modules) {
+ dict.insert(named_module.name(), std::move(named_module.module()));
+ }
+ return dict;
+}
+
+} // namespace nn
+} // namespace torch