[layer] Support filter masking in mol attention
authorParichay Kapoor <pk.kapoor@samsung.com>
Thu, 2 Dec 2021 04:06:45 +0000 (13:06 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 3 Dec 2021 12:51:57 +0000 (21:51 +0900)
Add support for filter based masking in mol attention.
Add corresponding unittest.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/layers/mol_attention_layer.cpp
nntrainer/tensor/tensor.cpp
nntrainer/tensor/tensor.h
packaging/unittest_models_v2.tar.gz
test/input_gen/genModelTests_v2.py
test/input_gen/recorder_v2.py
test/unittest/models/models_golden_test.cpp
test/unittest/models/models_golden_test.h
test/unittest/models/unittest_models.cpp

index c5e35a1..f0f92fd 100644 (file)
@@ -31,6 +31,7 @@ enum AttentionParams {
   query = 0,
   value = 1,
   state = 2,
+  mask_len = 3,
   fc_w,
   fc_bias,
   fc_proj_w,
@@ -46,8 +47,8 @@ enum AttentionParams {
 };
 
 void MoLAttentionLayer::finalize(InitLayerContext &context) {
-  if (context.getNumInputs() != 3)
-    throw std::runtime_error("MoL Attention layer needs 3 inputs.");
+  if (context.getNumInputs() < 3 || context.getNumInputs() > 4)
+    throw std::runtime_error("MoL Attention layer needs 3-4 inputs.");
 
   auto const &all_dims = context.getInputDimensions();
   auto const &query_dim = all_dims[AttentionParams::query];
@@ -57,6 +58,7 @@ void MoLAttentionLayer::finalize(InitLayerContext &context) {
   wt_idx[AttentionParams::query] = AttentionParams::query;
   wt_idx[AttentionParams::value] = AttentionParams::value;
   wt_idx[AttentionParams::state] = AttentionParams::state;
+  wt_idx[AttentionParams::mask_len] = AttentionParams::mask_len;
 
   softmax.setActiFunc(ActivationType::ACT_SOFTMAX);
   tanh.setActiFunc(ActivationType::ACT_TANH);
@@ -225,6 +227,13 @@ void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) {
   Tensor prob_scaled = prob.multiply(alpha);
   prob_scaled.sum(3, scores);
 
+  if (context.getNumInputs() == 4) {
+    Tensor mask = Tensor(scores.getDim());
+    mask.filter_mask(context.getInput(wt_idx[AttentionParams::mask_len]),
+                     false);
+    scores.multiply_i(mask);
+  }
+
   scores.dotBatched(value, output);
 }
 
@@ -261,6 +270,11 @@ void MoLAttentionLayer::calcDerivativeHelper(RunLayerContext &context,
   Tensor dscores = Tensor(TensorDim({value.batch(), 1, 1, value.height()}));
   dscores.dot_batched_deriv_wrt_1(value, derivative);
   dscores.reshape(TensorDim({scores.batch(), 1, scores.width(), 1}));
+  if (context.getNumInputs() == 4) {
+    Tensor mask = Tensor(dscores.getDim());
+    mask.filter_mask(context.getInput(wt_idx[AttentionParams::mask_len]));
+    dscores.multiply_i(mask);
+  }
 
   Tensor dprob_scaled = Tensor(TensorDim({batch, 1, value.height(), mol_k}));
   dprob_scaled.setZero();
index eee5081..d0cf702 100644 (file)
@@ -1228,6 +1228,26 @@ void Tensor::dropout_mask(float dropout) {
   }
 }
 
+void Tensor::filter_mask(const Tensor &mask_len, bool reverse) {
+  float fill_mask_val = 0.0;
+  float en_mask_val = 1.0 - fill_mask_val;
+
+  if (reverse) {
+    fill_mask_val = 1.0;
+    en_mask_val = 1.0 - fill_mask_val;
+  }
+
+  setValue(fill_mask_val);
+  if (mask_len.batch() != batch())
+    throw std::invalid_argument("Number of filter masks mismatched");
+
+  for (unsigned int b = 0; b < batch(); b++) {
+    float *addr = getAddress(b, 0, 0, 0);
+    const uint *mask_len_val = mask_len.getAddress<uint>(b, 0, 0, 0);
+    std::fill(addr, addr + (*mask_len_val), en_mask_val);
+  }
+}
+
 int Tensor::apply_i(std::function<float(float)> f) {
   Tensor result = *this;
   apply(f, result);
index 53e5064..499ef6b 100644 (file)
@@ -698,6 +698,13 @@ public:
   void dropout_mask(float dropout);
 
   /**
+   * @brief Calculate filter mask
+   * @param mask_len length of each mask along the last axis
+   * @param invert invert the mask
+   */
+  void filter_mask(const Tensor &mask_len, bool reverse = false);
+
+  /**
    * @brief     sum all the Tensor elements according to the batch
    * @retval    Calculated Tensor(batch, 1, 1, 1)
    */
index 4873478..c00c0d5 100644 (file)
Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ
index 89ef073..da52bf6 100644 (file)
@@ -34,7 +34,11 @@ class MolAttention(torch.nn.Module):
         self.loss = torch.nn.Identity()
 
     def forward(self, inputs, labels):
-        query, values, attention_state = inputs
+        if len(inputs) == 4:
+            query, values, attention_state, mask_len = inputs
+        else:
+            query, values, attention_state = inputs
+            mask_len = None
         batch_size, timesteps, _ = values.size()
 
         dense1_out = torch.tanh(self.dense1(query.unsqueeze(1)))
@@ -54,6 +58,14 @@ class MolAttention(torch.nn.Module):
         integrals = alpha * (integrals_left - integrals_right)
         scores = torch.sum(integrals, dim=2)
 
+        if mask_len is not None:
+            max_len = max(int(mask_len.max()), scores.shape[1])
+            mask = torch.arange(0, max_len)\
+                    .type_as(mask_len)\
+                    .unsqueeze(0).expand(mask_len.numel(), max_len)\
+                    .lt(mask_len.unsqueeze(1))
+            scores.masked_fill_(torch.logical_not(mask), 0.)
+
         output = torch.matmul(scores.unsqueeze(1), values).squeeze(dim=1)
 
         loss = self.loss(torch.sum(output))
@@ -72,9 +84,19 @@ if __name__ == "__main__":
     record_v2(
         MolAttention(query_size=6),
         iteration=2,
+        input_dims=[(3,6), (3,4,6), (3,1,5), (3)],
+        input_dtype=[float, float, float, int],
+        label_dims=[(3,1,6)],
+        name="mol_attention_masked",
+    )
+
+    record_v2(
+        MolAttention(query_size=6),
+        iteration=2,
         input_dims=[(3,6), (3,4,6), (3,1,5)],
+        input_dtype=[float, float, float],
         label_dims=[(3,1,6)],
         name="mol_attention",
     )
 
-    # inspect_file("mol_attention.nnmodelgolden")
+    # inspect_file("mol_attention_masked.nnmodelgolden")
index 168091d..3656b83 100644 (file)
@@ -42,15 +42,17 @@ def _get_writer(file):
     return write_fn
 
 
-def _rand_like(*shapes, scale=1, rand="int"):
-    shape_to_np = (
-        lambda shape: np.random.randint(0, 10, shape).astype(dtype=np.float32)
-        if rand == "int"
-        else np.random.rand(*shape).astype(dtype=np.float32)
-    )
+def _rand_like(shapes, scale=1, dtype=None):
+    def shape_to_np(shape, dtype=int):
+        if dtype == int:
+            return np.random.randint(0, 4, shape).astype(dtype=np.int32)
+        else:
+            return np.random.rand(*shape).astype(dtype=np.float32)
 
-    np_array = map(shape_to_np, shapes)
-    return [torch.tensor(t * scale) for t in np_array]
+    if not isinstance(dtype, list):
+        dtype = [dtype] * len(shapes)
+    np_array = list([shape_to_np(s,t) for s,t in zip(shapes, dtype)])
+    return list([torch.tensor(t * scale) for t in np_array])
 
 
 ##
@@ -59,7 +61,8 @@ def _rand_like(*shapes, scale=1, rand="int"):
 # @param input_dims dimensions to record including batch (list of tuple)
 # @param label_dims dimensions to record including batch (list of tuple)
 # @param name golden name
-def record_v2(model, iteration, input_dims, label_dims, name, clip=False):
+def record_v2(model, iteration, input_dims, label_dims, name, clip=False,
+              input_dtype=None):
     ## file format is as below
     # [<number of iteration(int)> <Iteration> <Iteration>...<Iteration>]
     # Each iteration contains
@@ -74,8 +77,8 @@ def record_v2(model, iteration, input_dims, label_dims, name, clip=False):
     optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
 
     def record_iteration(write_fn):
-        inputs = _rand_like(*input_dims, rand="float")
-        labels = _rand_like(*label_dims, rand="float")
+        inputs = _rand_like(input_dims, dtype=input_dtype if input_dtype is not None else float)
+        labels = _rand_like(label_dims, dtype=float)
         write_fn(inputs)
         write_fn(labels)
         write_fn(list(t for _, t in params_translated(model)))
index 4514727..dd26d89 100644 (file)
@@ -78,6 +78,11 @@ TEST_P(nntrainerModelTest, model_test_optimized) {
  * @brief check given ini is failing/suceeding at validation
  */
 TEST_P(nntrainerModelTest, model_test_validate) {
+  if (!shouldValidate()) {
+    std::cout << "[ SKIPPED  ] option not enabled \n";
+    return;
+  }
+
   validate(true);
   /// add stub test for tcm
   EXPECT_TRUE(true);
index 33870c9..daa43d1 100644 (file)
@@ -31,14 +31,17 @@ class NeuralNetwork;
  *
  */
 typedef enum {
-  NO_THROW_RUN = 0, /**< no comparison, only validate execution without throw */
-  COMPARE = 1 << 0, /**< Set this to compare the numbers */
-  SAVE_AND_LOAD_INI = 1 << 1, /**< Set this to check if saving and constructing
+  NO_THROW_RUN =
+    1 << 0, /**< no comparison, only validate execution without throw */
+  COMPARE_RUN = 1 << 1,       /**< Set this to compare the numbers */
+  SAVE_AND_LOAD_INI = 1 << 2, /**< Set this to check if saving and constructing
                                  a new model works okay (without weights) */
-  USE_V2 = 1 << 2,            /**< use v2 model format */
+  USE_V2 = 1 << 3,            /**< use v2 model format */
+  COMPARE = COMPARE_RUN | NO_THROW_RUN, /**< Set this to comp are the numbers */
 
-  COMPARE_V2 = COMPARE | USE_V2,                 /**< compare v2 */
+  COMPARE_RUN_V2 = COMPARE_RUN | USE_V2,         /**< compare run v2 */
   NO_THROW_RUN_V2 = NO_THROW_RUN | USE_V2,       /**< no throw run with v2 */
+  COMPARE_V2 = COMPARE | USE_V2,                 /**< compare v2 */
   SAVE_AND_LOAD_V2 = SAVE_AND_LOAD_INI | USE_V2, /**< save and load with v2 */
 
   ALL = COMPARE | SAVE_AND_LOAD_INI, /**< Set every option */
@@ -135,7 +138,13 @@ protected:
    *
    * @return bool true if test should be done
    */
-  bool shouldCompare() { return options & (ModelTestOption::COMPARE); }
+  bool shouldCompare() { return options & (ModelTestOption::COMPARE_RUN); }
+  /**
+   * @brief query if compare test should be conducted
+   *
+   * @return bool true if test should be done
+   */
+  bool shouldValidate() { return options & (ModelTestOption::NO_THROW_RUN); }
 
   /**
    * @brief query if saveload ini test should be done
index be49067..c3ca770 100644 (file)
@@ -61,6 +61,28 @@ static std::unique_ptr<NeuralNetwork> makeMolAttention() {
   return nn;
 }
 
+static std::unique_ptr<NeuralNetwork> makeMolAttentionMasked() {
+  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  nn->setProperty({"batch_size=3"});
+
+  auto outer_graph = makeGraph({
+    {"input", {"name=in4", "input_shape=1:1:1"}},
+    {"input", {"name=in3", "input_shape=1:1:5"}},
+    {"input", {"name=in2", "input_shape=1:4:6"}},
+    {"input", {"name=in1", "input_shape=1:1:6"}},
+    {"mol_attention",
+     {"name=mol", "input_layers=in1,in2,in3,in4", "unit=8", "mol_k=5"}},
+    {"constant_derivative", {"name=loss", "input_layers=mol"}},
+  });
+
+  for (auto &node : outer_graph) {
+    nn->addLayer(node);
+  }
+
+  nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+  return nn;
+}
+
 INSTANTIATE_TEST_CASE_P(
   model, nntrainerModelTest,
   ::testing::ValuesIn({
@@ -68,6 +90,8 @@ INSTANTIATE_TEST_CASE_P(
                  ModelTestOption::COMPARE_V2),
     mkModelTc_V2(makeMolAttention, "mol_attention",
                  ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeMolAttentionMasked, "mol_attention_masked",
+                 ModelTestOption::COMPARE_RUN_V2),
   }),
   [](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info) {
     return std::get<1>(info.param);