2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "kernels/TransposeConv.h"
18 #include "kernels/TestUtils.h"
20 namespace luci_interpreter
27 using namespace testing;
30 void Check(std::initializer_list<int32_t> output_shape_shape,
31 std::initializer_list<int32_t> weight_shape,
32 std::initializer_list<int32_t> input_data_shape,
33 std::initializer_list<int32_t> output_shape,
34 std::initializer_list<int32_t> output_shape_data, std::initializer_list<T> weight_data,
35 std::initializer_list<T> input_data_data, std::initializer_list<T> output_data,
36 luci::Padding padding, int32_t stride_height, int32_t stride_width,
37 DataType element_type)
39 Tensor output_shape_tensor{element_type, output_shape_shape, {}, ""};
40 output_shape_tensor.writeData(output_shape_data.begin(), output_shape_data.size() * sizeof(T));
41 Tensor weight_tensor{element_type, weight_shape, {}, ""};
42 weight_tensor.writeData(weight_data.begin(), weight_data.size() * sizeof(T));
43 Tensor input_data_tensor{element_type, input_data_shape, {}, ""};
44 input_data_tensor.writeData(input_data_data.begin(), input_data_data.size() * sizeof(T));
46 Tensor output_tensor = makeOutputTensor(element_type);
48 TransposeConvParams params{};
49 params.padding = padding;
50 params.stride_height = stride_height;
51 params.stride_width = stride_width;
53 TransposeConv kernel(&output_shape_tensor, &weight_tensor, &input_data_tensor, &output_tensor,
58 EXPECT_THAT(extractTensorData<T>(output_tensor), ::testing::ElementsAreArray(output_data));
61 TEST(TransposeConvTest, FloatSimple)
64 /*outputShape_shape=*/{4}, /*weight_shape=*/{1, 3, 3, 1}, /*input_shape=*/{1, 4, 4, 1},
65 /*output_shape=*/{1, 4, 4, 1}, /*outputShape_data=*/{1, 4, 4, 1},
66 /*weight_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9},
67 /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
68 /*output_data=*/{29, 62, 83, 75, 99, 192, 237, 198, 207, 372, 417, 330, 263, 446, 485, 365},
69 /*params.padding=*/luci::Padding::SAME, /*stride_height=*/1, /*stride_width=*/1,
70 getElementType<float>());
75 TEST(TransposeConvTest, FloatTwoFiltersTest)
78 /*outputShape_shape=*/{4}, /*weight_shape=*/{1, 3, 3, 2}, /*input_shape=*/{1, 4, 4, 2},
79 /*output_shape=*/{1, 4, 4, 1}, /*outputShape_data=*/{1, 4, 4, 1},
80 /*weight_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18},
81 /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
82 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32},
83 /*output_data=*/{184, 412, 568, 528, 678, 1347, 1689, 1434, 1494, 2715, 3057, 2442, 1968,
85 /*params.padding=*/luci::Padding::SAME, /*stride_height=*/1, /*stride_width=*/1,
86 getElementType<float>());
92 // Implement GetDequantizedOutput Function.
93 // Create Test for Uint8 Case
95 // TODO Uint8FiltersTest
96 // Implement GetDequantizedOutput Function.
97 // Create Test for Uint8 Case
100 } // namespace kernels
101 } // namespace luci_interpreter