Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / TransposeConv.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "kernels/TransposeConv.h"
18 #include "kernels/TestUtils.h"
19
20 namespace luci_interpreter
21 {
22 namespace kernels
23 {
24 namespace
25 {
26
27 using namespace testing;
28
29 template <typename T, typename B>
30 void Check(std::initializer_list<int32_t> output_shape_shape,
31            std::initializer_list<int32_t> weight_shape,
32            std::initializer_list<int32_t> input_data_shape,
33            std::initializer_list<int32_t> bias_shape, std::initializer_list<int32_t> output_shape,
34            std::initializer_list<int32_t> output_shape_data, std::initializer_list<T> weight_data,
35            std::initializer_list<T> input_data_data, std::initializer_list<B> bias_data,
36            std::initializer_list<T> output_data, luci::Padding padding, int32_t stride_height,
37            int32_t stride_width, DataType element_type)
38 {
39   Tensor output_shape_tensor{element_type, output_shape_shape, {}, ""};
40   output_shape_tensor.writeData(output_shape_data.begin(), output_shape_data.size() * sizeof(T));
41   Tensor weight_tensor{element_type, weight_shape, {}, ""};
42   weight_tensor.writeData(weight_data.begin(), weight_data.size() * sizeof(T));
43   Tensor input_data_tensor{element_type, input_data_shape, {}, ""};
44   input_data_tensor.writeData(input_data_data.begin(), input_data_data.size() * sizeof(T));
45
46   Tensor output_tensor = makeOutputTensor(element_type);
47
48   TransposeConvParams params{};
49   params.padding = padding;
50   params.stride_height = stride_height;
51   params.stride_width = stride_width;
52
53   if (bias_data.size() != 0)
54   {
55     Tensor bias_tensor = makeInputTensor<getElementType<B>()>(bias_shape, bias_data);
56     TransposeConv kernel(&output_shape_tensor, &weight_tensor, &input_data_tensor, &bias_tensor,
57                          &output_tensor, params);
58     kernel.configure();
59     kernel.execute();
60   }
61   else
62   {
63     TransposeConv kernel(&output_shape_tensor, &weight_tensor, &input_data_tensor, nullptr,
64                          &output_tensor, params);
65     kernel.configure();
66     kernel.execute();
67   }
68   EXPECT_THAT(extractTensorData<T>(output_tensor), ::testing::ElementsAreArray(output_data));
69 }
70
71 TEST(TransposeConvTest, FloatSimple)
72 {
73   Check<float, float>(
74       /*outputShape_shape=*/{4}, /*weight_shape=*/{1, 3, 3, 1}, /*input_shape=*/{1, 4, 4, 1},
75       /*bias_shape=*/{}, /*output_shape=*/{1, 4, 4, 1}, /*outputShape_data=*/{1, 4, 4, 1},
76       /*weight_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9},
77       /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
78       /*bias_data=*/{},
79       /*output_data=*/{29, 62, 83, 75, 99, 192, 237, 198, 207, 372, 417, 330, 263, 446, 485, 365},
80       /*params.padding=*/luci::Padding::SAME, /*stride_height=*/1, /*stride_width=*/1,
81       getElementType<float>());
82
83   SUCCEED();
84 }
85
86 TEST(TransposeConvTest, FloatTwoFiltersTest)
87 {
88   Check<float, float>(
89       /*outputShape_shape=*/{4}, /*weight_shape=*/{1, 3, 3, 2}, /*input_shape=*/{1, 4, 4, 2},
90       /*bias_shape=*/{}, /*output_shape=*/{1, 4, 4, 1}, /*outputShape_data=*/{1, 4, 4, 1},
91       /*weight_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18},
92       /*input_data=*/{1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 16,
93                       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32},
94       /*bias_data=*/{},
95       /*output_data=*/{184, 412, 568, 528, 678, 1347, 1689, 1434, 1494, 2715, 3057, 2442, 1968,
96                        3352, 3652, 2760},
97       /*params.padding=*/luci::Padding::SAME, /*stride_height=*/1, /*stride_width=*/1,
98       getElementType<float>());
99
100   SUCCEED();
101 }
102
103 TEST(TransposeConvTest, SimpleBiasTest)
104 {
105   Check<float, float>(
106       /*outputShape_shape=*/{4}, /*weight_shape=*/{2, 3, 3, 1},
107       /*input_shape=*/{1, 2, 2, 1},
108       /*bias_shape=*/{2}, /*output_shape=*/{1, 4, 4, 1}, /*outputShape_data=*/{1, 5, 5, 2},
109       /*weight_data=*/{1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, 8, 10, 12, 14, 16, 18},
110       /*input_data=*/{1, 2, 3, 4},
111       /*bias_data=*/{3, 4},
112       /*output_data=*/{4,  6,  6,  8,  10, 14, 9,  12, 13, 16, 10,  12,  12, 14, 28, 32, 21,
113                        24, 25, 28, 19, 24, 27, 32, 65, 76, 45, 52,  57,  64, 24, 28, 30, 34,
114                        64, 72, 39, 44, 47, 52, 42, 46, 48, 52, 106, 114, 63, 68, 71, 76},
115       /*params.padding=*/luci::Padding::VALID, /*stride_height=*/2, /*stride_width=*/2,
116       getElementType<float>());
117
118   SUCCEED();
119 }
120
121 // TODO Uint8Simple
122 // Implement GetDequantizedOutput Function.
123 // Create Test for Uint8 Case
124
125 // TODO Uint8FiltersTest
126 // Implement GetDequantizedOutput Function.
127 // Create Test for Uint8 Case
128
129 } // namespace
130 } // namespace kernels
131 } // namespace luci_interpreter