[layer] Support multiple outputs for mol attention layer
authorParichay Kapoor <pk.kapoor@samsung.com>
Fri, 3 Dec 2021 07:00:40 +0000 (16:00 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 3 Dec 2021 12:51:57 +0000 (21:51 +0900)
Support multiple outputs for mol attention layer along with the
unitests.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/layers/mol_attention_layer.cpp
packaging/unittest_models_v2.tar.gz
test/input_gen/genModelTests_v2.py
test/input_gen/recorder_v2.py
test/unittest/layers/meson.build
test/unittest/models/unittest_models.cpp
test/unittest/unittest_nntrainer_models.cpp

index f0f92fd..a857a38 100644 (file)
@@ -138,7 +138,7 @@ void MoLAttentionLayer::finalize(InitLayerContext &context) {
     context.requestTensor(prob_dim, "u_pos_div", Tensor::Initializer::NONE,
                           false, TensorLifespan::ITERATION_LIFESPAN);
 
-  context.setOutputDimensions({query_dim});
+  context.setOutputDimensions({query_dim, state_dim});
 }
 
 void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) {
@@ -146,7 +146,8 @@ void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) {
   Tensor &value = context.getInput(wt_idx[AttentionParams::value]);
   Tensor &state = context.getInput(wt_idx[AttentionParams::state]);
 
-  Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
+  Tensor &output = context.getOutput(0);
+  Tensor &updated_state = context.getOutput(1);
   Tensor &fc_w = context.getWeight(wt_idx[AttentionParams::fc_w]);
   Tensor &fc_bias = context.getWeight(wt_idx[AttentionParams::fc_bias]);
   Tensor &fc_proj_w = context.getWeight(wt_idx[AttentionParams::fc_proj_w]);
@@ -197,7 +198,7 @@ void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) {
   fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false)
     .copy_with_stride(alpha);
 
-  Tensor m = state.add(kappa);
+  state.add(kappa, updated_state);
 
   /** @todo cache u_base, u_pos, u_neg */
   Tensor u_base = Tensor(TensorDim({batch, 1, value.height(), mol_k}));
@@ -214,11 +215,11 @@ void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) {
 
   Tensor beta_eps = beta.add(1e-8f);
 
-  Tensor u_pos_m = u_pos.subtract(m);
+  Tensor u_pos_m = u_pos.subtract(updated_state);
   u_pos_m.divide(beta_eps, u_pos_div);
   sigmoid.run_fn(u_pos_div, prob_left);
 
-  Tensor u_neg_m = u_neg.subtract(m);
+  Tensor u_neg_m = u_neg.subtract(updated_state);
   u_neg_m.divide(beta_eps, u_neg_div);
   sigmoid.run_fn(u_neg_div, prob_right);
 
@@ -243,7 +244,8 @@ void MoLAttentionLayer::calcDerivativeHelper(RunLayerContext &context,
   Tensor &query = context.getInput(wt_idx[AttentionParams::query]);
   Tensor &value = context.getInput(wt_idx[AttentionParams::value]);
 
-  Tensor &derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX);
+  Tensor &derivative = context.getIncomingDerivative(0);
+  Tensor &derivative_state = context.getIncomingDerivative(1);
 
   Tensor &fc_proj_out = context.getTensor(wt_idx[AttentionParams::fc_proj_out]);
   Tensor &dfc_proj_out =
@@ -300,6 +302,7 @@ void MoLAttentionLayer::calcDerivativeHelper(RunLayerContext &context,
 
   Tensor dbeta_eps = dbeta_eps_neg.add(dbeta_eps_pos);
   dm_neg.add(dm_pos, dstate);
+  dstate.add_i(derivative_state);
   Tensor dkappa = dstate;
   Tensor dbeta = dbeta_eps;
 
index c00c0d5..36673d9 100644 (file)
Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ
index da52bf6..1f050ec 100644 (file)
@@ -68,9 +68,9 @@ class MolAttention(torch.nn.Module):
 
         output = torch.matmul(scores.unsqueeze(1), values).squeeze(dim=1)
 
-        loss = self.loss(torch.sum(output))
+        loss = self.loss(torch.sum(output)) + self.loss(torch.sum(kappa))
 
-        return output, loss
+        return (output, kappa), loss
 
 if __name__ == "__main__":
     record_v2(
@@ -86,7 +86,7 @@ if __name__ == "__main__":
         iteration=2,
         input_dims=[(3,6), (3,4,6), (3,1,5), (3)],
         input_dtype=[float, float, float, int],
-        label_dims=[(3,1,6)],
+        label_dims=[(3,1,6), (3,1,5)],
         name="mol_attention_masked",
     )
 
@@ -95,7 +95,7 @@ if __name__ == "__main__":
         iteration=2,
         input_dims=[(3,6), (3,4,6), (3,1,5)],
         input_dtype=[float, float, float],
-        label_dims=[(3,1,6)],
+        label_dims=[(3,1,6), (3,1,5)],
         name="mol_attention",
     )
 
index 3656b83..f659810 100644 (file)
@@ -34,6 +34,7 @@ def _get_writer(file):
             items = [items]
 
         for item in items:
+            print(item.numel())
             np.array([item.numel()], dtype="int32").tofile(file)
             item.detach().cpu().numpy().tofile(file)
 
@@ -80,6 +81,7 @@ def record_v2(model, iteration, input_dims, label_dims, name, clip=False,
         inputs = _rand_like(input_dims, dtype=input_dtype if input_dtype is not None else float)
         labels = _rand_like(label_dims, dtype=float)
         write_fn(inputs)
+        print(labels)
         write_fn(labels)
         write_fn(list(t for _, t in params_translated(model)))
         output, loss = model(inputs, labels)
index 33d3762..ccb3bf7 100644 (file)
@@ -57,7 +57,7 @@ test_target = [
   'unittest_layers_attention.cpp',
   'unittest_layers_dropout.cpp',
   'unittest_layers_reshape.cpp',
-  'unittest_layers_mol_attention.cpp',
+  'unittest_layers_mol_attention.cpp',
 ]
 
 if get_option('enable-tflite-backbone')
index c3ca770..5512208 100644 (file)
@@ -50,9 +50,11 @@ static std::unique_ptr<NeuralNetwork> makeMolAttention() {
     {"input", {"name=in1", "input_shape=1:1:6"}},
     {"mol_attention",
      {"name=mol", "input_layers=in1,in2,in3", "unit=8", "mol_k=5"}},
-    {"constant_derivative", {"name=loss", "input_layers=mol"}},
+    {"constant_derivative", {"name=loss1", "input_layers=mol(0)"}},
+    {"constant_derivative", {"name=loss2", "input_layers=mol(1)"}},
   });
 
+  nn->setProperty({"label_layers=loss1,loss2"});
   for (auto &node : outer_graph) {
     nn->addLayer(node);
   }
@@ -72,9 +74,11 @@ static std::unique_ptr<NeuralNetwork> makeMolAttentionMasked() {
     {"input", {"name=in1", "input_shape=1:1:6"}},
     {"mol_attention",
      {"name=mol", "input_layers=in1,in2,in3,in4", "unit=8", "mol_k=5"}},
-    {"constant_derivative", {"name=loss", "input_layers=mol"}},
+    {"constant_derivative", {"name=loss1", "input_layers=mol(0)"}},
+    {"constant_derivative", {"name=loss2", "input_layers=mol(1)"}},
   });
 
+  nn->setProperty({"label_layers=loss1,loss2"});
   for (auto &node : outer_graph) {
     nn->addLayer(node);
   }
index c79ed54..12febe4 100644 (file)
@@ -767,15 +767,6 @@ INI multiout_model(
   }
 );
 
-INI mol_attention(
-  "mol_attention",
-  {nn_base + "batch_size = 3",
-   sgd_base + "learning_rate = 1",
-   I("in0") + input_base + "input_shape = 1:1:256",
-   I("in1") + input_base + "input_shape = 1:84:256",
-   I("in2") + input_base + "input_shape = 1:1:5", // this shape should match mol_k
-   I("mol") + "type=mol_attention" + "unit = 128" + "mol_k=5" + "input_layers=in0,in1,in2"});
-
 /**
  * @brief helper function to make model testcase
  *
@@ -911,7 +902,6 @@ INSTANTIATE_TEST_CASE_P(
       mkModelIniTc(preprocess_translate, "3:1:1:10", 10, ModelTestOption::NO_THROW_RUN),
   #endif
       mkModelIniTc(preprocess_flip_validate, "3:1:1:10", 10, ModelTestOption::NO_THROW_RUN),
-      mkModelIniTc(mol_attention, "3:1:1:128", 2, ModelTestOption::NO_THROW_RUN),
 
       /**< Addition test */
       mkModelIniTc(addition_resnet_like, "3:1:1:10", 10, ModelTestOption::COMPARE), // Todo: Enable option to ALL
@@ -973,7 +963,8 @@ TEST(nntrainerModels, read_save_01_n) {
 
 TEST(nntrainerModels, loadFromLayersBackbone_p) {
   std::vector<std::shared_ptr<ml::train::Layer>> reference;
-  reference.emplace_back(ml::train::layer::FullyConnected({"name=fc1", "input_shape=3:1:2"}));
+  reference.emplace_back(
+    ml::train::layer::FullyConnected({"name=fc1", "input_shape=3:1:2"}));
   reference.emplace_back(
     ml::train::layer::FullyConnected({"name=fc2", "input_layers=fc1"}));