arm_compute v18.02
[platform/upstream/armcl.git] / tests / validation / CL / PixelWiseMultiplication.cpp
1 /*
2  * Copyright (c) 2017-2018 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h"
25 #include "tests/CL/CLAccessor.h"
26 #include "tests/PaddingCalculator.h"
27 #include "tests/datasets/ConvertPolicyDataset.h"
28 #include "tests/datasets/ShapeDatasets.h"
29 #include "tests/framework/Macros.h"
30 #include "tests/validation/Validation.h"
31 #include "tests/validation/fixtures/FixedPointPixelWiseMultiplicationFixture.h"
32 #include "tests/validation/fixtures/PixelWiseMultiplicationFixture.h"
33
34 namespace arm_compute
35 {
36 namespace test
37 {
38 namespace validation
39 {
40 namespace
41 {
42 const float scale_unity = 1.f;
43 const float scale_255   = 1.f / 255.f;
44
45 // *INDENT-OFF*
46 // clang-format off
47 #define VALIDATE(TYPE, TOLERANCE) validate(CLAccessor(_target), _reference, AbsoluteTolerance<TYPE>(TOLERANCE), 0.f);
48
49 #define PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(TEST_NAME, FIXTURE, MODE, SHAPES, DT1, DT2, SCALE, RP, VALIDATE) \
50     FIXTURE_DATA_TEST_CASE(TEST_NAME, CLPixelWiseMultiplication##FIXTURE, framework::DatasetMode::MODE,                   \
51                            combine(combine(combine(combine(combine(                                                       \
52                            datasets::SHAPES,                                                                              \
53                            framework::dataset::make("DataType1", DataType::DT1)),                                         \
54                            framework::dataset::make("DataType2", DataType::DT2)),                                         \
55                            framework::dataset::make("Scale", std::move(SCALE))),                                          \
56                            datasets::ConvertPolicies()),                                                                  \
57                            framework::dataset::make("RoundingPolicy", RoundingPolicy::RP)))                               \
58     {                                                                                                                     \
59         VALIDATE                                                                                                          \
60     }
61
62 #define FP_PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(TEST_NAME, FIXTURE, MODE, SHAPES, DT1, DT2, SCALE, RP, FPP_START, FPP_END) \
63     FIXTURE_DATA_TEST_CASE(TEST_NAME, CLFixedPointPixelWiseMultiplication##FIXTURE, framework::DatasetMode::MODE,                      \
64                            combine(combine(combine(combine(combine(combine(                                                            \
65                            datasets::SHAPES,                                                                                           \
66                            framework::dataset::make("DataType1", DataType::DT1)),                                                      \
67                            framework::dataset::make("DataType2", DataType::DT2)),                                                      \
68                            framework::dataset::make("Scale", std::move(SCALE))),                                                       \
69                            datasets::ConvertPolicies()),                                                                               \
70                            framework::dataset::make("RoundingPolicy", RoundingPolicy::RP)),                                            \
71                            framework::dataset::make("FixedPointPosition", FPP_START, FPP_END)))                                        \
72     {                                                                                                                                  \
73         validate(CLAccessor(_target), _reference);                                                                                     \
74     }
75 // clang-format on
76 // *INDENT-ON*
77 } // namespace
78
79 template <typename T>
80 using CLPixelWiseMultiplicationToF16Fixture = PixelWiseMultiplicationValidationFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T, half_float::half>;
81 template <typename T>
82 using CLPixelWiseMultiplicationToF32Fixture = PixelWiseMultiplicationValidationFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T, float>;
83 template <typename T>
84 using CLPixelWiseMultiplicationToQS8Fixture = PixelWiseMultiplicationValidationFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T, qint8_t>;
85 template <typename T>
86 using CLPixelWiseMultiplicationToQS16Fixture = PixelWiseMultiplicationValidationFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T, qint16_t>;
87 template <typename T>
88 using CLFixedPointPixelWiseMultiplicationFixture = FixedPointPixelWiseMultiplicationValidationFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T>;
89 template <typename T>
90 using CLPixelWiseMultiplicationBroadcastFixture = PixelWiseMultiplicationBroadcastValidationFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T, float>;
91
92 TEST_SUITE(CL)
93 TEST_SUITE(PixelWiseMultiplication)
94
95 // *INDENT-OFF*
96 // clang-format off
97 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
98                framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
99                                                         TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
100                                                         TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8),      // Window shrink
101                                                         TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),      // Invalid scale
102                                                         TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),      // Invalid data type combination
103                                                         TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),     // Mismatching shapes
104                                                         TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2),  // Mismatching data type
105                                                         TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2),  // Mismatching fixed point
106                                                         TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2),  // Invalid scale
107                                                       }),
108                framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
109                                                        TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
110                                                        TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8),
111                                                        TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
112                                                        TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
113                                                        TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
114                                                        TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS16, 2),
115                                                        TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3),
116                                                        TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2),
117                                                      })),
118                framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
119                                                        TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
120                                                        TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8),
121                                                        TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
122                                                        TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
123                                                        TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
124                                                        TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3),
125                                                        TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2),
126                                                        TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2),
127                                                      })),
128                framework::dataset::make("Scale",{  2.f, 2.f, 2.f, -1.f, 1.f, 1.f, 1.f, 1.f, 3.f})),
129                framework::dataset::make("Expected", { true, true, false, false, false, false, false, false, false })),
130                input1_info, input2_info, output_info, scale, expected)
131 {
132     bool has_error = bool(CLPixelWiseMultiplication::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), scale, ConvertPolicy::WRAP, RoundingPolicy::TO_ZERO));
133     ARM_COMPUTE_EXPECT(has_error == expected, framework::LogLevel::ERRORS);
134 }
135 // clang-format on
136 // *INDENT-ON*
137
138 TEST_SUITE(F16toF16)
139
140 TEST_SUITE(Scale255)
141 PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToF16Fixture<half_float::half>, PRECOMMIT, SmallShapes(), F16, F16, scale_255, TO_NEAREST_UP, VALIDATE(float, 1.f))
142 TEST_SUITE_END() // Scale255
143
144 TEST_SUITE_END() // F16toF16
145
146 TEST_SUITE(F32toF32)
147
148 TEST_SUITE(Scale255)
149 PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToF32Fixture<float>, PRECOMMIT, SmallShapes(), F32, F32, scale_255, TO_NEAREST_UP, VALIDATE(float, 1.f))
150 TEST_SUITE_END() // Scale255
151
152 TEST_SUITE_END() // F32toF32
153
154 TEST_SUITE_END() // PixelWiseMultiplication
155
156 TEST_SUITE(FixedPointPixelWiseMultiplication)
157
158 TEST_SUITE(QS8)
159
160 TEST_SUITE(ScaleUnity)
161 FP_PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunTiny, Fixture<qint8_t>, PRECOMMIT, TinyShapes(), QS8, QS8, scale_unity, TO_ZERO, 1, 7)
162 TEST_SUITE_END() // ScaleUnity
163
164 TEST_SUITE_END() // QS8
165
166 TEST_SUITE(QS16)
167
168 TEST_SUITE(ScaleUnity)
169 FP_PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunTiny, Fixture<qint16_t>, PRECOMMIT, TinyShapes(), QS16, QS16, scale_unity, TO_ZERO, 1, 15)
170 TEST_SUITE_END() // ScaleUnity
171
172 TEST_SUITE_END() // QS16
173
174 TEST_SUITE(Broadcast)
175 PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, BroadcastFixture<float>, PRECOMMIT, SmallShapesBroadcast(), F32, F32, scale_255, TO_NEAREST_UP, VALIDATE(float, 1.f))
176 TEST_SUITE_END() // Broadcast
177
178 TEST_SUITE_END() // FixedPointPixelWiseMultiplication
179 TEST_SUITE_END()
180 } // namespace validation
181 } // namespace test
182 } // namespace arm_compute