2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "loco/IR/PermutingCodec.h"
19 #include <gtest/gtest.h>
23 TEST(PemutationTest, feature)
25 Permutation<Domain::Feature> perm;
27 // All values are invalid at the beginning
28 ASSERT_FALSE(perm.mapped(FeatureAxis::Count));
29 ASSERT_FALSE(perm.mapped(FeatureAxis::Depth));
30 ASSERT_FALSE(perm.mapped(FeatureAxis::Height));
31 ASSERT_FALSE(perm.mapped(FeatureAxis::Width));
34 perm[FeatureAxis::Count] = 5;
35 perm[FeatureAxis::Depth] = 6;
36 perm[FeatureAxis::Height] = 7;
37 perm[FeatureAxis::Width] = 8;
39 // Now perm has a mapping for all the axes
40 ASSERT_TRUE(perm.mapped(FeatureAxis::Count));
41 ASSERT_TRUE(perm.mapped(FeatureAxis::Depth));
42 ASSERT_TRUE(perm.mapped(FeatureAxis::Height));
43 ASSERT_TRUE(perm.mapped(FeatureAxis::Width));
46 ASSERT_EQ(5, perm[FeatureAxis::Count]);
47 ASSERT_EQ(6, perm[FeatureAxis::Depth]);
48 ASSERT_EQ(7, perm[FeatureAxis::Height]);
49 ASSERT_EQ(8, perm[FeatureAxis::Width]);
52 TEST(PemutationTest, filter)
54 Permutation<Domain::Filter> perm;
56 // All values are invalid at the beginning
57 ASSERT_FALSE(perm.mapped(FilterAxis::Count));
58 ASSERT_FALSE(perm.mapped(FilterAxis::Depth));
59 ASSERT_FALSE(perm.mapped(FilterAxis::Height));
60 ASSERT_FALSE(perm.mapped(FilterAxis::Width));
63 perm[FilterAxis::Count] = 5;
64 perm[FilterAxis::Depth] = 6;
65 perm[FilterAxis::Height] = 7;
66 perm[FilterAxis::Width] = 8;
68 // Now perm has a mapping for all the axes
69 ASSERT_TRUE(perm.mapped(FilterAxis::Count));
70 ASSERT_TRUE(perm.mapped(FilterAxis::Depth));
71 ASSERT_TRUE(perm.mapped(FilterAxis::Height));
72 ASSERT_TRUE(perm.mapped(FilterAxis::Width));
75 ASSERT_EQ(5, perm[FilterAxis::Count]);
76 ASSERT_EQ(6, perm[FilterAxis::Depth]);
77 ASSERT_EQ(7, perm[FilterAxis::Height]);
78 ASSERT_EQ(8, perm[FilterAxis::Width]);
81 TEST(PemutationTest, depthwise_filter)
83 Permutation<Domain::DepthwiseFilter> perm;
85 // All values are invalid at the beginning
86 ASSERT_FALSE(perm.mapped(DepthwiseFilterAxis::Depth));
87 ASSERT_FALSE(perm.mapped(DepthwiseFilterAxis::Multiplier));
88 ASSERT_FALSE(perm.mapped(DepthwiseFilterAxis::Height));
89 ASSERT_FALSE(perm.mapped(DepthwiseFilterAxis::Width));
92 perm[DepthwiseFilterAxis::Depth] = 5;
93 perm[DepthwiseFilterAxis::Multiplier] = 6;
94 perm[DepthwiseFilterAxis::Height] = 7;
95 perm[DepthwiseFilterAxis::Width] = 8;
97 // Now perm has a mapping for all the axes
98 ASSERT_TRUE(perm.mapped(DepthwiseFilterAxis::Depth));
99 ASSERT_TRUE(perm.mapped(DepthwiseFilterAxis::Multiplier));
100 ASSERT_TRUE(perm.mapped(DepthwiseFilterAxis::Height));
101 ASSERT_TRUE(perm.mapped(DepthwiseFilterAxis::Width));
104 ASSERT_EQ(5, perm[DepthwiseFilterAxis::Depth]);
105 ASSERT_EQ(6, perm[DepthwiseFilterAxis::Multiplier]);
106 ASSERT_EQ(7, perm[DepthwiseFilterAxis::Height]);
107 ASSERT_EQ(8, perm[DepthwiseFilterAxis::Width]);
110 TEST(PermutingEncoderTest, feature)
112 PermutingEncoder<Domain::Feature> enc;
114 // Encoder is invalid at the beginning
115 ASSERT_FALSE(enc.valid());
117 // Set "invalid" mapping
118 enc.perm()->axis(FeatureAxis::Count) = 0;
119 enc.perm()->axis(FeatureAxis::Depth) = 6;
120 enc.perm()->axis(FeatureAxis::Height) = 1;
121 enc.perm()->axis(FeatureAxis::Width) = 2;
123 // Encoder is still invalid
124 ASSERT_FALSE(enc.valid());
126 // Set another "invalid" mapping
127 enc.perm()->axis(FeatureAxis::Depth) = 1;
129 // Encoder is still invalid
130 ASSERT_FALSE(enc.valid());
132 // Set "valid" mapping
133 enc.perm()->axis(FeatureAxis::Depth) = 3;
135 // Encoder is now valid
136 ASSERT_TRUE(enc.valid());
138 // Let's test with a HD (1280x720) RGB image
139 TensorShape tensor_shape;
141 tensor_shape.rank(4);
142 tensor_shape.dim(0) = 1; // COUNT
143 tensor_shape.dim(1) = 720; // HEIGHT
144 tensor_shape.dim(2) = 1280; // WIDTH
145 tensor_shape.dim(3) = 3; // DEPTH
147 // Get the feature shape corresponding to a given image
148 auto feature_shape = enc.shape(tensor_shape);
150 ASSERT_EQ(1, feature_shape.count());
151 ASSERT_EQ(3, feature_shape.depth());
152 ASSERT_EQ(720, feature_shape.height());
153 ASSERT_EQ(1280, feature_shape.width());
155 // Let's find a source tensor index!
156 FeatureIndex feature_index;
158 feature_index.batch() = 0;
159 feature_index.channel() = 1;
160 feature_index.row() = 2;
161 feature_index.column() = 3;
163 auto tensor_index = enc.value(feature_index);
165 ASSERT_EQ(0, tensor_index.at(0)); // BATCH(COUNT)
166 ASSERT_EQ(2, tensor_index.at(1)); // ROW(HEIGHT)
167 ASSERT_EQ(3, tensor_index.at(2)); // COLUMN(WIDTH)
168 ASSERT_EQ(1, tensor_index.at(3)); // CHANNEL(DEPTH)
171 TEST(PermutingEncoderTest, feature_clone)
173 PermutingEncoder<Domain::Feature> src_enc;
175 auto src_perm = src_enc.perm();
177 src_perm->axis(FeatureAxis::Count) = 0;
178 src_perm->axis(FeatureAxis::Depth) = 3;
179 src_perm->axis(FeatureAxis::Height) = 1;
180 src_perm->axis(FeatureAxis::Width) = 2;
182 auto dst_enc = src_enc.clone();
183 auto dst_perm = loco::must_cast<PermutingEncoder<Domain::Feature> *>(dst_enc.get())->perm();
185 EXPECT_EQ(dst_perm->axis(FeatureAxis::Count), src_perm->axis(FeatureAxis::Count));
186 EXPECT_EQ(dst_perm->axis(FeatureAxis::Depth), src_perm->axis(FeatureAxis::Depth));
187 EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), src_perm->axis(FeatureAxis::Height));
188 EXPECT_EQ(dst_perm->axis(FeatureAxis::Width), src_perm->axis(FeatureAxis::Width));
190 // Update on cloned encoder SHOULD NOT affect the original encoder
191 dst_perm->axis(FeatureAxis::Height) += 1;
193 EXPECT_EQ(src_perm->axis(FeatureAxis::Height), 1);
194 EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), 2);
197 TEST(PermutingEncoderTest, filter)
199 PermutingEncoder<Domain::Filter> enc;
201 // Encoder is invalid at the beginning
202 ASSERT_FALSE(enc.valid());
204 // Set "invalid" mapping
205 enc.perm()->axis(FilterAxis::Count) = 0;
206 enc.perm()->axis(FilterAxis::Depth) = 6;
207 enc.perm()->axis(FilterAxis::Height) = 1;
208 enc.perm()->axis(FilterAxis::Width) = 2;
210 // Encoder is still invalid
211 ASSERT_FALSE(enc.valid());
213 // Set another "invalid" mapping
214 enc.perm()->axis(FilterAxis::Depth) = 1;
216 // Encoder is still invalid
217 ASSERT_FALSE(enc.valid());
219 // Set "valid" mapping
220 enc.perm()->axis(FilterAxis::Depth) = 3;
222 // Encoder is now valid
223 ASSERT_TRUE(enc.valid());
225 TensorShape tensor_shape;
227 tensor_shape.rank(4);
228 tensor_shape.dim(0) = 8; // COUNT
229 tensor_shape.dim(1) = 1; // HEIGHT
230 tensor_shape.dim(2) = 7; // WIDTH
231 tensor_shape.dim(3) = 4; // DEPTH
233 // Get the corresponding filter shape
234 auto filter_shape = enc.shape(tensor_shape);
236 ASSERT_EQ(8, filter_shape.count());
237 ASSERT_EQ(4, filter_shape.depth());
238 ASSERT_EQ(1, filter_shape.height());
239 ASSERT_EQ(7, filter_shape.width());
241 // Let's find a source tensor index!
242 FilterIndex filter_index;
244 filter_index.nth() = 1;
245 filter_index.channel() = 2;
246 filter_index.row() = 0;
247 filter_index.column() = 3;
249 auto tensor_index = enc.value(filter_index);
251 ASSERT_EQ(1, tensor_index.at(0)); // NTH(COUNT)
252 ASSERT_EQ(0, tensor_index.at(1)); // ROW(HEIGHT)
253 ASSERT_EQ(3, tensor_index.at(2)); // COLUMN(WIDTH)
254 ASSERT_EQ(2, tensor_index.at(3)); // CHANNEL(DEPTH)
257 TEST(PermutingEncoderTest, depthwise_filter)
259 PermutingEncoder<Domain::DepthwiseFilter> enc;
261 // Encoder is invalid at the beginning
262 ASSERT_FALSE(enc.valid());
264 // Set "invalid" mapping
265 enc.perm()->axis(DepthwiseFilterAxis::Depth) = 0;
266 enc.perm()->axis(DepthwiseFilterAxis::Multiplier) = 6;
267 enc.perm()->axis(DepthwiseFilterAxis::Height) = 1;
268 enc.perm()->axis(DepthwiseFilterAxis::Width) = 2;
270 // Encoder is still invalid
271 ASSERT_FALSE(enc.valid());
273 // Set another "invalid" mapping
274 enc.perm()->axis(DepthwiseFilterAxis::Multiplier) = 1;
276 // Encoder is still invalid
277 ASSERT_FALSE(enc.valid());
279 // Set "valid" mapping
280 enc.perm()->axis(DepthwiseFilterAxis::Multiplier) = 3;
282 // Encoder is now valid
283 ASSERT_TRUE(enc.valid());
285 TensorShape tensor_shape;
287 tensor_shape.rank(4);
288 tensor_shape.dim(0) = 8; // DEPTH
289 tensor_shape.dim(1) = 1; // HEIGHT
290 tensor_shape.dim(2) = 7; // WIDTH
291 tensor_shape.dim(3) = 4; // MULTIPLIER
293 // Get the corresponding depthwise filter shape
294 auto filter_shape = enc.shape(tensor_shape);
296 ASSERT_EQ(8, filter_shape.depth());
297 ASSERT_EQ(4, filter_shape.multiplier());
298 ASSERT_EQ(1, filter_shape.height());
299 ASSERT_EQ(7, filter_shape.width());
301 // Let's find a source tensor index!
302 DepthwiseFilterIndex filter_index;
304 filter_index.channel() = 1;
305 filter_index.nth() = 2;
306 filter_index.row() = 0;
307 filter_index.column() = 3;
309 auto tensor_index = enc.value(filter_index);
311 ASSERT_EQ(1, tensor_index.at(0)); // CHANNEL(DEPTH)
312 ASSERT_EQ(0, tensor_index.at(1)); // ROW(HEIGHT)
313 ASSERT_EQ(3, tensor_index.at(2)); // COLUMN(WIDTH)
314 ASSERT_EQ(2, tensor_index.at(3)); // NTH(MULTIPLIER)
317 TEST(PermutingEncoderTest, depthwisefilter_init)
319 Permutation<Domain::DepthwiseFilter> src_perm;
321 src_perm.axis(DepthwiseFilterAxis::Multiplier) = 0;
322 src_perm.axis(DepthwiseFilterAxis::Depth) = 3;
323 src_perm.axis(DepthwiseFilterAxis::Height) = 1;
324 src_perm.axis(DepthwiseFilterAxis::Width) = 2;
326 PermutingEncoder<Domain::DepthwiseFilter> dst_enc{src_perm};
327 auto dst_perm = dst_enc.perm();
329 EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Multiplier),
330 src_perm.axis(DepthwiseFilterAxis::Multiplier));
331 EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Depth), src_perm.axis(DepthwiseFilterAxis::Depth));
332 EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Height),
333 src_perm.axis(DepthwiseFilterAxis::Height));
334 EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Width), src_perm.axis(DepthwiseFilterAxis::Width));
336 // Update on dst perm SHOULD NOT affect the src perm
337 dst_perm->axis(DepthwiseFilterAxis::Height) += 1;
339 EXPECT_EQ(src_perm.axis(DepthwiseFilterAxis::Height), 1);
340 EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Height), 2);
343 TEST(PermutingDecoderTest, feature)
345 PermutingDecoder<Domain::Feature> dec;
347 // Decoder is invalid at the beginning
348 ASSERT_FALSE(dec.valid());
350 // Set "invalid" mapping
351 dec.perm()->axis(FeatureAxis::Count) = 0;
352 dec.perm()->axis(FeatureAxis::Depth) = 6;
353 dec.perm()->axis(FeatureAxis::Height) = 1;
354 dec.perm()->axis(FeatureAxis::Width) = 2;
356 // Decoder is still invalid
357 ASSERT_FALSE(dec.valid());
359 // Set another "invalid" mapping
360 dec.perm()->axis(FeatureAxis::Depth) = 1;
362 // Decoder is still invalid
363 ASSERT_FALSE(dec.valid());
365 // Set "valid" mapping
366 dec.perm()->axis(FeatureAxis::Depth) = 3;
368 // Decoder is now valid
369 ASSERT_TRUE(dec.valid());
371 // Let's test with a HD (1280x720) RGB image
372 FeatureShape feature_shape;
374 feature_shape.count() = 1;
375 feature_shape.depth() = 3;
376 feature_shape.height() = 720;
377 feature_shape.width() = 1280;
379 // Get the tensor shape corresponding to a given image
380 auto tensor_shape = dec.shape(feature_shape);
382 ASSERT_EQ(4, tensor_shape.rank());
383 ASSERT_EQ(1, tensor_shape.dim(0)); // COUNT
384 ASSERT_EQ(720, tensor_shape.dim(1)); // HEIGHT
385 ASSERT_EQ(1280, tensor_shape.dim(2)); // WIDTH
386 ASSERT_EQ(3, tensor_shape.dim(3)); // DEPTH
388 // Let's find a source feature index!
389 TensorIndex tensor_index;
391 tensor_index.resize(4);
393 tensor_index.at(0) = 0; // BATCH(COUNT)
394 tensor_index.at(3) = 1; // CHANNEL(DEPTH)
395 tensor_index.at(1) = 2; // ROW(HEIGHT)
396 tensor_index.at(2) = 3; // COLUMN(WIDTH)
398 auto feature_index = dec.value(tensor_index);
400 ASSERT_EQ(0, feature_index.batch());
401 ASSERT_EQ(1, feature_index.channel());
402 ASSERT_EQ(2, feature_index.row());
403 ASSERT_EQ(3, feature_index.column());
406 TEST(PermutingDecoderTest, feature_clone)
408 PermutingDecoder<Domain::Feature> src_enc;
410 auto src_perm = src_enc.perm();
412 src_perm->axis(FeatureAxis::Count) = 0;
413 src_perm->axis(FeatureAxis::Depth) = 3;
414 src_perm->axis(FeatureAxis::Height) = 1;
415 src_perm->axis(FeatureAxis::Width) = 2;
417 auto dst_enc = src_enc.clone();
418 auto dst_perm = loco::must_cast<PermutingDecoder<Domain::Feature> *>(dst_enc.get())->perm();
420 EXPECT_EQ(dst_perm->axis(FeatureAxis::Count), src_perm->axis(FeatureAxis::Count));
421 EXPECT_EQ(dst_perm->axis(FeatureAxis::Depth), src_perm->axis(FeatureAxis::Depth));
422 EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), src_perm->axis(FeatureAxis::Height));
423 EXPECT_EQ(dst_perm->axis(FeatureAxis::Width), src_perm->axis(FeatureAxis::Width));
425 // Update on cloned decoder SHOULD NOT affect the original decoder
426 dst_perm->axis(FeatureAxis::Height) += 1;
428 EXPECT_EQ(src_perm->axis(FeatureAxis::Height), 1);
429 EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), 2);
432 TEST(PermutingDecoderTest, filter)
434 PermutingDecoder<Domain::Filter> dec;
436 // Decoder is invalid at the beginning
437 ASSERT_FALSE(dec.valid());
439 // Set "invalid" mapping
440 dec.perm()->axis(FilterAxis::Count) = 0;
441 dec.perm()->axis(FilterAxis::Depth) = 6;
442 dec.perm()->axis(FilterAxis::Height) = 1;
443 dec.perm()->axis(FilterAxis::Width) = 2;
445 // Decoder is still invalid
446 ASSERT_FALSE(dec.valid());
448 // Set another "invalid" mapping
449 dec.perm()->axis(FilterAxis::Depth) = 1;
451 // Decoder is still invalid
452 ASSERT_FALSE(dec.valid());
454 // Set "valid" mapping
455 dec.perm()->axis(FilterAxis::Depth) = 3;
457 // Decoder is now valid
458 ASSERT_TRUE(dec.valid());
460 // Let's test with a small filter
461 FilterShape filter_shape;
463 filter_shape.count() = 10;
464 filter_shape.depth() = 3;
465 filter_shape.height() = 6;
466 filter_shape.width() = 8;
468 // Get the tensor shape corresponding to a given image
469 auto tensor_shape = dec.shape(filter_shape);
471 ASSERT_EQ(4, tensor_shape.rank());
472 ASSERT_EQ(10, tensor_shape.dim(0)); // COUNT
473 ASSERT_EQ(6, tensor_shape.dim(1)); // HEIGHT
474 ASSERT_EQ(8, tensor_shape.dim(2)); // WIDTH
475 ASSERT_EQ(3, tensor_shape.dim(3)); // DEPTH
477 // Let's find a source filter index!
478 TensorIndex tensor_index;
480 tensor_index.resize(4);
482 tensor_index.at(0) = 0; // BATCH(COUNT)
483 tensor_index.at(3) = 1; // CHANNEL(DEPTH)
484 tensor_index.at(1) = 2; // ROW(HEIGHT)
485 tensor_index.at(2) = 3; // COLUMN(WIDTH)
487 auto filter_index = dec.value(tensor_index);
489 ASSERT_EQ(0, filter_index.nth());
490 ASSERT_EQ(1, filter_index.channel());
491 ASSERT_EQ(2, filter_index.row());
492 ASSERT_EQ(3, filter_index.column());
495 TEST(PermutingDecoderTest, depthwise_filter)
497 PermutingDecoder<Domain::DepthwiseFilter> dec;
499 // Decoder is invalid at the beginning
500 ASSERT_FALSE(dec.valid());
502 // Set "invalid" mapping
503 dec.perm()->axis(DepthwiseFilterAxis::Depth) = 0;
504 dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 6;
505 dec.perm()->axis(DepthwiseFilterAxis::Height) = 1;
506 dec.perm()->axis(DepthwiseFilterAxis::Width) = 2;
508 // Decoder is still invalid
509 ASSERT_FALSE(dec.valid());
511 // Set another "invalid" mapping
512 dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 1;
514 // Decoder is still invalid
515 ASSERT_FALSE(dec.valid());
517 // Set "valid" mapping
518 dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 3;
520 // Decoder is now valid
521 ASSERT_TRUE(dec.valid());
523 DepthwiseFilterShape dw_filter_shape;
525 dw_filter_shape.depth() = 8;
526 dw_filter_shape.multiplier() = 1;
527 dw_filter_shape.height() = 7;
528 dw_filter_shape.width() = 4;
530 // Get the corresponding depthwise filter shape
531 auto tensor_shape = dec.shape(dw_filter_shape);
533 ASSERT_EQ(8, tensor_shape.dim(0).value());
534 ASSERT_EQ(7, tensor_shape.dim(1).value());
535 ASSERT_EQ(4, tensor_shape.dim(2).value());
536 ASSERT_EQ(1, tensor_shape.dim(3).value());
538 // Let's find a source tensor index!
539 TensorIndex tensor_index;
540 tensor_index.resize(4);
542 tensor_index.at(0) = 4;
543 tensor_index.at(1) = 2;
544 tensor_index.at(2) = 1;
545 tensor_index.at(3) = 0;
547 auto dw_filter_index = dec.value(tensor_index);
549 ASSERT_EQ(4, dw_filter_index.channel());
550 ASSERT_EQ(0, dw_filter_index.nth());
551 ASSERT_EQ(2, dw_filter_index.row());
552 ASSERT_EQ(1, dw_filter_index.column());