ASSERT_NO_THROW(module->modules());
}
}
+
+struct EmptyModule : torch::nn::Module {};
+
+TEST_F(ModuleTest, PrettyPrint) {
+ struct TestModule : torch::nn::Module {
+ TestModule(int x, float y) : x_(x), y_(y) {}
+
+ void pretty_print(std::ostream& stream) const {
+ stream << "TestModule(x=" << x_ << ", y=" << y_ << ")";
+ }
+
+ int x_;
+ float y_;
+ };
+
+ using namespace torch::nn;
+
+ ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule");
+ ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)");
+}
ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
}
+
+TEST_F(ModulesTest, PrettyPrintLinear) {
+ ASSERT_EQ(
+ c10::str(Linear(3, 4)), "torch::nn::Linear(in=3, out=4, with_bias=true)");
+}
+
+TEST_F(ModulesTest, PrettyPrintConv) {
+ ASSERT_EQ(
+ c10::str(Conv1d(3, 4, 5)),
+ "torch::nn::Conv1d(input_channels=3, output_channels=4, kernel_size=5, stride=1)");
+ ASSERT_EQ(
+ c10::str(Conv2d(3, 4, 5)),
+ "torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 5], stride=[1, 1])");
+ ASSERT_EQ(
+ c10::str(Conv2d(Conv2dOptions(3, 4, 5).stride(2))),
+ "torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 5], stride=[2, 2])");
+
+ const auto options = Conv2dOptions(3, 4, torch::IntList{5, 6}).stride({1, 2});
+ ASSERT_EQ(
+ c10::str(Conv2d(options)),
+ "torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 6], stride=[1, 2])");
+}
+
+TEST_F(ModulesTest, PrettyPrintDropout) {
+ ASSERT_EQ(c10::str(Dropout(0.5)), "torch::nn::Dropout(rate=0.5)");
+ ASSERT_EQ(
+ c10::str(FeatureDropout(0.5)), "torch::nn::FeatureDropout(rate=0.5)");
+}
+
+TEST_F(ModulesTest, PrettyPrintFunctional) {
+ ASSERT_EQ(c10::str(Functional(torch::relu)), "torch::nn::Functional()");
+}
+
+TEST_F(ModulesTest, PrettyPrintBatchNorm) {
+ ASSERT_EQ(
+ c10::str(BatchNorm(
+ BatchNormOptions(4).eps(0.5).momentum(0.1).affine(false).stateful(
+ true))),
+ "torch::nn::BatchNorm(features=4, eps=0.5, momentum=0.1, affine=false, stateful=true)");
+}
+
+TEST_F(ModulesTest, PrettyPrintEmbedding) {
+ ASSERT_EQ(
+ c10::str(Embedding(10, 2)),
+ "torch::nn::Embedding(count=10, dimension=2)");
+}
+
+TEST_F(ModulesTest, PrettyPrintNestedModel) {
+ struct InnerTestModule : torch::nn::Module {
+ InnerTestModule()
+ : torch::nn::Module("InnerTestModule"),
+ fc(register_module("fc", torch::nn::Linear(3, 4))),
+ table(register_module("table", torch::nn::Embedding(10, 2))) {}
+
+ torch::nn::Linear fc;
+ torch::nn::Embedding table;
+ };
+
+ struct TestModule : torch::nn::Module {
+ TestModule()
+ : torch::nn::Module("TestModule"),
+ fc(register_module("fc", torch::nn::Linear(4, 5))),
+ table(register_module("table", torch::nn::Embedding(10, 2))),
+ inner(register_module("inner", std::make_shared<InnerTestModule>())) {
+ }
+
+ torch::nn::Linear fc;
+ torch::nn::Embedding table;
+ std::shared_ptr<InnerTestModule> inner;
+ };
+
+ ASSERT_EQ(
+ c10::str(TestModule{}),
+ "TestModule(\n"
+ " (fc): torch::nn::Linear(in=4, out=5, with_bias=true)\n"
+ " (table): torch::nn::Embedding(count=10, dimension=2)\n"
+ " (inner): InnerTestModule(\n"
+ " (fc): torch::nn::Linear(in=3, out=4, with_bias=true)\n"
+ " (table): torch::nn::Embedding(count=10, dimension=2)\n"
+ " )\n"
+ ")");
+}
ASSERT_TRUE(test_RNN_xor<RNN>(
[](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }, true));
}
+
+TEST_F(RNNTest, PrettyPrintRNNs) {
+ ASSERT_EQ(
+ c10::str(LSTM(LSTMOptions(128, 64).layers(3).dropout(0.2))),
+ "torch::nn::LSTM(input_size=128, hidden_size=64, layers=3, dropout=0.2)");
+ ASSERT_EQ(
+ c10::str(GRU(GRUOptions(128, 64).layers(3).dropout(0.5))),
+ "torch::nn::GRU(input_size=128, hidden_size=64, layers=3, dropout=0.5)");
+ ASSERT_EQ(
+ c10::str(RNN(RNNOptions(128, 64).layers(3).dropout(0.2).tanh())),
+ "torch::nn::RNN(input_size=128, hidden_size=64, layers=3, dropout=0.2, activation=tanh)");
+}
ASSERT_EQ(b.device(), device);
}
}
+
+TEST_F(SequentialTest, PrettyPrintSequential) {
+ Sequential sequential(
+ Linear(10, 3),
+ Conv2d(1, 2, 3),
+ Dropout(0.5),
+ BatchNorm(5),
+ Embedding(4, 10),
+ LSTM(4, 5));
+ ASSERT_EQ(
+ c10::str(sequential),
+ "torch::nn::Sequential(\n"
+ " (0): torch::nn::Linear(in=10, out=3, with_bias=true)\n"
+ " (1): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
+ " (2): torch::nn::Dropout(rate=0.5)\n"
+ " (3): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"
+ " (4): torch::nn::Embedding(count=4, dimension=10)\n"
+ " (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
+ ")");
+}
/// the length is checked against the `ExpandingArray`'s extent parameter `D`
/// at runtime.
/*implicit*/ ExpandingArray(std::initializer_list<T> list)
- : ExpandingArray(std::vector<T>(list)) {}
+ : ExpandingArray(at::ArrayRef<T>(list)) {}
- /// Constructs an `ExpandingArray` from a `vector`. The extent of the
- /// length is checked against the `ExpandingArray`'s extent parameter `D` at
- /// runtime.
- /*implicit*/ ExpandingArray(const std::vector<T>& values) {
+ /// Constructs an `ExpandingArray` from an `initializer_list`. The extent of
+ /// the length is checked against the `ExpandingArray`'s extent parameter `D`
+ /// at runtime.
+ /*implicit*/ ExpandingArray(at::ArrayRef<T> values) {
+ // clang-format off
AT_CHECK(
values.size() == D,
- "Expected ",
- D,
- " values, but instead got ",
- values.size());
+ "Expected ", D, " values, but instead got ", values.size());
+ // clang-format on
std::copy(values.begin(), values.end(), values_.begin());
}
std::array<T, D> values_;
};
+template <size_t D, typename T>
+std::ostream& operator<<(
+ std::ostream& stream,
+ const ExpandingArray<D, T>& expanding_array) {
+ if (expanding_array.size() == 1) {
+ return stream << expanding_array->at(0);
+ }
+ return stream << static_cast<at::ArrayRef<T>>(expanding_array);
+}
} // namespace torch
#include <ATen/ATen.h>
#include <functional>
+#include <iosfwd>
#include <map>
#include <memory>
#include <string>
/// Deserializes the `Module` from the given `InputArchive`.
virtual void load(serialize::InputArchive& archive);
+ /// Streams a pretty representation of the `Module` into the given `stream`.
+ /// By default, this representation will be the name of the module (taken from
+ /// `name()`), followed by a recursive pretty print of all of the `Module`'s
+ /// submodules.
+ ///
+ /// Override this method to change the pretty print. The input
+ /// `stream` should be returned from the method, to allow easy chaining.
+ virtual void pretty_print(std::ostream& stream) const;
+
protected:
/// Registers a parameter with this `Module`.
///
template <typename Derived>
friend class Cloneable;
+ /// Pretty prints the given `Module` into the `ostream`.
+ TORCH_API friend std::ostream& operator<<(
+ std::ostream& stream,
+ const nn::Module& module);
+
// Private methods.
/// Used in the implementation of `Cloneable`.
template <typename... Ts>
void to_impl(Ts&&... ts);
+ /// Implements pretty printing the module hierarchy.
+ void pretty_print_recursive(
+ std::ostream& stream,
+ const std::string& indentation) const;
+
/// Applies the `function` to every submodule recursively, starting at this
/// `Module`'s children (thus not including the module itself).
void apply_to_submodules(
void reset() override;
+ /// Pretty prints the `BatchNorm` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
+
/// Applies batch normalization on the `input` using the stored mean and
/// variance.
///
void reset() override;
+ /// Pretty prints the `Conv{1,2,3}d` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
+
/// The options with which this `Module` was constructed.
ConvOptions<D> options;
class TORCH_API DropoutImpl : public detail::DropoutImplBase<DropoutImpl> {
public:
using detail::DropoutImplBase<DropoutImpl>::DropoutImplBase;
+
/// During training, applies a noise mask to the input tensor.
/// During evaluation, applies an identity function.
Tensor forward(const Tensor& input);
+
+ /// Pretty prints the `Dropout` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
};
/// Applies spatial [Dropout](https://arxiv.org/abs/1207.0580) to inputs with
/// [Dropout3d](https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout3d) for
/// 3-D features. This `FeatureDropout` module can instead deal with both 2-D
/// and 3-D features.
-class TORCH_API FeatureDropoutImpl : public detail::DropoutImplBase<FeatureDropoutImpl> {
+class TORCH_API FeatureDropoutImpl
+ : public detail::DropoutImplBase<FeatureDropoutImpl> {
public:
using detail::DropoutImplBase<FeatureDropoutImpl>::DropoutImplBase;
+
/// During training, applies a noise mask to the input tensor.
/// During evaluation, applies an identity function.
Tensor forward(const Tensor& input);
+
+ /// Pretty prints the `FeatureDropout` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
};
/// A `ModuleHolder` subclass for `DropoutImpl`.
void reset() override;
+ /// Pretty prints the `Embedding` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
+
/// Performs a lookup on the embedding table stored in `weight` using the
/// `indices` supplied and returns the result.
Tensor forward(const Tensor& indices);
void reset() override;
+ /// Pretty prints the `Functional` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
+
/// Forwards the `input` tensor to the underlying (bound) function object.
Tensor forward(Tensor input);
void reset() override;
+ /// Pretty prints the `Linear` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
+
/// Transforms the `input` tensor by multiplying with the `weight` and
/// optionally adding the `bias`, if `with_bias` is true in the options.
Tensor forward(const Tensor& input);
void to(torch::Dtype dtype, bool non_blocking = false) override;
void to(torch::Device device, bool non_blocking = false) override;
+ /// Pretty prints the RNN module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
+
/// Modifies the internal storage of weights for optimization purposes.
///
/// On CPU, this method should be called if any of the weight or bias vectors
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-enum class RNNActivation : uint32_t TORCH_API { ReLU, Tanh };
+enum class RNNActivation : uint32_t TORCH_API{ReLU, Tanh};
/// Options for RNN modules.
struct TORCH_API RNNOptions {
: RNNImpl(RNNOptions(input_size, hidden_size)) {}
explicit RNNImpl(const RNNOptions& options);
+ /// Pretty prints the `RNN` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
+
/// Applies the `RNN` module to an input sequence and input state.
/// The `input` should follow a `(sequence, batch, features)` layout unless
/// `batch_first` is true, in which case the layout should be `(batch,
#include <cstdint>
#include <memory>
+#include <ostream>
#include <string>
#include <type_traits>
#include <utility>
/// its own.
void reset() override {}
+ /// Pretty prints the `Sequential` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override {
+ stream << "torch::nn::Sequential";
+ }
+
/// Feeds `inputs` to the first module and then chains outputs to inputs,
/// returning the last output.
///
}
};
+/// Pretty prints the given `Module` into the `ostream`.
+template <typename ModuleType>
+std::ostream& operator<<(
+ std::ostream& stream,
+ const nn::ModuleHolder<ModuleType>& module) {
+ return stream << *module;
+}
+
/// Serializes a `ModuleHolder` into an `OutputArchive`.
template <typename ModuleType>
serialize::OutputArchive& operator<<(
#include <algorithm>
#include <functional>
#include <map>
+#include <ostream>
#include <string>
#include <typeinfo>
return buffers_.insert(std::move(name), std::move(tensor));
}
+void Module::pretty_print(std::ostream& stream) const {
+ stream << name();
+}
+
+void Module::pretty_print_recursive(
+ std::ostream& stream,
+ const std::string& indentation) const {
+ pretty_print(stream);
+ if (!children_.is_empty()) {
+ stream << "(\n";
+ const std::string next_indentation = indentation + " ";
+ for (const auto& child : children_) {
+ stream << next_indentation << "(" << child.key() << "): ";
+ child.value()->pretty_print_recursive(stream, next_indentation);
+ stream << '\n';
+ }
+ stream << indentation << ")";
+ }
+}
+
void Module::clone_(Module& other, const optional<Device>& device) {}
void Module::apply_to_submodules(
return std::const_pointer_cast<Module>(ptr);
}
+std::ostream& operator<<(std::ostream& stream, const nn::Module& module) {
+ module.pretty_print_recursive(stream, "");
+ return stream;
+}
+
serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const std::shared_ptr<nn::Module>& module) {
#include <c10/util/Exception.h>
#include <cstddef>
+#include <ostream>
#include <utility>
#include <vector>
}
}
+void BatchNormImpl::pretty_print(std::ostream& stream) const {
+ stream << std::boolalpha
+ << "torch::nn::BatchNorm(features=" << options.features_
+ << ", eps=" << options.eps_ << ", momentum=" << options.momentum_
+ << ", affine=" << options.affine_ << ", stateful=" << options.stateful_
+ << ")";
+}
+
Tensor BatchNormImpl::forward(const Tensor& input) {
AT_CHECK(
options.stateful_,
}
}
+template <size_t D, typename Derived>
+void ConvImpl<D, Derived>::pretty_print(std::ostream& stream) const {
+ stream << "torch::nn::Conv" << D << "d"
+ << "(input_channels=" << options.input_channels_
+ << ", output_channels=" << options.output_channels_
+ << ", kernel_size=" << options.kernel_size_
+ << ", stride=" << options.stride_ << ")";
+}
+
Tensor Conv1dImpl::forward(const Tensor& input) {
if (options.transposed_) {
return torch::conv_transpose1d(
#include <c10/util/Exception.h>
#include <cstddef>
+#include <ostream>
#include <vector>
namespace torch {
return torch::dropout(input, options.rate_, this->is_training());
}
+void DropoutImpl::pretty_print(std::ostream& stream) const {
+ stream << "torch::nn::Dropout(rate=" << options.rate_ << ")";
+}
+
Tensor FeatureDropoutImpl::forward(const Tensor& input) {
return torch::feature_dropout(input, options.rate_, this->is_training());
}
+
+void FeatureDropoutImpl::pretty_print(std::ostream& stream) const {
+ stream << "torch::nn::FeatureDropout(rate=" << options.rate_ << ")";
+}
} // namespace nn
} // namespace torch
#include <torch/utils.h>
#include <cstddef>
+#include <ostream>
#include <utility>
#include <vector>
EmbeddingOptions::EmbeddingOptions(int64_t count, int64_t dimension)
: count_(count), dimension_(dimension) {}
-EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options)
- : options(options) {
+EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options) : options(options) {
reset();
}
weight.normal_(0, 1);
}
+void EmbeddingImpl::pretty_print(std::ostream& stream) const {
+ stream << "torch::nn::Embedding(count=" << options.count_
+ << ", dimension=" << options.dimension_ << ")";
+}
+
Tensor EmbeddingImpl::forward(const Tensor& input) {
return torch::embedding(weight, /*indices=*/input);
}
void FunctionalImpl::reset() {}
+void FunctionalImpl::pretty_print(std::ostream& stream) const {
+ stream << "torch::nn::Functional()";
+}
+
Tensor FunctionalImpl::forward(Tensor input) {
return function_(std::move(input));
}
}
}
+void LinearImpl::pretty_print(std::ostream& stream) const {
+ stream << std::boolalpha << "torch::nn::Linear(in=" << options.in_
+ << ", out=" << options.out_ << ", with_bias=" << options.with_bias_
+ << ")";
+}
+
Tensor LinearImpl::forward(const Tensor& input) {
AT_ASSERT(!options.with_bias_ || bias.defined());
return torch::linear(input, weight, bias);
}
template <typename Derived>
+void RNNImplBase<Derived>::pretty_print(std::ostream& stream) const {
+ const std::string name = this->name();
+ const std::string name_without_impl = name.substr(0, name.size() - 4);
+ stream << name_without_impl << "(input_size=" << options.input_size_
+ << ", hidden_size=" << options.hidden_size_
+ << ", layers=" << options.layers_ << ", dropout=" << options.dropout_
+ << ")";
+}
+
+template <typename Derived>
void RNNImplBase<Derived>::flatten_parameters() {
// Cache the flattened weight and bias vector.
flat_weights_ = flat_weights();
static_cast<CuDNNMode>(options.activation_)),
options(options) {}
+void RNNImpl::pretty_print(std::ostream& stream) const {
+ stream << "torch::nn::RNN(input_size=" << options.input_size_
+ << ", hidden_size=" << options.hidden_size_
+ << ", layers=" << options.layers_ << ", dropout=" << options.dropout_
+ << ", activation="
+ << (options.activation_ == RNNActivation::Tanh ? "tanh" : "relu")
+ << ")";
+}
+
RNNOutput RNNImpl::forward(const Tensor& input, Tensor state) {
switch (options.activation_) {
case RNNActivation::ReLU: