Pretty printing of C++ modules (#15326)
authorPeter Goldsborough <psag@fb.com>
Thu, 20 Dec 2018 05:38:00 +0000 (21:38 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 05:55:49 +0000 (21:55 -0800)
Summary:
A long outstanding nicety: pretty printing of C++ modules. E.g.
```
  Sequential sequential(
      Linear(10, 3),
      Conv2d(1, 2, 3),
      Dropout(0.5),
      BatchNorm(5),
      Embedding(4, 10),
      LSTM(4, 5));
std::cout << sequential;
```
prints
```
torch::nn::Sequential(
  (0): torch::nn::Linear(in=10, out=3, with_bias=true)
  (1): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])
  (2): torch::nn::Dropout(rate=0.5)
  (3): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)
  (4): torch::nn::Embedding(count=4, dimension=10)
  (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)
)
```

apaszke ebetica ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15326

Differential Revision: D13518986

Pulled By: goldsborough

fbshipit-source-id: 63bf753672f0e348951de3645208f263581de5fb

23 files changed:
test/cpp/api/module.cpp
test/cpp/api/modules.cpp
test/cpp/api/rnn.cpp
test/cpp/api/sequential.cpp
torch/csrc/api/include/torch/expanding_array.h
torch/csrc/api/include/torch/nn/module.h
torch/csrc/api/include/torch/nn/modules/batchnorm.h
torch/csrc/api/include/torch/nn/modules/conv.h
torch/csrc/api/include/torch/nn/modules/dropout.h
torch/csrc/api/include/torch/nn/modules/embedding.h
torch/csrc/api/include/torch/nn/modules/functional.h
torch/csrc/api/include/torch/nn/modules/linear.h
torch/csrc/api/include/torch/nn/modules/rnn.h
torch/csrc/api/include/torch/nn/modules/sequential.h
torch/csrc/api/include/torch/nn/pimpl.h
torch/csrc/api/src/nn/module.cpp
torch/csrc/api/src/nn/modules/batchnorm.cpp
torch/csrc/api/src/nn/modules/conv.cpp
torch/csrc/api/src/nn/modules/dropout.cpp
torch/csrc/api/src/nn/modules/embedding.cpp
torch/csrc/api/src/nn/modules/functional.cpp
torch/csrc/api/src/nn/modules/linear.cpp
torch/csrc/api/src/nn/modules/rnn.cpp

index 2e77a8d..7721279 100644 (file)
@@ -832,3 +832,23 @@ TEST_F(ModuleTest, ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr) {
     ASSERT_NO_THROW(module->modules());
   }
 }
+
+struct EmptyModule : torch::nn::Module {};
+
+TEST_F(ModuleTest, PrettyPrint) {
+  struct TestModule : torch::nn::Module {
+    TestModule(int x, float y) : x_(x), y_(y) {}
+
+    void pretty_print(std::ostream& stream) const {
+      stream << "TestModule(x=" << x_ << ", y=" << y_ << ")";
+    }
+
+    int x_;
+    float y_;
+  };
+
+  using namespace torch::nn;
+
+  ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule");
+  ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)");
+}
index 6d240c4..6852d19 100644 (file)
@@ -325,3 +325,85 @@ TEST_F(ModulesTest, Linear2_CUDA) {
 
   ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
 }
