[layer] Add MoL layer unittest
[platform/core/ml/nntrainer.git] / test / unittest / models / unittest_models.cpp
index a2384da..be49067 100644 (file)
@@ -40,11 +40,34 @@ IniWrapper reduce_mean_last("reduce_mean_last",
                               constant_loss,
                             });
 
+static std::unique_ptr<NeuralNetwork> makeMolAttention() {
+  std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
+  nn->setProperty({"batch_size=3"});
+
+  auto outer_graph = makeGraph({
+    {"input", {"name=in3", "input_shape=1:1:5"}},
+    {"input", {"name=in2", "input_shape=1:4:6"}},
+    {"input", {"name=in1", "input_shape=1:1:6"}},
+    {"mol_attention",
+     {"name=mol", "input_layers=in1,in2,in3", "unit=8", "mol_k=5"}},
+    {"constant_derivative", {"name=loss", "input_layers=mol"}},
+  });
+
+  for (auto &node : outer_graph) {
+    nn->addLayer(node);
+  }
+
+  nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
+  return nn;
+}
+
 INSTANTIATE_TEST_CASE_P(
   model, nntrainerModelTest,
   ::testing::ValuesIn({
     mkModelIniTc(reduce_mean_last, DIM_UNUSED, NOT_USED_,
                  ModelTestOption::COMPARE_V2),
+    mkModelTc_V2(makeMolAttention, "mol_attention",
+                 ModelTestOption::COMPARE_V2),
   }),
   [](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info) {
     return std::get<1>(info.param);