From 31ff0ecd2b400b4863741bcbc41748f2ad01745c Mon Sep 17 00:00:00 2001 From: Omegastick Date: Tue, 9 Apr 2019 10:36:13 -0700 Subject: [PATCH] Fix torch::nn::init::orthogonal_ with CNNs (#18915) Summary: Fixes #18518 I changed the C++ API torch::nn::init::orthogonal_ implementation to match the Python implementation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18915 Differential Revision: D14851833 Pulled By: ezyang fbshipit-source-id: 45b5e9741582777c203e9ebed564ab3ac1f94baf --- test/cpp/api/init.cpp | 6 ++++++ torch/csrc/api/src/nn/init.cpp | 2 +- torch/nn/init.py | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/test/cpp/api/init.cpp b/test/cpp/api/init.cpp index c4b2f97..5527d72 100644 --- a/test/cpp/api/init.cpp +++ b/test/cpp/api/init.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -123,4 +124,9 @@ TEST(InitTest, CalculateGainWithLeakyRelu) { double gain = torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::LeakyReLU); ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2)))); +} + +TEST(InitTest, CanInitializeCnnWithOrthogonal) { + torch::nn::Conv2d conv_layer(torch::nn::Conv2dOptions(3, 2, 3).stride(2)); + torch::nn::init::orthogonal_(conv_layer->named_parameters()["weight"]); } \ No newline at end of file diff --git a/torch/csrc/api/src/nn/init.cpp b/torch/csrc/api/src/nn/init.cpp index 187a252..7d64b9f 100644 --- a/torch/csrc/api/src/nn/init.cpp +++ b/torch/csrc/api/src/nn/init.cpp @@ -123,7 +123,7 @@ Tensor orthogonal_(Tensor tensor, double gain) { "Only tensors with 2 or more dimensions are supported"); const auto rows = tensor.size(0); - const auto columns = tensor.size(1); + const auto columns = tensor.numel() / rows; auto flattened = torch::randn({rows, columns}); if (rows < columns) { diff --git a/torch/nn/init.py b/torch/nn/init.py index 731cd72..583053e 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -345,7 +345,7 @@ def orthogonal_(tensor, gain=1): raise ValueError("Only tensors with 2 or more dimensions are supported") rows = tensor.size(0) - cols = tensor[0].numel() + cols = tensor.numel() // rows flattened = tensor.new(rows, cols).normal_(0, 1) if rows < cols: -- 2.7.4