[unittest] zoneout lstmcell unittest
authorhyeonseok lee <hs89.lee@samsung.com>
Thu, 2 Dec 2021 15:55:39 +0000 (00:55 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 6 Dec 2021 12:35:08 +0000 (21:35 +0900)
 - unittest for zoneout lstmcell layer

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
packaging/unittest_models_v2.tar.gz
test/input_gen/genModelsRecurrent_v2.py
test/input_gen/recorder_v2.py
test/input_gen/transLayer_v2.py
test/input_gen/zoneout.py [new file with mode: 0644]
test/unittest/models/unittest_models_recurrent.cpp

index 36673d9..21f49e1 100644 (file)
Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ
index 18f62d7..adb98ec 100644 (file)
@@ -9,6 +9,7 @@
 # @author Jihoon lee <jhoon.it.lee@samsung.com>
 
 from recorder_v2 import record_v2, inspect_file
+from zoneout import Zoneout
 import torch
 
 class FCUnroll(torch.nn.Module):
@@ -88,6 +89,38 @@ class LSTMStacked(torch.nn.Module):
         loss = self.loss(ret, labels[0])
         return ret, loss
 
+class ZoneoutLSTMStacked(torch.nn.Module):
+    def __init__(self, batch_size=3, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=1, cell_state_zoneout_rate=1):
+        super().__init__()
+        self.input_size = self.hidden_size = 2
+        self.cell_state_zoneout_rate = cell_state_zoneout_rate
+        self.zoneout_lstms = torch.nn.ModuleList(
+            [
+                Zoneout(batch_size, self.input_size, self.hidden_size, unroll_for, hidden_state_zoneout_rate, cell_state_zoneout_rate)
+                for _ in range(num_lstm)
+            ]
+        )
+        for zoneout_lstm in self.zoneout_lstms:
+            zoneout_lstm.bias_hh.data.fill_(0.0)
+            zoneout_lstm.bias_hh.requires_grad=False
+        self.unroll_for = unroll_for
+        self.loss = torch.nn.MSELoss()
+
+    def forward(self, inputs, labels):
+        hs = [torch.zeros_like(inputs[0]) for _ in self.zoneout_lstms]
+        cs = [torch.zeros_like(inputs[0]) for _ in self.zoneout_lstms]
+        out = inputs[0]
+        ret = []
+        for num_unroll in range(self.unroll_for):
+            for i, (zoneout_lstm, h, c) in enumerate(zip(self.zoneout_lstms, hs, cs)):
+                hs[i], cs[i] = zoneout_lstm(out, (h, c, num_unroll))
+                out = hs[i]
+            ret.append(out)
+
+        ret = torch.stack(ret, dim=1)
+        loss = self.loss(ret, labels[0])
+        return ret, loss
+
 class GRUCellStacked(torch.nn.Module):
     def __init__(self, unroll_for=2, num_gru=1):
         super().__init__()
@@ -177,6 +210,150 @@ if __name__ == "__main__":
     )
 
     record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=0.0, cell_state_zoneout_rate=0.0),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_single_000_000",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=0.0, cell_state_zoneout_rate=0.0),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_stacked_000_000",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=0.5, cell_state_zoneout_rate=0.0),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_single_050_000",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=0.5, cell_state_zoneout_rate=0.0),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_stacked_050_000",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=1.0, cell_state_zoneout_rate=0.0),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_single_100_000",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=1.0, cell_state_zoneout_rate=0.0),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_stacked_100_000",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=0.0, cell_state_zoneout_rate=0.5),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_single_000_050",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=0.0, cell_state_zoneout_rate=0.5),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_stacked_000_050",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=0.5, cell_state_zoneout_rate=0.5),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_single_050_050",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=0.5, cell_state_zoneout_rate=0.5),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_stacked_050_050",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=1.0, cell_state_zoneout_rate=0.5),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_single_100_050",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=1.0, cell_state_zoneout_rate=0.5),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_stacked_100_050",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=0.0, cell_state_zoneout_rate=1.0),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_single_000_100",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=0.0, cell_state_zoneout_rate=1.0),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_stacked_000_100",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=0.5, cell_state_zoneout_rate=1.0),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_single_050_100",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=0.5, cell_state_zoneout_rate=1.0),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_stacked_050_100",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=1.0, cell_state_zoneout_rate=1.0),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_single_100_100",
