[test] add test case when a specific layer is non-trainable
[platform/core/ml/nntrainer.git] / test / unittest / models / unittest_models.cpp
1 // SPDX-License-Identifier: Apache-2.0
2 /**
3  * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4  *
5  * @file unittest_models.cpp
6  * @date 25 Nov 2021
7  * @brief unittest models for v2 version
8  * @see https://github.com/nnstreamer/nntrainer
9  * @author Parichay Kapoor <pk.kapoor@samsung.com>
10  * @bug No known bugs except for NYI items
11  */
12
13 #include <gtest/gtest.h>
14
15 #include <memory>
16
17 #include <ini_wrapper.h>
18 #include <neuralnet.h>
19 #include <nntrainer_test_util.h>
20
21 #include <models_golden_test.h>
22
23 using namespace nntrainer;
24
25 static inline constexpr const int NOT_USED_ = 1;
26
27 static IniSection nn_base("model", "type = NeuralNetwork");
28 static std::string fc_base = "type = Fully_connected";
29 static std::string red_mean_base = "type = reduce_mean";
30 static IniSection sgd_base("optimizer", "Type = sgd");
31 static IniSection constant_loss("loss", "type = constant_derivative");
32 static IniSection act_base("activation", "Type = Activation");
33
34 IniWrapper reduce_mean_last("reduce_mean_last",
35                             {
36                               nn_base + "batch_size=3",
37                               sgd_base + "learning_rate=0.1",
38                               IniSection("fc_1") + fc_base +
39                                 "unit=7 | input_shape=1:1:2",
40                               IniSection("red_mean") + red_mean_base + "axis=3",
41                               constant_loss,
42                             });
43
44 IniWrapper fc_relu_decay(
45   "fc_relu_decay",
46   {nn_base + "Loss=mse | batch_size = 3", sgd_base + "learning_rate = 0.1",
47    IniSection("input") + "type=input" + "input_shape = 1:1:3",
48    IniSection("dense") + fc_base + "unit = 10" + "weight_decay=0.9",
49    IniSection("act") + act_base + "Activation = relu",
50    IniSection("dense_1") + fc_base + "unit = 2" + "bias_decay=0.9",
51    IniSection("act_1") + act_base + "Activation = sigmoid"});
52
53 /**
54  * @brief get function to make model with non-trainable fc layer
55  * @param[in] idx index of the fc layer to be non-trainable
56  * @retval function to make model with non-trainable fc layer
57  */
58 std::function<std::unique_ptr<NeuralNetwork>()>
59 getFuncToMakeNonTrainableFc(int idx) {
60
61   std::string fc1_trainable = (idx == 1) ? "trainable=false" : "trainable=true";
62   std::string fc2_trainable = (idx == 2) ? "trainable=false" : "trainable=true";
63   std::string fc3_trainable = (idx == 3) ? "trainable=false" : "trainable=true";
64
65   return [fc1_trainable, fc2_trainable, fc3_trainable]() {
66     std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
67
68     nn->setProperty({"batch_size=3"});
69
70     auto outer_graph = makeGraph({
71       {"input", {"name=in", "input_shape=1:1:3"}},
72       {"fully_connected",
73        {"name=fc1", "input_layers=in", "unit=10", "activation=relu",
74         fc1_trainable}},
75       {"fully_connected",
76        {"name=fc2", "input_layers=fc1", "unit=10", "activation=relu",
77         fc2_trainable}},
78       {"fully_connected",
79        {"name=fc3", "input_layers=fc2", "unit=2", "activation=sigmoid",
80         fc3_trainable}},
81       {"mse", {"name=loss", "input_layers=fc3"}},
82     });
83
84     for (auto &node : outer_graph) {
85       nn->addLayer(node);
86     }
87
88     nn->setOptimizer(
89       ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
90     nn->setProperty({"input_layers=in", "label_layers=loss"});
91
92     return nn;
93   };
94 }
95
96 static auto makeNonTrainableFcIdx1 = getFuncToMakeNonTrainableFc(1);
97 static auto makeNonTrainableFcIdx2 = getFuncToMakeNonTrainableFc(2);
98 static auto makeNonTrainableFcIdx3 = getFuncToMakeNonTrainableFc(3);
99
100 static std::unique_ptr<NeuralNetwork> makeMolAttention() {
101   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
102   nn->setProperty({"batch_size=3"});
103
104   auto outer_graph = makeGraph({
105     {"input", {"name=in3", "input_shape=1:1:5"}},
106     {"input", {"name=in2", "input_shape=1:4:6"}},
107     {"input", {"name=in1", "input_shape=1:1:6"}},
108     {"mol_attention",
109      {"name=mol", "input_layers=in1,in2,in3", "unit=8", "mol_k=5"}},
110     {"constant_derivative", {"name=loss1", "input_layers=mol(0)"}},
111     {"constant_derivative", {"name=loss2", "input_layers=mol(1)"}},
112   });
113
114   nn->setProperty({"label_layers=loss1,loss2"});
115   for (auto &node : outer_graph) {
116     nn->addLayer(node);
117   }
118
119   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
120   return nn;
121 }
122
123 static std::unique_ptr<NeuralNetwork> makeMolAttentionMasked() {
124   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
125   nn->setProperty({"batch_size=3"});
126
127   auto outer_graph = makeGraph({
128     {"input", {"name=in4", "input_shape=1:1:1"}},
129     {"input", {"name=in3", "input_shape=1:1:5"}},
130     {"input", {"name=in2", "input_shape=1:4:6"}},
131     {"input", {"name=in1", "input_shape=1:1:6"}},
132     {"mol_attention",
133      {"name=mol", "input_layers=in1,in2,in3,in4", "unit=8", "mol_k=5"}},
134     {"constant_derivative", {"name=loss1", "input_layers=mol(0)"}},
135     {"constant_derivative", {"name=loss2", "input_layers=mol(1)"}},
136   });
137
138   nn->setProperty({"label_layers=loss1,loss2"});
139   for (auto &node : outer_graph) {
140     nn->addLayer(node);
141   }
142
143   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
144   return nn;
145 }
146
147 static std::unique_ptr<NeuralNetwork>
148 makeMultiHeadAttention_disable_need_weights() {
149   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
150   nn->setProperty({"batch_size=3"});
151
152   auto outer_graph = makeGraph({
153     {"input", {"name=input_0", "input_shape=1:3:6"}},
154     {"input", {"name=input_1", "input_shape=1:2:6"}},
155     {"input", {"name=input_2", "input_shape=1:2:6"}},
156     {"multi_head_attention",
157      {"name=multi_head_attention", "input_layers=input_0, input_1, input_2",
158       "disable_bias=true", "num_heads=2"}},
159     {"mse", {"name=loss", "input_layers=multi_head_attention"}},
160   });
161
162   for (auto &node : outer_graph) {
163     nn->addLayer(node);
164   }
165
166   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
167   nn->setProperty({"input_layers=input_0, input_1, input_2"});
168
169   return nn;
170 }
171
172 static std::unique_ptr<NeuralNetwork> makeMultiHeadAttention() {
173   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
174   nn->setProperty({"batch_size=3"});
175
176   auto outer_graph = makeGraph({
177     {"input", {"name=input_0", "input_shape=1:3:6"}},
178     {"input", {"name=input_1", "input_shape=1:2:6"}},
179     {"input", {"name=input_2", "input_shape=1:2:6"}},
180     {"multi_head_attention",
181      {"name=multi_head_attention", "input_layers=input_0, input_1, input_2",
182       "num_heads=2", "return_attention_weight=after"}},
183     {"mse", {"name=loss1", "input_layers=multi_head_attention(0)"}},
184     {"mse", {"name=loss2", "input_layers=multi_head_attention(1)"}},
185   });
186
187   for (auto &node : outer_graph) {
188     nn->addLayer(node);
189   }
190
191   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
192   nn->setProperty(
193     {"input_layers=input_0, input_1, input_2", "label_layers=loss1, loss2"});
194
195   return nn;
196 }
197
198 static std::unique_ptr<NeuralNetwork> makeMultiHeadAttention_kdim_vdim() {
199   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
200   nn->setProperty({"batch_size=3"});
201
202   auto outer_graph = makeGraph({
203     {"input", {"name=input_0", "input_shape=1:3:6"}},
204     {"input", {"name=input_1", "input_shape=1:2:4"}},
205     {"input", {"name=input_2", "input_shape=1:2:5"}},
206     {"multi_head_attention",
207      {"name=multi_head_attention", "input_layers=input_0, input_1, input_2",
208       "num_heads=2", "return_attention_weight=after"}},
209     {"mse", {"name=loss1", "input_layers=multi_head_attention(0)"}},
210     {"mse", {"name=loss2", "input_layers=multi_head_attention(1)"}},
211   });
212
213   for (auto &node : outer_graph) {
214     nn->addLayer(node);
215   }
216
217   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
218   nn->setProperty(
219     {"input_layers=input_0, input_1, input_2", "label_layers=loss1, loss2"});
220
221   return nn;
222 }
223
224 static std::unique_ptr<NeuralNetwork> makeMultiHeadAttention_float_attn_mask() {
225   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
226   nn->setProperty({"batch_size=3"});
227
228   auto outer_graph = makeGraph({
229     {"input", {"name=input_0", "input_shape=1:3:6"}},
230     {"input", {"name=input_1", "input_shape=1:2:6"}},
231     {"input", {"name=input_2", "input_shape=1:2:6"}},
232     {"input", {"name=input_3", "input_shape=2:3:2"}},
233     {"multi_head_attention",
234      {"name=multi_head_attention",
235       "input_layers=input_0, input_1, input_2, input_3", "num_heads=2",
236       "return_attention_weight=after"}},
237     {"mse", {"name=loss1", "input_layers=multi_head_attention(0)"}},
238     {"mse", {"name=loss2", "input_layers=multi_head_attention(1)"}},
239   });
240
241   for (auto &node : outer_graph) {
242     nn->addLayer(node);
243   }
244
245   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
246   nn->setProperty({"input_layers=input_0, input_1, input_2, input_3",
247                    "label_layers=loss1, loss2"});
248
249   return nn;
250 }
251
252 static std::unique_ptr<NeuralNetwork> makeMultiHeadAttention_self_attention() {
253   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
254   nn->setProperty({"batch_size=3"});
255
256   auto outer_graph = makeGraph({
257     {"input", {"name=input_0", "input_shape=1:3:6"}},
258     {"multi_head_attention",
259      {"name=multi_head_attention", "input_layers=input_0, input_0, input_0",
260       "num_heads=2", "return_attention_weight=after"}},
261     {"mse", {"name=loss1", "input_layers=multi_head_attention(0)"}},
262     {"mse", {"name=loss2", "input_layers=multi_head_attention(1)"}},
263   });
264
265   for (auto &node : outer_graph) {
266     nn->addLayer(node);
267   }
268
269   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
270   nn->setProperty({"input_layers=input_0", "label_layers=loss1, loss2"});
271
272   return nn;
273 }
274
275 static std::unique_ptr<NeuralNetwork> makePositionalEncoding() {
276   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
277   nn->setProperty({"batch_size=3"});
278
279   auto outer_graph = makeGraph({
280     {"input", {"name=input", "input_shape=5:1:6"}},
281     {"reshape", {"name=reshape", "target_shape=1:5:6"}},
282     {"positional_encoding", {"name=positional_encoding", "max_timestep=7"}},
283     {"multi_head_attention",
284      {"name=multi_head_attention",
285       "input_layers=positional_encoding, positional_encoding, "
286       "positional_encoding",
287       "num_heads=2"}},
288     {"mse", {"name=loss", "input_layers=multi_head_attention(0)"}},
289   });
290
291   for (auto &node : outer_graph) {
292     nn->addLayer(node);
293   }
294
295   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
296   return nn;
297 }
298
299 static std::unique_ptr<NeuralNetwork> makeTransformerEncoderLayer() {
300   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
301   nn->setProperty({"batch_size=3"});
302
303   auto outer_graph = makeGraph({
304     {"input", {"name=input_0", "input_shape=1:5:6"}},
305     {"multi_head_attention",
306      {"name=multi_head_attention", "input_layers=input_0, input_0, input_0",
307       "num_heads=2"}},
308     {"addition", {"name=add1", "input_layers=input_0, multi_head_attention"}},
309     {"layer_normalization", {"name=ln1", "axis=3", "epsilon=1e-5"}},
310     {"fully_connected", {"name=fc1", "unit=7", "activation=relu"}},
311     {"fully_connected", {"name=fc2", "unit=6"}},
312     {"addition", {"name=add2", "input_layers=ln1, fc2"}},
313     {"layer_normalization", {"name=ln2", "axis=3", "epsilon=1e-5"}},
314     {"mse", {"name=loss", "input_layers=ln2"}},
315   });
316
317   for (auto &node : outer_graph) {
318     nn->addLayer(node);
319   }
320
321   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
322   nn->setProperty({"input_layers=input_0", "label_layers=loss"});
323
324   return nn;
325 }
326
327 static std::unique_ptr<NeuralNetwork>
328 makeTransformerEncoderLayer_float_attn_mask() {
329   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
330   nn->setProperty({"batch_size=3"});
331
332   auto outer_graph = makeGraph({
333     {"input", {"name=input_0", "input_shape=1:5:6"}},
334     {"input", {"name=input_1", "input_shape=2:5:5"}},
335     {"multi_head_attention",
336      {"name=multi_head_attention",
337       "input_layers=input_0, input_0, input_0, input_1", "num_heads=2"}},
338     {"addition", {"name=add1", "input_layers=input_0, multi_head_attention"}},
339     {"layer_normalization", {"name=ln1", "axis=3", "epsilon=1e-5"}},
340     {"fully_connected", {"name=fc1", "unit=7", "activation=relu"}},
341     {"fully_connected", {"name=fc2", "unit=6"}},
342     {"addition", {"name=add2", "input_layers=ln1, fc2"}},
343     {"layer_normalization", {"name=ln2", "axis=3", "epsilon=1e-5"}},
344     {"mse", {"name=loss", "input_layers=ln2"}},
345   });
346
347   for (auto &node : outer_graph) {
348     nn->addLayer(node);
349   }
350
351   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
352   nn->setProperty({"input_layers=input_0, input_1", "label_layers=loss"});
353
354   return nn;
355 }
356
357 static std::unique_ptr<NeuralNetwork> makeTransformerDecoderLayer() {
358   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
359   nn->setProperty({"batch_size=3"});
360
361   auto outer_graph = makeGraph({
362     {"input", {"name=input_0", "input_shape=1:5:6"}},
363     {"input", {"name=input_1", "input_shape=1:4:6"}},
364     {"multi_head_attention",
365      {"name=masked_multi_head_attention",
366       "input_layers=input_0, input_0, input_0", "num_heads=2"}},
367     {"addition",
368      {"name=add1", "input_layers=input_0, masked_multi_head_attention"}},
369     {"layer_normalization", {"name=ln1", "axis=3", "epsilon=1e-5"}},
370     {"multi_head_attention",
371      {"name=multi_head_attention", "input_layers=ln1, input_1, input_1",
372       "num_heads=2"}},
373     {"addition", {"name=add2", "input_layers=ln1, multi_head_attention"}},
374     {"layer_normalization", {"name=ln2", "axis=3", "epsilon=1e-5"}},
375     {"fully_connected", {"name=fc1", "unit=7", "activation=relu"}},
376     {"fully_connected", {"name=fc2", "unit=6"}},
377     {"addition", {"name=add3", "input_layers=ln2, fc2"}},
378     {"layer_normalization", {"name=ln3", "axis=3", "epsilon=1e-5"}},
379     {"mse", {"name=loss", "input_layers=ln3"}},
380   });
381
382   for (auto &node : outer_graph) {
383     nn->addLayer(node);
384   }
385
386   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
387   nn->setProperty({"input_layers=input_0, input_1", "label_layers=loss"});
388
389   return nn;
390 }
391
392 static std::unique_ptr<NeuralNetwork>
393 makeTransformerDecoderLayer_float_attn_mask() {
394   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
395   nn->setProperty({"batch_size=3"});
396
397   auto outer_graph = makeGraph({
398     {"input", {"name=input_0", "input_shape=1:5:6"}},
399     {"input", {"name=input_1", "input_shape=1:4:6"}},
400     {"input", {"name=input_2", "input_shape=2:5:5"}},
401     {"input", {"name=input_3", "input_shape=2:5:4"}},
402     {"multi_head_attention",
403      {"name=masked_multi_head_attention",
404       "input_layers=input_0, input_0, input_0, input_2", "num_heads=2"}},
405     {"addition",
406      {"name=add1", "input_layers=input_0, masked_multi_head_attention"}},
407     {"layer_normalization", {"name=ln1", "axis=3", "epsilon=1e-5"}},
408     {"multi_head_attention",
409      {"name=multi_head_attention",
410       "input_layers=ln1, input_1, input_1, input_3", "num_heads=2"}},
411     {"addition", {"name=add2", "input_layers=ln1, multi_head_attention"}},
412     {"layer_normalization", {"name=ln2", "axis=3", "epsilon=1e-5"}},
413     {"fully_connected", {"name=fc1", "unit=7", "activation=relu"}},
414     {"fully_connected", {"name=fc2", "unit=6"}},
415     {"addition", {"name=add3", "input_layers=ln2, fc2"}},
416     {"layer_normalization", {"name=ln3", "axis=3", "epsilon=1e-5"}},
417     {"mse", {"name=loss", "input_layers=ln3"}},
418   });
419
420   for (auto &node : outer_graph) {
421     nn->addLayer(node);
422   }
423
424   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
425   nn->setProperty(
426     {"input_layers=input_0, input_1, input_2, input_3", "label_layers=loss"});
427
428   return nn;
429 }
430
431 static std::unique_ptr<NeuralNetwork> makeTransformer_single_layer() {
432   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
433   nn->setProperty({"batch_size=3"});
434
435   auto decoder_input = makeGraph({
436     {"input", {"name=decoder_input", "input_shape=1:4:6"}},
437   });
438
439   for (auto &node : decoder_input) {
440     nn->addLayer(node);
441   }
442
443   auto decoder_layer = makeGraph({
444     {"multiout", {"name=decoder_layer1/multi_out1"}},
445     {"multi_head_attention",
446      {"name=decoder_layer1/masked_multi_head_attention",
447       "input_layers=decoder_layer1/multi_out1(0), "
448       "decoder_layer1/multi_out1(1), decoder_layer1/multi_out1(2)",
449       "num_heads=2"}},
450     {"addition",
451      {"name=decoder_layer1/add1",
452       "input_layers=decoder_layer1/multi_out1(3), "
453       "decoder_layer1/masked_multi_head_attention"}},
454     {"layer_normalization",
455      {"name=decoder_layer1/ln1", "axis=3", "epsilon=1e-5"}},
456     {"multiout", {"name=decoder_layer1/multi_out2"}},
457     {"multi_head_attention",
458      {"name=decoder_layer1/multi_head_attention",
459       "input_layers=decoder_layer1/multi_out2(0), encoder_output(0), "
460       "encoder_output(1)",
461       "num_heads=2"}},
462     {"addition",
463      {"name=decoder_layer1/add2", "input_layers=decoder_layer1/multi_out2(1), "
464                                   "decoder_layer1/multi_head_attention"}},
465     {"layer_normalization",
466      {"name=decoder_layer1/ln2", "axis=3", "epsilon=1e-5"}},
467     {"multiout", {"name=decoder_layer1/multi_out3"}},
468     {"fully_connected",
469      {"name=decoder_layer1/fc1", "input_layers=decoder_layer1/multi_out3(0)",
470       "unit=7", "activation=relu"}},
471     {"fully_connected", {"name=decoder_layer1/fc2", "unit=6"}},
472     {"addition",
473      {"name=add3",
474       "input_layers=decoder_layer1/multi_out3(1), decoder_layer1/fc2"}},
475     {"layer_normalization",
476      {"name=decoder_layer1/ln3", "axis=3", "epsilon=1e-5"}},
477   });
478
479   for (auto &node : decoder_layer) {
480     nn->addLayer(node);
481   }
482
483   auto decoder_output = makeGraph({
484     {"layer_normalization",
485      {"name=decoder_layer_normalization", "axis=3", "epsilon=1e-5"}},
486     {"mse", {"name=loss"}},
487   });
488
489   for (auto &node : decoder_output) {
490     nn->addLayer(node);
491   }
492
493   auto encoder_input = makeGraph({
494     {"input", {"name=encoder_input", "input_shape=1:5:6"}},
495   });
496
497   for (auto &node : encoder_input) {
498     nn->addLayer(node);
499   }
500
501   auto encoder = makeGraph({
502     {"multiout", {"name=encoder_layer1/multi_out1"}},
503     {"multi_head_attention",
504      {"name=encoder_layer1/multi_head_attention",
505       "input_layers=encoder_layer1/multi_out1(0), "
506       "encoder_layer1/multi_out1(1), encoder_layer1/multi_out1(2)",
507       "num_heads=2"}},
508     {"addition",
509      {"name=encoder_layer1/add1", "input_layers=encoder_layer1/multi_out1(3), "
510                                   "encoder_layer1/multi_head_attention"}},
511     {"layer_normalization",
512      {"name=encoder_layer1/ln1", "axis=3", "epsilon=1e-5"}},
513     {"multiout", {"name=encoder_layer1/multi_out2"}},
514     {"fully_connected",
515      {"name=encoder_layer1/fc1", "input_layers=encoder_layer1/multi_out2(0)",
516       "unit=7", "activation=relu"}},
517     {"fully_connected", {"name=encoder_layer1/fc2", "unit=6"}},
518     {"addition",
519      {"name=add2",
520       "input_layers=encoder_layer1/multi_out2(1), encoder_layer1/fc2"}},
521     {"layer_normalization", {"name=ln2", "axis=3", "epsilon=1e-5"}},
522   });
523
524   for (auto &node : encoder) {
525     nn->addLayer(node);
526   }
527
528   auto encoder_output = makeGraph({
529     {"layer_normalization",
530      {"name=encoder_layer_normalization", "axis=3", "epsilon=1e-5"}},
531     {"multiout", {"name=encoder_output"}},
532   });
533
534   for (auto &node : encoder_output) {
535     nn->addLayer(node);
536   }
537
538   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
539   nn->setProperty(
540     {"input_layers=encoder_input, decoder_input", "label_layers=loss"});
541
542   return nn;
543 }
544
545 static std::unique_ptr<NeuralNetwork> makeTransformer_stack_layer() {
546   const unsigned int num_encoder_layer = 2;
547   const unsigned int num_decoder_layer = 2;
548   const unsigned int batch_size = 3;
549   const unsigned int num_heads = 2;
550   const unsigned int encoder_timestep = 5;
551   const unsigned int decoder_timestep = 4;
552   const unsigned int model_dim = 6;
553   const unsigned int fc_unit = 7;
554
555   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
556   nn->setProperty({"batch_size=" + std::to_string(batch_size)});
557
558   auto decoder_input = makeGraph({
559     {"input",
560      {"name=decoder_input",
561       "input_shape=1:" + std::to_string(decoder_timestep) + ":" +
562         std::to_string(model_dim)}},
563   });
564
565   for (auto &node : decoder_input) {
566     nn->addLayer(node);
567   }
568
569   for (unsigned int i = 0; i < num_decoder_layer; ++i) {
570     auto decoder_layer = makeGraph({
571       {"multiout", {"name=decoder_layer" + std::to_string(i) + "/multi_out1"}},
572       {"multi_head_attention",
573        {"name=decoder_layer" + std::to_string(i) +
574           "/masked_multi_head_attention",
575         "input_layers=decoder_layer" + std::to_string(i) +
576           "/multi_out1(0), decoder_layer" + std::to_string(i) +
577           "/multi_out1(1), decoder_layer" + std::to_string(i) +
578           "/multi_out1(2)",
579         "num_heads=" + std::to_string(num_heads)}},
580       {"addition",
581        {"name=decoder_layer" + std::to_string(i) + "/add1",
582         "input_layers=decoder_layer" + std::to_string(i) +
583           "/multi_out1(3), decoder_layer" + std::to_string(i) +
584           "/masked_multi_head_attention"}},
585       {"layer_normalization",
586        {"name=decoder_layer" + std::to_string(i) + "/ln1", "axis=3",
587         "epsilon=1e-5"}},
588       {"multiout", {"name=decoder_layer" + std::to_string(i) + "/multi_out2"}},
589       {"multi_head_attention",
590        {"name=decoder_layer" + std::to_string(i) + "/multi_head_attention",
591         "input_layers=decoder_layer" + std::to_string(i) +
592           "/multi_out2(0), encoder_output(0), encoder_output(1)",
593         "num_heads=" + std::to_string(num_heads)}},
594       {"addition",
595        {"name=decoder_layer" + std::to_string(i) + "/add2",
596         "input_layers=decoder_layer" + std::to_string(i) +
597           "/multi_out2(1), decoder_layer" + std::to_string(i) +
598           "/multi_head_attention"}},
599       {"layer_normalization",
600        {"name=decoder_layer" + std::to_string(i) + "/ln2", "axis=3",
601         "epsilon=1e-5"}},
602       {"multiout", {"name=decoder_layer" + std::to_string(i) + "/multi_out3"}},
603       {"fully_connected",
604        {"name=decoder_layer" + std::to_string(i) + "/fc1",
605         "input_layers=decoder_layer" + std::to_string(i) + "/multi_out3(0)",
606         "unit=" + std::to_string(fc_unit), "activation=relu"}},
607       {"fully_connected",
608        {"name=decoder_layer" + std::to_string(i) + "/fc2",
609         "unit=" + std::to_string(model_dim)}},
610       {"addition",
611        {"name=decoder_layer" + std::to_string(i) + "/add3",
612         "input_layers=decoder_layer" + std::to_string(i) +
613           "/multi_out3(1), decoder_layer" + std::to_string(i) + "/fc2"}},
614       {"layer_normalization",
615        {"name=decoder_layer" + std::to_string(i) + "/ln3", "axis=3",
616         "epsilon=1e-5"}},
617     });
618
619     for (auto &node : decoder_layer) {
620       nn->addLayer(node);
621     }
622   }
623
624   auto decoder_output = makeGraph({
625     {"layer_normalization",
626      {"name=decoder_layer_normalization", "axis=3", "epsilon=1e-5"}},
627     {"mse", {"name=loss"}},
628   });
629
630   for (auto &node : decoder_output) {
631     nn->addLayer(node);
632   }
633
634   auto encoder_input = makeGraph({
635     {"input",
636      {"name=encoder_input",
637       "input_shape=1:" + std::to_string(encoder_timestep) + ":" +
638         std::to_string(model_dim)}},
639   });
640
641   for (auto &node : encoder_input) {
642     nn->addLayer(node);
643   }
644
645   for (unsigned int i = 0; i < num_encoder_layer; ++i) {
646     auto encoder_layer = makeGraph({
647       {"multiout", {"name=encoder_layer" + std::to_string(i) + "/multi_out1"}},
648       {"multi_head_attention",
649        {"name=encoder_layer" + std::to_string(i) + "/multi_head_attention",
650         "input_layers=encoder_layer" + std::to_string(i) +
651           "/multi_out1(0), encoder_layer" + std::to_string(i) +
652           "/multi_out1(1), encoder_layer" + std::to_string(i) +
653           "/multi_out1(2)",
654         "num_heads=" + std::to_string(num_heads)}},
655       {"addition",
656        {"name=encoder_layer" + std::to_string(i) + "/add1",
657         "input_layers=encoder_layer" + std::to_string(i) +
658           "/multi_out1(3), encoder_layer" + std::to_string(i) +
659           "/multi_head_attention"}},
660       {"layer_normalization",
661        {"name=encoder_layer" + std::to_string(i) + "/ln1", "axis=3",
662         "epsilon=1e-5"}},
663       {"multiout", {"name=encoder_layer" + std::to_string(i) + "/multi_out2"}},
664       {"fully_connected",
665        {"name=encoder_layer" + std::to_string(i) + "/fc1",
666         "input_layers=encoder_layer" + std::to_string(i) + "/multi_out2(0)",
667         "unit=" + std::to_string(fc_unit), "activation=relu"}},
668       {"fully_connected",
669        {"name=encoder_layer" + std::to_string(i) + "/fc2",
670         "unit=" + std::to_string(model_dim)}},
671       {"addition",
672        {"name=encoder_layer" + std::to_string(i) + "/add2",
673         "input_layers=encoder_layer" + std::to_string(i) +
674           "/multi_out2(1), encoder_layer" + std::to_string(i) + "/fc2"}},
675       {"layer_normalization",
676        {"name=encoder_layer" + std::to_string(i) + "/ln2", "axis=3",
677         "epsilon=1e-5"}},
678     });
679
680     for (auto &node : encoder_layer) {
681       nn->addLayer(node);
682     }
683   }
684
685   auto encoder_output = makeGraph({
686     {"layer_normalization",
687      {"name=encoder_layer_normalization", "axis=3", "epsilon=1e-5"}},
688     {"multiout", {"name=encoder_output"}},
689   });
690
691   for (auto &node : encoder_output) {
692     nn->addLayer(node);
693   }
694
695   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
696   nn->setProperty(
697     {"input_layers=encoder_input, decoder_input", "label_layers=loss"});
698
699   return nn;
700 }
701
702 static std::unique_ptr<NeuralNetwork> makeTransformer_float_attn_mask() {
703   const unsigned int num_encoder_layer = 2;
704   const unsigned int num_decoder_layer = 2;
705   const unsigned int batch_size = 3;
706   const unsigned int num_heads = 2;
707   const unsigned int encoder_timestep = 5;
708   const unsigned int decoder_timestep = 4;
709   const unsigned int model_dim = 6;
710   const unsigned int fc_unit = 7;
711
712   std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
713   nn->setProperty({"batch_size=" + std::to_string(batch_size)});
714
715   auto mask_input = makeGraph({
716     {"input",
717      {"name=memory_mask", "input_shape=" + std::to_string(num_heads) + ":" +
718                             std::to_string(decoder_timestep) + ":" +
719                             std::to_string(encoder_timestep)}},
720     {"input",
721      {"name=tgt_mask", "input_shape=" + std::to_string(num_heads) + ":" +
722                          std::to_string(decoder_timestep) + ":" +
723                          std::to_string(decoder_timestep)}},
724     {"input",
725      {"name=src_mask", "input_shape=" + std::to_string(num_heads) + ":" +
726                          std::to_string(encoder_timestep) + ":" +
727                          std::to_string(encoder_timestep)}},
728   });
729
730   for (auto &node : mask_input) {
731     nn->addLayer(node);
732   }
733
734   auto decoder_input = makeGraph({
735     {"input",
736      {"name=decoder_input",
737       "input_shape=1:" + std::to_string(decoder_timestep) + ":" +
738         std::to_string(model_dim)}},
739   });
740
741   for (auto &node : decoder_input) {
742     nn->addLayer(node);
743   }
744
745   for (unsigned int i = 0; i < num_decoder_layer; ++i) {
746     auto decoder_layer = makeGraph({
747       {"multiout", {"name=decoder_layer" + std::to_string(i) + "/multi_out1"}},
748       {"multi_head_attention",
749        {"name=decoder_layer" + std::to_string(i) +
750           "/masked_multi_head_attention",
751         "input_layers=decoder_layer" + std::to_string(i) +
752           "/multi_out1(0), decoder_layer" + std::to_string(i) +
753           "/multi_out1(1), decoder_layer" + std::to_string(i) +
754           "/multi_out1(2), tgt_mask",
755         "num_heads=" + std::to_string(num_heads)}},
756       {"addition",
757        {"name=decoder_layer" + std::to_string(i) + "/add1",
758         "input_layers=decoder_layer" + std::to_string(i) +
759           "/multi_out1(3), decoder_layer" + std::to_string(i) +
760           "/masked_multi_head_attention"}},
761       {"layer_normalization",
762        {"name=decoder_layer" + std::to_string(i) + "/ln1", "axis=3",
763         "epsilon=1e-5"}},
764       {"multiout", {"name=decoder_layer" + std::to_string(i) + "/multi_out2"}},
765       {"multi_head_attention",
766        {"name=decoder_layer" + std::to_string(i) + "/multi_head_attention",
767         "input_layers=decoder_layer" + std::to_string(i) +
768           "/multi_out2(0), encoder_output(0), encoder_output(1), memory_mask",
769         "num_heads=" + std::to_string(num_heads)}},
770       {"addition",
771        {"name=decoder_layer" + std::to_string(i) + "/add2",
772         "input_layers=decoder_layer" + std::to_string(i) +
773           "/multi_out2(1), decoder_layer" + std::to_string(i) +
774           "/multi_head_attention"}},
775       {"layer_normalization",
776        {"name=decoder_layer" + std::to_string(i) + "/ln2", "axis=3",
777         "epsilon=1e-5"}},
778       {"multiout", {"name=decoder_layer" + std::to_string(i) + "/multi_out3"}},
779       {"fully_connected",
780        {"name=decoder_layer" + std::to_string(i) + "/fc1",
781         "input_layers=decoder_layer" + std::to_string(i) + "/multi_out3(0)",
782         "unit=" + std::to_string(fc_unit), "activation=relu"}},
783       {"fully_connected",
784        {"name=decoder_layer" + std::to_string(i) + "/fc2",
785         "unit=" + std::to_string(model_dim)}},
786       {"addition",
787        {"name=decoder_layer" + std::to_string(i) + "/add3",
788         "input_layers=decoder_layer" + std::to_string(i) +
789           "/multi_out3(1), decoder_layer" + std::to_string(i) + "/fc2"}},
790       {"layer_normalization",
791        {"name=decoder_layer" + std::to_string(i) + "/ln3", "axis=3",
792         "epsilon=1e-5"}},
793     });
794
795     for (auto &node : decoder_layer) {
796       nn->addLayer(node);
797     }
798   }
799
800   auto decoder_output = makeGraph({
801     {"layer_normalization",
802      {"name=decoder_layer_normalization", "axis=3", "epsilon=1e-5"}},
803     {"mse", {"name=loss"}},
804   });
805
806   for (auto &node : decoder_output) {
807     nn->addLayer(node);
808   }
809
810   auto encoder_input = makeGraph({
811     {"input", {"name=encoder_input", "input_shape=1:5:6"}},
812   });
813
814   for (auto &node : encoder_input) {
815     nn->addLayer(node);
816   }
817
818   for (unsigned int i = 0; i < num_encoder_layer; ++i) {
819     auto encoder_layer = makeGraph({
820       {"multiout", {"name=encoder_layer" + std::to_string(i) + "/multi_out1"}},
821       {"multi_head_attention",
822        {"name=encoder_layer" + std::to_string(i) + "/multi_head_attention",
823         "input_layers=encoder_layer" + std::to_string(i) +
824           "/multi_out1(0), encoder_layer" + std::to_string(i) +
825           "/multi_out1(1), encoder_layer" + std::to_string(i) +
826           "/multi_out1(2), src_mask",
827         "num_heads=" + std::to_string(num_heads)}},
828       {"addition",
829        {"name=encoder_layer" + std::to_string(i) + "/add1",
830         "input_layers=encoder_layer" + std::to_string(i) +
831           "/multi_out1(3), encoder_layer" + std::to_string(i) +
832           "/multi_head_attention"}},
833       {"layer_normalization",
834        {"name=encoder_layer" + std::to_string(i) + "/ln1", "axis=3",
835         "epsilon=1e-5"}},
836       {"multiout", {"name=encoder_layer" + std::to_string(i) + "/multi_out2"}},
837       {"fully_connected",
838        {"name=encoder_layer" + std::to_string(i) + "/fc1",
839         "input_layers=encoder_layer" + std::to_string(i) + "/multi_out2(0)",
840         "unit==" + std::to_string(fc_unit), "activation=relu"}},
841       {"fully_connected",
842        {"name=encoder_layer" + std::to_string(i) + "/fc2",
843         "unit=" + std::to_string(model_dim)}},
844       {"addition",
845        {"name=encoder_layer" + std::to_string(i) + "/add2",
846         "input_layers=encoder_layer" + std::to_string(i) +
847           "/multi_out2(1), encoder_layer" + std::to_string(i) + "/fc2"}},
848       {"layer_normalization",
849        {"name=encoder_layer" + std::to_string(i) + "/ln2", "axis=3",
850         "epsilon=1e-5"}},
851     });
852
853     for (auto &node : encoder_layer) {
854       nn->addLayer(node);
855     }
856   }
857
858   auto encoder_output = makeGraph({
859     {"layer_normalization",
860      {"name=encoder_layer_normalization", "axis=3", "epsilon=1e-5"}},
861     {"multiout", {"name=encoder_output"}},
862   });
863
864   for (auto &node : encoder_output) {
865     nn->addLayer(node);
866   }
867
868   nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
869   nn->setProperty({"input_layers=encoder_input, decoder_input, src_mask, "
870                    "tgt_mask, memory_mask",
871                    "label_layers=loss"});
872
873   return nn;
874 }
875
876 GTEST_PARAMETER_TEST(
877   model, nntrainerModelTest,
878   ::testing::ValuesIn({
879     mkModelIniTc(reduce_mean_last, DIM_UNUSED, NOT_USED_,
880                  ModelTestOption::COMPARE_V2),
881     mkModelTc_V2(makeMolAttention, "mol_attention",
882                  ModelTestOption::COMPARE_V2),
883     mkModelTc_V2(makeMolAttentionMasked, "mol_attention_masked",
884                  ModelTestOption::COMPARE_RUN_V2),
885     mkModelTc_V2(makeMultiHeadAttention_disable_need_weights,
886                  "multi_head_attention_disable_need_weights",
887                  ModelTestOption::ALL_V2),
888     mkModelTc_V2(makeMultiHeadAttention, "multi_head_attention",
889                  ModelTestOption::ALL_V2),
890     mkModelTc_V2(makeMultiHeadAttention_kdim_vdim,
891                  "multi_head_attention_kdim_vdim", ModelTestOption::ALL_V2),
892     mkModelTc_V2(makeMultiHeadAttention_float_attn_mask,
893                  "multi_head_attention_float_attn_mask",
894                  ModelTestOption::ALL_V2),
895     /** @todo:change model if bool type tensor is supported */
896     mkModelTc_V2(makeMultiHeadAttention_float_attn_mask,
897                  "multi_head_attention_pseudo_bool_attn_mask",
898                  ModelTestOption::ALL_V2),
899     mkModelTc_V2(makeMultiHeadAttention_self_attention,
900                  "multi_head_attention_self_attention",
901                  ModelTestOption::ALL_V2),
902     mkModelTc_V2(makePositionalEncoding, "positional_encoding",
903                  ModelTestOption::ALL_V2),
904     mkModelTc_V2(makeTransformerEncoderLayer, "transformer_encoder_layer",
905                  ModelTestOption::ALL_V2),
906     mkModelTc_V2(makeTransformerEncoderLayer_float_attn_mask,
907                  "transformer_encoder_layer_float_attn_mask",
908                  ModelTestOption::ALL_V2),
909     /** @todo:change model if bool type tensor is supported */
910     mkModelTc_V2(makeTransformerEncoderLayer_float_attn_mask,
911                  "transformer_encoder_layer_pseudo_bool_attn_mask",
912                  ModelTestOption::ALL_V2),
913     mkModelTc_V2(makeTransformerDecoderLayer, "transformer_decoder_layer",
914                  ModelTestOption::ALL_V2),
915     mkModelTc_V2(makeTransformerDecoderLayer_float_attn_mask,
916                  "transformer_decoder_layer_float_attn_mask",
917                  ModelTestOption::ALL_V2),
918     /** @todo:change model if bool type tensor is supported */
919     mkModelTc_V2(makeTransformerDecoderLayer_float_attn_mask,
920                  "transformer_decoder_layer_pseudo_bool_attn_mask",
921                  ModelTestOption::ALL_V2),
922     mkModelTc_V2(makeTransformer_single_layer, "transformer_single",
923                  ModelTestOption::ALL_V2),
924     mkModelTc_V2(makeTransformer_stack_layer, "transformer_stack",
925                  ModelTestOption::ALL_V2),
926     mkModelTc_V2(makeTransformer_float_attn_mask, "transformer_float_attn_mask",
927                  ModelTestOption::ALL_V2),
928     mkModelTc_V2(makeTransformer_float_attn_mask,
929                  "transformer_pseudo_bool_attn_mask", ModelTestOption::ALL_V2),
930     mkModelIniTc(fc_relu_decay, DIM_UNUSED, NOT_USED_,
931                  ModelTestOption::COMPARE_V2),
932     mkModelTc_V2(makeNonTrainableFcIdx1, "non_trainable_fc_idx1",
933                  ModelTestOption::ALL_V2),
934     mkModelTc_V2(makeNonTrainableFcIdx2, "non_trainable_fc_idx2",
935                  ModelTestOption::ALL_V2),
936     mkModelTc_V2(makeNonTrainableFcIdx3, "non_trainable_fc_idx3",
937                  ModelTestOption::ALL_V2),
938   }),
939   [](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info) {
940     return std::get<1>(info.param);
941   });