Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / loco / src / IR / PermutingCodec.test.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "loco/IR/PermutingCodec.h"
18
19 #include <gtest/gtest.h>
20
21 using namespace loco;
22
23 TEST(PemutationTest, feature)
24 {
25   Permutation<Domain::Feature> perm;
26
27   // All values are invalid at the beginning
28   ASSERT_FALSE(perm.mapped(FeatureAxis::Count));
29   ASSERT_FALSE(perm.mapped(FeatureAxis::Depth));
30   ASSERT_FALSE(perm.mapped(FeatureAxis::Height));
31   ASSERT_FALSE(perm.mapped(FeatureAxis::Width));
32
33   // Update mapping
34   perm[FeatureAxis::Count] = 5;
35   perm[FeatureAxis::Depth] = 6;
36   perm[FeatureAxis::Height] = 7;
37   perm[FeatureAxis::Width] = 8;
38
39   // Now perm has a mapping for all the axes
40   ASSERT_TRUE(perm.mapped(FeatureAxis::Count));
41   ASSERT_TRUE(perm.mapped(FeatureAxis::Depth));
42   ASSERT_TRUE(perm.mapped(FeatureAxis::Height));
43   ASSERT_TRUE(perm.mapped(FeatureAxis::Width));
44
45   // Check the value
46   ASSERT_EQ(5, perm[FeatureAxis::Count]);
47   ASSERT_EQ(6, perm[FeatureAxis::Depth]);
48   ASSERT_EQ(7, perm[FeatureAxis::Height]);
49   ASSERT_EQ(8, perm[FeatureAxis::Width]);
50 }
51
52 TEST(PemutationTest, filter)
53 {
54   Permutation<Domain::Filter> perm;
55
56   // All values are invalid at the beginning
57   ASSERT_FALSE(perm.mapped(FilterAxis::Count));
58   ASSERT_FALSE(perm.mapped(FilterAxis::Depth));
59   ASSERT_FALSE(perm.mapped(FilterAxis::Height));
60   ASSERT_FALSE(perm.mapped(FilterAxis::Width));
61
62   // Update mapping
63   perm[FilterAxis::Count] = 5;
64   perm[FilterAxis::Depth] = 6;
65   perm[FilterAxis::Height] = 7;
66   perm[FilterAxis::Width] = 8;
67
68   // Now perm has a mapping for all the axes
69   ASSERT_TRUE(perm.mapped(FilterAxis::Count));
70   ASSERT_TRUE(perm.mapped(FilterAxis::Depth));
71   ASSERT_TRUE(perm.mapped(FilterAxis::Height));
72   ASSERT_TRUE(perm.mapped(FilterAxis::Width));
73
74   // Check the value
75   ASSERT_EQ(5, perm[FilterAxis::Count]);
76   ASSERT_EQ(6, perm[FilterAxis::Depth]);
77   ASSERT_EQ(7, perm[FilterAxis::Height]);
78   ASSERT_EQ(8, perm[FilterAxis::Width]);
79 }
80
81 TEST(PemutationTest, depthwise_filter)
82 {
83   Permutation<Domain::DepthwiseFilter> perm;
84
85   // All values are invalid at the beginning
86   ASSERT_FALSE(perm.mapped(DepthwiseFilterAxis::Depth));
87   ASSERT_FALSE(perm.mapped(DepthwiseFilterAxis::Multiplier));
88   ASSERT_FALSE(perm.mapped(DepthwiseFilterAxis::Height));
89   ASSERT_FALSE(perm.mapped(DepthwiseFilterAxis::Width));
90
91   // Update mapping
92   perm[DepthwiseFilterAxis::Depth] = 5;
93   perm[DepthwiseFilterAxis::Multiplier] = 6;
94   perm[DepthwiseFilterAxis::Height] = 7;
95   perm[DepthwiseFilterAxis::Width] = 8;
96
97   // Now perm has a mapping for all the axes
98   ASSERT_TRUE(perm.mapped(DepthwiseFilterAxis::Depth));
99   ASSERT_TRUE(perm.mapped(DepthwiseFilterAxis::Multiplier));
100   ASSERT_TRUE(perm.mapped(DepthwiseFilterAxis::Height));
101   ASSERT_TRUE(perm.mapped(DepthwiseFilterAxis::Width));
102
103   // Check the value
104   ASSERT_EQ(5, perm[DepthwiseFilterAxis::Depth]);
105   ASSERT_EQ(6, perm[DepthwiseFilterAxis::Multiplier]);
106   ASSERT_EQ(7, perm[DepthwiseFilterAxis::Height]);
107   ASSERT_EQ(8, perm[DepthwiseFilterAxis::Width]);
108 }
109
110 TEST(PermutingEncoderTest, feature)
111 {
112   PermutingEncoder<Domain::Feature> enc;
113
114   // Encoder is invalid at the beginning
115   ASSERT_FALSE(enc.valid());
116
117   // Set "invalid" mapping
118   enc.perm()->axis(FeatureAxis::Count) = 0;
119   enc.perm()->axis(FeatureAxis::Depth) = 6;
120   enc.perm()->axis(FeatureAxis::Height) = 1;
121   enc.perm()->axis(FeatureAxis::Width) = 2;
122
123   // Encoder is still invalid
124   ASSERT_FALSE(enc.valid());
125
126   // Set another "invalid" mapping
127   enc.perm()->axis(FeatureAxis::Depth) = 1;
128
129   // Encoder is still invalid
130   ASSERT_FALSE(enc.valid());
131
132   // Set "valid" mapping
133   enc.perm()->axis(FeatureAxis::Depth) = 3;
134
135   // Encoder is now valid
136   ASSERT_TRUE(enc.valid());
137
138   // Let's test with a HD (1280x720) RGB image
139   TensorShape tensor_shape;
140
141   tensor_shape.rank(4);
142   tensor_shape.dim(0) = 1;    // COUNT
143   tensor_shape.dim(1) = 720;  // HEIGHT
144   tensor_shape.dim(2) = 1280; // WIDTH
145   tensor_shape.dim(3) = 3;    // DEPTH
146
147   // Get the feature shape corresponding to a given image
148   auto feature_shape = enc.shape(tensor_shape);
149
150   ASSERT_EQ(1, feature_shape.count());
151   ASSERT_EQ(3, feature_shape.depth());
152   ASSERT_EQ(720, feature_shape.height());
153   ASSERT_EQ(1280, feature_shape.width());
154
155   // Let's find a source tensor index!
156   FeatureIndex feature_index;
157
158   feature_index.batch() = 0;
159   feature_index.channel() = 1;
160   feature_index.row() = 2;
161   feature_index.column() = 3;
162
163   auto tensor_index = enc.value(feature_index);
164
165   ASSERT_EQ(0, tensor_index.at(0)); // BATCH(COUNT)
166   ASSERT_EQ(2, tensor_index.at(1)); // ROW(HEIGHT)
167   ASSERT_EQ(3, tensor_index.at(2)); // COLUMN(WIDTH)
168   ASSERT_EQ(1, tensor_index.at(3)); // CHANNEL(DEPTH)
169 }
170
171 TEST(PermutingEncoderTest, feature_clone)
172 {
173   PermutingEncoder<Domain::Feature> src_enc;
174
175   auto src_perm = src_enc.perm();
176
177   src_perm->axis(FeatureAxis::Count) = 0;
178   src_perm->axis(FeatureAxis::Depth) = 3;
179   src_perm->axis(FeatureAxis::Height) = 1;
180   src_perm->axis(FeatureAxis::Width) = 2;
181
182   auto dst_enc = src_enc.clone();
183   auto dst_perm = loco::must_cast<PermutingEncoder<Domain::Feature> *>(dst_enc.get())->perm();
184
185   EXPECT_EQ(dst_perm->axis(FeatureAxis::Count), src_perm->axis(FeatureAxis::Count));
186   EXPECT_EQ(dst_perm->axis(FeatureAxis::Depth), src_perm->axis(FeatureAxis::Depth));
187   EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), src_perm->axis(FeatureAxis::Height));
188   EXPECT_EQ(dst_perm->axis(FeatureAxis::Width), src_perm->axis(FeatureAxis::Width));
189
190   // Update on cloned encoder SHOULD NOT affect the original encoder
191   dst_perm->axis(FeatureAxis::Height) += 1;
192
193   EXPECT_EQ(src_perm->axis(FeatureAxis::Height), 1);
194   EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), 2);
195 }
196
197 TEST(PermutingEncoderTest, filter)
198 {
199   PermutingEncoder<Domain::Filter> enc;
200
201   // Encoder is invalid at the beginning
202   ASSERT_FALSE(enc.valid());
203
204   // Set "invalid" mapping
205   enc.perm()->axis(FilterAxis::Count) = 0;
206   enc.perm()->axis(FilterAxis::Depth) = 6;
207   enc.perm()->axis(FilterAxis::Height) = 1;
208   enc.perm()->axis(FilterAxis::Width) = 2;
209
210   // Encoder is still invalid
211   ASSERT_FALSE(enc.valid());
212
213   // Set another "invalid" mapping
214   enc.perm()->axis(FilterAxis::Depth) = 1;
215
216   // Encoder is still invalid
217   ASSERT_FALSE(enc.valid());
218
219   // Set "valid" mapping
220   enc.perm()->axis(FilterAxis::Depth) = 3;
221
222   // Encoder is now valid
223   ASSERT_TRUE(enc.valid());
224
225   TensorShape tensor_shape;
226
227   tensor_shape.rank(4);
228   tensor_shape.dim(0) = 8; // COUNT
229   tensor_shape.dim(1) = 1; // HEIGHT
230   tensor_shape.dim(2) = 7; // WIDTH
231   tensor_shape.dim(3) = 4; // DEPTH
232
233   // Get the corresponding filter shape
234   auto filter_shape = enc.shape(tensor_shape);
235
236   ASSERT_EQ(8, filter_shape.count());
237   ASSERT_EQ(4, filter_shape.depth());
238   ASSERT_EQ(1, filter_shape.height());
239   ASSERT_EQ(7, filter_shape.width());
240
241   // Let's find a source tensor index!
242   FilterIndex filter_index;
243
244   filter_index.nth() = 1;
245   filter_index.channel() = 2;
246   filter_index.row() = 0;
247   filter_index.column() = 3;
248
249   auto tensor_index = enc.value(filter_index);
250
251   ASSERT_EQ(1, tensor_index.at(0)); // NTH(COUNT)
252   ASSERT_EQ(0, tensor_index.at(1)); // ROW(HEIGHT)
253   ASSERT_EQ(3, tensor_index.at(2)); // COLUMN(WIDTH)
254   ASSERT_EQ(2, tensor_index.at(3)); // CHANNEL(DEPTH)
255 }
256
257 TEST(PermutingEncoderTest, depthwise_filter)
258 {
259   PermutingEncoder<Domain::DepthwiseFilter> enc;
260
261   // Encoder is invalid at the beginning
262   ASSERT_FALSE(enc.valid());
263
264   // Set "invalid" mapping
265   enc.perm()->axis(DepthwiseFilterAxis::Depth) = 0;
266   enc.perm()->axis(DepthwiseFilterAxis::Multiplier) = 6;
267   enc.perm()->axis(DepthwiseFilterAxis::Height) = 1;
268   enc.perm()->axis(DepthwiseFilterAxis::Width) = 2;
269
270   // Encoder is still invalid
271   ASSERT_FALSE(enc.valid());
272
273   // Set another "invalid" mapping
274   enc.perm()->axis(DepthwiseFilterAxis::Multiplier) = 1;
275
276   // Encoder is still invalid
277   ASSERT_FALSE(enc.valid());
278
279   // Set "valid" mapping
280   enc.perm()->axis(DepthwiseFilterAxis::Multiplier) = 3;
281
282   // Encoder is now valid
283   ASSERT_TRUE(enc.valid());
284
285   TensorShape tensor_shape;
286
287   tensor_shape.rank(4);
288   tensor_shape.dim(0) = 8; // DEPTH
289   tensor_shape.dim(1) = 1; // HEIGHT
290   tensor_shape.dim(2) = 7; // WIDTH
291   tensor_shape.dim(3) = 4; // MULTIPLIER
292
293   // Get the corresponding depthwise filter shape
294   auto filter_shape = enc.shape(tensor_shape);
295
296   ASSERT_EQ(8, filter_shape.depth());
297   ASSERT_EQ(4, filter_shape.multiplier());
298   ASSERT_EQ(1, filter_shape.height());
299   ASSERT_EQ(7, filter_shape.width());
300
301   // Let's find a source tensor index!
302   DepthwiseFilterIndex filter_index;
303
304   filter_index.channel() = 1;
305   filter_index.nth() = 2;
306   filter_index.row() = 0;
307   filter_index.column() = 3;
308
309   auto tensor_index = enc.value(filter_index);
310
311   ASSERT_EQ(1, tensor_index.at(0)); // CHANNEL(DEPTH)
312   ASSERT_EQ(0, tensor_index.at(1)); // ROW(HEIGHT)
313   ASSERT_EQ(3, tensor_index.at(2)); // COLUMN(WIDTH)
314   ASSERT_EQ(2, tensor_index.at(3)); // NTH(MULTIPLIER)
315 }
316
317 TEST(PermutingEncoderTest, depthwisefilter_init)
318 {
319   Permutation<Domain::DepthwiseFilter> src_perm;
320
321   src_perm.axis(DepthwiseFilterAxis::Multiplier) = 0;
322   src_perm.axis(DepthwiseFilterAxis::Depth) = 3;
323   src_perm.axis(DepthwiseFilterAxis::Height) = 1;
324   src_perm.axis(DepthwiseFilterAxis::Width) = 2;
325
326   PermutingEncoder<Domain::DepthwiseFilter> dst_enc{src_perm};
327   auto dst_perm = dst_enc.perm();
328
329   EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Multiplier),
330             src_perm.axis(DepthwiseFilterAxis::Multiplier));
331   EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Depth), src_perm.axis(DepthwiseFilterAxis::Depth));
332   EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Height),
333             src_perm.axis(DepthwiseFilterAxis::Height));
334   EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Width), src_perm.axis(DepthwiseFilterAxis::Width));
335
336   // Update on dst perm SHOULD NOT affect the src perm
337   dst_perm->axis(DepthwiseFilterAxis::Height) += 1;
338
339   EXPECT_EQ(src_perm.axis(DepthwiseFilterAxis::Height), 1);
340   EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Height), 2);
341 }
342
343 TEST(PermutingDecoderTest, feature)
344 {
345   PermutingDecoder<Domain::Feature> dec;
346
347   // Decoder is invalid at the beginning
348   ASSERT_FALSE(dec.valid());
349
350   // Set "invalid" mapping
351   dec.perm()->axis(FeatureAxis::Count) = 0;
352   dec.perm()->axis(FeatureAxis::Depth) = 6;
353   dec.perm()->axis(FeatureAxis::Height) = 1;
354   dec.perm()->axis(FeatureAxis::Width) = 2;
355
356   // Decoder is still invalid
357   ASSERT_FALSE(dec.valid());
358
359   // Set another "invalid" mapping
360   dec.perm()->axis(FeatureAxis::Depth) = 1;
361
362   // Decoder is still invalid
363   ASSERT_FALSE(dec.valid());
364
365   // Set "valid" mapping
366   dec.perm()->axis(FeatureAxis::Depth) = 3;
367
368   // Decoder is now valid
369   ASSERT_TRUE(dec.valid());
370
371   // Let's test with a HD (1280x720) RGB image
372   FeatureShape feature_shape;
373
374   feature_shape.count() = 1;
375   feature_shape.depth() = 3;
376   feature_shape.height() = 720;
377   feature_shape.width() = 1280;
378
379   // Get the tensor shape corresponding to a given image
380   auto tensor_shape = dec.shape(feature_shape);
381
382   ASSERT_EQ(4, tensor_shape.rank());
383   ASSERT_EQ(1, tensor_shape.dim(0));    // COUNT
384   ASSERT_EQ(720, tensor_shape.dim(1));  // HEIGHT
385   ASSERT_EQ(1280, tensor_shape.dim(2)); // WIDTH
386   ASSERT_EQ(3, tensor_shape.dim(3));    // DEPTH
387
388   // Let's find a source feature index!
389   TensorIndex tensor_index;
390
391   tensor_index.resize(4);
392
393   tensor_index.at(0) = 0; // BATCH(COUNT)
394   tensor_index.at(3) = 1; // CHANNEL(DEPTH)
395   tensor_index.at(1) = 2; // ROW(HEIGHT)
396   tensor_index.at(2) = 3; // COLUMN(WIDTH)
397
398   auto feature_index = dec.value(tensor_index);
399
400   ASSERT_EQ(0, feature_index.batch());
401   ASSERT_EQ(1, feature_index.channel());
402   ASSERT_EQ(2, feature_index.row());
403   ASSERT_EQ(3, feature_index.column());
404 }
405
406 TEST(PermutingDecoderTest, feature_clone)
407 {
408   PermutingDecoder<Domain::Feature> src_enc;
409
410   auto src_perm = src_enc.perm();
411
412   src_perm->axis(FeatureAxis::Count) = 0;
413   src_perm->axis(FeatureAxis::Depth) = 3;
414   src_perm->axis(FeatureAxis::Height) = 1;
415   src_perm->axis(FeatureAxis::Width) = 2;
416
417   auto dst_enc = src_enc.clone();
418   auto dst_perm = loco::must_cast<PermutingDecoder<Domain::Feature> *>(dst_enc.get())->perm();
419
420   EXPECT_EQ(dst_perm->axis(FeatureAxis::Count), src_perm->axis(FeatureAxis::Count));
421   EXPECT_EQ(dst_perm->axis(FeatureAxis::Depth), src_perm->axis(FeatureAxis::Depth));
422   EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), src_perm->axis(FeatureAxis::Height));
423   EXPECT_EQ(dst_perm->axis(FeatureAxis::Width), src_perm->axis(FeatureAxis::Width));
424
425   // Update on cloned decoder SHOULD NOT affect the original decoder
426   dst_perm->axis(FeatureAxis::Height) += 1;
427
428   EXPECT_EQ(src_perm->axis(FeatureAxis::Height), 1);
429   EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), 2);
430 }
431
432 TEST(PermutingDecoderTest, filter)
433 {
434   PermutingDecoder<Domain::Filter> dec;
435
436   // Decoder is invalid at the beginning
437   ASSERT_FALSE(dec.valid());
438
439   // Set "invalid" mapping
440   dec.perm()->axis(FilterAxis::Count) = 0;
441   dec.perm()->axis(FilterAxis::Depth) = 6;
442   dec.perm()->axis(FilterAxis::Height) = 1;
443   dec.perm()->axis(FilterAxis::Width) = 2;
444
445   // Decoder is still invalid
446   ASSERT_FALSE(dec.valid());
447
448   // Set another "invalid" mapping
449   dec.perm()->axis(FilterAxis::Depth) = 1;
450
451   // Decoder is still invalid
452   ASSERT_FALSE(dec.valid());
453
454   // Set "valid" mapping
455   dec.perm()->axis(FilterAxis::Depth) = 3;
456
457   // Decoder is now valid
458   ASSERT_TRUE(dec.valid());
459
460   // Let's test with a small filter
461   FilterShape filter_shape;
462
463   filter_shape.count() = 10;
464   filter_shape.depth() = 3;
465   filter_shape.height() = 6;
466   filter_shape.width() = 8;
467
468   // Get the tensor shape corresponding to a given image
469   auto tensor_shape = dec.shape(filter_shape);
470
471   ASSERT_EQ(4, tensor_shape.rank());
472   ASSERT_EQ(10, tensor_shape.dim(0)); // COUNT
473   ASSERT_EQ(6, tensor_shape.dim(1));  // HEIGHT
474   ASSERT_EQ(8, tensor_shape.dim(2));  // WIDTH
475   ASSERT_EQ(3, tensor_shape.dim(3));  // DEPTH
476
477   // Let's find a source filter index!
478   TensorIndex tensor_index;
479
480   tensor_index.resize(4);
481
482   tensor_index.at(0) = 0; // BATCH(COUNT)
483   tensor_index.at(3) = 1; // CHANNEL(DEPTH)
484   tensor_index.at(1) = 2; // ROW(HEIGHT)
485   tensor_index.at(2) = 3; // COLUMN(WIDTH)
486
487   auto filter_index = dec.value(tensor_index);
488
489   ASSERT_EQ(0, filter_index.nth());
490   ASSERT_EQ(1, filter_index.channel());
491   ASSERT_EQ(2, filter_index.row());
492   ASSERT_EQ(3, filter_index.column());
493 }
494
495 TEST(PermutingDecoderTest, depthwise_filter)
496 {
497   PermutingDecoder<Domain::DepthwiseFilter> dec;
498
499   // Decoder is invalid at the beginning
500   ASSERT_FALSE(dec.valid());
501
502   // Set "invalid" mapping
503   dec.perm()->axis(DepthwiseFilterAxis::Depth) = 0;
504   dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 6;
505   dec.perm()->axis(DepthwiseFilterAxis::Height) = 1;
506   dec.perm()->axis(DepthwiseFilterAxis::Width) = 2;
507
508   // Decoder is still invalid
509   ASSERT_FALSE(dec.valid());
510
511   // Set another "invalid" mapping
512   dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 1;
513
514   // Decoder is still invalid
515   ASSERT_FALSE(dec.valid());
516
517   // Set "valid" mapping
518   dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 3;
519
520   // Decoder is now valid
521   ASSERT_TRUE(dec.valid());
522
523   DepthwiseFilterShape dw_filter_shape;
524
525   dw_filter_shape.depth() = 8;
526   dw_filter_shape.multiplier() = 1;
527   dw_filter_shape.height() = 7;
528   dw_filter_shape.width() = 4;
529
530   // Get the corresponding depthwise filter shape
531   auto tensor_shape = dec.shape(dw_filter_shape);
532
533   ASSERT_EQ(8, tensor_shape.dim(0).value());
534   ASSERT_EQ(7, tensor_shape.dim(1).value());
535   ASSERT_EQ(4, tensor_shape.dim(2).value());
536   ASSERT_EQ(1, tensor_shape.dim(3).value());
537
538   // Let's find a source tensor index!
539   TensorIndex tensor_index;
540   tensor_index.resize(4);
541
542   tensor_index.at(0) = 4;
543   tensor_index.at(1) = 2;
544   tensor_index.at(2) = 1;
545   tensor_index.at(3) = 0;
546
547   auto dw_filter_index = dec.value(tensor_index);
548
549   ASSERT_EQ(4, dw_filter_index.channel());
550   ASSERT_EQ(0, dw_filter_index.nth());
551   ASSERT_EQ(2, dw_filter_index.row());
552   ASSERT_EQ(1, dw_filter_index.column());
553 }