+    )
+
+    record_v2(
+        ZoneoutLSTMStacked(batch_size=1, unroll_for=2, num_lstm=2, hidden_state_zoneout_rate=1.0, cell_state_zoneout_rate=1.0),
+        iteration=2,
+        input_dims=[(1, 2)],
+        label_dims=[(1, 2, 2)],
+        name="zoneout_lstm_stacked_100_100",
+    )
+
+    record_v2(
         GRUCellStacked(unroll_for=2, num_gru=1),
         iteration=2,
         input_dims=[(3, 2)],
index f659810..3656b83 100644 (file)
@@ -34,7 +34,6 @@ def _get_writer(file):
             items = [items]
 
         for item in items:
-            print(item.numel())
             np.array([item.numel()], dtype="int32").tofile(file)
             item.detach().cpu().numpy().tofile(file)
 
@@ -81,7 +80,6 @@ def record_v2(model, iteration, input_dims, label_dims, name, clip=False,
         inputs = _rand_like(input_dims, dtype=input_dtype if input_dtype is not None else float)
         labels = _rand_like(label_dims, dtype=float)
         write_fn(inputs)
-        print(labels)
         write_fn(labels)
         write_fn(list(t for _, t in params_translated(model)))
         output, loss = model(inputs, labels)
index c16035f..bfefcc4 100644 (file)
@@ -10,6 +10,7 @@
 
 import torch
 from collections.abc import Iterable
+from zoneout import Zoneout
 
 __all__ = ["params_translated"]
 
@@ -56,6 +57,20 @@ def bn1d_translate(model):
     yield from [mu, var, gamma, beta]
 
 
+@register_for_((Zoneout))
+def zoneout_translate(model):
+    params = [(name, tensor.detach()) for name, tensor in model.named_parameters()]
+    bias = ("bias", params[2][1] + params[3][1])
+    hidden_state = ("hidden_state", torch.stack(model.hidden_state_zoneout_mask, dim=0))
+    cell_state = ("cell_state", torch.stack(model.cell_state_zoneout_mask, dim=0))
+
+    # [hidden, input] -> [input, hidden]
+    def transpose_(weight):
+        return (weight[0], weight[1].transpose(1, 0))
+
+    new_params = [transpose_(params[0]), transpose_(params[1]), bias, hidden_state, cell_state]
+    yield from new_params
+
 @register_for_((torch.nn.RNNCell, torch.nn.LSTMCell))
 def rnn_lstm_translate(model):
     params = [(name, tensor.detach()) for name, tensor in model.named_parameters()]
diff --git a/test/input_gen/zoneout.py b/test/input_gen/zoneout.py
new file mode 100644 (file)
index 0000000..42185c2
--- /dev/null
@@ -0,0 +1,31 @@
+#!/usr/bin/env python3
+# SPDX-License-Identifier: Apache-2.0
+##
+# Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
+#
+# @file zoneout.py
+# @date 02 December 2021
+# @brief Generate Zoneout LSTM cell using torch lstmcell
+# @author hyeonseok lee <hs89.lee@samsung.com>
+
+import torch
+
+# Note: Each iteration share the same zoneout mask
+class Zoneout(torch.nn.LSTMCell):
+    def __init__(self, batch_size, input_size, hidden_size, num_roll=2, hidden_state_zoneout_rate=1, cell_state_zoneout_rate=1):
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        super().__init__(self.input_size, self.hidden_size, bias=True)
+        self.hidden_state_zoneout_mask = [ torch.zeros([batch_size, self.hidden_size]).bernoulli_(1. - hidden_state_zoneout_rate) for _ in range(num_roll)]
+        self.cell_state_zoneout_mask = [ torch.zeros([batch_size, self.hidden_size]).bernoulli_(1. - cell_state_zoneout_rate) for _ in range(num_roll)]
+
+    def zoneout(self, prev_state, next_state, mask):
+        return prev_state * (1. - mask) + next_state * mask
+
+    def forward(self, out, states):
+        hidden_state, cell_state, num_unroll = states
+        next_hidden_state, next_cell_state = super().forward(out, (hidden_state, cell_state))
+        next_hidden_state = self.zoneout(hidden_state, next_hidden_state, self.hidden_state_zoneout_mask[num_unroll])
+        next_cell_state = self.zoneout(cell_state, next_cell_state, self.cell_state_zoneout_mask[num_unroll])
+
+        return (next_hidden_state, next_cell_state)
index 25deff6..b1aa7fd 100644 (file)
@@ -261,6 +261,75 @@ static std::unique_ptr<NeuralNetwork> makeStackedLSTMCell() {
   return nn;
 }
 
+static std::unique_ptr<NeuralNetwork> makeSingleZoneoutLSTMCell() {
+  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  nn->setProperty({"batch_size=1"});
+
+  auto outer_graph = makeGraph({
+    {"input", {"name=input", "input_shape=1:1:2"}},
+    /// here zoneout_lstm_cell is being inserted
+    {"mse", {"name=loss", "input_layers=zoneout_lstm_scope/a1"}},
+  });
+  for (auto &node : outer_graph) {
+    nn->addLayer(node);
+  }
+
+  auto zoneout_lstm = makeGraph({
+    {"zoneout_lstmcell",
+     {"name=a1", "unit=2", "hidden_state_zoneout_rate=1.0",
+      "cell_state_zoneout_rate=1.0", "test=true"}},
+  });
+
+  nn->addWithReferenceLayers(zoneout_lstm, "zoneout_lstm_scope", {"input"},
+                             {"a1"}, {"a1"},
+                             ml::train::ReferenceLayersType::RECURRENT,
+                             {
+                               "unroll_for=2",
+                               "return_sequences=true",
+                               "recurrent_input=a1",
+                               "recurrent_output=a1",
+                             });
+
+  nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+  return nn;
+}
+
+static std::unique_ptr<NeuralNetwork> makeStackedZoneoutLSTMCell() {
+  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  nn->setProperty({"batch_size=1"});
+
+  auto outer_graph = makeGraph({
+    {"input", {"name=input", "input_shape=1:1:2"}},
+    /// here zoneout_lstm_cell is being inserted
+    {"mse", {"name=loss", "input_layers=zoneout_lstm_scope/a2"}},
+  });
+  for (auto &node : outer_graph) {
+    nn->addLayer(node);
+  }
+
+  auto zoneout_lstm = makeGraph({
+    {"zoneout_lstmcell",
+     {"name=a1", "unit=2", "hidden_state_zoneout_rate=1.0",
+      "cell_state_zoneout_rate=1.0", "test=true"}},
+    {"zoneout_lstmcell",
+     {"name=a2", "unit=2", "hidden_state_zoneout_rate=1.0",
+      "cell_state_zoneout_rate=1.0", "test=true", "input_layers=a1"}},
+  });
+
+  nn->addWithReferenceLayers(zoneout_lstm, "zoneout_lstm_scope", {"input"},
+                             {"a1"}, {"a2"},
+                             ml::train::ReferenceLayersType::RECURRENT,
+                             {
+                               "unroll_for=2",
+                               "return_sequences=true",
+                               "recurrent_input=a1",
+                               "recurrent_output=a2",
+                             });
+
+  nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+  return nn;
+}
+
 static std::unique_ptr<NeuralNetwork> makeSingleRNNCell() {
   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
   nn->setProperty({"batch_size=3"});
@@ -401,6 +470,42 @@ INSTANTIATE_TEST_CASE_P(
     mkModelTc_V2(makeStackedLSTM, "lstm_stacked", ModelTestOption::COMPARE_V2),
     mkModelTc_V2(makeStackedLSTMCell, "lstm_stacked__1",
                  ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_000_000",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_000_000",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_050_000",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_050_000",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_100_000",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_100_000",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_000_050",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_000_050",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_050_050",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_050_050",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_100_050",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_100_050",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_000_100",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_000_100",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_050_100",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_050_100",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_100_100",
+                 ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_100_100",
+                 ModelTestOption::COMPARE_V2),
     mkModelTc_V2(makeSingleRNNCell, "rnncell_single__1",
                  ModelTestOption::COMPARE_V2),
     mkModelTc_V2(makeStackedRNNCell, "rnncell_stacked__1",