Make call operator on module holder call forward (#15831)
authorPeter Goldsborough <psag@fb.com>
Mon, 14 Jan 2019 22:32:32 +0000 (14:32 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 14 Jan 2019 22:40:33 +0000 (14:40 -0800)
Summary:
In Python, you can use the call operator to invoke the `forward()` method of a module. In C++ this was currently not possible, because I couldn't figure out how to deduce the return type of a module's `forward()` method under the constraint that `forward()` may not exist at all (since the base module class in C++ does not mandate a `forward()` method). I now figured it out, so the call operator can be used.

ezyang ebetica
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15831

Differential Revision: D13652676

Pulled By: goldsborough

fbshipit-source-id: ccab45a15215dda56460e560f0038781b539135f

test/cpp/api/module.cpp
test/cpp/api/modules.cpp
test/cpp/api/static.cpp
torch/csrc/api/include/torch/nn/pimpl-inl.h
torch/csrc/api/include/torch/nn/pimpl.h

index 7721279..557b7de 100644 (file)
@@ -37,7 +37,7 @@ TEST_F(ModuleTest, CanEnableAndDisableTrainingMode) {
 TEST_F(ModuleTest, ZeroGrad) {
   Linear module(3, 4);
   auto weight = torch::ones({8, 3}, torch::requires_grad());
-  auto loss = module->forward(weight).sum();
+  auto loss = module(weight).sum();
   loss.backward();
   for (auto& parameter : module->parameters()) {
     auto grad = parameter.grad();
@@ -852,3 +852,15 @@ TEST_F(ModuleTest, PrettyPrint) {
   ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule");
   ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)");
 }
+
+struct ModuleWithNonTensorForwardImpl : torch::nn::Module {
+  int64_t forward(torch::Tensor x) {
+    return x.numel();
+  }
+};
+TORCH_MODULE(ModuleWithNonTensorForward);
+
+TEST_F(ModuleTest, CanCallForwardOnNonTensorForwardThroughPimpl) {
+  ModuleWithNonTensorForward m;
+  ASSERT_EQ(m(torch::ones(123)), 123);
+}
index 6852d19..ca2ab8c 100644 (file)
@@ -42,7 +42,7 @@ struct ModulesTest : torch::test::SeedingFixture {};
 TEST_F(ModulesTest, Conv1d) {
   Conv1d model(Conv1dOptions(3, 2, 3).stride(2));
   auto x = torch::randn({2, 3, 5}, torch::requires_grad());
-  auto y = model->forward(x);
+  auto y = model(x);
   torch::Tensor s = y.sum();
 
   s.backward();
@@ -58,7 +58,7 @@ TEST_F(ModulesTest, Conv1d) {
 TEST_F(ModulesTest, Conv2dEven) {
   Conv2d model(Conv2dOptions(3, 2, 3).stride(2));
   auto x = torch::randn({2, 3, 5, 5}, torch::requires_grad());
-  auto y = model->forward(x);
+  auto y = model(x);
   torch::Tensor s = y.sum();
 
   s.backward();
@@ -74,7 +74,7 @@ TEST_F(ModulesTest, Conv2dEven) {
 TEST_F(ModulesTest, Conv2dUneven) {
   Conv2d model(Conv2dOptions(3, 2, {3, 2}).stride({2, 2}));
   auto x = torch::randn({2, 3, 5, 4}, torch::requires_grad());
-  auto y = model->forward(x);
+  auto y = model(x);
   torch::Tensor s = y.sum();
 
   s.backward();
@@ -90,7 +90,7 @@ TEST_F(ModulesTest, Conv2dUneven) {
 TEST_F(ModulesTest, Conv3d) {
   Conv3d model(Conv3dOptions(3, 2, 3).stride(2));
   auto x = torch::randn({2, 3, 5, 5, 5}, torch::requires_grad());
-  auto y = model->forward(x);
+  auto y = model(x);
   torch::Tensor s = y.sum();
 
   s.backward();
@@ -106,7 +106,7 @@ TEST_F(ModulesTest, Conv3d) {
 TEST_F(ModulesTest, Linear) {
   Linear model(5, 2);
   auto x = torch::randn({10, 5}, torch::requires_grad());
-  auto y = model->forward(x);
+  auto y = model(x);
   torch::Tensor s = y.sum();
 
   s.backward();
@@ -125,9 +125,9 @@ TEST_F(ModulesTest, SimpleContainer) {
   auto l3 = model->add(Linear(5, 100), "l3");
 
   auto x = torch::randn({1000, 10}, torch::requires_grad());
-  x = l1->forward(x).clamp_min(0);
-  x = l2->forward(x).clamp_min(0);
-  x = l3->forward(x).clamp_min(0);
+  x = l1(x).clamp_min(0);
+  x = l2(x).clamp_min(0);
+  x = l3(x).clamp_min(0);
 
   x.backward();
   ASSERT_EQ(x.ndimension(), 2);
@@ -147,7 +147,7 @@ TEST_F(ModulesTest, EmbeddingBasic) {
   // Cannot get gradients to change indices (input) - only for embedding
   // params
   auto x = torch::full({10}, dict_size - 1, torch::kInt64);
-  auto y = model->forward(x);
+  auto y = model(x);
   torch::Tensor s = y.sum();
 
   s.backward();
@@ -162,7 +162,7 @@ TEST_F(ModulesTest, EmbeddingBasic) {
 TEST_F(ModulesTest, EmbeddingList) {
   Embedding model(6, 4);
   auto x = torch::full({2, 3}, 5, torch::kInt64);
-  auto y = model->forward(x);
+  auto y = model(x);
   torch::Tensor s = y.sum();
 
   s.backward();
@@ -175,7 +175,7 @@ TEST_F(ModulesTest, EmbeddingList) {
 TEST_F(ModulesTest, Dropout) {
   Dropout dropout(0.5);
   torch::Tensor x = torch::ones(100, torch::requires_grad());
-  torch::Tensor y = dropout->forward(x);
+  torch::Tensor y = dropout(x);
 
   y.backward();
   ASSERT_EQ(y.ndimension(), 1);
@@ -184,7 +184,7 @@ TEST_F(ModulesTest, Dropout) {
   ASSERT_GT(y.sum().item<float>(), 70); // Probably
 
   dropout->eval();
-  y = dropout->forward(x);
+  y = dropout(x);
   ASSERT_EQ(y.sum().item<float>(), 100);
 }
 
@@ -214,7 +214,7 @@ TEST_F(ModulesTest, FunctionalCallsSuppliedFunction) {
     was_called = true;
     return input;
   });
-  auto output = functional->forward(torch::ones(5, torch::requires_grad()));
+  auto output = functional(torch::ones(5, torch::requires_grad()));
   ASSERT_TRUE(was_called);
   ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad())));
 
@@ -272,7 +272,7 @@ TEST_F(ModulesTest, BatchNormStateless) {
   ASSERT_FALSE(bn->bias.defined());
 
   ASSERT_THROWS_WITH(
-      bn->forward(torch::ones({2, 5})),
+      bn(torch::ones({2, 5})),
       "Calling BatchNorm::forward is only permitted "
       "when the 'stateful' option is true (was false). "
       "Use BatchNorm::pure_forward instead.");
@@ -297,7 +297,7 @@ TEST_F(ModulesTest, Linear_CUDA) {
   model->to(torch::kCUDA);
   auto x =
       torch::randn({10, 5}, torch::device(torch::kCUDA).requires_grad(true));
-  auto y = model->forward(x);
+  auto y = model(x);
   torch::Tensor s = y.sum();
 
   s.backward();
@@ -314,7 +314,7 @@ TEST_F(ModulesTest, Linear2_CUDA) {
   model->to(torch::kCUDA);
   model->to(torch::kCPU);
   auto x = torch::randn({10, 5}, torch::requires_grad());
-  auto y = model->forward(x);
+  auto y = model(x);
   torch::Tensor s = y.sum();
 
   s.backward();
index 475604e..a4cae7f 100644 (file)
@@ -49,6 +49,51 @@ TEST(TestStatic, EnableIfModule) {
   ASSERT_FALSE(torch::detail::check_not_lvalue_references<std::string&>());
 }
 
+struct A : torch::nn::Module {
+  int forward() {
+    return 5;
+  }
+};
+
+struct B : torch::nn::Module {
+  std::string forward(torch::Tensor tensor) {
+    return "";
+  }
+};
+
+struct C : torch::nn::Module {
+  float forward(torch::Tensor& tensor) {
+    return 5.0;
+  }
+};
+
+struct D : torch::nn::Module {
+  char forward(torch::Tensor&& tensor) {
+    return 'x';
+  }
+};
+
+struct E : torch::nn::Module {};
+
+// Put in a function because macros don't handle the comma between arguments to
+// is_same well ...
+template <typename Module, typename ExpectedType, typename... Args>
+void assert_has_expected_type() {
+  using ReturnType =
+      typename torch::detail::return_type_of_forward<Module, Args...>::type;
+  constexpr bool is_expected_type =
+      std::is_same<ReturnType, ExpectedType>::value;
+  ASSERT_TRUE(is_expected_type) << Module().name();
+}
+
+TEST(TestStatic, ReturnTypeOfForward) {
+  assert_has_expected_type<A, int>();
+  assert_has_expected_type<B, std::string, torch::Tensor>();
+  assert_has_expected_type<C, float, torch::Tensor&>();
+  assert_has_expected_type<D, char, torch::Tensor&&>();
+  assert_has_expected_type<E, void>();
+}
+
 TEST(TestStatic, Apply) {
   std::vector<int> v;
   torch::apply([&v](int x) { v.push_back(x); }, 1, 2, 3, 4, 5);
index f632ef8..b38e6cf 100644 (file)
@@ -45,3 +45,30 @@ struct is_module_holder_of : is_module_holder_of_impl<
                                  is_module_holder<T>::value,
                                  decay_t<T>,
                                  decay_t<C>> {};
+
+// A collection of templates that allow deducing the return type of the
+// `forward()` method, but only if a module actually has a `forward()` method,
+// and otherwise deduces to the type `void`.
+
+template <bool has_forward_value, typename C, typename... Args>
+struct return_type_of_forward_impl;
+
+template <typename C, typename... Args>
+struct return_type_of_forward_impl<true, C, Args...> {
+  using type = decltype(::std::declval<C>().forward(::std::declval<Args>()...));
+};
+
+template <typename C, typename... Args>
+struct return_type_of_forward_impl<false, C, Args...> {
+  using type = void;
+};
+
+template <typename C, typename... Args>
+using return_type_of_forward = return_type_of_forward_impl<
+    torch::detail::has_forward<C>::value,
+    C,
+    Args...>;
+
+template <typename C, typename... Args>
+using return_type_of_forward_t =
+    typename return_type_of_forward<C, Args...>::type;
index bf09bd6..e3cda20 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include <torch/arg.h>
+#include <torch/detail/static.h>
 #include <torch/serialize/archive.h>
 #include <torch/types.h>
 
@@ -113,13 +114,15 @@ class ModuleHolder : torch::detail::ModuleHolderIndicator {
     return impl_.get();
   }
 
-  /// Forwards to the call operator of the contained module.
-  /// NOTE: std::forward is qualified to prevent VS2017 emitting
-  ///       error C2872: 'std': ambiguous symbol
+  /// Calls the `forward()` method of the contained module.
   template <typename... Args>
   auto operator()(Args&&... args)
-      -> decltype((*impl_)(::std::forward<Args>(args)...)) {
-    return (*impl_)(::std::forward<Args>(args)...);
+      -> torch::detail::return_type_of_forward_t<Contained, Args...> {
+    // This will not compile if the module does not have a `forward()` method
+    // (as expected).
+    // NOTE: `std::forward` is qualified to prevent VS2017 emitting
+    // error C2872: 'std': ambiguous symbol
+    return impl_->forward(::std::forward<Args>(args)...);
   }
 
   /// Forwards to the subscript operator of the contained module.