2 * Copyright (c) 2017-2018 ARM Limited.
4 * SPDX-License-Identifier: MIT
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 #include "arm_compute/core/Types.h"
25 #include "arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h"
26 #include "arm_compute/runtime/Tensor.h"
27 #include "arm_compute/runtime/TensorAllocator.h"
28 #include "tests/NEON/Accessor.h"
29 #include "tests/PaddingCalculator.h"
30 #include "tests/datasets/ShapeDatasets.h"
31 #include "tests/framework/Asserts.h"
32 #include "tests/framework/Macros.h"
33 #include "tests/framework/datasets/Datasets.h"
34 #include "tests/validation/Validation.h"
35 #include "tests/validation/fixtures/DirectConvolutionLayerFixture.h"
45 constexpr AbsoluteTolerance<float> tolerance_qs(1.f); /**< Tolerance for fixed point tests */
46 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
47 constexpr AbsoluteTolerance<float> tolerance_fp16(0.01f); /**< Tolerance for half precision floating point tests */
48 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
49 constexpr AbsoluteTolerance<float> tolerance_fp32(0.001f); /**< Tolerance for floating point tests */
51 /** Direct convolution data set. */
52 const auto data_pad_f32 = concat(concat(combine(framework::dataset::make("PadX", 0, 1),
53 combine(framework::dataset::make("PadY", 0, 1),
54 framework::dataset::make("KernelSize", 1))),
55 combine(framework::dataset::make("PadX", 0, 2),
56 combine(framework::dataset::make("PadY", 0, 2),
57 framework::dataset::make("KernelSize", 3)))),
58 combine(framework::dataset::make("PadX", 0, 3),
59 combine(framework::dataset::make("PadY", 0, 3),
60 framework::dataset::make("KernelSize", 5))));
62 const auto data_pad_qs8 = concat(combine(framework::dataset::make("PadX", 0),
63 combine(framework::dataset::make("PadY", 0),
64 framework::dataset::make("KernelSize", 1))),
65 combine(framework::dataset::make("PadX", 0, 2),
66 combine(framework::dataset::make("PadY", 0, 2),
67 framework::dataset::make("KernelSize", 3))));
69 const auto data_f32 = combine(datasets::SmallDirectConvolutionShapes(),
70 combine(framework::dataset::make("StrideX", 1, 3),
71 combine(framework::dataset::make("StrideY", 1, 3),
73 framework::dataset::make("NumKernels", { 1, 4, 8, 16 })))));
75 const auto data_qs8 = combine(datasets::TinyDirectConvolutionShapes(),
76 combine(framework::dataset::make("StrideX", 1, 3),
77 combine(framework::dataset::make("StrideY", 1, 3),
79 framework::dataset::make("NumKernels", { 1, 4, 8, 16 })))));
81 /** Direct convolution QS16 data set. */
82 const auto data_qs16 = combine(datasets::TinyDirectConvolutionShapes(),
83 combine(framework::dataset::make("StrideX", 1, 3),
84 combine(framework::dataset::make("StrideY", 1, 3),
85 combine(framework::dataset::make("PadX", 0),
86 combine(framework::dataset::make("PadY", 0),
87 combine(framework::dataset::make("KernelSize", 1),
88 framework::dataset::make("NumKernels", { 1, 4, 8, 16 })))))));
92 TEST_SUITE(DirectConvolutionLayer)
96 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(
97 framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Mismatching data type input/weights
98 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Mismatching input feature maps
99 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Unsupported kernel width
100 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Non-rectangular weights dimensions
101 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Invalid weights dimensions
102 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Invalid stride
103 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Invalid biases size
104 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Invalid biases dimensions
105 TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32, 0), // Invalid output size
107 framework::dataset::make("WeightsInfo",{ TensorInfo(TensorShape(3U, 3U, 2U, 4U), 1, DataType::F16, 0),
108 TensorInfo(TensorShape(3U, 3U, 3U, 4U), 1, DataType::F32, 0),
109 TensorInfo(TensorShape(9U, 9U, 2U, 4U), 1, DataType::F32, 0),
110 TensorInfo(TensorShape(5U, 3U, 2U, 4U), 1, DataType::F32, 0),
111 TensorInfo(TensorShape(3U, 3U, 2U, 4U, 3U), 1, DataType::F32, 0),
112 TensorInfo(TensorShape(3U, 3U, 2U, 4U), 1, DataType::F32, 0),
113 TensorInfo(TensorShape(3U, 3U, 2U, 4U), 1, DataType::F32, 0),
114 TensorInfo(TensorShape(3U, 3U, 2U, 4U), 1, DataType::F32, 0),
115 TensorInfo(TensorShape(3U, 3U, 2U, 4U), 1, DataType::F32, 0),
117 framework::dataset::make("BiasesInfo",{ TensorInfo(TensorShape(4U), 1, DataType::F32, 0),
118 TensorInfo(TensorShape(4U), 1, DataType::F32, 0),
119 TensorInfo(TensorShape(4U), 1, DataType::F32, 0),
120 TensorInfo(TensorShape(4U), 1, DataType::F32, 0),
121 TensorInfo(TensorShape(4U), 1, DataType::F32, 0),
122 TensorInfo(TensorShape(4U), 1, DataType::F32, 0),
123 TensorInfo(TensorShape(3U), 1, DataType::F32, 0),
124 TensorInfo(TensorShape(4U, 2U), 1, DataType::F32, 0),
125 TensorInfo(TensorShape(4U), 1, DataType::F32, 0),
127 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(25U, 11U, 4U), 1, DataType::F32, 0),
128 TensorInfo(TensorShape(25U, 11U, 4U), 1, DataType::F32, 0),
129 TensorInfo(TensorShape(25U, 11U, 4U), 1, DataType::F32, 0),
130 TensorInfo(TensorShape(25U, 11U, 4U), 1, DataType::F32, 0),
131 TensorInfo(TensorShape(25U, 11U, 4U), 1, DataType::F32, 0),
132 TensorInfo(TensorShape(25U, 11U, 4U), 1, DataType::F32, 0),
133 TensorInfo(TensorShape(25U, 11U, 4U), 1, DataType::F32, 0),
134 TensorInfo(TensorShape(25U, 11U, 4U), 1, DataType::F32, 0),
135 TensorInfo(TensorShape(26U, 11U, 4U), 1, DataType::F32, 0),
137 framework::dataset::make("ConvInfo", { PadStrideInfo(1, 1, 0, 0),
138 PadStrideInfo(1, 1, 0, 0),
139 PadStrideInfo(1, 1, 0, 0),
140 PadStrideInfo(1, 1, 0, 0),
141 PadStrideInfo(1, 1, 0, 0),
142 PadStrideInfo(3, 3, 0, 0),
143 PadStrideInfo(1, 1, 0, 0),
144 PadStrideInfo(1, 1, 0, 0),
145 PadStrideInfo(1, 1, 0, 0),
147 framework::dataset::make("Expected", { false, false, false, false, false, false, false, false, false })),
148 input_info, weights_info, biases_info, output_info, conv_info, expected)
150 bool is_valid = bool(NEDirectConvolutionLayer::validate(&input_info.clone()->set_is_resizable(false), &weights_info.clone()->set_is_resizable(false), &biases_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), conv_info));
151 ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS);
156 template <typename T>
157 using NEDirectConvolutionLayerFixture = DirectConvolutionValidationFixture<Tensor, Accessor, NEDirectConvolutionLayer, T>;
160 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
162 FIXTURE_DATA_TEST_CASE(Run, NEDirectConvolutionLayerFixture<half>, framework::DatasetMode::ALL, combine(data_f32, framework::dataset::make("DataType", DataType::F16)))
165 validate(Accessor(_target), _reference, tolerance_fp16);
168 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
171 FIXTURE_DATA_TEST_CASE(Run, NEDirectConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(data_f32, framework::dataset::make("DataType", DataType::F32)))
174 validate(Accessor(_target), _reference, tolerance_fp32);
179 template <typename T>
180 using NEDirectConvolutionLayerFixedPointFixture = DirectConvolutionValidationFixedPointFixture<Tensor, Accessor, NEDirectConvolutionLayer, T>;
182 TEST_SUITE(Quantized)
184 // We test for fixed point precision [4,6]
185 FIXTURE_DATA_TEST_CASE(Run, NEDirectConvolutionLayerFixedPointFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(data_qs8, framework::dataset::make("DataType", DataType::QS8)),
186 framework::dataset::make("FractionalBits", 4, 7)))
189 validate(Accessor(_target), _reference, tolerance_qs);
194 // We test for fixed point precision [4,13]
195 FIXTURE_DATA_TEST_CASE(Run, NEDirectConvolutionLayerFixedPointFixture<int16_t>, framework::DatasetMode::ALL, combine(combine(data_qs16, framework::dataset::make("DataType", DataType::QS16)),
196 framework::dataset::make("FractionalBits", 4, 14)))
199 validate(Accessor(_target), _reference, tolerance_qs);
206 } // namespace validation
208 } // namespace arm_compute