+
+TEST_F(ModulesTest, PrettyPrintLinear) {
+  ASSERT_EQ(
+      c10::str(Linear(3, 4)), "torch::nn::Linear(in=3, out=4, with_bias=true)");
+}
+
+TEST_F(ModulesTest, PrettyPrintConv) {
+  ASSERT_EQ(
+      c10::str(Conv1d(3, 4, 5)),
+      "torch::nn::Conv1d(input_channels=3, output_channels=4, kernel_size=5, stride=1)");
+  ASSERT_EQ(
+      c10::str(Conv2d(3, 4, 5)),
+      "torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 5], stride=[1, 1])");
+  ASSERT_EQ(
+      c10::str(Conv2d(Conv2dOptions(3, 4, 5).stride(2))),
+      "torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 5], stride=[2, 2])");
+
+  const auto options = Conv2dOptions(3, 4, torch::IntList{5, 6}).stride({1, 2});
+  ASSERT_EQ(
+      c10::str(Conv2d(options)),
+      "torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 6], stride=[1, 2])");
+}
+
+TEST_F(ModulesTest, PrettyPrintDropout) {
+  ASSERT_EQ(c10::str(Dropout(0.5)), "torch::nn::Dropout(rate=0.5)");
+  ASSERT_EQ(
+      c10::str(FeatureDropout(0.5)), "torch::nn::FeatureDropout(rate=0.5)");
+}
+
+TEST_F(ModulesTest, PrettyPrintFunctional) {
+  ASSERT_EQ(c10::str(Functional(torch::relu)), "torch::nn::Functional()");
+}
+
+TEST_F(ModulesTest, PrettyPrintBatchNorm) {
+  ASSERT_EQ(
+      c10::str(BatchNorm(
+          BatchNormOptions(4).eps(0.5).momentum(0.1).affine(false).stateful(
+              true))),
+      "torch::nn::BatchNorm(features=4, eps=0.5, momentum=0.1, affine=false, stateful=true)");
+}
+
+TEST_F(ModulesTest, PrettyPrintEmbedding) {
+  ASSERT_EQ(
+      c10::str(Embedding(10, 2)),
+      "torch::nn::Embedding(count=10, dimension=2)");
+}
+
+TEST_F(ModulesTest, PrettyPrintNestedModel) {
+  struct InnerTestModule : torch::nn::Module {
+    InnerTestModule()
+        : torch::nn::Module("InnerTestModule"),
+          fc(register_module("fc", torch::nn::Linear(3, 4))),
+          table(register_module("table", torch::nn::Embedding(10, 2))) {}
+
+    torch::nn::Linear fc;
+    torch::nn::Embedding table;
+  };
+
+  struct TestModule : torch::nn::Module {
+    TestModule()
+        : torch::nn::Module("TestModule"),
+          fc(register_module("fc", torch::nn::Linear(4, 5))),
+          table(register_module("table", torch::nn::Embedding(10, 2))),
+          inner(register_module("inner", std::make_shared<InnerTestModule>())) {
+    }
+
+    torch::nn::Linear fc;
+    torch::nn::Embedding table;
+    std::shared_ptr<InnerTestModule> inner;
+  };
+
+  ASSERT_EQ(
+      c10::str(TestModule{}),
+      "TestModule(\n"
+      "  (fc): torch::nn::Linear(in=4, out=5, with_bias=true)\n"
+      "  (table): torch::nn::Embedding(count=10, dimension=2)\n"
+      "  (inner): InnerTestModule(\n"
+      "    (fc): torch::nn::Linear(in=3, out=4, with_bias=true)\n"
+      "    (table): torch::nn::Embedding(count=10, dimension=2)\n"
+      "  )\n"
+      ")");
+}
index 5c95851..86b05df 100644 (file)
@@ -227,3 +227,15 @@ TEST_F(RNNTest, EndToEndRNNTanh_CUDA) {
   ASSERT_TRUE(test_RNN_xor<RNN>(
       [](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }, true));
 }
+
+TEST_F(RNNTest, PrettyPrintRNNs) {
+  ASSERT_EQ(
+      c10::str(LSTM(LSTMOptions(128, 64).layers(3).dropout(0.2))),
+      "torch::nn::LSTM(input_size=128, hidden_size=64, layers=3, dropout=0.2)");
+  ASSERT_EQ(
+      c10::str(GRU(GRUOptions(128, 64).layers(3).dropout(0.5))),
+      "torch::nn::GRU(input_size=128, hidden_size=64, layers=3, dropout=0.5)");
+  ASSERT_EQ(
+      c10::str(RNN(RNNOptions(128, 64).layers(3).dropout(0.2).tanh())),
+      "torch::nn::RNN(input_size=128, hidden_size=64, layers=3, dropout=0.2, activation=tanh)");
+}
index f89d034..1f7cf55 100644 (file)
@@ -323,3 +323,23 @@ TEST_F(SequentialTest, CloneToDevice_CUDA) {
     ASSERT_EQ(b.device(), device);
   }
 }
