Fix torch::nn::init::orthogonal_ with CNNs (#18915)
authorOmegastick <omegastick@hotmail.co.uk>
Tue, 9 Apr 2019 17:36:13 +0000 (10:36 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 9 Apr 2019 17:39:15 +0000 (10:39 -0700)
Summary:
Fixes #18518

I changed the C++ API torch::nn::init::orthogonal_ implementation to match the Python implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18915

Differential Revision: D14851833

Pulled By: ezyang

fbshipit-source-id: 45b5e9741582777c203e9ebed564ab3ac1f94baf

test/cpp/api/init.cpp
torch/csrc/api/src/nn/init.cpp
torch/nn/init.py

index c4b2f97..5527d72 100644 (file)
@@ -2,6 +2,7 @@
 
 #include <torch/nn/init.h>
 #include <torch/nn/modules/linear.h>
+#include <torch/nn/modules/conv.h>
 
 #include <test/cpp/api/init_baseline.h>
 #include <test/cpp/api/support.h>
@@ -123,4 +124,9 @@ TEST(InitTest, CalculateGainWithLeakyRelu) {
   double gain =
       torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::LeakyReLU);
   ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2))));
+}
+
+TEST(InitTest, CanInitializeCnnWithOrthogonal) {
+  torch::nn::Conv2d conv_layer(torch::nn::Conv2dOptions(3, 2, 3).stride(2));
+  torch::nn::init::orthogonal_(conv_layer->named_parameters()["weight"]);
 }
\ No newline at end of file
index 187a252..7d64b9f 100644 (file)
@@ -123,7 +123,7 @@ Tensor orthogonal_(Tensor tensor, double gain) {
       "Only tensors with 2 or more dimensions are supported");
 
   const auto rows = tensor.size(0);
-  const auto columns = tensor.size(1);
+  const auto columns = tensor.numel() / rows;
   auto flattened = torch::randn({rows, columns});
 
   if (rows < columns) {
index 731cd72..583053e 100644 (file)
@@ -345,7 +345,7 @@ def orthogonal_(tensor, gain=1):
         raise ValueError("Only tensors with 2 or more dimensions are supported")
 
     rows = tensor.size(0)
-    cols = tensor[0].numel()
+    cols = tensor.numel() // rows
     flattened = tensor.new(rows, cols).normal_(0, 1)
 
     if rows < cols: