8cc69936120770bf48de3c7097e02b5c70281809
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Transpose.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/Transpose.h"
19 #include "kernels/TestUtils.h"
20 #include "luci_interpreter/TestMemoryManager.h"
21
22 namespace luci_interpreter
23 {
24 namespace kernels
25 {
26 namespace
27 {
28
29 using namespace testing;
30
31 template <typename T>
32 void Check(std::initializer_list<int32_t> input_shape, std::initializer_list<int32_t> perm_shape,
33            std::initializer_list<int32_t> output_shape, std::initializer_list<T> input_data,
34            std::initializer_list<int32_t> perm_data, std::initializer_list<T> output_data)
35 {
36   std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
37   constexpr DataType element_type = getElementType<T>();
38   Tensor input_tensor =
39     makeInputTensor<element_type>(input_shape, input_data, memory_manager.get());
40   Tensor perm_tensor = makeInputTensor<DataType::S32>(perm_shape, perm_data, memory_manager.get());
41   Tensor output_tensor = makeOutputTensor(element_type);
42
43   Transpose kernel(&input_tensor, &perm_tensor, &output_tensor);
44   kernel.configure();
45   memory_manager->allocate_memory(output_tensor);
46   kernel.execute();
47
48   EXPECT_THAT(extractTensorData<T>(output_tensor), ::testing::ElementsAreArray(output_data));
49 }
50
51 template <typename T> class TransposeTest : public ::testing::Test
52 {
53 };
54
55 using DataTypes = ::testing::Types<float, uint8_t>;
56 TYPED_TEST_SUITE(TransposeTest, DataTypes);
57
58 TYPED_TEST(TransposeTest, Small3D)
59 {
60   Check<TypeParam>(/*input_shape=*/{2, 3, 4}, /*perm_shape=*/{3}, /*output_shape=*/{4, 2, 3},
61                    /*input_data=*/{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
62                                    12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
63                    /*perm_data=*/{2, 0, 1},
64                    /*output_data=*/{0, 4, 8,  12, 16, 20, 1, 5, 9,  13, 17, 21,
65                                     2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23});
66 }
67
68 TYPED_TEST(TransposeTest, Large4D)
69 {
70   Check<TypeParam>(
71     /*input_shape=*/{2, 3, 4, 5}, /*perm_shape=*/{4}, /*output_shape=*/{4, 2, 3, 5},
72     /*input_data=*/{0,   1,   2,   3,   4,   5,   6,   7,   8,   9,   10,  11,  12,  13,  14,
73                     15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,  29,
74                     30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,
75                     45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,
76                     60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,
77                     75,  76,  77,  78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,
78                     90,  91,  92,  93,  94,  95,  96,  97,  98,  99,  100, 101, 102, 103, 104,
79                     105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119},
80     /*perm_data=*/{2, 0, 1, 3},
81     /*output_data=*/{0,  1,  2,  3,  4,  20, 21, 22, 23, 24, 40,  41,  42,  43,  44,
82                      60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104,
83                      5,  6,  7,  8,  9,  25, 26, 27, 28, 29, 45,  46,  47,  48,  49,
84                      65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109,
85                      10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50,  51,  52,  53,  54,
86                      70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114,
87                      15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55,  56,  57,  58,  59,
88                      75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119});
89 }
90
91 TYPED_TEST(TransposeTest, Large2D)
92 {
93   Check<TypeParam>(
94     /*input_shape=*/{10, 12}, /*perm_shape=*/{2}, /*output_shape=*/{12, 10},
95     /*input_data=*/{0,   1,   2,   3,   4,   5,   6,   7,   8,   9,   10,  11,  12,  13,  14,
96                     15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,  29,
97                     30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,
98                     45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,
99                     60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,
100                     75,  76,  77,  78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,
101                     90,  91,  92,  93,  94,  95,  96,  97,  98,  99,  100, 101, 102, 103, 104,
102                     105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119},
103     /*perm_data=*/{1, 0},
104     /*output_data=*/{0,  12, 24, 36,  48,  60, 72, 84, 96,  108, 1,  13, 25, 37,  49,
105                      61, 73, 85, 97,  109, 2,  14, 26, 38,  50,  62, 74, 86, 98,  110,
106                      3,  15, 27, 39,  51,  63, 75, 87, 99,  111, 4,  16, 28, 40,  52,
107                      64, 76, 88, 100, 112, 5,  17, 29, 41,  53,  65, 77, 89, 101, 113,
108                      6,  18, 30, 42,  54,  66, 78, 90, 102, 114, 7,  19, 31, 43,  55,
109                      67, 79, 91, 103, 115, 8,  20, 32, 44,  56,  68, 80, 92, 104, 116,
110                      9,  21, 33, 45,  57,  69, 81, 93, 105, 117, 10, 22, 34, 46,  58,
111                      70, 82, 94, 106, 118, 11, 23, 35, 47,  59,  71, 83, 95, 107, 119});
112 }
113
114 } // namespace
115 } // namespace kernels
116 } // namespace luci_interpreter