+
+TEST_F(SequentialTest, PrettyPrintSequential) {
+  Sequential sequential(
+      Linear(10, 3),
+      Conv2d(1, 2, 3),
+      Dropout(0.5),
+      BatchNorm(5),
+      Embedding(4, 10),
+      LSTM(4, 5));
+  ASSERT_EQ(
+      c10::str(sequential),
+      "torch::nn::Sequential(\n"
+      "  (0): torch::nn::Linear(in=10, out=3, with_bias=true)\n"
+      "  (1): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
+      "  (2): torch::nn::Dropout(rate=0.5)\n"
+      "  (3): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"
+      "  (4): torch::nn::Embedding(count=4, dimension=10)\n"
+      "  (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
+      ")");
+}
index 9105900..a840a88 100644 (file)
@@ -24,18 +24,17 @@ class ExpandingArray {
   /// the length is checked against the `ExpandingArray`'s extent parameter `D`
   /// at runtime.
   /*implicit*/ ExpandingArray(std::initializer_list<T> list)
-      : ExpandingArray(std::vector<T>(list)) {}
+      : ExpandingArray(at::ArrayRef<T>(list)) {}
 
-  /// Constructs an `ExpandingArray` from a `vector`. The extent of the
-  /// length is checked against the `ExpandingArray`'s extent parameter `D` at
-  /// runtime.
-  /*implicit*/ ExpandingArray(const std::vector<T>& values) {
+  /// Constructs an `ExpandingArray` from an `initializer_list`. The extent of
+  /// the length is checked against the `ExpandingArray`'s extent parameter `D`
+  /// at runtime.
+  /*implicit*/ ExpandingArray(at::ArrayRef<T> values) {
+    // clang-format off
     AT_CHECK(
         values.size() == D,
-        "Expected ",
-        D,
-        " values, but instead got ",
-        values.size());
+        "Expected ", D, " values, but instead got ", values.size());
+    // clang-format on
     std::copy(values.begin(), values.end(), values_.begin());
   }
 
@@ -84,4 +83,13 @@ class ExpandingArray {
   std::array<T, D> values_;
 };
 
+template <size_t D, typename T>
+std::ostream& operator<<(
+    std::ostream& stream,
+    const ExpandingArray<D, T>& expanding_array) {
+  if (expanding_array.size() == 1) {
+    return stream << expanding_array->at(0);
+  }
+  return stream << static_cast<at::ArrayRef<T>>(expanding_array);
+}
 } // namespace torch
index bb3fa9e..91b77fb 100644 (file)
@@ -8,6 +8,7 @@
 #include <ATen/ATen.h>
 
 #include <functional>
+#include <iosfwd>
 #include <map>
 #include <memory>
 #include <string>
