e50ba6ab4f6640793caeaa1d66041a6d50393f52
[platform/core/ml/nntrainer.git] / test / input_gen / genModelTests_v2.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: Apache-2.0
3 ##
4 # Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
5 #
6 # @file genModelTests_v2.py
7 # @date 25 November 2021
8 # @brief Generate model tcs
9 # @author Parichay Kapoor <pk.kapoor@samsung.com>
10
11 import math
12 from recorder_v2 import record_v2, inspect_file, _rand_like
13 import torch
14
15 class ReduceMeanLast(torch.nn.Module):
16     def __init__(self):
17         super().__init__()
18         self.fc = torch.nn.Linear(2, 7)
19         self.loss = torch.nn.Identity()
20
21     def forward(self, inputs, labels):
22         out = self.fc(inputs[0])
23         out = torch.mean(out, dim=-1)
24         loss = self.loss(torch.sum(out))
25         return out, loss
26
27 class MolAttention(torch.nn.Module):
28     def __init__(self, query_size):
29         super(MolAttention, self).__init__()
30         self.query_size = query_size
31         self.units = 8
32         self.K = 5 # number of mixtures
33         self.dense1 = torch.nn.Linear(self.query_size, self.units)
34         self.dense2 = torch.nn.Linear(self.units, 3 * self.K, bias=False)
35         self.loss = torch.nn.Identity()
36
37     def forward(self, inputs, labels):
38         if len(inputs) == 4:
39             query, values, attention_state, mask_len = inputs
40         else:
41             query, values, attention_state = inputs
42             mask_len = None
43         batch_size, timesteps, _ = values.size()
44
45         dense1_out = torch.tanh(self.dense1(query.unsqueeze(1)))
46         mlp_proj_out = self.dense2(dense1_out)
47         kappa, beta, alpha = mlp_proj_out.chunk(chunks=3, dim=2)
48
49         kappa = torch.exp(kappa)
50         beta = torch.exp(beta)
51         alpha = torch.softmax(alpha, dim=2)
52         kappa = kappa + attention_state
53
54         # Timesteps const array
55         j = torch.arange(start=1, end=timesteps + 1).view(1, -1, 1).expand(batch_size, -1, self.K)
56
57         integrals_left = torch.sigmoid(torch.div(j + 0.5 - kappa, beta + 1e-8))
58         integrals_right = torch.sigmoid(torch.div(j - 0.5 - kappa, beta + 1e-8))
59         integrals = alpha * (integrals_left - integrals_right)
60         scores = torch.sum(integrals, dim=2)
61
62         if mask_len is not None:
63             max_len = max(int(mask_len.max()), scores.shape[1])
64             mask = torch.arange(0, max_len)\
65                     .type_as(mask_len)\
66                     .unsqueeze(0).expand(mask_len.numel(), max_len)\
67                     .lt(mask_len.unsqueeze(1))
68             scores.masked_fill_(torch.logical_not(mask), 0.)
69
70         output = torch.matmul(scores.unsqueeze(1), values).squeeze(dim=1)
71
72         loss = self.loss(torch.sum(output)) + self.loss(torch.sum(kappa))
73
74         return (output, kappa), loss
75
76 class MultiHeadAttention(torch.nn.Module):
77     def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, need_weights=True, provide_attention_mask=False):
78         super(MultiHeadAttention, self).__init__()
79         self.multi_head_attention = torch.nn.MultiheadAttention(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first=True)
80         self.loss = torch.nn.MSELoss()
81         self.need_weights = need_weights
82         self.provide_attention_mask = provide_attention_mask
83
84     def forward(self, inputs, labels):
85         inputs, attn_mask = (inputs[:-1], inputs[-1]) if self.provide_attention_mask else (inputs, None)
86         query, *left = inputs
87         if len(left) == 0:
88             key = value = query
89         else:
90             key, value = left
91
92         output, attention_weight = self.multi_head_attention(query, key, value, need_weights=self.need_weights, attn_mask=attn_mask)
93         loss = self.loss(output, labels[0])
94         if attention_weight is not None:
95             output = [output, attention_weight]
96
97         return output, loss
98
99     def input_label_reader(input_dims, label_dims, input_dtype):
100         query_dim, key_dim, value_dim, *left_dim = input_dims
101         query_dtype, key_dtype, value_dtype, *left_dtype = input_dtype
102         assert(query_dtype == key_dtype == value_dtype)
103         if left_dim != []:
104             mask_dim = left_dim[0]
105             mask_dtype = left_dtype[0]
106             if mask_dtype == bool:
107                 # Since nntrainer does not support bool type tensor yet, convert mask to float type
108                 # todo: return bool type mask tensor
109                 mask = torch.randn(mask_dim) > 0.5
110                 new_attn_mask = torch.zeros_like(mask, dtype=torch.float32)
111                 new_attn_mask.masked_fill_(mask, float("-inf"))
112                 mask = [new_attn_mask]
113             elif mask_dtype == int:
114                 mask = [torch.randint(0, 1, mask_dim, torch.int32)]
115             else:
116                 mask = _rand_like([mask_dim], -1e9, mask_dtype)
117         else:
118             mask = []
119         inputs = _rand_like([query_dim, key_dim, value_dim], dtype=input_dtype if input_dtype is not None else float) + mask
120         labels = _rand_like(label_dims, dtype=float)
121         return inputs, labels
122
123 class PositionalEncoding(torch.nn.Module):
124     def __init__(self, d_model: int, max_len):
125         super().__init__()
126         position = torch.arange(max_len).unsqueeze(1)
127         div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
128         pe = torch.zeros(1, max_len, d_model)
129         pe[0, :, 0::2] = torch.sin(position * div_term)
130         pe[0, :, 1::2] = torch.cos(position * div_term)
131         self.register_buffer('pe', pe)
132         self.multi_head_attention = torch.nn.MultiheadAttention(d_model, 2, batch_first=True)
133         self.loss = torch.nn.MSELoss()
134
135     def forward(self, inputs, labels):
136         output = inputs[0]
137         output += self.pe[:,:output.size(1),:]
138         output = self.multi_head_attention(output, output, output)
139         loss = self.loss(output[0], labels[0])
140         return output, loss
141
142 # class for test transformer encoder layer
143 class TransformerEncoderLayer(torch.nn.Module):
144     def __init__(self, d_model, nhead, dim_feedforward, provide_attention_mask=False):
145         super(TransformerEncoderLayer, self).__init__()
146         self.encoder_layer = torch.nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout=0.0, batch_first=True)
147         self.loss = torch.nn.MSELoss()
148         # indicate attention mask will be given or not
149         self.provide_attention_mask = provide_attention_mask
150
151     def forward(self, inputs, labels):
152         inputs, attn_mask = (inputs[0], inputs[-1]) if self.provide_attention_mask else (inputs[0], None)
153         output = self.encoder_layer(inputs, attn_mask)
154
155         loss = self.loss(output, labels[0])
156
157         return output, loss
158
159     def input_label_reader(input_dims, label_dims, input_dtypes):
160         input_dim, *left_dim = input_dims
161         input_dtype, *left_dtype = input_dtypes
162         if left_dim != []:
163             mask_dim = left_dim[0]
164             mask_dtype = left_dtype[0]
165             if mask_dtype == bool:
166                 # Since nntrainer does not support bool type tensor yet, convert mask to float type
167                 # todo: return bool type mask tensor
168                 mask = torch.randn(mask_dim) > 0.5
169                 new_attn_mask = torch.zeros_like(mask, dtype=torch.float32)
170                 new_attn_mask.masked_fill_(mask, float("-inf"))
171                 mask = [new_attn_mask]
172             elif mask_dtype == int:
173                 mask = [torch.randint(0, 1, mask_dim, torch.int32)]
174             else:
175                 mask = _rand_like([mask_dim], -1e9, mask_dtype)
176         else:
177             mask = []
178         inputs = _rand_like([input_dim], dtype=input_dtype if input_dtype is not None else float) + mask
179         labels = _rand_like(label_dims, dtype=float)
180         return inputs, labels
181
182 # class for test transformer decoder layer
183 class TransformerDecoderLayer(torch.nn.Module):
184     def __init__(self, d_model, nhead, dim_feedforward, provide_attention_mask=False):
185         super(TransformerDecoderLayer, self).__init__()
186         self.decoder_layer = torch.nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout=0.0, batch_first=True)
187         self.loss = torch.nn.MSELoss()
188         # indicate attention mask will be given or not
189         self.provide_attention_mask = provide_attention_mask
190
191     def forward(self, inputs, labels):
192         tgt, memory, tgt_mask, memory_mask = (inputs[0], inputs[1], inputs[-2], inputs[-1]) if self.provide_attention_mask else (inputs[0], inputs[1], None, None)
193         output = self.decoder_layer(tgt, memory, tgt_mask, memory_mask)
194
195         loss = self.loss(output, labels[0])
196
197         return output, loss
198
199     def input_label_reader(input_dims, label_dims, input_dtypes):
200         tgt_dim, memory_dim, *mask_dims = input_dims
201         tgt_dtype, memory_dtype, *mask_dtypes = input_dtypes
202         if mask_dims != []:
203             if mask_dtypes[0] == bool:
204                 # Since nntrainer does not support bool type tensor yet, convert mask to float type
205                 # todo: return bool type mask tensor
206                 masks = [torch.randn(dim) > 0.5 for dim in mask_dims]
207                 new_attn_masks = [torch.zeros_like(mask, dtype=torch.float32) for mask in masks]
208                 for mask, new_attn_mask in zip(masks, new_attn_masks):
209                     new_attn_mask.masked_fill_(mask, float("-inf"))
210                 masks = new_attn_masks
211             elif mask_dtypes[0] == int:
212                 masks = [torch.randint(0, 1, mask_dim, torch.int32) for mask_dim in mask_dims]
213             else:
214                 masks = _rand_like(mask_dims, -1e9, mask_dtypes)
215         else:
216             masks = []
217         inputs = _rand_like([tgt_dim, memory_dim], dtype=[tgt_dtype, memory_dtype] if tgt_dtype is not None and memory_dtype is not None else float) + masks
218         labels = _rand_like(label_dims, dtype=float)
219         return inputs, labels
220
221 # class for test transformer.
222 # Transformer in this class consist of transformer encoder and transformer decoder
223 class Transformer(torch.nn.Module):
224     def __init__(self, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, provide_attention_mask=False):
225         super(Transformer, self).__init__()
226         self.transformer = torch.nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout=0.0, batch_first=True)
227         self.loss = torch.nn.MSELoss()
228         # indicate attention mask will be given or not
229         self.provide_attention_mask = provide_attention_mask
230
231     def forward(self, inputs, labels):
232         src, tgt, src_mask, tgt_mask, memory_mask = (inputs[0], inputs[1], inputs[-3], inputs[-2], inputs[-1]) if self.provide_attention_mask else (inputs[0], inputs[1], None, None, None)
233         output = self.transformer(src, tgt, src_mask, tgt_mask, memory_mask)
234
235         loss = self.loss(output, labels[0])
236
237         return output, loss
238
239     def input_label_reader(input_dims, label_dims, input_dtypes):
240         src_dim, tgt_dim, *mask_dims = input_dims
241         src_dtype, tgt_dtype, *mask_dtypes = input_dtypes
242         if mask_dims != []:
243             if mask_dtypes[0] == bool:
244                 # Since nntrainer does not support bool type tensor yet, convert mask to float type
245                 # todo: return bool type mask tensor
246                 masks = [torch.randn(dim) > 0.5 for dim in mask_dims]
247                 new_attn_masks = [torch.zeros_like(mask, dtype=torch.float32) for mask in masks]
248                 for mask, new_attn_mask in zip(masks, new_attn_masks):
249                     new_attn_mask.masked_fill_(mask, float("-inf"))
250                 masks = new_attn_masks
251             elif mask_dtypes[0] == int:
252                 masks = [torch.randint(0, 1, mask_dim, torch.int32) for mask_dim in mask_dims]
253             else:
254                 masks = _rand_like(mask_dims, -1e9, mask_dtypes)
255         else:
256             masks = []
257         inputs = _rand_like([src_dim, tgt_dim], dtype=[src_dtype, tgt_dtype] if src_dtype is not None and tgt_dtype is not None else float) + masks
258         labels = _rand_like(label_dims, dtype=float)
259         return inputs, labels
260
261 class FCRelu(torch.nn.Module):
262     def __init__(self, decay=False):
263         super().__init__()
264         self.fc = torch.nn.Linear(3, 10)
265         self.fc1 = torch.nn.Linear(10, 2)
266         self.loss = torch.nn.MSELoss()
267         self.decay = decay
268
269     def forward(self, inputs, labels):
270         out = torch.relu(self.fc(inputs[0]))
271         out = torch.sigmoid(self.fc1(out))
272         loss = self.loss(out, labels[0])
273         return out, loss
274
275     def getOptimizer(self):
276         if not self.decay:
277             return torch.optim.SGD(self.parameters(), lr=0.1)
278         else:
279             decay_params = []
280             non_decay_params = []
281             for name, params in self.named_parameters():
282                 if name == 'fc.weight' or name == 'fc1.bias':
283                     decay_params.append(params)
284                 else:
285                     non_decay_params.append(params)
286             return torch.optim.SGD([
287                 {'params': non_decay_params},
288                 {'params': decay_params, 'weight_decay': 0.9}], lr=0.1)
289
290 # class for test non-trainable fc layer
291 class NonTrainableFC(torch.nn.Module):
292     def __init__(self, idx):
293         super().__init__()
294         self.fc1 = torch.nn.Linear(3, 10)
295         self.fc2 = torch.nn.Linear(10, 10)
296         self.fc3 = torch.nn.Linear(10, 2)
297         self.loss = torch.nn.MSELoss()
298         # determine which layer to set to non-trainable
299         if idx == 1:
300             for param in self.fc1.parameters():
301                 param.requires_grad = False
302         elif idx == 2:
303             for param in self.fc2.parameters():
304                 param.requires_grad = False
305
306     def forward(self, inputs, labels):
307         out = torch.relu(self.fc1(inputs[0]))
308         out = torch.relu(self.fc2(out))
309         out = torch.sigmoid(self.fc3(out))
310         loss = self.loss(out, labels[0])
311         return out, loss
312
313 if __name__ == "__main__":
314     record_v2(
315         ReduceMeanLast(),
316         iteration=2,
317         input_dims=[(3, 2,)],
318         label_dims=[(3, 1,)],
319         name="reduce_mean_last",
320     )
321
322     record_v2(
323         MolAttention(query_size=6),
324         iteration=2,
325         input_dims=[(3,6), (3,4,6), (3,1,5), (3)],
326         input_dtype=[float, float, float, int],
327         label_dims=[(3,1,6), (3,1,5)],
328         name="mol_attention_masked",
329     )
330
331     record_v2(
332         MolAttention(query_size=6),
333         iteration=2,
334         input_dims=[(3,6), (3,4,6), (3,1,5)],
335         input_dtype=[float, float, float],
336         label_dims=[(3,1,6), (3,1,5)],
337         name="mol_attention",
338     )
339
340     record_v2(
341         MultiHeadAttention(embed_dim=6, num_heads=2, bias=False, need_weights=False),
342         iteration=2,
343         input_dims=[(3,3,6), (3,2,6), (3,2,6)],
344         label_dims=[(3,3,6)],
345         input_dtype=[float, float, float],
346         name="multi_head_attention_disable_need_weights",
347     )
348
349     record_v2(
350         MultiHeadAttention(embed_dim=6, num_heads=2),
351         iteration=2,
352         input_dims=[(3,3,6), (3,2,6), (3,2,6)],
353         label_dims=[(3,3,6), (3,3,2)],
354         input_dtype=[float, float, float],
355         name="multi_head_attention",
356     )
357
358     record_v2(
359         MultiHeadAttention(embed_dim=6, num_heads=2, kdim=4, vdim=5),
360         iteration=2,
361         input_dims=[(3,3,6), (3,2,4), (3,2,5)],
362         label_dims=[(3,3,6), (3,3,2)],
363         input_dtype=[float, float, float],
364         name="multi_head_attention_kdim_vdim",
365     )
366
367     record_v2(
368         MultiHeadAttention(embed_dim=6, num_heads=2, provide_attention_mask=True),
369         iteration=2,
370         input_dims=[(3,3,6), (3,2,6), (3,2,6), (6,3,2)],
371         label_dims=[(3,3,6), (3,3,2)],
372         input_dtype=[float, float, float, float],
373         input_label_reader=MultiHeadAttention.input_label_reader,
374         name="multi_head_attention_float_attn_mask",
375     )
376
377     # @todo: change this pseudo bool type tensor to actual bool tensor
378     record_v2(
379         MultiHeadAttention(embed_dim=6, num_heads=2, provide_attention_mask=True),
380         iteration=2,
381         input_dims=[(3,3,6), (3,2,6), (3,2,6), (6,3,2)],
382         label_dims=[(3,3,6), (3,3,2)],
383         input_dtype=[float, float, float, bool],
384         input_label_reader=MultiHeadAttention.input_label_reader,
385         name="multi_head_attention_pseudo_bool_attn_mask",
386     )
387
388     record_v2(
389         MultiHeadAttention(embed_dim=6, num_heads=2),
390         iteration=2,
391         input_dims=[(3,3,6)],
392         label_dims=[(3,3,6), (3,3,3)],
393         input_dtype=[float],
394         name="multi_head_attention_self_attention",
395     )
396
397     record_v2(
398         PositionalEncoding(d_model=6, max_len=7),
399         iteration=1,
400         input_dims=[(3,5,6)],
401         input_dtype=[float],
402         label_dims=[(3,5,6)],
403         name="positional_encoding",
404     )
405
406     record_v2(
407         TransformerEncoderLayer(d_model=6, nhead=2, dim_feedforward=7),
408         iteration=2,
409         input_dims=[(3,5,6)],
410         label_dims=[(3,5,6)],
411         input_dtype=[float],
412         name="transformer_encoder_layer",
413     )
414
415     record_v2(
416         TransformerEncoderLayer(d_model=6, nhead=2, dim_feedforward=7, provide_attention_mask=True),
417         iteration=2,
418         input_dims=[(3,5,6), (6,5,5)],
419         label_dims=[(3,5,6)],
420         input_dtype=[float, float],
421         input_label_reader=TransformerEncoderLayer.input_label_reader,
422         name="transformer_encoder_layer_float_attn_mask",
423     )
424
425     record_v2(
426         TransformerEncoderLayer(d_model=6, nhead=2, dim_feedforward=7, provide_attention_mask=True),
427         iteration=2,
428         input_dims=[(3,5,6), (6,5,5)],
429         label_dims=[(3,5,6)],
430         input_dtype=[float, bool],
431         input_label_reader=TransformerEncoderLayer.input_label_reader,
432         name="transformer_encoder_layer_pseudo_bool_attn_mask",
433     )
434
435     record_v2(
436         TransformerDecoderLayer(d_model=6, nhead=2, dim_feedforward=7),
437         iteration=2,
438         input_dims=[(3,5,6), (3,4,6)],
439         label_dims=[(3,5,6)],
440         input_dtype=[float, float],
441         name="transformer_decoder_layer",
442     )
443
444     record_v2(
445         TransformerDecoderLayer(d_model=6, nhead=2, dim_feedforward=7, provide_attention_mask=True),
446         iteration=2,
447         input_dims=[(3,5,6), (3,4,6), (6,5,5), (6,5,4)],
448         label_dims=[(3,5,6)],
449         input_dtype=[float, float, float, float],
450         input_label_reader=TransformerDecoderLayer.input_label_reader,
451         name="transformer_decoder_layer_float_attn_mask",
452     )
453
454     record_v2(
455         TransformerDecoderLayer(d_model=6, nhead=2, dim_feedforward=7, provide_attention_mask=True),
456         iteration=2,
457         input_dims=[(3,5,6), (3,4,6), (6,5,5), (6,5,4)],
458         label_dims=[(3,5,6)],
459         input_dtype=[float, float, bool, bool],
460         input_label_reader=TransformerDecoderLayer.input_label_reader,
461         name="transformer_decoder_layer_pseudo_bool_attn_mask",
462     )
463
464     record_v2(
465         Transformer(d_model=6, nhead=2, num_encoder_layers=1, num_decoder_layers=1, dim_feedforward=7),
466         iteration=2,
467         input_dims=[(3,5,6), (3,4,6)],
468         label_dims=[(3,4,6)],
469         input_dtype=[float, float],
470         name="transformer_single",
471     )
472
473     record_v2(
474         Transformer(d_model=6, nhead=2, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=7),
475         iteration=2,
476         input_dims=[(3,5,6), (3,4,6)],
477         label_dims=[(3,4,6)],
478         input_dtype=[float, float],
479         name="transformer_stack",
480     )
481
482     record_v2(
483         Transformer(d_model=6, nhead=2, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=7, provide_attention_mask=True),
484         iteration=2,
485         input_dims=[(3,5,6), (3,4,6), (6,5,5), (6,4,4), (6,4,5)],
486         label_dims=[(3,4,6)],
487         input_dtype=[float, float, float, float, float],
488         input_label_reader=Transformer.input_label_reader,
489         name="transformer_float_attn_mask",
490     )
491
492     record_v2(
493         Transformer(d_model=6, nhead=2, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=7, provide_attention_mask=True),
494         iteration=2,
495         input_dims=[(3,5,6), (3,4,6), (6,5,5), (6,4,4), (6,4,5)],
496         label_dims=[(3,4,6)],
497         input_dtype=[float, float, bool, bool, bool],
498         input_label_reader=Transformer.input_label_reader,
499         name="transformer_pseudo_bool_attn_mask",
500     )
501
502     fc_relu_decay = FCRelu(decay=True)
503     record_v2(
504         fc_relu_decay,
505         iteration=2,
506         input_dims=[(3,3)],
507         input_dtype=[float],
508         label_dims=[(3,2)],
509         name="fc_relu_decay",
510         optimizer=fc_relu_decay.getOptimizer()
511     )
512
513     non_trainable_fc_idx1 = NonTrainableFC(idx=1)
514     record_v2(
515         non_trainable_fc_idx1,
516         iteration=2,
517         input_dims=[(3,3)],
518         input_dtype=[float],
519         label_dims=[(3,2)],
520         name="non_trainable_fc_idx1"
521     )
522
523     non_trainable_fc_idx2 = NonTrainableFC(idx=2)
524     record_v2(
525         non_trainable_fc_idx2,
526         iteration=2,
527         input_dims=[(3,3)],
528         input_dtype=[float],
529         label_dims=[(3,2)],
530         name="non_trainable_fc_idx2"
531     )
532     
533     # Function to check the created golden test file
534     inspect_file("fc_relu_decay.nnmodelgolden")