2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/Transpose.h"
19 #include "kernels/TestUtils.h"
20 #include "luci_interpreter/TestMemoryManager.h"
22 namespace luci_interpreter
29 using namespace testing;
32 void Check(std::initializer_list<int32_t> input_shape, std::initializer_list<int32_t> perm_shape,
33 std::initializer_list<int32_t> output_shape, std::initializer_list<T> input_data,
34 std::initializer_list<int32_t> perm_data, std::initializer_list<T> output_data)
36 std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
37 constexpr DataType element_type = getElementType<T>();
39 makeInputTensor<element_type>(input_shape, input_data, memory_manager.get());
40 Tensor perm_tensor = makeInputTensor<DataType::S32>(perm_shape, perm_data, memory_manager.get());
41 Tensor output_tensor = makeOutputTensor(element_type);
43 Transpose kernel(&input_tensor, &perm_tensor, &output_tensor);
45 memory_manager->allocate_memory(output_tensor);
48 EXPECT_THAT(extractTensorData<T>(output_tensor), ::testing::ElementsAreArray(output_data));
51 template <typename T> class TransposeTest : public ::testing::Test
55 using DataTypes = ::testing::Types<float, uint8_t>;
56 TYPED_TEST_SUITE(TransposeTest, DataTypes);
58 TYPED_TEST(TransposeTest, Small3D)
60 Check<TypeParam>(/*input_shape=*/{2, 3, 4}, /*perm_shape=*/{3}, /*output_shape=*/{4, 2, 3},
61 /*input_data=*/{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
62 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
63 /*perm_data=*/{2, 0, 1},
64 /*output_data=*/{0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21,
65 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23});
68 TYPED_TEST(TransposeTest, Large4D)
71 /*input_shape=*/{2, 3, 4, 5}, /*perm_shape=*/{4}, /*output_shape=*/{4, 2, 3, 5},
72 /*input_data=*/{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
73 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
74 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
75 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
76 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74,
77 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
78 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,
79 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119},
80 /*perm_data=*/{2, 0, 1, 3},
81 /*output_data=*/{0, 1, 2, 3, 4, 20, 21, 22, 23, 24, 40, 41, 42, 43, 44,
82 60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104,
83 5, 6, 7, 8, 9, 25, 26, 27, 28, 29, 45, 46, 47, 48, 49,
84 65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109,
85 10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50, 51, 52, 53, 54,
86 70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114,
87 15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55, 56, 57, 58, 59,
88 75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119});
91 TYPED_TEST(TransposeTest, Large2D)
94 /*input_shape=*/{10, 12}, /*perm_shape=*/{2}, /*output_shape=*/{12, 10},
95 /*input_data=*/{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
96 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
97 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
98 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
99 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74,
100 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
101 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,
102 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119},
103 /*perm_data=*/{1, 0},
104 /*output_data=*/{0, 12, 24, 36, 48, 60, 72, 84, 96, 108, 1, 13, 25, 37, 49,
105 61, 73, 85, 97, 109, 2, 14, 26, 38, 50, 62, 74, 86, 98, 110,
106 3, 15, 27, 39, 51, 63, 75, 87, 99, 111, 4, 16, 28, 40, 52,
107 64, 76, 88, 100, 112, 5, 17, 29, 41, 53, 65, 77, 89, 101, 113,
108 6, 18, 30, 42, 54, 66, 78, 90, 102, 114, 7, 19, 31, 43, 55,
109 67, 79, 91, 103, 115, 8, 20, 32, 44, 56, 68, 80, 92, 104, 116,
110 9, 21, 33, 45, 57, 69, 81, 93, 105, 117, 10, 22, 34, 46, 58,
111 70, 82, 94, 106, 118, 11, 23, 35, 47, 59, 71, 83, 95, 107, 119});
115 } // namespace kernels
116 } // namespace luci_interpreter