@@ -386,6 +387,15 @@ class TORCH_API Module : public std::enable_shared_from_this<Module> {
   /// Deserializes the `Module` from the given `InputArchive`.
   virtual void load(serialize::InputArchive& archive);
 
+  /// Streams a pretty representation of the `Module` into the given `stream`.
+  /// By default, this representation will be the name of the module (taken from
+  /// `name()`), followed by a recursive pretty print of all of the `Module`'s
+  /// submodules.
+  ///
+  /// Override this method to change the pretty print. The input
+  /// `stream` should be returned from the method, to allow easy chaining.
+  virtual void pretty_print(std::ostream& stream) const;
+
  protected:
   /// Registers a parameter with this `Module`.
   ///
@@ -462,6 +472,11 @@ class TORCH_API Module : public std::enable_shared_from_this<Module> {
   template <typename Derived>
   friend class Cloneable;
 
+  /// Pretty prints the given `Module` into the `ostream`.
+  TORCH_API friend std::ostream& operator<<(
+      std::ostream& stream,
+      const nn::Module& module);
+
   // Private methods.
 
   /// Used in the implementation of `Cloneable`.
@@ -471,6 +486,11 @@ class TORCH_API Module : public std::enable_shared_from_this<Module> {
   template <typename... Ts>
   void to_impl(Ts&&... ts);
 
+  /// Implements pretty printing the module hierarchy.
+  void pretty_print_recursive(
+      std::ostream& stream,
+      const std::string& indentation) const;
+
   /// Applies the `function` to every submodule recursively, starting at this
   /// `Module`'s children (thus not including the module itself).
   void apply_to_submodules(
index 33aabfd..2f7dd3b 100644 (file)
@@ -53,6 +53,9 @@ class TORCH_API BatchNormImpl : public torch::nn::Cloneable<BatchNormImpl> {
 
   void reset() override;
 
+  /// Pretty prints the `BatchNorm` module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override;
+
   /// Applies batch normalization on the `input` using the stored mean and
   /// variance.
   ///
index a242af1..38cf9b6 100644 (file)
@@ -88,6 +88,9 @@ class ConvImpl : public torch::nn::Cloneable<Derived> {
 
   void reset() override;
 
+  /// Pretty prints the `Conv{1,2,3}d` module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override;
+
   /// The options with which this `Module` was constructed.
   ConvOptions<D> options;
 
index 46fb0e3..edd7c62 100644 (file)
@@ -39,9 +39,13 @@ class DropoutImplBase : public torch::nn::Cloneable<Derived> {
 class TORCH_API DropoutImpl : public detail::DropoutImplBase<DropoutImpl> {
  public:
   using detail::DropoutImplBase<DropoutImpl>::DropoutImplBase;
+
   /// During training, applies a noise mask to the input tensor.
   /// During evaluation, applies an identity function.
   Tensor forward(const Tensor& input);
+
+  /// Pretty prints the `Dropout` module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override;
 };
 
 /// Applies spatial [Dropout](https://arxiv.org/abs/1207.0580) to inputs with
@@ -53,12 +57,17 @@ class TORCH_API DropoutImpl : public detail::DropoutImplBase<DropoutImpl> {
 /// [Dropout3d](https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout3d) for
 /// 3-D features. This `FeatureDropout` module can instead deal with both 2-D
 /// and 3-D features.
-class TORCH_API FeatureDropoutImpl : public detail::DropoutImplBase<FeatureDropoutImpl> {
+class TORCH_API FeatureDropoutImpl
+    : public detail::DropoutImplBase<FeatureDropoutImpl> {
  public:
   using detail::DropoutImplBase<FeatureDropoutImpl>::DropoutImplBase;
+
   /// During training, applies a noise mask to the input tensor.
   /// During evaluation, applies an identity function.
   Tensor forward(const Tensor& input);
+
+  /// Pretty prints the `FeatureDropout` module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override;
 };
 
 /// A `ModuleHolder` subclass for `DropoutImpl`.
index 1c94884..4c8e60d 100644 (file)
@@ -28,6 +28,9 @@ class TORCH_API EmbeddingImpl : public torch::nn::Cloneable<EmbeddingImpl> {
 
   void reset() override;
 
+  /// Pretty prints the `Embedding` module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override;
+
   /// Performs a lookup on the embedding table stored in `weight` using the
   /// `indices` supplied and returns the result.
   Tensor forward(const Tensor& indices);
index 553faf6..96378b9 100644 (file)
@@ -79,6 +79,9 @@ class TORCH_API FunctionalImpl : public torch::nn::Cloneable<FunctionalImpl> {
 
   void reset() override;
 
+  /// Pretty prints the `Functional` module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override;
+
   /// Forwards the `input` tensor to the underlying (bound) function object.
   Tensor forward(Tensor input);
 
index 3a11543..d11cd93 100644 (file)
@@ -29,6 +29,9 @@ class TORCH_API LinearImpl : public Cloneable<LinearImpl> {
 
   void reset() override;
 
+  /// Pretty prints the `Linear` module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override;
+
   /// Transforms the `input` tensor by multiplying with the `weight` and
   /// optionally adding the `bias`, if `with_bias` is true in the options.
   Tensor forward(const Tensor& input);
index c418ed8..62602dd 100644 (file)
@@ -74,6 +74,9 @@ class RNNImplBase : public torch::nn::Cloneable<Derived> {
   void to(torch::Dtype dtype, bool non_blocking = false) override;
   void to(torch::Device device, bool non_blocking = false) override;
 
+  /// Pretty prints the RNN module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override;
+
   /// Modifies the internal storage of weights for optimization purposes.
   ///
   /// On CPU, this method should be called if any of the weight or bias vectors
@@ -136,7 +139,7 @@ class RNNImplBase : public torch::nn::Cloneable<Derived> {
 
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-enum class RNNActivation : uint32_t TORCH_API { ReLU, Tanh };
+enum class RNNActivation : uint32_t TORCH_API{ReLU, Tanh};
 
 /// Options for RNN modules.
 struct TORCH_API RNNOptions {
@@ -177,6 +180,9 @@ class TORCH_API RNNImpl : public detail::RNNImplBase<RNNImpl> {
       : RNNImpl(RNNOptions(input_size, hidden_size)) {}
   explicit RNNImpl(const RNNOptions& options);
 
+  /// Pretty prints the `RNN` module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override;
+
   /// Applies the `RNN` module to an input sequence and input state.
   /// The `input` should follow a `(sequence, batch, features)` layout unless
   /// `batch_first` is true, in which case the layout should be `(batch,
index 32cd913..976e0e3 100644 (file)
@@ -11,6 +11,7 @@
 
 #include <cstdint>
 #include <memory>
+#include <ostream>
 #include <string>
 #include <type_traits>
 #include <utility>
@@ -116,6 +117,11 @@ class SequentialImpl : public Cloneable<SequentialImpl> {
   /// its own.
   void reset() override {}
 
+  /// Pretty prints the `Sequential` module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override {
+    stream << "torch::nn::Sequential";
+  }
+
   /// Feeds `inputs` to the first module and then chains outputs to inputs,
   /// returning the last output.
   ///
index 790a63d..c7cef93 100644 (file)
@@ -156,6 +156,14 @@ class ModuleHolder : torch::detail::ModuleHolderIndicator {
   }
 };
 
+/// Pretty prints the given `Module` into the `ostream`.
+template <typename ModuleType>
+std::ostream& operator<<(
+    std::ostream& stream,
+    const nn::ModuleHolder<ModuleType>& module) {
+  return stream << *module;
+}
+
 /// Serializes a `ModuleHolder` into an `OutputArchive`.
 template <typename ModuleType>
 serialize::OutputArchive& operator<<(
index de526b2..5c0629d 100644 (file)
@@ -9,6 +9,7 @@
 #include <algorithm>
 #include <functional>
 #include <map>
+#include <ostream>
 #include <string>
 #include <typeinfo>
 
@@ -321,6 +322,26 @@ Tensor& Module::register_buffer(std::string name, Tensor tensor) {
   return buffers_.insert(std::move(name), std::move(tensor));
 }
 
+void Module::pretty_print(std::ostream& stream) const {
+  stream << name();
+}
+
+void Module::pretty_print_recursive(
+    std::ostream& stream,
+    const std::string& indentation) const {
+  pretty_print(stream);
+  if (!children_.is_empty()) {
+    stream << "(\n";
+    const std::string next_indentation = indentation + "  ";
+    for (const auto& child : children_) {
+      stream << next_indentation << "(" << child.key() << "): ";
+      child.value()->pretty_print_recursive(stream, next_indentation);
+      stream << '\n';
+    }
+    stream << indentation << ")";
+  }
+}
+
 void Module::clone_(Module& other, const optional<Device>& device) {}
 
 void Module::apply_to_submodules(
@@ -351,6 +372,11 @@ std::shared_ptr<Module> Module::shared_from_this_checked() const {
   return std::const_pointer_cast<Module>(ptr);
 }
 
+std::ostream& operator<<(std::ostream& stream, const nn::Module& module) {
+  module.pretty_print_recursive(stream, "");
+  return stream;
+}
+
 serialize::OutputArchive& operator<<(
     serialize::OutputArchive& archive,
     const std::shared_ptr<nn::Module>& module) {
index cb1e91d..fa2789c 100644 (file)
@@ -6,6 +6,7 @@
 #include <c10/util/Exception.h>
 
 #include <cstddef>
+#include <ostream>
 #include <utility>
 #include <vector>
 
@@ -32,6 +33,14 @@ void BatchNormImpl::reset() {
   }
 }
 
+void BatchNormImpl::pretty_print(std::ostream& stream) const {
+  stream << std::boolalpha
+         << "torch::nn::BatchNorm(features=" << options.features_
+         << ", eps=" << options.eps_ << ", momentum=" << options.momentum_
+         << ", affine=" << options.affine_ << ", stateful=" << options.stateful_
+         << ")";
+}
+
 Tensor BatchNormImpl::forward(const Tensor& input) {
   AT_CHECK(
       options.stateful_,
index 581d38d..4214743 100644 (file)
@@ -59,6 +59,15 @@ void ConvImpl<D, Derived>::reset() {
   }
 }
 
+template <size_t D, typename Derived>
+void ConvImpl<D, Derived>::pretty_print(std::ostream& stream) const {
+  stream << "torch::nn::Conv" << D << "d"
+         << "(input_channels=" << options.input_channels_
+         << ", output_channels=" << options.output_channels_
+         << ", kernel_size=" << options.kernel_size_
+         << ", stride=" << options.stride_ << ")";
+}
+
 Tensor Conv1dImpl::forward(const Tensor& input) {
   if (options.transposed_) {
     return torch::conv_transpose1d(
index 2ec5400..c068f70 100644 (file)
@@ -5,6 +5,7 @@
 #include <c10/util/Exception.h>
 
 #include <cstddef>
+#include <ostream>
 #include <vector>
 
 namespace torch {
@@ -30,8 +31,16 @@ Tensor DropoutImpl::forward(const Tensor& input) {
   return torch::dropout(input, options.rate_, this->is_training());
 }
 
+void DropoutImpl::pretty_print(std::ostream& stream) const {
+  stream << "torch::nn::Dropout(rate=" << options.rate_ << ")";
+}
+
 Tensor FeatureDropoutImpl::forward(const Tensor& input) {
   return torch::feature_dropout(input, options.rate_, this->is_training());
 }
+
+void FeatureDropoutImpl::pretty_print(std::ostream& stream) const {
+  stream << "torch::nn::FeatureDropout(rate=" << options.rate_ << ")";
+}
 } // namespace nn
 } // namespace torch
index be972e5..786b427 100644 (file)
@@ -4,6 +4,7 @@
 #include <torch/utils.h>
 
 #include <cstddef>
+#include <ostream>
 #include <utility>
 #include <vector>
 
@@ -13,8 +14,7 @@ namespace nn {
 EmbeddingOptions::EmbeddingOptions(int64_t count, int64_t dimension)
     : count_(count), dimension_(dimension) {}
 
-EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options)
-    : options(options) {
+EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options) : options(options) {
   reset();
 }
 
@@ -25,6 +25,11 @@ void EmbeddingImpl::reset() {
   weight.normal_(0, 1);
 }
 
+void EmbeddingImpl::pretty_print(std::ostream& stream) const {
+  stream << "torch::nn::Embedding(count=" << options.count_
+         << ", dimension=" << options.dimension_ << ")";
+}
+
 Tensor EmbeddingImpl::forward(const Tensor& input) {
   return torch::embedding(weight, /*indices=*/input);
 }
index e35a2e6..0da0b74 100644 (file)
@@ -12,6 +12,10 @@ FunctionalImpl::FunctionalImpl(Function function)
 
 void FunctionalImpl::reset() {}
 
+void FunctionalImpl::pretty_print(std::ostream& stream) const {
+  stream << "torch::nn::Functional()";
+}
+
 Tensor FunctionalImpl::forward(Tensor input) {
   return function_(std::move(input));
 }
index f9ec0d1..8cd6842 100644 (file)
@@ -28,6 +28,12 @@ void LinearImpl::reset() {
   }
 }
 
+void LinearImpl::pretty_print(std::ostream& stream) const {
+  stream << std::boolalpha << "torch::nn::Linear(in=" << options.in_
+         << ", out=" << options.out_ << ", with_bias=" << options.with_bias_
+         << ")";
+}
+
 Tensor LinearImpl::forward(const Tensor& input) {
   AT_ASSERT(!options.with_bias_ || bias.defined());
   return torch::linear(input, weight, bias);
index 632e4d4..2707536 100644 (file)
@@ -98,6 +98,16 @@ void RNNImplBase<Derived>::to(torch::Device device, bool non_blocking) {
 }
 
 template <typename Derived>
+void RNNImplBase<Derived>::pretty_print(std::ostream& stream) const {
+  const std::string name = this->name();
+  const std::string name_without_impl = name.substr(0, name.size() - 4);
+  stream << name_without_impl << "(input_size=" << options.input_size_
+         << ", hidden_size=" << options.hidden_size_
+         << ", layers=" << options.layers_ << ", dropout=" << options.dropout_
+         << ")";
+}
+
+template <typename Derived>
 void RNNImplBase<Derived>::flatten_parameters() {
   // Cache the flattened weight and bias vector.
   flat_weights_ = flat_weights();
@@ -203,6 +213,15 @@ RNNImpl::RNNImpl(const RNNOptions& options)
           static_cast<CuDNNMode>(options.activation_)),
       options(options) {}
 
+void RNNImpl::pretty_print(std::ostream& stream) const {
+  stream << "torch::nn::RNN(input_size=" << options.input_size_
+         << ", hidden_size=" << options.hidden_size_
+         << ", layers=" << options.layers_ << ", dropout=" << options.dropout_
+         << ", activation="
+         << (options.activation_ == RNNActivation::Tanh ? "tanh" : "relu")
+         << ")";
+}
+
 RNNOutput RNNImpl::forward(const Tensor& input, Tensor state) {
   switch (options.activation_) {
     case RNNActivation::ReLU: