Kaiming Initialization (#14718)
authorJosh Varty <joshvarty@gmail.com>
Fri, 15 Feb 2019 22:51:56 +0000 (14:51 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 15 Feb 2019 22:58:22 +0000 (14:58 -0800)
Summary:
/cc goldsborough

Working on #14582

The corresponding python implementations are at: [pytorch/torch/nn/init.py](https://github.com/pytorch/pytorch/blob/6302e4001ab54b3ddeca2b608d337fe7077e801c/torch/nn/init.py#L261-L327)

Here is my initial implementation of Kaiming Initialization. I have not been able to figure out how to successfully run tests locally so I haven't added any yet.

A couple questions:
- Are the enums defined in the right place? I copied their names from Python, but do you prefer different naming conventions for C++?
- To run tests locally do I use `python setup.py test`? Can I run just a subset of the tests somehow?
- Should I add my tests at [test/cpp/api/misc.cpp](https://github.com/pytorch/pytorch/blob/master/test/cpp/api/misc.cpp#L47-L54)?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14718

Differential Revision: D14049159

Pulled By: goldsborough

fbshipit-source-id: 966ac5126875936e69b185b5041f16476ed4cf70

test/cpp/api/CMakeLists.txt
test/cpp/api/init.cpp [new file with mode: 0644]
test/cpp/api/init_baseline.h [new file with mode: 0644]
test/cpp/api/init_baseline.py [new file with mode: 0644]
test/cpp/api/misc.cpp
torch/csrc/api/include/torch/nn/init.h
torch/csrc/api/src/nn/init.cpp

index aa9f01a..daf50e5 100644 (file)
@@ -5,6 +5,7 @@ set(TORCH_API_TEST_SOURCES
   ${TORCH_API_TEST_DIR}/dataloader.cpp
   ${TORCH_API_TEST_DIR}/expanding-array.cpp
   ${TORCH_API_TEST_DIR}/integration.cpp
+  ${TORCH_API_TEST_DIR}/init.cpp
   ${TORCH_API_TEST_DIR}/jit.cpp
   ${TORCH_API_TEST_DIR}/memory.cpp
   ${TORCH_API_TEST_DIR}/misc.cpp
diff --git a/test/cpp/api/init.cpp b/test/cpp/api/init.cpp
new file mode 100644 (file)
index 0000000..c4b2f97
--- /dev/null
@@ -0,0 +1,126 @@
+#include <gtest/gtest.h>
+
+#include <torch/nn/init.h>
+#include <torch/nn/modules/linear.h>
+
+#include <test/cpp/api/init_baseline.h>
+#include <test/cpp/api/support.h>
+
+#include <functional>
+#include <vector>
+
+void check_exact_values(
+    const std::vector<torch::Tensor>& parameters,
+    const std::vector<std::vector<torch::Tensor>>& expected_parameters) {
+  ASSERT_EQ(parameters.size(), expected_parameters.size());
+
+  for (size_t i = 0; i < parameters.size(); i++) {
+    auto layerParameters = parameters[i];
+    auto expectedLayerParameters = expected_parameters[i];
+
+    if (layerParameters.size(0) != expectedLayerParameters.size()) {
+      std::cout << "layer #" << i
+                << " layerParameters size: " << layerParameters.size(0)
+                << " != "
+                << " expectedLayerParameters size: "
+                << expectedLayerParameters.size() << std::endl;
+      ASSERT_TRUE(false);
+    }
+
+    for (size_t p = 0; p < layerParameters.size(0); p++) {
+      auto tensor = layerParameters[p];
+      auto expectedTensor = expectedLayerParameters[p];
+
+      if (!tensor.allclose(expectedTensor, /*rtol=*/1e-3, /*atol=*/5e-4)) {
+        std::cout << "layer " << i << ": " << tensor << " != " << expectedTensor
+                  << " (parameter " << p << ")" << std::endl;
+        ASSERT_TRUE(false);
+      }
+    }
+  }
+}
+
+void check_initializer_against_baseline(
+    std::function<void(torch::Tensor)> initializer,
+    std::vector<std::vector<torch::Tensor>> expected) {
+  torch::manual_seed(0);
+
+  auto layer1 = torch::nn::Linear(7, 15);
+  initializer(layer1->weight);
+  layer1->to(torch::kFloat64);
+
+  auto layer2 = torch::nn::Linear(15, 15);
+  initializer(layer2->weight);
+  layer2->to(torch::kFloat64);
+
+  auto layer3 = torch::nn::Linear(15, 2);
+  initializer(layer3->weight);
+  layer3->to(torch::kFloat64);
+
+  auto parameters = std::vector<torch::Tensor>{
+      layer1->weight,
+      layer2->weight,
+      layer3->weight,
+  };
+
+  check_exact_values(parameters, expected);
+}
+
+TEST(InitTest, ProducesPyTorchValues_XavierUniform) {
+  auto expected = expected_parameters::Xavier_Uniform();
+  auto initializer = [](torch::Tensor tensor) {
+    torch::nn::init::xavier_uniform_(tensor);
+  };
+  check_initializer_against_baseline(initializer, expected);
+}
+
+TEST(InitTest, ProducesPyTorchValues_XavierNormal) {
+  auto expected = expected_parameters::Xavier_Normal();
+  auto initializer = [](torch::Tensor tensor) {
+    torch::nn::init::xavier_normal_(tensor);
+  };
+  check_initializer_against_baseline(initializer, expected);
+}
+
+TEST(InitTest, ProducesPyTorchValues_KaimingNormal) {
+  auto expected = expected_parameters::Kaiming_Normal();
+  auto initializer = [](torch::Tensor tensor) {
+    torch::nn::init::kaiming_normal_(tensor);
+  };
+  check_initializer_against_baseline(initializer, expected);
+}
+
+TEST(InitTest, ProducesPyTorchValues_KaimingUniform) {
+  auto expected = expected_parameters::Kaiming_Uniform();
+  auto initializer = [](torch::Tensor tensor) {
+    torch::nn::init::kaiming_uniform_(tensor);
+  };
+  check_initializer_against_baseline(initializer, expected);
+}
+
+TEST(InitTest, CanInitializeTensorThatRequiresGrad) {
+  auto tensor = torch::empty({3, 4}, torch::requires_grad());
+  ASSERT_THROWS_WITH(
+      tensor.fill_(1),
+      "a leaf Variable that requires grad "
+      "has been used in an in-place operation");
+  ASSERT_EQ(torch::nn::init::ones_(tensor).sum().item<int32_t>(), 12);
+}
+
+TEST(InitTest, CalculateGainWithTanh) {
+  double gain =
+      torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::Tanh);
+  ASSERT_DOUBLE_EQ(gain, 5.0 / 3.0);
+}
+
+TEST(InitTest, CalculateGainWithRelu) {
+  double gain =
+      torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::ReLU);
+  ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0));
+}
+
+TEST(InitTest, CalculateGainWithLeakyRelu) {
+  double gain =
+      torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::LeakyReLU);
+  ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2))));
+}
\ No newline at end of file
diff --git a/test/cpp/api/init_baseline.h b/test/cpp/api/init_baseline.h
new file mode 100644 (file)
index 0000000..7565d07
--- /dev/null
@@ -0,0 +1,181 @@
+// @generated from /home/josh/git/pytorch/test/cpp/api/init_baseline.py
+
+#include <torch/types.h>
+
+#include <vector>
+
+namespace expected_parameters {
+
+inline std::vector<std::vector<torch::Tensor>> Xavier_Uniform() {
+  return {
+    {
+      torch::tensor({0.20493895, 0.03128183, -0.25481248, 0.2471149, -0.5009514, -0.30953097, -0.13073051}),
+      torch::tensor({-0.25438663, -0.18269452, -0.42803344, -0.111086845, 0.11163068, -0.34021688, -0.026800662}),
+      torch::tensor({0.37384087, -0.053685665, 0.014514029, -0.04505247, 0.10569024, 0.33205634, 0.4946832}),
+      torch::tensor({0.3316471, 0.49581498, -0.03776875, -0.46913308, -0.24757874, 0.35559112, -0.003385365}),
+      torch::tensor({-0.259574, -0.40019324, -0.48873278, -0.4407689, -0.10592803, 0.28639567, 0.2823406}),
+      torch::tensor({-0.5036581, 0.32575953, -0.40865222, -0.110405415, -0.21175116, -0.10059005, -0.10253662}),
+      torch::tensor({-0.46862572, -0.45091572, -0.08171874, 0.0067536235, -0.2372373, 0.19672471, -0.47004014}),
+      torch::tensor({-0.035244048, 0.45926183, -0.21301463, 0.47157794, 0.18912864, -0.47129482, 0.33041543}),
+      torch::tensor({-0.0602628, -0.23312837, 0.41760528, -0.42201394, 0.05603814, -0.10933924, 0.37293315}),
+      torch::tensor({0.14577848, 0.25093573, 0.18443125, -0.1255838, -0.10982844, -0.43036246, 0.28296888}),
+      torch::tensor({0.4146415, 0.35732478, -0.36837178, 0.023291528, -0.3681398, -0.28748098, -0.304308}),
+      torch::tensor({0.17847049, -0.3112055, -0.011393666, 0.021969378, 0.3366434, -0.39476636, -0.35851932}),
+      torch::tensor({-0.3032406, 0.3655283, -0.18772447, 0.44049758, 0.18884337, 0.066128254, -0.0038875341}),
+      torch::tensor({-0.10323581, 0.06552267, -0.119249105, -0.0036694407, 0.066633284, -0.40849328, -0.2737187}),
+      torch::tensor({0.4216993, -0.4238164, -0.037499547, 0.51661307, 0.1886499, 0.014786005, -0.45257226}),
+    },
+    {
+      torch::tensor({0.2400673, 0.06230098, -0.29922318, -0.3467335, -0.13797283, 0.19630808, 0.44112986, 0.25716078, -0.05036038, 0.15680045, -0.43874466, -0.3819657, 0.20867342, -0.25330856, 0.2151525}),
+      torch::tensor({-0.31570244, -0.22150977, -0.3683649, 0.23337424, -0.045568883, 0.34417605, 0.27676803, 0.24746233, 0.014380664, -0.13826859, -0.09723839, 0.05943495, 0.22168803, -0.3133133, 0.37533647}),
+      torch::tensor({-0.04862556, -0.37474066, -0.24196841, 0.39570248, 0.408989, -0.41424486, 0.31541896, 0.2241252, 0.264714, 0.37857938, -0.24102591, 0.1412192, 0.18301463, -0.13214865, 0.14966142}),
+      torch::tensor({-0.12866935, 0.27649486, -0.12408146, -0.16671932, 0.112585604, 0.15862381, -0.21849588, 0.03953293, 0.25917625, -0.044496298, 0.13610226, -0.107862085, 0.15674818, -0.32395893, -0.26297447}),
+      torch::tensor({-0.22700138, 0.41099417, -0.12033805, -0.0012210608, -0.21667299, 0.44644886, 0.43678015, -0.3372825, -0.3625426, -0.33898476, -0.002156794, -0.113996506, -0.29272172, -0.16040304, 0.084492445}),
+      torch::tensor({-0.23366496, 0.09909475, -0.10255319, -0.21670331, 0.061440647, 0.36772507, -0.30235183, 0.02076611, -0.16491795, 0.4388535, -0.4242998, -0.42872632, 0.44067752, -0.28294754, 0.08574128}),
+      torch::tensor({-0.038597256, -0.09420869, -0.09988415, 0.28417766, 0.021375448, -0.43541366, -0.2640171, -0.15245524, 0.2250452, -0.28940699, 0.42168647, -0.09960759, -0.08030257, 0.3504178, 0.22477299}),
+      torch::tensor({0.37929094, 0.25868815, -0.13566399, -0.2967139, -0.033274055, 0.37013078, -0.15009373, -0.41473246, 0.18332559, 0.43534964, -0.12731421, -0.3703034, -0.40564942, 0.112071455, -0.03386289}),
+      torch::tensor({-0.22583716, 0.090396106, 0.1698333, 0.35567743, 0.34720868, -0.066940606, -0.39433825, -0.40411252, 0.41755867, 0.19769311, 0.19494373, -0.3869386, 0.4141268, 0.4236647, 0.4037714}),
+      torch::tensor({-0.37726268, -0.16874415, -0.3075773, 0.4234959, -0.19215873, -0.2041774, 0.23430097, -0.20687759, -0.22026259, -0.03911844, -0.042985946, -0.34836975, 0.3728277, -0.19727562, 0.15863329}),
+      torch::tensor({0.38897908, 0.22553718, 0.063316464, 0.3805148, 0.060117245, -0.20690633, 0.4230638, 0.10584676, -0.43633774, -0.12731794, -0.30462736, 0.39209586, -0.07385549,-0.40764633, -0.028113335}),
+      torch::tensor({0.2808559, 0.11618626, 0.14141095, 0.041534156, 0.16672957, -0.10896447, -0.17790166, -0.41801435, -0.3369025, 0.19382352, -0.26480114, 0.06416017, 0.14274675, 0.03166446, -0.28995082}),
+      torch::tensor({0.42768306, -0.26005447, 0.36783344, -0.35576212, -0.10757655, 0.24327022, -0.18272284, 0.3756786, -0.30775294, -0.37555724, -0.20165718, 0.07229227, 0.41177452, -0.21350017, 0.15993619}),
+      torch::tensor({-0.112119585, -0.09698379, 0.3288377, -0.34658423, 0.047500044, 0.42056376, -0.061452597, 0.34723365, -0.13772246, 0.35999, -0.43260327, -0.06445408, -0.07855147, 0.14493519, 0.17545414}),
+      torch::tensor({0.34337813, -0.066628546, -0.01773429, 0.3062569, -0.121003985, 0.39204246, -0.29776025, -0.04839292, -0.024020255, 0.1995511, 0.30574924, -0.07088503, -0.37050778, 0.22159088, 0.13377577}),
+    },
+    {
+      torch::tensor({-0.095182985, -0.35154456, -0.33943355, 0.14092201, 0.5576401, -0.4759267, 0.35954505, -0.30801514, -0.11571318, 0.47157025, -0.1343652, 0.05409521, -0.41528726, 0.5057125, -0.076797724}),
+      torch::tensor({-0.4345107, 0.17395526, -0.42240727, -0.4714136, 0.036191404, 0.47101927, -0.16811755, 0.2796178, 0.51051295, 0.39403576, -0.3116357, -0.065123916, -0.18695068,-0.47772023, 0.00024545193}),
+    },
+  };
+}
+
+inline std::vector<std::vector<torch::Tensor>> Xavier_Normal() {
+  return {
+    {
+      torch::tensor({-0.21151732, 0.31257284, -0.18201339, -0.3855622, 0.028025549, -0.20083663, 0.18333313}),
+      torch::tensor({-0.22010927, 0.41458952, 0.19888626, 0.14368738, -0.30642822, 0.05438269, 0.032663286}),
+      torch::tensor({-0.22758777, 0.07366481, 0.34382868, -0.027100464, 0.22004184, -0.5563846, -0.0075437957}),
+      torch::tensor({0.4128839, 0.80112267, 0.29702467, 0.11372463, 0.3320348, -0.34456047, 0.011332423}),
+      torch::tensor({0.81295794, 0.37259677, 0.16366935, 0.15845336, -0.2500635, -0.42430383, 0.49051273}),
+      torch::tensor({0.051942326, -0.48588628, -0.14455895, -0.04322197, -0.09566793, 0.17296337, 0.3008876}),
+      torch::tensor({0.16390441, 0.023760417, 0.26016212, -0.005876346, 0.29881194, -0.23449583, -0.090267815}),
+      torch::tensor({-0.05661722, 0.57766485, 0.20810172, -0.70001936, -0.36073124, 0.059406225, -0.35497883}),
+      torch::tensor({0.034237247, 0.33308733, -0.4206602, 0.14325368, -0.24534757, 0.27866775, -0.07457621}),
+      torch::tensor({-0.42675552, 0.29772332, -0.44859084, 0.17689292, 0.04772785, 0.03324038, -0.24688111}),
+      torch::tensor({0.19078615, -0.57796097, 0.39555034, -0.063268825, 0.23570086, 0.29840353, 0.12504078}),
+      torch::tensor({-0.45496583, 0.61388826, 0.039676014, -0.15409455, -0.5167084, -0.15379032, -0.14318566}),
+      torch::tensor({-0.19097842, -0.44253433, -0.26487318, -0.6266602, -0.33180708, 0.4737542, 0.05777406}),
+      torch::tensor({0.11455678, -0.04364589, 0.19224901, -0.084812194, -0.40097243, -0.197128, -0.161051}),
+      torch::tensor({-0.15780471, 0.25975332, -0.26742947, 0.2529001, 0.34761095, -0.5309911, -0.3337848}),
+    },
+    {
+      torch::tensor({-0.11110223, -0.08603837, -0.39927804, -0.0037998888, 0.3163225, 0.41147578, -0.4212508, 0.27163807, 0.16257726, 0.07001552, -0.17712006, -0.28190005, 0.43368816, -0.22742769, 0.14976384}),
+      torch::tensor({0.15162551, 0.022756116, -0.337443, -0.1823836, -0.04240044, -0.25083202, -0.26616275, 0.16712677, -0.3765372, 0.18503848, -0.51644665, -0.622171, 0.05665474, -0.4386439, 0.33809182}),
+      torch::tensor({-0.42894447, 0.44781545, -0.16271128, -0.16386595, 0.25165728, 0.054181106, 0.0077175284, 0.44132262, -0.18739006, -0.37522528, 0.15394184, -0.32285675, 0.2957909, 0.19089724, 0.32357433}),
+      torch::tensor({-0.11477537, 0.21132338, 0.0032186622, 0.09303321, -0.41689086, -0.6386192, 0.009335384, -0.088360295, -0.09855254, -0.014685968, -0.22642778, 0.17631726, 0.87642914, -0.4308766, 0.13190313}),
+      torch::tensor({-0.07384103, 0.086509995, 0.30259287, -0.10428588, 0.23001948, -0.12647822, -0.30277884, -0.17739438, -0.60285777, 0.024281243, -0.052177392, 0.807063, 0.5194553, -0.08695924, 0.084167756}),
+      torch::tensor({0.13818856, 0.50950116, -0.053579096, -0.007894313, 0.067052916, 0.0014321166, 0.20509668, 0.105126604, -0.093184546, 0.33830634, -0.24917552, 0.22737288, -0.2688545, -0.17480324, -0.106032155}),
+      torch::tensor({-0.4179092, 0.1311413, 0.5997842, 0.059327736, -0.13675568, 0.47349188, 0.0011002938, -0.32478496, -0.28000802, 0.19441846, 0.08356549, -0.07100728, 0.33710754,0.27447432, 0.070220366}),
+      torch::tensor({-0.23930988, -0.7056576, -0.14566903, -0.0707464, 0.03609119, 0.13131015, 0.085740715, -0.25335756, 0.22949918, 0.40511283, -0.021134255, -0.090214714, 0.052266303, -0.07446028, -0.0024477358}),
+      torch::tensor({0.6245137, 0.34285438, -0.068128414, 0.09410469, 0.6568622, -0.6944174, 0.63067895, 0.10417306, -0.25729227, 0.2526477, -0.1139792, -0.067401096, 0.20603673, -0.28586346, 0.601752}),
+      torch::tensor({-0.26997554, -0.12263605, -0.12787469, -0.05121646, 0.57187945, -0.03529285, -0.26288798, 0.046064075, 0.33631605, -0.1457349, -0.23712163, -0.19353649, -0.024511583, 0.28424242, 0.3383595}),
+      torch::tensor({-0.07561234, 0.10794748, -0.0437702, -0.56154686, 0.18596707, 0.07370188, 0.059135318, 0.33133864, -0.35610923, 0.36255547, -0.24472283, 0.052193344, -0.09056281, 0.14071976, 0.3979389}),
+      torch::tensor({0.15498109, -0.08727514, -0.1047872, 0.23060486, -0.37544468, 0.30660948, -0.07733662, 0.59290683, 0.08534651, 0.56152266, -0.058316797, 0.64462274, -0.1825164,0.29704455, -0.14194927}),
+      torch::tensor({0.21056584, 0.5165385, -0.11343903, -0.2790672, -0.12210965, 0.08408014, -0.25138792, 0.20601495, -0.14802553, 0.053201754, 0.055186458, 0.17179312, 0.18927598,-0.37573355, -0.5282227}),
+      torch::tensor({-0.045720562, 0.16112846, -0.43607467, 0.21109799, -0.008630521, 0.32987764, 0.13289018, 0.0899868, -0.2679775, -0.24753189, 0.23707163, 0.12620693, -0.6093915,0.25720116, 0.33193505}),
+      torch::tensor({-0.4173836, -0.19854523, -0.3151219, 0.14798103, 0.18053184, 0.07821397, 0.051448464, -0.024434408, 0.41581196, -0.031413753, -0.35996363, -0.23361506, -0.08951727, -0.052789558, 0.054483347}),
+    },
+    {
+      torch::tensor({0.41098386, 0.45162097, -0.11947367, 0.4141703, -0.2222035, 0.1805965, 0.3349034, -0.3735036, 0.13126858, -0.14054178, -0.25333884, 0.56775486, 0.17848605, -0.079786636, 0.24059108}),
+      torch::tensor({-0.19057135, -0.13091972, -0.015175606, 0.17821532, 0.26070553, -0.35303566, -0.42888534, -0.038996607, -0.21060736, 0.2925792, 0.053198263, 0.3793282, 0.3753942, 0.19560203, 0.2662913}),
+    },
+  };
+}
+
+inline std::vector<std::vector<torch::Tensor>> Kaiming_Normal() {
+  return {
+    {
+      torch::tensor({-0.37498012, 0.5541324, -0.32267526, -0.6835287, 0.049683988, -0.35604528, 0.3250149}),
+      torch::tensor({-0.39021203, 0.7349887, 0.35258764, 0.2547305, -0.5432392, 0.0964102, 0.057905816}),
+      torch::tensor({-0.40347, 0.13059375, 0.6095431, -0.048043985, 0.3900925, -0.9863645, -0.01337372}),
+      torch::tensor({0.7319649, 1.4202386, 0.5265685, 0.2016122, 0.5886347, -0.61084044, 0.02009024}),
+      torch::tensor({1.4412204, 0.66054344, 0.29015473, 0.28090778, -0.44331524, -0.75221026, 0.8695861}),
+      torch::tensor({0.0920839, -0.8613843, -0.25627562, -0.076624356, -0.16960111, 0.30663127, 0.5334167}),
+      torch::tensor({0.29057145, 0.042122718, 0.46121812, -0.010417648, 0.52973694, -0.41571668, -0.16002773}),
+      torch::tensor({-0.1003716, 1.0240903, 0.36892492, -1.2410016, -0.6395081, 0.105315976, -0.6293102}),
+      torch::tensor({0.06069615, 0.5905007, -0.7457508, 0.25396162, -0.43495473, 0.4940251, -0.13220948}),
+      torch::tensor({-0.75655663, 0.52780706, -0.79526657, 0.31359762, 0.08461243, 0.058928896, -0.43767342}),
+      torch::tensor({0.3382277, -1.0246153, 0.7012358, -0.11216363, 0.41785294, 0.5290129, 0.22167361}),
+      torch::tensor({-0.8065682, 1.0883075, 0.070338055, -0.27318043, -0.91602606, -0.2726411, -0.25384104}),
+      torch::tensor({-0.33856857, -0.78452945, -0.46956995, -1.1109498, -0.5882311, 0.83987635, 0.10242246}),
+      torch::tensor({0.20308746, -0.07737589, 0.34082106, -0.15035595, -0.71084815, -0.3494706, -0.28551292}),
+      torch::tensor({-0.27975786, 0.4604934, -0.47410175, 0.4483439, 0.6162483, -0.9413465, -0.59173715}),
+    },
+    {
+      torch::tensor({-0.15712228, -0.12167663, -0.5646644, -0.005373854, 0.44734758, 0.58191466, -0.59573853, 0.38415426, 0.22991896, 0.0990169, -0.2504856, -0.3986669, 0.6133277, -0.3216313, 0.21179804}),
+      torch::tensor({0.21443085, 0.032182008, -0.47721645, -0.25792935, -0.059963275, -0.35473004, -0.376411, 0.23635295, -0.532504, 0.26168394, -0.7303659, -0.8798827, 0.0801219, -0.6203361, 0.47813404}),
+      torch::tensor({-0.60661906, 0.6333067, -0.2301085, -0.23174146, 0.35589716, 0.076623656, 0.0109142335, 0.6241244, -0.26500958, -0.53064865, 0.21770664, -0.45658842, 0.4183115,0.26996946, 0.45760322}),
+      torch::tensor({-0.16231689, 0.29885638, 0.004551876, 0.13156882, -0.5895727, -0.9031439, 0.013202227, -0.124960326, -0.13937433, -0.020769095, -0.32021725, 0.24935026, 1.239458, -0.6093515, 0.18653919}),
+      torch::tensor({-0.10442699, 0.12234361, 0.42793095, -0.14748251, 0.32529667, -0.17886722, -0.42819393, -0.25087354, -0.85256964, 0.03433886, -0.07378998, 1.1413594, 0.73462075, -0.12297893, 0.11903118}),
+      torch::tensor({0.19542812, 0.72054344, -0.075772285, -0.011164244, 0.09482714, 0.0020253188, 0.2900505, 0.14867148, -0.13178284, 0.4784374, -0.3523874, 0.32155383, -0.3802177,-0.24720912, -0.14995211}),
+      torch::tensor({-0.5910129, 0.18546182, 0.8482229, 0.08390209, -0.19340174, 0.6696186, 0.0015560504, -0.4593153, -0.39599115, 0.27494922, 0.11817944, -0.10041946, 0.47674203, 0.38816532, 0.0993066}),
+      torch::tensor({-0.33843526, -0.9979506, -0.20600711, -0.10005052, 0.05104065, 0.1857006, 0.12125569, -0.3583017, 0.32456085, 0.57291603, -0.02988835, -0.12758288, 0.07391571, -0.105302736, -0.003461621}),
+      torch::tensor({0.88319576, 0.4848693, -0.09634813, 0.13308413, 0.9289434, -0.98205453, 0.8919147, 0.14732295, -0.3638662, 0.35729778, -0.16119093, -0.09531955, 0.29137993, -0.404272, 0.8510058}),
+      torch::tensor({-0.38180307, -0.17343357, -0.1808421, -0.07243101, 0.8087596, -0.049911626, -0.37177977, 0.06514444, 0.4756227, -0.20610029, -0.33534062, -0.27370194, -0.034664612, 0.4019795, 0.4785126}),
+      torch::tensor({-0.106931984, 0.15266079, -0.06190041, -0.7941472, 0.26299715, 0.1042302, 0.083629966, 0.46858358, -0.5036145, 0.5127309, -0.34609035, 0.07381254, -0.12807515, 0.19900778, 0.5627706}),
+      torch::tensor({0.21917637, -0.123425685, -0.14819148, 0.32612452, -0.53095895, 0.4336113, -0.10937049, 0.83849686, 0.1206982, 0.794113, -0.08247241, 0.91163427, -0.25811717, 0.42008442, -0.20074657}),
+      torch::tensor({0.29778504, 0.73049575, -0.16042702, -0.3946606, -0.17268912, 0.11890727, -0.35551623, 0.2913491, -0.20933971, 0.07523864, 0.078045435, 0.24295215, 0.26767665, -0.5313675, -0.74701965}),
+      torch::tensor({-0.06465864, 0.22787006, -0.61670274, 0.2985376, -0.0122054, 0.46651745, 0.1879351, 0.12726055, -0.37897742, -0.35006294, 0.33526993, 0.17848356, -0.86180973, 0.3637374, 0.46942705}),
+      torch::tensor({-0.59026957, -0.28078535, -0.44564965, 0.2092768, 0.25531057, 0.110611245, 0.072759114, -0.034555472, 0.5880469, -0.044425756, -0.50906545, -0.33038157, -0.12659654, -0.07465571, 0.07705109}),
+    },
+    {
+      torch::tensor({0.43752572, 0.48078725, -0.12718944, 0.44091794, -0.23655368, 0.19225967, 0.35653186, -0.39762494, 0.13974607, -0.14961815, -0.26969978, 0.6044212, 0.19001292, -0.08493936, 0.25612876}),
+      torch::tensor({-0.20287868, -0.13937469, -0.016155666, 0.1897247, 0.27754223, -0.37583515, -0.45658332, -0.041515056, -0.22420865, 0.31147435, 0.056633875, 0.4038257, 0.39963764, 0.20823427, 0.28348872}),
+    },
+  };
+}
+
+inline std::vector<std::vector<torch::Tensor>> Kaiming_Uniform() {
+  return {
+    {
+      torch::tensor({0.36331797, 0.055456758, -0.45173424, 0.43808794, -0.8880919, -0.5487398, -0.23176038}),
+      torch::tensor({-0.45097935, -0.32388276, -0.7588222, -0.19693595, 0.19790006, -0.6031401, -0.04751247}),
+      torch::tensor({0.66274905, -0.09517455, 0.02573061, -0.07986951, 0.18736875, 0.588673, 0.8769796}),
+      torch::tensor({0.5879475, 0.8789861, -0.06695682, -0.8316841, -0.43891022, 0.63039577, -0.0060015917}),
+      torch::tensor({-0.4601755, -0.7094668, -0.86643064, -0.7813998, -0.18779033, 0.50772536, 0.5005363}),
+      torch::tensor({-0.89289045, 0.57751, -0.724463, -0.19572788, -0.3753947, -0.17832708, -0.18177801}),
+      torch::tensor({-0.8307846, -0.7993882, -0.14487183, 0.011972845, -0.4205768, 0.34875572, -0.8332921}),
+      torch::tensor({-0.062481046, 0.8141842, -0.37763458, 0.83601844, 0.33528924, -0.83551645, 0.58576393}),
+      torch::tensor({-0.10683453, -0.4132924, 0.7403351, -0.7481508, 0.09934509, -0.19383776, 0.66113985}),
+      torch::tensor({0.25843763, 0.44486153, 0.32696164, -0.22263634, -0.19470501, -0.76295114, 0.5016502}),
+      torch::tensor({0.73508084, 0.6334691, -0.6530534, 0.041291475, -0.6526422, -0.50964934, -0.53948045}),
+      torch::tensor({0.31639445, -0.5517084, -0.020198822, 0.038947523, 0.596805, -0.69984597, -0.63558686}),
+      torch::tensor({-0.5375881, 0.6480124, -0.3327999, 0.78091884, 0.33478355, 0.11723292, -0.0068918467}),
+      torch::tensor({-0.18301755, 0.11615932, -0.21140611, -0.0065051913, 0.11812818, -0.72418123, -0.4852514}),
+      torch::tensor({0.74759305, -0.75134623, -0.06647962, 0.9158571, 0.33444047, 0.026212811, -0.8023249}),
+    },
+    {
+      torch::tensor({0.33950645, 0.08810693, -0.42316547, -0.49035522, -0.19512305, 0.2776215, 0.62385184, 0.3636803, -0.07122034, 0.2217493, -0.62047863, -0.5401811, 0.29510874, -0.3582324, 0.30427164}),
+      torch::tensor({-0.44647068, -0.3132621, -0.5209466, 0.33004105, -0.064444125, 0.4867385, 0.3914091, 0.34996456, 0.020337284, -0.19554132, -0.13751587, 0.084053695, 0.31351423,-0.44309196, 0.5308059}),
+      torch::tensor({-0.06876695, -0.5299633, -0.342195, 0.5596078, 0.5783978, -0.5858307, 0.44606978, 0.31696087, 0.37436205, 0.5353921, -0.34086213, 0.19971412, 0.2588218, -0.1868864, 0.21165323}),
+      torch::tensor({-0.18196595, 0.39102274, -0.17547771, -0.2357767, 0.1592201, 0.22432798, -0.30899984, 0.055908024, 0.3665306, -0.062927246, 0.1924777, -0.15254003, 0.22167546, -0.4581471, -0.37190205}),
+      torch::tensor({-0.32102844, 0.58123356, -0.17018372, -0.0017268658, -0.30642188, 0.63137406, 0.6177004, -0.4769895, -0.51271266, -0.47939685, -0.0030501485, -0.1612154, -0.413971, -0.22684419, 0.119490385}),
+      torch::tensor({-0.33045214, 0.14014107, -0.14503211, -0.30646476, 0.08689022, 0.52004176, -0.42759007, 0.029367685, -0.23322919, 0.6206326, -0.60005057, -0.60631055, 0.62321216, -0.40014827, 0.12125647}),
+      torch::tensor({-0.05458474, -0.1332312, -0.14125755, 0.40188795, 0.03022945, -0.6157679, -0.37337655, -0.21560428, 0.31826198, -0.40928328, 0.59635466, -0.1408664, -0.11356497, 0.4955656, 0.317877}),
+      torch::tensor({0.53639835, 0.36584032, -0.19185784, -0.4196168, -0.047056615, 0.523444, -0.2122646, -0.58652025, 0.2592615, 0.6156774, -0.18004948, -0.5236881, -0.5736749, 0.15849298, -0.04788935}),
+      torch::tensor({-0.31938198, 0.12783945, 0.24018055, 0.5030039, 0.49102718, -0.09466827, -0.5576785, -0.57150143, 0.5905171, 0.2795803, 0.27569205, -0.5472138, 0.58566374, 0.5991524, 0.571019}),
+      torch::tensor({-0.53353, -0.23864025, -0.43498003, 0.5989136, -0.2717535, -0.28875044, 0.33135164, -0.2925691, -0.31149834, -0.055321813, -0.060791314, -0.49266922, 0.527258, -0.27898985, 0.22434139}),
+      torch::tensor({0.55009943, 0.31895775, 0.089542985, 0.53812927, 0.085018635, -0.29260972, 0.59830266, 0.14968991, -0.6170747, -0.18005475, -0.43080813, 0.5545073, -0.104447424, -0.576499, -0.039758265}),
+      torch::tensor({0.39719027, 0.16431224, 0.19998527, 0.058738172, 0.23579127, -0.15409905, -0.25159094, -0.59116155, -0.4764521, 0.2741078, -0.37448537, 0.09073615, 0.20187438, 0.044780314, -0.4100524}),
+      torch::tensor({0.6048352, -0.36777255, 0.52019507, -0.5031236, -0.15213624, 0.34403604, -0.25840908, 0.53128976, -0.43522838, -0.53111815, -0.28518632, 0.10223669, 0.5823371, -0.30193484, 0.22618395}),
+      torch::tensor({-0.15856105, -0.13715577, 0.4650467, -0.49014413, 0.06717521, 0.59476703, -0.08690709, 0.49106258, -0.194769, 0.50910276, -0.6117934, -0.09115183, -0.111088574,0.20496935, 0.24812967}),
+      torch::tensor({0.48561007, -0.094227016, -0.025080085, 0.43311268, -0.17112547, 0.55443174, -0.42109656, -0.068437934, -0.03396976, 0.2822079, 0.4323948, -0.10024661, -0.52397716, 0.31337678, 0.18918753}),
+    },
+    {
+      torch::tensor({-0.10132998, -0.37424773, -0.3613546, 0.15002292, 0.59365314, -0.5066626, 0.38276488, -0.32790715, -0.12318605, 0.5020248, -0.14304265, 0.057588696, -0.442107, 0.538372, -0.081757426}),
+      torch::tensor({-0.46257195, 0.18518949, -0.44968688, -0.5018581, 0.03852868, 0.5014383, -0.1789748, 0.29767585, 0.5434825, 0.419483, -0.33176154, -0.06932968, -0.19902417, -0.508572, 0.00026130676}),
+    },
+  };
+}
+
+} // namespace expected_parameters
\ No newline at end of file
diff --git a/test/cpp/api/init_baseline.py b/test/cpp/api/init_baseline.py
new file mode 100644 (file)
index 0000000..018a0aa
--- /dev/null
@@ -0,0 +1,72 @@
+"""Script to generate baseline values from PyTorch initialization algorithms"""
+
+import sys
+import torch
+
+HEADER = """
+#include <torch/types.h>
+
+#include <vector>
+
+namespace expected_parameters {
+"""
+
+FOOTER = "} // namespace expected_parameters"
+
+PARAMETERS = "inline std::vector<std::vector<torch::Tensor>> {}() {{"
+
+INITIALIZERS = {
+    "Xavier_Uniform": lambda w: torch.nn.init.xavier_uniform(w),
+    "Xavier_Normal": lambda w: torch.nn.init.xavier_normal(w),
+    "Kaiming_Normal": lambda w: torch.nn.init.kaiming_normal(w),
+    "Kaiming_Uniform": lambda w: torch.nn.init.kaiming_uniform(w)
+}
+
+
+def emit(initializer_parameter_map):
+    # Don't write generated with an @ in front, else this file is recognized as generated.
+    print("// @{} from {}".format('generated', __file__))
+    print(HEADER)
+    for initializer_name, weights in initializer_parameter_map.items():
+        print(PARAMETERS.format(initializer_name))
+        print("  return {")
+        for sample in weights:
+            print("    {")
+            for parameter in sample:
+                parameter_values = "{{{}}}".format(", ".join(map(str, parameter)))
+                print("      torch::tensor({}),".format(parameter_values))
+            print("    },")
+        print("  };")
+        print("}\n")
+    print(FOOTER)
+
+
+def run(initializer):
+    torch.manual_seed(0)
+
+    layer1 = torch.nn.Linear(7, 15)
+    INITIALIZERS[initializer](layer1.weight)
+
+    layer2 = torch.nn.Linear(15, 15)
+    INITIALIZERS[initializer](layer2.weight)
+
+    layer3 = torch.nn.Linear(15, 2)
+    INITIALIZERS[initializer](layer3.weight)
+
+    weight1 = layer1.weight.data.numpy()
+    weight2 = layer2.weight.data.numpy()
+    weight3 = layer3.weight.data.numpy()
+
+    return [weight1, weight2, weight3]
+
+
+def main():
+    initializer_parameter_map = {}
+    for initializer in INITIALIZERS.keys():
+        sys.stderr.write('Evaluating {} ...\n'.format(initializer))
+        initializer_parameter_map[initializer] = run(initializer)
+
+    emit(initializer_parameter_map)
+
+if __name__ == "__main__":
+    main()
index 69e7a8c..42ee0fb 100644 (file)
@@ -41,13 +41,4 @@ TEST_F(AutogradTest, CanTakeDerivativesOfZeroDimTensors) {
 TEST_F(AutogradTest, CanPassCustomGradientInputs) {
   z.sum().backward(torch::ones({}) * 2);
   ASSERT_TRUE(x.grad().allclose(y * 2));
-}
-
-TEST(NNInitTest, CanInitializeTensorThatRequiresGrad) {
-  auto tensor = torch::empty({3, 4}, torch::requires_grad());
-  ASSERT_THROWS_WITH(
-      tensor.fill_(1),
-      "a leaf Variable that requires grad "
-      "has been used in an in-place operation");
-  ASSERT_EQ(torch::nn::init::ones_(tensor).sum().item<int32_t>(), 12);
-}
+}
\ No newline at end of file
index 8d543c4..4e43462 100644 (file)
@@ -1,12 +1,31 @@
 #pragma once
 
-#include <torch/types.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/types.h>
 
 namespace torch {
 namespace nn {
 namespace init {
 
+enum class Nonlinearity {
+  Linear,
+  Conv1D,
+  Conv2D,
+  Conv3D,
+  ConvTranspose1D,
+  ConvTranspose2D,
+  ConvTranspose3D,
+  Sigmoid,
+  Tanh,
+  ReLU,
+  LeakyReLU
+};
+
+enum class FanMode { FanIn, FanOut };
+
+/// Return the recommended gain value for the given nonlinearity function.
+TORCH_API double calculate_gain(Nonlinearity nonlinearity, double param = 0.01);
+
 /// Fills the given `tensor` with the provided `value` in-place, and returns it.
 /// No gradient will be recorded for this operation.
 TORCH_API Tensor constant_(Tensor tensor, Scalar value);
@@ -51,6 +70,28 @@ TORCH_API Tensor sparse_(Tensor tensor, double sparsity, double std = 0.01);
 TORCH_API Tensor uniform_(Tensor tensor, double low = 0, double high = 1);
 
 /// Fills the input `Tensor` with values according to the method
+/// described in "Delving deep into rectifiers: Surpassing human-level
+/// performance on ImageNet classification" - He, K. et al. (2015), using a
+/// normal distribution. Also known as He initialization.
+/// No gradient will be recorded for this operation.
+TORCH_API Tensor kaiming_normal_(
+    Tensor tensor,
+    double a = 0,
+    FanMode mode = FanMode::FanIn,
+    Nonlinearity nonlinearity = Nonlinearity::LeakyReLU);
+
+/// Fills the input `Tensor` with values according to the method
+/// described in "Delving deep into rectifiers: Surpassing human-level
+/// performance on ImageNet classification" - He, K. et al. (2015), using a
+/// uniform distribution. Also known as He initialization.
+/// No gradient will be recorded for this operation.
+TORCH_API Tensor kaiming_uniform_(
+    Tensor tensor,
+    double a = 0,
+    FanMode mode = FanMode::FanIn,
+    Nonlinearity nonlinearity = Nonlinearity::LeakyReLU);
+
+/// Fills the input `Tensor` with values according to the method
 /// described in "Understanding the difficulty of training deep feedforward
 /// neural networks" - Glorot, X. & Bengio, Y. (2010). Values are scaled by the
 /// `gain` parameter. No gradient will be recorded for this operation.
index a2f49bb..187a252 100644 (file)
@@ -34,8 +34,37 @@ struct Fan {
   int64_t in;
   int64_t out;
 };
+
+double calculate_kaiming_std(
+    Tensor tensor,
+    double a,
+    FanMode mode,
+    Nonlinearity nonlinearity) {
+  NoGradGuard guard;
+  Fan fan(tensor);
+  const auto gain = calculate_gain(nonlinearity, a);
+  double std = 0.0;
+  if (mode == FanMode::FanIn) {
+    std = gain / std::sqrt(fan.in);
+  } else {
+    std = gain / std::sqrt(fan.out);
+  }
+  return std;
+}
 } // namespace
 
+double calculate_gain(Nonlinearity nonlinearity, double param) {
+  if (nonlinearity == Nonlinearity::Tanh) {
+    return 5.0 / 3.0;
+  } else if (nonlinearity == Nonlinearity::ReLU) {
+    return std::sqrt(2.0);
+  } else if (nonlinearity == Nonlinearity::LeakyReLU) {
+    return std::sqrt(2.0 / (1 + pow(param, 2)));
+  }
+
+  return 1.0;
+}
+
 Tensor constant_(Tensor tensor, Scalar value) {
   NoGradGuard guard;
   return tensor.fill_(value);
@@ -146,6 +175,29 @@ Tensor uniform_(Tensor tensor, double low, double high) {
   return tensor.uniform_(low, high);
 }
 
+Tensor kaiming_uniform_(
+    Tensor tensor,
+    double a,
+    FanMode mode,
+    Nonlinearity nonlinearity) {
+  NoGradGuard guard;
+  auto std = calculate_kaiming_std(tensor, a, mode, nonlinearity);
+  // Calculate uniform bounds from standard deviation
+  const auto bound = std::sqrt(3.0) * std;
+  return tensor.uniform_(-bound, bound);
+}
+
+Tensor kaiming_normal_(
+    Tensor tensor,
+    double a,
+    FanMode mode,
+    Nonlinearity nonlinearity) {
+  NoGradGuard guard;
+
+  auto std = calculate_kaiming_std(tensor, a, mode, nonlinearity);
+  return tensor.normal_(0, std);
+}
+
 Tensor xavier_normal_(Tensor tensor, double gain) {
   NoGradGuard guard;