TEST_F(ModuleTest, ZeroGrad) {
Linear module(3, 4);
auto weight = torch::ones({8, 3}, torch::requires_grad());
- auto loss = module->forward(weight).sum();
+ auto loss = module(weight).sum();
loss.backward();
for (auto& parameter : module->parameters()) {
auto grad = parameter.grad();
ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule");
ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)");
}
+
+struct ModuleWithNonTensorForwardImpl : torch::nn::Module {
+ int64_t forward(torch::Tensor x) {
+ return x.numel();
+ }
+};
+TORCH_MODULE(ModuleWithNonTensorForward);
+
+TEST_F(ModuleTest, CanCallForwardOnNonTensorForwardThroughPimpl) {
+ ModuleWithNonTensorForward m;
+ ASSERT_EQ(m(torch::ones(123)), 123);
+}
TEST_F(ModulesTest, Conv1d) {
Conv1d model(Conv1dOptions(3, 2, 3).stride(2));
auto x = torch::randn({2, 3, 5}, torch::requires_grad());
- auto y = model->forward(x);
+ auto y = model(x);
torch::Tensor s = y.sum();
s.backward();
TEST_F(ModulesTest, Conv2dEven) {
Conv2d model(Conv2dOptions(3, 2, 3).stride(2));
auto x = torch::randn({2, 3, 5, 5}, torch::requires_grad());
- auto y = model->forward(x);
+ auto y = model(x);
torch::Tensor s = y.sum();
s.backward();
TEST_F(ModulesTest, Conv2dUneven) {
Conv2d model(Conv2dOptions(3, 2, {3, 2}).stride({2, 2}));
auto x = torch::randn({2, 3, 5, 4}, torch::requires_grad());
- auto y = model->forward(x);
+ auto y = model(x);
torch::Tensor s = y.sum();
s.backward();
TEST_F(ModulesTest, Conv3d) {
Conv3d model(Conv3dOptions(3, 2, 3).stride(2));
auto x = torch::randn({2, 3, 5, 5, 5}, torch::requires_grad());
- auto y = model->forward(x);
+ auto y = model(x);
torch::Tensor s = y.sum();
s.backward();
TEST_F(ModulesTest, Linear) {
Linear model(5, 2);
auto x = torch::randn({10, 5}, torch::requires_grad());
- auto y = model->forward(x);
+ auto y = model(x);
torch::Tensor s = y.sum();
s.backward();
auto l3 = model->add(Linear(5, 100), "l3");
auto x = torch::randn({1000, 10}, torch::requires_grad());
- x = l1->forward(x).clamp_min(0);
- x = l2->forward(x).clamp_min(0);
- x = l3->forward(x).clamp_min(0);
+ x = l1(x).clamp_min(0);
+ x = l2(x).clamp_min(0);
+ x = l3(x).clamp_min(0);
x.backward();
ASSERT_EQ(x.ndimension(), 2);
// Cannot get gradients to change indices (input) - only for embedding
// params
auto x = torch::full({10}, dict_size - 1, torch::kInt64);
- auto y = model->forward(x);
+ auto y = model(x);
torch::Tensor s = y.sum();
s.backward();
TEST_F(ModulesTest, EmbeddingList) {
Embedding model(6, 4);
auto x = torch::full({2, 3}, 5, torch::kInt64);
- auto y = model->forward(x);
+ auto y = model(x);
torch::Tensor s = y.sum();
s.backward();
TEST_F(ModulesTest, Dropout) {
Dropout dropout(0.5);
torch::Tensor x = torch::ones(100, torch::requires_grad());
- torch::Tensor y = dropout->forward(x);
+ torch::Tensor y = dropout(x);
y.backward();
ASSERT_EQ(y.ndimension(), 1);
ASSERT_GT(y.sum().item<float>(), 70); // Probably
dropout->eval();
- y = dropout->forward(x);
+ y = dropout(x);
ASSERT_EQ(y.sum().item<float>(), 100);
}
was_called = true;
return input;
});
- auto output = functional->forward(torch::ones(5, torch::requires_grad()));
+ auto output = functional(torch::ones(5, torch::requires_grad()));
ASSERT_TRUE(was_called);
ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad())));
ASSERT_FALSE(bn->bias.defined());
ASSERT_THROWS_WITH(
- bn->forward(torch::ones({2, 5})),
+ bn(torch::ones({2, 5})),
"Calling BatchNorm::forward is only permitted "
"when the 'stateful' option is true (was false). "
"Use BatchNorm::pure_forward instead.");
model->to(torch::kCUDA);
auto x =
torch::randn({10, 5}, torch::device(torch::kCUDA).requires_grad(true));
- auto y = model->forward(x);
+ auto y = model(x);
torch::Tensor s = y.sum();
s.backward();
model->to(torch::kCUDA);
model->to(torch::kCPU);
auto x = torch::randn({10, 5}, torch::requires_grad());
- auto y = model->forward(x);
+ auto y = model(x);
torch::Tensor s = y.sum();
s.backward();
ASSERT_FALSE(torch::detail::check_not_lvalue_references<std::string&>());
}
+struct A : torch::nn::Module {
+ int forward() {
+ return 5;
+ }
+};
+
+struct B : torch::nn::Module {
+ std::string forward(torch::Tensor tensor) {
+ return "";
+ }
+};
+
+struct C : torch::nn::Module {
+ float forward(torch::Tensor& tensor) {
+ return 5.0;
+ }
+};
+
+struct D : torch::nn::Module {
+ char forward(torch::Tensor&& tensor) {
+ return 'x';
+ }
+};
+
+struct E : torch::nn::Module {};
+
+// Put in a function because macros don't handle the comma between arguments to
+// is_same well ...
+template <typename Module, typename ExpectedType, typename... Args>
+void assert_has_expected_type() {
+ using ReturnType =
+ typename torch::detail::return_type_of_forward<Module, Args...>::type;
+ constexpr bool is_expected_type =
+ std::is_same<ReturnType, ExpectedType>::value;
+ ASSERT_TRUE(is_expected_type) << Module().name();
+}
+
+TEST(TestStatic, ReturnTypeOfForward) {
+ assert_has_expected_type<A, int>();
+ assert_has_expected_type<B, std::string, torch::Tensor>();
+ assert_has_expected_type<C, float, torch::Tensor&>();
+ assert_has_expected_type<D, char, torch::Tensor&&>();
+ assert_has_expected_type<E, void>();
+}
+
TEST(TestStatic, Apply) {
std::vector<int> v;
torch::apply([&v](int x) { v.push_back(x); }, 1, 2, 3, 4, 5);
is_module_holder<T>::value,
decay_t<T>,
decay_t<C>> {};
+
+// A collection of templates that allow deducing the return type of the
+// `forward()` method, but only if a module actually has a `forward()` method,
+// and otherwise deduces to the type `void`.
+
+template <bool has_forward_value, typename C, typename... Args>
+struct return_type_of_forward_impl;
+
+template <typename C, typename... Args>
+struct return_type_of_forward_impl<true, C, Args...> {
+ using type = decltype(::std::declval<C>().forward(::std::declval<Args>()...));
+};
+
+template <typename C, typename... Args>
+struct return_type_of_forward_impl<false, C, Args...> {
+ using type = void;
+};
+
+template <typename C, typename... Args>
+using return_type_of_forward = return_type_of_forward_impl<
+ torch::detail::has_forward<C>::value,
+ C,
+ Args...>;
+
+template <typename C, typename... Args>
+using return_type_of_forward_t =
+ typename return_type_of_forward<C, Args...>::type;