b8c0ac497510ee3a7dec68e523a68309378118b8
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / TransposeConv.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "kernels/TransposeConv.h"
18 #include "kernels/TestUtils.h"
19
20 namespace luci_interpreter
21 {
22 namespace kernels
23 {
24 namespace
25 {
26
27 using namespace testing;
28
29 template <typename T>
30 void Check(std::initializer_list<int32_t> output_shape_shape,
31            std::initializer_list<int32_t> weight_shape,
32            std::initializer_list<int32_t> input_data_shape,
33            std::initializer_list<int32_t> output_shape,
34            std::initializer_list<int32_t> output_shape_data, std::initializer_list<T> weight_data,
35            std::initializer_list<T> input_data_data, std::initializer_list<T> output_data,
36            luci::Padding padding, int32_t stride_height, int32_t stride_width,
37            DataType element_type)
38 {
39   Tensor output_shape_tensor{element_type, output_shape_shape, {}, ""};
40   output_shape_tensor.writeData(output_shape_data.begin(), output_shape_data.size() * sizeof(T));
41   Tensor weight_tensor{element_type, weight_shape, {}, ""};
42   weight_tensor.writeData(weight_data.begin(), weight_data.size() * sizeof(T));
43   Tensor input_data_tensor{element_type, input_data_shape, {}, ""};
44   input_data_tensor.writeData(input_data_data.begin(), input_data_data.size() * sizeof(T));
45
46   Tensor output_tensor = makeOutputTensor(element_type);
47
48   TransposeConvParams params{};
49   params.padding = padding;
50   params.stride_height = stride_height;
51   params.stride_width = stride_width;
52
53   TransposeConv kernel(&output_shape_tensor, &weight_tensor, &input_data_tensor, &output_tensor,
54                        params);
55   kernel.configure();
56   kernel.execute();
57
58   EXPECT_THAT(extractTensorData<T>(output_tensor), ::testing::ElementsAreArray(output_data));
59 }
60
61 TEST(TransposeConvTest, FloatSimple)
62 {
63   Check<float>(
64       /*outputShape_shape=*/{4}, /*weight_shape=*/{1, 3, 3, 1}, /*input_shape=*/{1, 4, 4, 1},
65       /*output_shape=*/{1, 4, 4, 1}, /*outputShape_data=*/{1, 4, 4, 1},
66       /*weight_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9},
67       /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
68       /*output_data=*/{29, 62, 83, 75, 99, 192, 237, 198, 207, 372, 417, 330, 263, 446, 485, 365},
69       /*params.padding=*/luci::Padding::SAME, /*stride_height=*/1, /*stride_width=*/1,
70       getElementType<float>());
71
72   SUCCEED();
73 }
74
75 TEST(TransposeConvTest, FloatTwoFiltersTest)
76 {
77   Check<float>(
78       /*outputShape_shape=*/{4}, /*weight_shape=*/{1, 3, 3, 2}, /*input_shape=*/{1, 4, 4, 2},
79       /*output_shape=*/{1, 4, 4, 1}, /*outputShape_data=*/{1, 4, 4, 1},
80       /*weight_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18},
81       /*input_data=*/{1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 16,
82                       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32},
83       /*output_data=*/{184, 412, 568, 528, 678, 1347, 1689, 1434, 1494, 2715, 3057, 2442, 1968,
84                        3352, 3652, 2760},
85       /*params.padding=*/luci::Padding::SAME, /*stride_height=*/1, /*stride_width=*/1,
86       getElementType<float>());
87
88   SUCCEED();
89 }
90
91 // TODO Uint8Simple
92 // Implement GetDequantizedOutput Function.
93 // Create Test for Uint8 Case
94
95 // TODO Uint8FiltersTest
96 // Implement GetDequantizedOutput Function.
97 // Create Test for Uint8 Case
98
99 } // namespace
100 } // namespace kernels
101 } // namespace luci_interpreter