arm_compute v18.02
[platform/upstream/armcl.git] / tests / validation / NEON / DirectConvolutionLayer.cpp
1 /*
2  * Copyright (c) 2017-2018 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/Types.h"
25 #include "arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h"
26 #include "arm_compute/runtime/Tensor.h"
27 #include "arm_compute/runtime/TensorAllocator.h"
28 #include "tests/NEON/Accessor.h"
29 #include "tests/PaddingCalculator.h"
30 #include "tests/datasets/ShapeDatasets.h"
31 #include "tests/framework/Asserts.h"
32 #include "tests/framework/Macros.h"
33 #include "tests/framework/datasets/Datasets.h"
34 #include "tests/validation/Validation.h"
35 #include "tests/validation/fixtures/DirectConvolutionLayerFixture.h"
36
37 namespace arm_compute
38 {
39 namespace test
40 {
41 namespace validation
42 {
43 namespace
44 {
45 constexpr AbsoluteTolerance<float> tolerance_qs(1.f); /**< Tolerance for fixed point tests */
46 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
47 constexpr AbsoluteTolerance<float> tolerance_fp16(0.01f);  /**< Tolerance for half precision floating point tests */
48 #endif                                                     /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
49 constexpr AbsoluteTolerance<float> tolerance_fp32(0.001f); /**< Tolerance for floating point tests */
50
51 /** Direct convolution data set. */
52 const auto data_pad_f32 = concat(concat(combine(framework::dataset::make("PadX", 0, 1),
53                                                 combine(framework::dataset::make("PadY", 0, 1),
54                                                         framework::dataset::make("KernelSize", 1))),
55                                         combine(framework::dataset::make("PadX", 0, 2),
56                                                 combine(framework::dataset::make("PadY", 0, 2),
57                                                         framework::dataset::make("KernelSize", 3)))),
58                                  combine(framework::dataset::make("PadX", 0, 3),
59                                          combine(framework::dataset::make("PadY", 0, 3),
60                                                  framework::dataset::make("KernelSize", 5))));
61
62 const auto data_pad_qs8 = concat(combine(framework::dataset::make("PadX", 0),
63                                          combine(framework::dataset::make("PadY", 0),
64                                                  framework::dataset::make("KernelSize", 1))),
65                                  combine(framework::dataset::make("PadX", 0, 2),
66                                          combine(framework::dataset::make("PadY", 0, 2),
67                                                  framework::dataset::make("KernelSize", 3))));
68
69 const auto data_f32 = combine(datasets::SmallDirectConvolutionShapes(),
70                               combine(framework::dataset::make("StrideX", 1, 3),
71                                       combine(framework::dataset::make("StrideY", 1, 3),
72                                               combine(data_pad_f32,
73                                                       framework::dataset::make("NumKernels", { 1, 4, 8, 16 })))));
74
75 const auto data_qs8 = combine(datasets::TinyDirectConvolutionShapes(),
76                               combine(framework::dataset::make("StrideX", 1, 3),
77                                       combine(framework::dataset::make("StrideY", 1, 3),
78                                               combine(data_pad_qs8,
79                                                       framework::dataset::make("NumKernels", { 1, 4, 8, 16 })))));
80
81 /** Direct convolution QS16 data set. */
82 const auto data_qs16 = combine(datasets::TinyDirectConvolutionShapes(),
83                                combine(framework::dataset::make("StrideX", 1, 3),
84                                        combine(framework::dataset::make("StrideY", 1, 3),
85                                                combine(framework::dataset::make("PadX", 0),
86                                                        combine(framework::dataset::make("PadY", 0),
87                                                                combine(framework::dataset::make("KernelSize", 1),
88                                                                        framework::dataset::make("NumKernels", { 1, 4, 8, 16 })))))));
89 } // namespace
90
91 TEST_SUITE(NEON)
92 TEST_SUITE(DirectConvolutionLayer)
93
94 // *INDENT-OFF*
95 // clang-format off
96 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(
97         framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Mismatching data type input/weights
98                                                 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Mismatching input feature maps
99                                                 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Unsupported kernel width
100                                                 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Non-rectangular weights dimensions
101                                                 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Invalid weights dimensions
102                                                 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Invalid stride
103                                                 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Invalid biases size
104                                                 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Invalid biases dimensions
105                                                 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Invalid output size
106                                               }),
107         framework::dataset::make("WeightsInfo",{ TensorInfo(TensorShape(3U, 3U, 2U, 4U), 1, DataType::F16, 0),
108                                                  TensorInfo(TensorShape(3U, 3U, 3U, 4U), 1, DataType::F32, 0),
109                                                  TensorInfo(TensorShape(9U, 9U, 2U, 4U), 1, DataType::F32, 0),
110                                                  TensorInfo(TensorShape(5U, 3U, 2U, 4U), 1, DataType::F32, 0),
111                                                  TensorInfo(TensorShape(3U, 3U, 2U, 4U, 3U), 1, DataType::F32, 0),
112                                                  TensorInfo(TensorShape(3U, 3U, 2U, 4U), 1, DataType::F32, 0),
113                                                  TensorInfo(TensorShape(3U, 3U, 2U, 4U), 1, DataType::F32, 0),
114                                                  TensorInfo(TensorShape(3U, 3U, 2U, 4U), 1, DataType::F32, 0),
115                                                  TensorInfo(TensorShape(3U, 3U, 2U, 4U), 1, DataType::F32, 0),
116                                               })),
117         framework::dataset::make("BiasesInfo",{ TensorInfo(TensorShape(4U), 1, DataType::F32, 0),
118                                                 TensorInfo(TensorShape(4U), 1, DataType::F32, 0),
119                                                 TensorInfo(TensorShape(4U), 1, DataType::F32, 0),
120                                                 TensorInfo(TensorShape(4U), 1, DataType::F32, 0),
121                                                 TensorInfo(TensorShape(4U), 1, DataType::F32, 0),
122                                                 TensorInfo(TensorShape(4U), 1, DataType::F32, 0),
123                                                 TensorInfo(TensorShape(3U), 1, DataType::F32, 0),
124                                                 TensorInfo(TensorShape(4U, 2U), 1, DataType::F32, 0),
125                                                 TensorInfo(TensorShape(4U), 1, DataType::F32, 0),
126                                               })),
127         framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(25U, 11U, 4U), 1, DataType::F32, 0),
128                                                 TensorInfo(TensorShape(25U, 11U, 4U), 1, DataType::F32, 0),
129                                                 TensorInfo(TensorShape(25U, 11U, 4U), 1, DataType::F32, 0),
130                                                 TensorInfo(TensorShape(25U, 11U, 4U), 1, DataType::F32, 0),
131                                                 TensorInfo(TensorShape(25U, 11U, 4U), 1, DataType::F32, 0),
132                                                 TensorInfo(TensorShape(25U, 11U, 4U), 1, DataType::F32, 0),
133                                                 TensorInfo(TensorShape(25U, 11U, 4U), 1, DataType::F32, 0),
134                                                 TensorInfo(TensorShape(25U, 11U, 4U), 1, DataType::F32, 0),
135                                                 TensorInfo(TensorShape(26U, 11U, 4U), 1, DataType::F32, 0),
136                                               })),
137         framework::dataset::make("ConvInfo",  { PadStrideInfo(1, 1, 0, 0),
138                                                 PadStrideInfo(1, 1, 0, 0),
139                                                 PadStrideInfo(1, 1, 0, 0),
140                                                 PadStrideInfo(1, 1, 0, 0),
141                                                 PadStrideInfo(1, 1, 0, 0),
142                                                 PadStrideInfo(3, 3, 0, 0),
143                                                 PadStrideInfo(1, 1, 0, 0),
144                                                 PadStrideInfo(1, 1, 0, 0),
145                                                 PadStrideInfo(1, 1, 0, 0),
146                                                })),
147         framework::dataset::make("Expected", { false, false, false, false, false, false, false, false, false })),
148         input_info, weights_info, biases_info, output_info, conv_info, expected)
149 {
150         bool is_valid = bool(NEDirectConvolutionLayer::validate(&input_info.clone()->set_is_resizable(false), &weights_info.clone()->set_is_resizable(false), &biases_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), conv_info));
151         ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS);
152 }
153 // clang-format on
154 // *INDENT-ON*
155
156 template <typename T>
157 using NEDirectConvolutionLayerFixture = DirectConvolutionValidationFixture<Tensor, Accessor, NEDirectConvolutionLayer, T>;
158
159 TEST_SUITE(Float)
160 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
161 TEST_SUITE(FP16)
162 FIXTURE_DATA_TEST_CASE(Run, NEDirectConvolutionLayerFixture<half>, framework::DatasetMode::ALL, combine(data_f32, framework::dataset::make("DataType", DataType::F16)))
163 {
164     // Validate output
165     validate(Accessor(_target), _reference, tolerance_fp16);
166 }
167 TEST_SUITE_END()
168 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
169
170 TEST_SUITE(FP32)
171 FIXTURE_DATA_TEST_CASE(Run, NEDirectConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(data_f32, framework::dataset::make("DataType", DataType::F32)))
172 {
173     // Validate output
174     validate(Accessor(_target), _reference, tolerance_fp32);
175 }
176 TEST_SUITE_END()
177 TEST_SUITE_END()
178
179 template <typename T>
180 using NEDirectConvolutionLayerFixedPointFixture = DirectConvolutionValidationFixedPointFixture<Tensor, Accessor, NEDirectConvolutionLayer, T>;
181
182 TEST_SUITE(Quantized)
183 TEST_SUITE(QS8)
184 // We test for fixed point precision [4,6]
185 FIXTURE_DATA_TEST_CASE(Run, NEDirectConvolutionLayerFixedPointFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(data_qs8, framework::dataset::make("DataType", DataType::QS8)),
186                                                                                                                     framework::dataset::make("FractionalBits", 4, 7)))
187 {
188     // Validate output
189     validate(Accessor(_target), _reference, tolerance_qs);
190 }
191 TEST_SUITE_END()
192
193 TEST_SUITE(QS16)
194 // We test for fixed point precision [4,13]
195 FIXTURE_DATA_TEST_CASE(Run, NEDirectConvolutionLayerFixedPointFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(data_qs16, framework::dataset::make("DataType", DataType::QS16)),
196                                                                                                                      framework::dataset::make("FractionalBits", 4, 14)))
197 {
198     // Validate output
199     validate(Accessor(_target), _reference, tolerance_qs);
200 }
201 TEST_SUITE_END()
202 TEST_SUITE_END()
203
204 TEST_SUITE_END()
205 TEST_SUITE_END()
206 } // namespace validation
207 } // namespace test
208 } // namespace arm_compute