arm_compute v17.03.1
[platform/upstream/armcl.git] / src / core / NEON / kernels / NEConvolutionKernel.cpp
1 /*
2  * Copyright (c) 2016, 2017 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/NEON/kernels/NEConvolutionKernel.h"
25
26 #include "arm_compute/core/Coordinates.h"
27 #include "arm_compute/core/Error.h"
28 #include "arm_compute/core/Helpers.h"
29 #include "arm_compute/core/ITensor.h"
30 #include "arm_compute/core/TensorInfo.h"
31 #include "arm_compute/core/Types.h"
32 #include "arm_compute/core/Utils.h"
33 #include "arm_compute/core/Validate.h"
34 #include "arm_compute/core/Window.h"
35
36 #include <algorithm>
37 #include <arm_neon.h>
38 #include <array>
39 #include <cstdint>
40 #include <cstring>
41 #include <tuple>
42
43 using namespace arm_compute;
44
45 namespace
46 {
47 const uint16x8_t max_int16 = vdupq_n_u16(INT16_MAX);
48
49 inline void store_results(const int32x4_t &out, const int32x4_t &out2, int16_t *output)
50 {
51     const int16x8_t s16results = vcombine_s16(vqmovn_s32(out),
52                                               vqmovn_s32(out2));
53     vst1q_s16(output, s16results);
54 }
55
56 inline void store_results(const int32x4_t &out, const int32x4_t &out2, uint8_t *output)
57 {
58     const uint8x8_t u8results = vqmovn_u16(vcombine_u16(vqmovun_s32(out),
59                                                         vqmovun_s32(out2)));
60     vst1_u8(output, u8results);
61 }
62
63 inline void store_results(const uint32x4_t &out, const uint32x4_t &out2, int16_t *output)
64 {
65     const uint16x8_t u16results = vcombine_u16(vqmovn_u32(out), vqmovn_u32(out2));
66     const int16x8_t  s16results = vreinterpretq_s16_u16(vminq_u16(u16results, max_int16));
67     vst1q_s16(output, s16results);
68 }
69
70 inline void store_results(const uint32x4_t &out, const uint32x4_t &out2, uint8_t *output)
71 {
72     const uint8x8_t u8results = vqmovn_u16(vcombine_u16(vqmovn_u32(out),
73                                                         vqmovn_u32(out2)));
74     vst1_u8(output, u8results);
75 }
76
77 inline void store_results(const int16x8_t &out, const int16x8_t &out2, int16_t *output)
78 {
79     vst1q_s16(output, out);
80     vst1q_s16(output + 8, out2);
81 }
82
83 inline void store_results(const int16x8_t &out, const int16x8_t &out2, uint8_t *output)
84 {
85     const uint8x16_t u8results = vcombine_u8(vqmovun_s16(out),
86                                              vqmovun_s16(out2));
87     vst1q_u8(output, u8results);
88 }
89
90 inline void store_results(const uint16x8_t &out, const uint16x8_t &out2, uint8_t *output)
91 {
92     const uint8x16_t u8results = vcombine_u8(vqmovn_u16(out),
93                                              vqmovn_u16(out2));
94     vst1q_u8(output, u8results);
95 }
96
97 inline void store_results(const uint16x8_t &out, const uint16x8_t &out2, int16_t *output)
98 {
99     vst1q_s16(output, vreinterpretq_s16_u16(vminq_u16(out, max_int16)));
100     vst1q_s16(output + 8, vreinterpretq_s16_u16(vminq_u16(out2, max_int16)));
101 }
102
103 inline void convolve_row3x1_unrolled(int32x4_t &out, int32x4_t &out2, const uint8x16_t &row_data, const int16x4_t &mat0, const int16x4_t &mat1, const int16x4_t &mat2)
104 {
105     // Convert to s16 and split in blocks of 4 values:
106     const int16x8_t s16_tmp0 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(row_data)));
107     const int16x8_t s16_tmp1 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(row_data)));
108
109     const int16x4x3_t row =
110     {
111         {
112             vget_low_s16(s16_tmp0),
113             vget_high_s16(s16_tmp0),
114             vget_low_s16(s16_tmp1)
115         }
116     };
117
118     // Calculate row left value for pixels [0,3]
119     out = vmlal_s16(out, row.val[0], mat0);
120     // Calculate row middle value for pixels [0,3]
121     out = vmlal_s16(out, vext_s16(row.val[0], row.val[1], 1), mat1);
122     // Calculate row right value for pixels [0,3]
123     out = vmlal_s16(out, vext_s16(row.val[0], row.val[1], 2), mat2);
124
125     // Calculate row left value for pixels [4,7]
126     out2 = vmlal_s16(out2, row.val[1], mat0);
127     // Calculate row middle value for pixels [4,7]
128     out2 = vmlal_s16(out2, vext_s16(row.val[1], row.val[2], 1), mat1);
129     // Calculate row right value for pixels [4,7]
130     out2 = vmlal_s16(out2, vext_s16(row.val[1], row.val[2], 2), mat2);
131 }
132
133 inline void convolve_row3x1(int32x4_t &out, int32x4_t &out2, const uint8x16_t &row_data, const int16_t *convolution)
134 {
135     const int16x4_t mat0 = vld1_dup_s16(convolution);
136     const int16x4_t mat1 = vld1_dup_s16(convolution + 1);
137     const int16x4_t mat2 = vld1_dup_s16(convolution + 2);
138
139     convolve_row3x1_unrolled(out, out2, row_data, mat0, mat1, mat2);
140 }
141
142 inline void convolve_row5x1(int32x4_t &out, int32x4_t &out2, const uint8x16_t &row_data, const int16_t *convolution)
143 {
144     const int16x4_t mat0 = vld1_dup_s16(convolution);
145     const int16x4_t mat1 = vld1_dup_s16(convolution + 1);
146     const int16x4_t mat2 = vld1_dup_s16(convolution + 2);
147     const int16x4_t mat3 = vld1_dup_s16(convolution + 3);
148     const int16x4_t mat4 = vld1_dup_s16(convolution + 4);
149
150     // Convert to s16 and split in blocks of 4 values:
151     const int16x8_t s16_tmp0 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(row_data)));
152     const int16x8_t s16_tmp1 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(row_data)));
153
154     const int16x4x3_t row =
155     {
156         {
157             vget_low_s16(s16_tmp0),
158             vget_high_s16(s16_tmp0),
159             vget_low_s16(s16_tmp1)
160         }
161     };
162
163     // Calculate row left 2 value for pixels [0,3]
164     out = vmlal_s16(out, row.val[0], mat0);
165     // Calculate row left 1 value for pixels [0,3]
166     out = vmlal_s16(out, vext_s16(row.val[0], row.val[1], 1), mat1);
167     // Calculate row middle value for pixels [0,3]
168     out = vmlal_s16(out, vext_s16(row.val[0], row.val[1], 2), mat2);
169     // Calculate row right +1 value for pixels [0,3]
170     out = vmlal_s16(out, vext_s16(row.val[0], row.val[1], 3), mat3);
171     // Calculate row right +2 value for pixels [0,3]
172     out = vmlal_s16(out, row.val[1], mat4);
173
174     // Calculate row left 2 value for pixels [4,7]
175     out2 = vmlal_s16(out2, row.val[1], mat0);
176     // Calculate row left 1 value for pixels [4,7]
177     out2 = vmlal_s16(out2, vext_s16(row.val[1], row.val[2], 1), mat1);
178     // Calculate row middle value for pixels [4,7]
179     out2 = vmlal_s16(out2, vext_s16(row.val[1], row.val[2], 2), mat2);
180     // Calculate row right +1 value for pixels [4,7]
181     out2 = vmlal_s16(out2, vext_s16(row.val[1], row.val[2], 3), mat3);
182     // Calculate row right +2 value for pixels [4,7]
183     out2 = vmlal_s16(out2, row.val[2], mat4);
184 }
185
186 inline void convolve_row7x1(int32x4_t &out, int32x4_t &out2, const uint8x16_t &row_data, const int16_t *convolution)
187 {
188     const int16x4_t mat0 = vld1_dup_s16(convolution);
189     const int16x4_t mat1 = vld1_dup_s16(convolution + 1);
190     const int16x4_t mat2 = vld1_dup_s16(convolution + 2);
191     const int16x4_t mat3 = vld1_dup_s16(convolution + 3);
192     const int16x4_t mat4 = vld1_dup_s16(convolution + 4);
193     const int16x4_t mat5 = vld1_dup_s16(convolution + 5);
194     const int16x4_t mat6 = vld1_dup_s16(convolution + 6);
195
196     // Convert to s16 and split in blocks of 4 values:
197     const int16x8_t s16_tmp0 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(row_data)));
198     const int16x8_t s16_tmp1 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(row_data)));
199
200     const int16x4x4_t row =
201     {
202         {
203             vget_low_s16(s16_tmp0),
204             vget_high_s16(s16_tmp0),
205             vget_low_s16(s16_tmp1),
206             vget_high_s16(s16_tmp1)
207         }
208     };
209
210     // Calculate row left 3 value for pixels [0,3]
211     out = vmlal_s16(out, row.val[0], mat0);
212     // Calculate row left 2 value for pixels [0,3]
213     out = vmlal_s16(out, vext_s16(row.val[0], row.val[1], 1), mat1);
214     // Calculate row left 1 value for pixels [0,3]
215     out = vmlal_s16(out, vext_s16(row.val[0], row.val[1], 2), mat2);
216     // Calculate row middle value for pixels [0,3]
217     out = vmlal_s16(out, vext_s16(row.val[0], row.val[1], 3), mat3);
218     // Calculate row right +1 value for pixels [0,3]
219     out = vmlal_s16(out, row.val[1], mat4);
220     // Calculate row right +2 value for pixels [0,3]
221     out = vmlal_s16(out, vext_s16(row.val[1], row.val[2], 1), mat5);
222     // Calculate row right +3 value for pixels [0,3]
223     out = vmlal_s16(out, vext_s16(row.val[1], row.val[2], 2), mat6);
224
225     // Calculate row left 3 value for pixels [4,7]
226     out2 = vmlal_s16(out2, row.val[1], mat0);
227     // Calculate row left 2 value for pixels [4,7]
228     out2 = vmlal_s16(out2, vext_s16(row.val[1], row.val[2], 1), mat1);
229     // Calculate row left 1 value for pixels [4,7]
230     out2 = vmlal_s16(out2, vext_s16(row.val[1], row.val[2], 2), mat2);
231     // Calculate row middle value for pixels [4,7]
232     out2 = vmlal_s16(out2, vext_s16(row.val[1], row.val[2], 3), mat3);
233     // Calculate row right +1 value for pixels [4,7]
234     out2 = vmlal_s16(out2, row.val[2], mat4);
235     // Calculate row right +2 value for pixels [4,7]
236     out2 = vmlal_s16(out2, vext_s16(row.val[2], row.val[3], 1), mat5);
237     // Calculate row right +3 value for pixels [4,7]
238     out2 = vmlal_s16(out2, vext_s16(row.val[2], row.val[3], 2), mat6);
239 }
240
241 inline void convolve_row9x1(int32x4_t &out, int32x4_t &out2, const uint8x16_t &row_data, const int16_t *convolution)
242 {
243     const int16x4_t mat0 = vld1_dup_s16(convolution);
244     const int16x4_t mat1 = vld1_dup_s16(convolution + 1);
245     const int16x4_t mat2 = vld1_dup_s16(convolution + 2);
246     const int16x4_t mat3 = vld1_dup_s16(convolution + 3);
247     const int16x4_t mat4 = vld1_dup_s16(convolution + 4);
248     const int16x4_t mat5 = vld1_dup_s16(convolution + 5);
249     const int16x4_t mat6 = vld1_dup_s16(convolution + 6);
250     const int16x4_t mat7 = vld1_dup_s16(convolution + 7);
251     const int16x4_t mat8 = vld1_dup_s16(convolution + 8);
252
253     // Convert to s16 and split in blocks of 4 values:
254     const int16x8_t s16_tmp0 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(row_data)));
255     const int16x8_t s16_tmp1 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(row_data)));
256
257     const int16x4x4_t row =
258     {
259         {
260             vget_low_s16(s16_tmp0),
261             vget_high_s16(s16_tmp0),
262             vget_low_s16(s16_tmp1),
263             vget_high_s16(s16_tmp1)
264         }
265     };
266
267     // Calculate row left 4 value for pixels [0,3]
268     out = vmlal_s16(out, row.val[0], mat0);
269     // Calculate row left 3 value for pixels [0,3]
270     out = vmlal_s16(out, vext_s16(row.val[0], row.val[1], 1), mat1);
271     // Calculate row left 2 value for pixels [0,3]
272     out = vmlal_s16(out, vext_s16(row.val[0], row.val[1], 2), mat2);
273     // Calculate row left 1 value for pixels [0,3]
274     out = vmlal_s16(out, vext_s16(row.val[0], row.val[1], 3), mat3);
275     // Calculate row middle value for pixels [0,3]
276     out = vmlal_s16(out, row.val[1], mat4);
277     // Calculate row right +1 value for pixels [0,3]
278     out = vmlal_s16(out, vext_s16(row.val[1], row.val[2], 1), mat5);
279     // Calculate row right +2 value for pixels [0,3]
280     out = vmlal_s16(out, vext_s16(row.val[1], row.val[2], 2), mat6);
281     // Calculate row right +3 value for pixels [0,3]
282     out = vmlal_s16(out, vext_s16(row.val[1], row.val[2], 3), mat7);
283     // Calculate row right +4 value for pixels [0,3]
284     out = vmlal_s16(out, row.val[2], mat8);
285
286     // Calculate row left 4 value for pixels [0,3]
287     out2 = vmlal_s16(out2, row.val[1], mat0);
288     // Calculate row left 3 value for pixels [0,3]
289     out2 = vmlal_s16(out2, vext_s16(row.val[1], row.val[2], 1), mat1);
290     // Calculate row left 2 value for pixels [0,3]
291     out2 = vmlal_s16(out2, vext_s16(row.val[1], row.val[2], 2), mat2);
292     // Calculate row left 1 value for pixels [0,3]
293     out2 = vmlal_s16(out2, vext_s16(row.val[1], row.val[2], 3), mat3);
294     // Calculate row middle value for pixels [0,3]
295     out2 = vmlal_s16(out2, row.val[2], mat4);
296     // Calculate row right +1 value for pixels [0,3]
297     out2 = vmlal_s16(out2, vext_s16(row.val[2], row.val[3], 1), mat5);
298     // Calculate row right +2 value for pixels [0,3]
299     out2 = vmlal_s16(out2, vext_s16(row.val[2], row.val[3], 2), mat6);
300     // Calculate row right +3 value for pixels [0,3]
301     out2 = vmlal_s16(out2, vext_s16(row.val[2], row.val[3], 3), mat7);
302     // Calculate row right +4 value for pixels [0,3]
303     out2 = vmlal_s16(out2, row.val[3], mat8);
304 }
305 } // namespace
306
307 /****************************************************************************************\
308  *                                    Square Convolution                                *
309 \****************************************************************************************/
310
311 template <unsigned int matrix_size>
312 NEConvolutionKernel<matrix_size>::NEConvolutionKernel()
313     : INESimpleKernel(), _scale(0), _convolution{ {} }
314 {
315 }
316
317 template <unsigned int matrix_size>
318 BorderSize             NEConvolutionKernel<matrix_size>::border_size() const
319 {
320     return BorderSize(matrix_size / 2);
321 }
322
323 template <unsigned int matrix_size>
324 void NEConvolutionKernel<matrix_size>::configure(const ITensor *input, ITensor *output, const int16_t *conv, uint32_t scale, bool border_undefined)
325 {
326     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
327     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16);
328     ARM_COMPUTE_ERROR_ON(conv == nullptr);
329
330     _input  = input;
331     _output = output;
332
333     std::copy_n(conv, _convolution.size(), _convolution.begin());
334
335     if(scale == 0)
336     {
337         _scale = calculate_matrix_scale(_convolution.data(), matrix_size);
338     }
339     else
340     {
341         _scale = scale;
342     }
343
344     // Configure kernel window
345     constexpr unsigned int processed_elements(8);
346     constexpr unsigned int read_elements(16);
347     constexpr unsigned int written_elements(8);
348
349     Window                 win = calculate_max_window(*input->info(), Steps(processed_elements), border_undefined, border_size());
350     AccessWindowHorizontal output_access(output->info(), 0, written_elements);
351
352     update_window_and_padding(win,
353                               AccessWindowRectangle(input->info(), -border_size().left, -border_size().top, read_elements, matrix_size),
354                               output_access);
355
356     output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
357
358     INEKernel::configure(win);
359 }
360
361 #ifndef DOXYGEN_SKIP_THIS /* Doxygen gets confused by the templates and can't match the implementation to the declaration */
362 template <>
363 template <typename OutputType>
364 void NEConvolutionKernel<3>::convolution(const Window &win)
365 {
366     static_assert(sizeof(OutputType) == sizeof(uint8_t) || sizeof(OutputType) == sizeof(int16_t), "The output buffer can only be u8 or s16");
367     ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
368
369     Iterator input(_input, win);
370     Iterator output(_output, win);
371
372     // Load the matrix's coefficients into NEON registers:
373     const int16x4_t   mat00     = vld1_dup_s16(_convolution.data());
374     const int16x4_t   mat01     = vld1_dup_s16(_convolution.data() + 1);
375     const int16x4_t   mat02     = vld1_dup_s16(_convolution.data() + 2);
376     const int16x4_t   mat10     = vld1_dup_s16(_convolution.data() + 3);
377     const int16x4_t   mat11     = vld1_dup_s16(_convolution.data() + 4);
378     const int16x4_t   mat12     = vld1_dup_s16(_convolution.data() + 5);
379     const int16x4_t   mat20     = vld1_dup_s16(_convolution.data() + 6);
380     const int16x4_t   mat21     = vld1_dup_s16(_convolution.data() + 7);
381     const int16x4_t   mat22     = vld1_dup_s16(_convolution.data() + 8);
382     const float32x4_t scale_val = vdupq_n_f32(1.0f / _scale);
383
384     const unsigned char *input_top_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-1, -1));
385     const unsigned char *input_mid_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-1, 0));
386     const unsigned char *input_low_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-1, 1));
387
388     execute_window_loop(win, [&](const Coordinates & id)
389     {
390         int32x4_t out  = vdupq_n_s32(0);
391         int32x4_t out2 = vdupq_n_s32(0);
392
393         // Load 16 bytes from the top row:
394         const uint8x16_t top_data = vld1q_u8(input_top_ptr + input.offset());
395         convolve_row3x1_unrolled(out, out2, top_data, mat00, mat01, mat02);
396
397         // Load 16 bytes from the middle row:
398         const uint8x16_t mid_data = vld1q_u8(input_mid_ptr + input.offset());
399         convolve_row3x1_unrolled(out, out2, mid_data, mat10, mat11, mat12);
400
401         // Load 16 bytes from the middle row:
402         const uint8x16_t low_data = vld1q_u8(input_low_ptr + input.offset());
403         convolve_row3x1_unrolled(out, out2, low_data, mat20, mat21, mat22);
404
405         // Apply scale
406         if(_scale != 1)
407         {
408             // Convert to F32, scale and convert back to S32
409             out  = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(out), scale_val));
410             out2 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(out2), scale_val));
411         }
412
413         // Clamp and store as U8 or S16:
414         store_results(out, out2, reinterpret_cast<OutputType *>(output.ptr()));
415     },
416     input, output);
417 }
418
419 template <>
420 template <typename OutputType>
421 void NEConvolutionKernel<5>::convolution(const Window &win)
422 {
423     static_assert(sizeof(OutputType) == sizeof(uint8_t) || sizeof(OutputType) == sizeof(int16_t), "The output buffer can only be u8 or s16");
424     ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
425
426     Iterator input(_input, win);
427     Iterator output(_output, win);
428
429     const float32x4_t scale_val = vdupq_n_f32(1.0f / _scale);
430
431     const unsigned char *input_top2_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-2, -2));
432     const unsigned char *input_top1_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-2, -1));
433     const unsigned char *input_mid_ptr  = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-2, 0));
434     const unsigned char *input_low1_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-2, 1));
435     const unsigned char *input_low2_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-2, 2));
436
437     execute_window_loop(win, [&](const Coordinates & id)
438     {
439         int32x4_t out  = vdupq_n_s32(0);
440         int32x4_t out2 = vdupq_n_s32(0);
441
442         // Load 16 bytes from the top2 row:
443         const uint8x16_t data_t2 = vld1q_u8(input_top2_ptr + input.offset());
444         convolve_row5x1(out, out2, data_t2, _convolution.data());
445
446         // Load 16 bytes from the top1 row:
447         const uint8x16_t data_t1 = vld1q_u8(input_top1_ptr + input.offset());
448         convolve_row5x1(out, out2, data_t1, _convolution.data() + 5);
449
450         // Load 16 bytes from the middle row:
451         const uint8x16_t data_m = vld1q_u8(input_mid_ptr + input.offset());
452         convolve_row5x1(out, out2, data_m, _convolution.data() + 10);
453
454         // Load 16 bytes from the low1 row:
455         const uint8x16_t data_b1 = vld1q_u8(input_low1_ptr + input.offset());
456         convolve_row5x1(out, out2, data_b1, _convolution.data() + 15);
457
458         // Load 16 bytes from the low2 row:
459         const uint8x16_t data_b2 = vld1q_u8(input_low2_ptr + input.offset());
460         convolve_row5x1(out, out2, data_b2, _convolution.data() + 20);
461
462         // Apply scale
463         if(_scale != 1)
464         {
465             // Convert to F32, scale and convert back to S32
466             out  = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(out), scale_val));
467             out2 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(out2), scale_val));
468         }
469
470         // Clamp and store as U8 or S16:
471         store_results(out, out2, reinterpret_cast<OutputType *>(output.ptr()));
472     },
473     input, output);
474 }
475
476 template <>
477 template <typename OutputType>
478 void NEConvolutionKernel<7>::convolution(const Window &win)
479 {
480     static_assert(sizeof(OutputType) == sizeof(uint8_t) || sizeof(OutputType) == sizeof(int16_t), "The output buffer can only be u8 or s16");
481     ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
482
483     Iterator input(_input, win);
484     Iterator output(_output, win);
485
486     const float32x4_t scale_val = vdupq_n_f32(1.0f / _scale);
487
488     const unsigned char *input_top3_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-3, -3));
489     const unsigned char *input_top2_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-3, -2));
490     const unsigned char *input_top1_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-3, -1));
491     const unsigned char *input_mid_ptr  = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-3, 0));
492     const unsigned char *input_low1_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-3, 1));
493     const unsigned char *input_low2_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-3, 2));
494     const unsigned char *input_low3_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-3, 3));
495
496     execute_window_loop(win, [&](const Coordinates & id)
497     {
498         int32x4_t out  = vdupq_n_s32(0);
499         int32x4_t out2 = vdupq_n_s32(0);
500
501         // Load 16 bytes from the top3 row:
502         const uint8x16_t data_t3 = vld1q_u8(input_top3_ptr + input.offset());
503         convolve_row7x1(out, out2, data_t3, _convolution.data());
504
505         // Load 16 bytes from the top2 row:
506         const uint8x16_t data_t2 = vld1q_u8(input_top2_ptr + input.offset());
507         convolve_row7x1(out, out2, data_t2, _convolution.data() + 7);
508
509         // Load 16 bytes from the top1 row:
510         const uint8x16_t data_t1 = vld1q_u8(input_top1_ptr + input.offset());
511         convolve_row7x1(out, out2, data_t1, _convolution.data() + 14);
512
513         // Load 16 bytes from the middle row:
514         const uint8x16_t data_m = vld1q_u8(input_mid_ptr + input.offset());
515         convolve_row7x1(out, out2, data_m, _convolution.data() + 21);
516
517         // Load 16 bytes from the low1 row:
518         const uint8x16_t data_b1 = vld1q_u8(input_low1_ptr + input.offset());
519         convolve_row7x1(out, out2, data_b1, _convolution.data() + 28);
520
521         // Load 16 bytes from the low2 row:
522         const uint8x16_t data_b2 = vld1q_u8(input_low2_ptr + input.offset());
523         convolve_row7x1(out, out2, data_b2, _convolution.data() + 35);
524
525         // Load 16 bytes from the low3 row:
526         const uint8x16_t data_b3 = vld1q_u8(input_low3_ptr + input.offset());
527         convolve_row7x1(out, out2, data_b3, _convolution.data() + 42);
528
529         // Apply scale
530         if(_scale != 1)
531         {
532             // Convert to F32, scale and convert back to S32
533             out  = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(out), scale_val));
534             out2 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(out2), scale_val));
535         }
536
537         // Clamp and store as U8 or S16:
538         store_results(out, out2, reinterpret_cast<OutputType *>(output.ptr()));
539     },
540     input, output);
541 }
542
543 template <>
544 template <typename OutputType>
545 void NEConvolutionKernel<9>::convolution(const Window &win)
546 {
547     static_assert(sizeof(OutputType) == sizeof(uint8_t) || sizeof(OutputType) == sizeof(int16_t), "The output buffer can only be u8 or s16");
548     ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
549
550     Iterator input(_input, win);
551     Iterator output(_output, win);
552
553     const float32x4_t scale_val = vdupq_n_f32(1.0f / _scale);
554
555     const unsigned char *input_top4_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-4, -4));
556     const unsigned char *input_top3_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-4, -3));
557     const unsigned char *input_top2_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-4, -2));
558     const unsigned char *input_top1_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-4, -1));
559     const unsigned char *input_mid_ptr  = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-4, 0));
560     const unsigned char *input_low1_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-4, 1));
561     const unsigned char *input_low2_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-4, 2));
562     const unsigned char *input_low3_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-4, 3));
563     const unsigned char *input_low4_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-4, 4));
564
565     execute_window_loop(win, [&](const Coordinates & id)
566     {
567         int32x4_t out  = vdupq_n_s32(0);
568         int32x4_t out2 = vdupq_n_s32(0);
569
570         // Load 16 bytes from the top4 row:
571         const uint8x16_t data_t4 = vld1q_u8(input_top4_ptr + input.offset());
572         convolve_row9x1(out, out2, data_t4, _convolution.data());
573
574         // Load 16 bytes from the top3 row:
575         const uint8x16_t data_t3 = vld1q_u8(input_top3_ptr + input.offset());
576         convolve_row9x1(out, out2, data_t3, _convolution.data() + 9);
577
578         // Load 16 bytes from the top2 row:
579         const uint8x16_t data_t2 = vld1q_u8(input_top2_ptr + input.offset());
580         convolve_row9x1(out, out2, data_t2, _convolution.data() + 18);
581
582         // Load 16 bytes from the top1 row:
583         const uint8x16_t data_t1 = vld1q_u8(input_top1_ptr + input.offset());
584         convolve_row9x1(out, out2, data_t1, _convolution.data() + 27);
585
586         // Load 16 bytes from the middle row:
587         const uint8x16_t data_m = vld1q_u8(input_mid_ptr + input.offset());
588         convolve_row9x1(out, out2, data_m, _convolution.data() + 36);
589
590         // Load 16 bytes from the low1 row:
591         const uint8x16_t data_b1 = vld1q_u8(input_low1_ptr + input.offset());
592         convolve_row9x1(out, out2, data_b1, _convolution.data() + 45);
593
594         // Load 16 bytes from the low2 row:
595         const uint8x16_t data_b2 = vld1q_u8(input_low2_ptr + input.offset());
596         convolve_row9x1(out, out2, data_b2, _convolution.data() + 54);
597
598         // Load 16 bytes from the low3 row:
599         const uint8x16_t data_b3 = vld1q_u8(input_low3_ptr + input.offset());
600         convolve_row9x1(out, out2, data_b3, _convolution.data() + 63);
601
602         // Load 16 bytes from the low4 row:
603         const uint8x16_t data_b4 = vld1q_u8(input_low4_ptr + input.offset());
604         convolve_row9x1(out, out2, data_b4, _convolution.data() + 72);
605
606         // Apply scale
607         if(_scale != 1)
608         {
609             // Convert to F32, scale and convert back to S32
610             out  = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(out), scale_val));
611             out2 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(out2), scale_val));
612         }
613
614         // Clamp and store as U8 or S16:
615         store_results(out, out2, reinterpret_cast<OutputType *>(output.ptr()));
616     },
617     input, output);
618 }
619 #endif /* DOXYGEN_SKIP_THIS */
620
621 template <unsigned int matrix_size>
622 void NEConvolutionKernel<matrix_size>::run(const Window &window)
623 {
624     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
625     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
626
627     switch(_output->info()->format())
628     {
629         case Format::U8:
630             convolution<uint8_t>(window);
631             break;
632         case Format::S16:
633             convolution<int16_t>(window);
634             break;
635         default:
636             ARM_COMPUTE_ERROR("Not supported");
637     }
638 }
639
640 template class arm_compute::NEConvolutionKernel<3>;
641 template class arm_compute::NEConvolutionKernel<5>;
642 template class arm_compute::NEConvolutionKernel<7>;
643 template class arm_compute::NEConvolutionKernel<9>;
644
645 /****************************************************************************************\
646  *                              Separable Square Convolution                            *
647 \****************************************************************************************/
648
649 template <unsigned int matrix_size>
650 NESeparableConvolutionHorKernel<matrix_size>::NESeparableConvolutionHorKernel()
651     : _conv_row{ { 0 } }, _border_size(0)
652 {
653 }
654
655 template <unsigned int matrix_size>
656 BorderSize             NESeparableConvolutionHorKernel<matrix_size>::border_size() const
657 {
658     return _border_size;
659 }
660
661 template <unsigned int matrix_size>
662 void NESeparableConvolutionHorKernel<matrix_size>::configure(const ITensor *input, ITensor *output, const int16_t *conv_row, bool border_undefined)
663 {
664     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
665     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U16, DataType::S16, DataType::S32);
666     ARM_COMPUTE_ERROR_ON(conv_row == nullptr);
667
668     _input  = input;
669     _output = output;
670     std::copy_n(conv_row, _conv_row.size(), _conv_row.begin());
671     _border_size = BorderSize(border_undefined ? 0 : matrix_size / 2, matrix_size / 2);
672
673     // Configure kernel window
674     constexpr unsigned int processed_elements(8);
675     constexpr unsigned int read_elements(16);
676     constexpr unsigned int written_elements(8);
677
678     Window                 win = calculate_max_window_horizontal(*input->info(), Steps(processed_elements), border_undefined, border_size());
679     AccessWindowHorizontal output_access(output->info(), 0, written_elements);
680
681     update_window_and_padding(win,
682                               AccessWindowHorizontal(input->info(), -border_size().left, read_elements),
683                               output_access);
684
685     output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
686
687     INEKernel::configure(win);
688 }
689
690 template <unsigned int matrix_size>
691 void NESeparableConvolutionHorKernel<matrix_size>::run(const Window &window)
692 {
693     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
694     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
695     switch(_output->info()->data_type())
696     {
697         case DataType::U16:
698             convolve<uint16_t>(window);
699             break;
700         case DataType::S16:
701             convolve<int16_t>(window);
702             break;
703         case DataType::S32:
704             convolve<int32_t>(window);
705             break;
706         default:
707             ARM_COMPUTE_ERROR("Unsupported intermediate data type!");
708             break;
709     }
710 }
711
712 #ifndef DOXYGEN_SKIP_THIS /* Doxygen gets confused by the templates and can't match the implementation to the declaration */
713 namespace arm_compute
714 {
715 template <>
716 template <>
717 inline void NESeparableConvolutionHorKernel<5>::convolve<uint16_t>(const Window &window)
718 {
719     Window win_in(window);
720     win_in.shift(Window::DimX, -2);
721
722     Iterator input(_input, win_in);
723     Iterator output(_output, window);
724
725     execute_window_loop(window, [&](const Coordinates & id)
726     {
727         const uint8x16_t data = vld1q_u8(input.ptr());
728
729         const uint16x8x2_t data_u16 =
730         {
731             {
732                 vmovl_u8(vget_low_u8(data)),
733                 vmovl_u8(vget_high_u8(data))
734             }
735         };
736
737         uint16x8_t out = vmulq_n_u16(data_u16.val[0], _conv_row[0]);
738         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 1), _conv_row[1]);
739         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 2), _conv_row[2]);
740         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 3), _conv_row[3]);
741         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 4), _conv_row[4]);
742
743         vst1q_u16(reinterpret_cast<uint16_t *>(output.ptr()), out);
744     },
745     input, output);
746 }
747
748 template <>
749 template <>
750 inline void NESeparableConvolutionHorKernel<5>::convolve<int16_t>(const Window &window)
751 {
752     Window win_in(window);
753     win_in.shift(Window::DimX, -2);
754
755     Iterator input(_input, win_in);
756     Iterator output(_output, window);
757
758     execute_window_loop(window, [&](const Coordinates & id)
759     {
760         const uint8x16_t data = vld1q_u8(input.ptr());
761
762         const int16x8x2_t data_s16 =
763         {
764             {
765                 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(data))),
766                 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(data)))
767             }
768         };
769
770         int16x8_t out = vmulq_n_s16(data_s16.val[0], _conv_row[0]);
771         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 1), _conv_row[1]);
772         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 2), _conv_row[2]);
773         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 3), _conv_row[3]);
774         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 4), _conv_row[4]);
775
776         vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), out);
777     },
778     input, output);
779 }
780
781 template <>
782 template <>
783 void NESeparableConvolutionHorKernel<5>::convolve<int32_t>(const Window &window)
784 {
785     Window win_in(window);
786     win_in.shift(Window::DimX, -2);
787
788     Iterator input(_input, win_in);
789     Iterator output(_output, window);
790
791     execute_window_loop(window, [&](const Coordinates & id)
792     {
793         const uint8x16_t data = vld1q_u8(input.ptr());
794
795         const int16x8x2_t data_s16 =
796         {
797             {
798                 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(data))),
799                 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(data)))
800             }
801         };
802
803         const int16x8_t data_s16_l1 = vextq_s16(data_s16.val[0], data_s16.val[1], 1);
804         const int16x8_t data_s16_m  = vextq_s16(data_s16.val[0], data_s16.val[1], 2);
805         const int16x8_t data_s16_r1 = vextq_s16(data_s16.val[0], data_s16.val[1], 3);
806         const int16x8_t data_s16_r2 = vextq_s16(data_s16.val[0], data_s16.val[1], 4);
807
808         int32x4_t out_low = vmull_n_s16(vget_low_s16(data_s16.val[0]), _conv_row[0]);
809         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_l1), _conv_row[1]);
810         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_m), _conv_row[2]);
811         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_r1), _conv_row[3]);
812         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_r2), _conv_row[4]);
813
814         vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()), out_low);
815
816         int32x4_t out_high = vmull_n_s16(vget_high_s16(data_s16.val[0]), _conv_row[0]);
817         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_l1), _conv_row[1]);
818         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_m), _conv_row[2]);
819         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_r1), _conv_row[3]);
820         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_r2), _conv_row[4]);
821
822         vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()) + 4, out_high);
823     },
824     input, output);
825 }
826
827 template <>
828 template <>
829 inline void NESeparableConvolutionHorKernel<7>::convolve<uint16_t>(const Window &window)
830 {
831     Window win_in(window);
832     win_in.shift(Window::DimX, -3);
833
834     Iterator input(_input, win_in);
835     Iterator output(_output, window);
836
837     execute_window_loop(window, [&](const Coordinates & id)
838     {
839         const uint8x16_t data = vld1q_u8(input.ptr());
840
841         const uint16x8x2_t data_u16 =
842         {
843             {
844                 vmovl_u8(vget_low_u8(data)),
845                 vmovl_u8(vget_high_u8(data))
846             }
847         };
848
849         uint16x8_t out = vmulq_n_u16(data_u16.val[0], _conv_row[0]);
850         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 1), _conv_row[1]);
851         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 2), _conv_row[2]);
852         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 3), _conv_row[3]);
853         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 4), _conv_row[4]);
854         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 5), _conv_row[5]);
855         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 6), _conv_row[6]);
856
857         vst1q_u16(reinterpret_cast<uint16_t *>(output.ptr()), out);
858     },
859     input, output);
860 }
861
862 template <>
863 template <>
864 inline void NESeparableConvolutionHorKernel<7>::convolve<int16_t>(const Window &window)
865 {
866     Window win_in(window);
867     win_in.shift(Window::DimX, -3);
868
869     Iterator input(_input, win_in);
870     Iterator output(_output, window);
871
872     execute_window_loop(window, [&](const Coordinates & id)
873     {
874         const uint8x16_t data = vld1q_u8(input.ptr());
875
876         const int16x8x2_t data_s16 =
877         {
878             {
879                 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(data))),
880                 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(data)))
881             }
882         };
883
884         int16x8_t out = vmulq_n_s16(data_s16.val[0], _conv_row[0]);
885         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 1), _conv_row[1]);
886         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 2), _conv_row[2]);
887         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 3), _conv_row[3]);
888         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 4), _conv_row[4]);
889         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 5), _conv_row[5]);
890         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 6), _conv_row[6]);
891
892         vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), out);
893     },
894     input, output);
895 }
896
897 template <>
898 template <>
899 void NESeparableConvolutionHorKernel<7>::convolve<int32_t>(const Window &window)
900 {
901     Window win_in(window);
902     win_in.shift(Window::DimX, -3);
903
904     Iterator input(_input, win_in);
905     Iterator output(_output, window);
906
907     execute_window_loop(window, [&](const Coordinates & id)
908     {
909         const uint8x16_t data = vld1q_u8(input.ptr());
910
911         const int16x8x2_t data_s16 =
912         {
913             {
914                 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(data))),
915                 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(data)))
916             }
917         };
918
919         const int16x8_t data_s16_l2 = vextq_s16(data_s16.val[0], data_s16.val[1], 1);
920         const int16x8_t data_s16_l1 = vextq_s16(data_s16.val[0], data_s16.val[1], 2);
921         const int16x8_t data_s16_m  = vextq_s16(data_s16.val[0], data_s16.val[1], 3);
922         const int16x8_t data_s16_r1 = vextq_s16(data_s16.val[0], data_s16.val[1], 4);
923         const int16x8_t data_s16_r2 = vextq_s16(data_s16.val[0], data_s16.val[1], 5);
924         const int16x8_t data_s16_r3 = vextq_s16(data_s16.val[0], data_s16.val[1], 6);
925
926         int32x4_t out_low = vmull_n_s16(vget_low_s16(data_s16.val[0]), _conv_row[0]);
927         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_l2), _conv_row[1]);
928         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_l1), _conv_row[2]);
929         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_m), _conv_row[3]);
930         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_r1), _conv_row[4]);
931         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_r2), _conv_row[5]);
932         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_r3), _conv_row[6]);
933
934         vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()), out_low);
935
936         int32x4_t out_high = vmull_n_s16(vget_high_s16(data_s16.val[0]), _conv_row[0]);
937         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_l2), _conv_row[1]);
938         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_l1), _conv_row[2]);
939         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_m), _conv_row[3]);
940         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_r1), _conv_row[4]);
941         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_r2), _conv_row[5]);
942         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_r3), _conv_row[6]);
943
944         vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()) + 4, out_high);
945     },
946     input, output);
947 }
948
949 template <>
950 template <>
951 inline void NESeparableConvolutionHorKernel<9>::convolve<uint16_t>(const Window &window)
952 {
953     Window win_in(window);
954     win_in.shift(Window::DimX, -4);
955
956     Iterator input(_input, win_in);
957     Iterator output(_output, window);
958
959     execute_window_loop(window, [&](const Coordinates & id)
960     {
961         const uint8x16_t data = vld1q_u8(input.ptr());
962
963         const uint16x8x2_t data_u16 =
964         {
965             {
966                 vmovl_u8(vget_low_u8(data)),
967                 vmovl_u8(vget_high_u8(data))
968             }
969         };
970
971         uint16x8_t out = vmulq_n_u16(data_u16.val[0], _conv_row[0]);
972         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 1), _conv_row[1]);
973         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 2), _conv_row[2]);
974         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 3), _conv_row[3]);
975         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 4), _conv_row[4]);
976         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 5), _conv_row[5]);
977         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 6), _conv_row[6]);
978         out            = vmlaq_n_u16(out, vextq_u16(data_u16.val[0], data_u16.val[1], 7), _conv_row[7]);
979         out            = vmlaq_n_u16(out, data_u16.val[1], _conv_row[8]);
980
981         vst1q_u16(reinterpret_cast<uint16_t *>(output.ptr()), out);
982     },
983     input, output);
984 }
985
986 template <>
987 template <>
988 inline void NESeparableConvolutionHorKernel<9>::convolve<int16_t>(const Window &window)
989 {
990     Window win_in(window);
991     win_in.shift(Window::DimX, -4);
992
993     Iterator input(_input, win_in);
994     Iterator output(_output, window);
995
996     execute_window_loop(window, [&](const Coordinates & id)
997     {
998         const uint8x16_t data = vld1q_u8(input.ptr());
999
1000         const int16x8x2_t data_s16 =
1001         {
1002             {
1003                 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(data))),
1004                 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(data)))
1005             }
1006         };
1007
1008         int16x8_t out = vmulq_n_s16(data_s16.val[0], _conv_row[0]);
1009         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 1), _conv_row[1]);
1010         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 2), _conv_row[2]);
1011         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 3), _conv_row[3]);
1012         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 4), _conv_row[4]);
1013         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 5), _conv_row[5]);
1014         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 6), _conv_row[6]);
1015         out           = vmlaq_n_s16(out, vextq_s16(data_s16.val[0], data_s16.val[1], 7), _conv_row[7]);
1016         out           = vmlaq_n_s16(out, data_s16.val[1], _conv_row[8]);
1017
1018         vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), out);
1019     },
1020     input, output);
1021 }
1022
1023 template <>
1024 template <>
1025 void NESeparableConvolutionHorKernel<9>::convolve<int32_t>(const Window &window)
1026 {
1027     Window win_in(window);
1028     win_in.shift(Window::DimX, -4);
1029
1030     Iterator input(_input, win_in);
1031     Iterator output(_output, window);
1032
1033     execute_window_loop(window, [&](const Coordinates & id)
1034     {
1035         const uint8x16_t data = vld1q_u8(input.ptr());
1036
1037         const int16x8x2_t data_s16 =
1038         {
1039             {
1040                 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(data))),
1041                 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(data)))
1042             }
1043         };
1044
1045         const int16x8_t data_s16_l3 = vextq_s16(data_s16.val[0], data_s16.val[1], 1);
1046         const int16x8_t data_s16_l2 = vextq_s16(data_s16.val[0], data_s16.val[1], 2);
1047         const int16x8_t data_s16_l1 = vextq_s16(data_s16.val[0], data_s16.val[1], 3);
1048         const int16x8_t data_s16_m  = vextq_s16(data_s16.val[0], data_s16.val[1], 4);
1049         const int16x8_t data_s16_r1 = vextq_s16(data_s16.val[0], data_s16.val[1], 5);
1050         const int16x8_t data_s16_r2 = vextq_s16(data_s16.val[0], data_s16.val[1], 6);
1051         const int16x8_t data_s16_r3 = vextq_s16(data_s16.val[0], data_s16.val[1], 7);
1052
1053         int32x4_t out_low = vmull_n_s16(vget_low_s16(data_s16.val[0]), _conv_row[0]);
1054         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_l3), _conv_row[1]);
1055         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_l2), _conv_row[2]);
1056         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_l1), _conv_row[3]);
1057         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_m), _conv_row[4]);
1058         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_r1), _conv_row[5]);
1059         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_r2), _conv_row[6]);
1060         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16_r3), _conv_row[7]);
1061         out_low           = vmlal_n_s16(out_low, vget_low_s16(data_s16.val[1]), _conv_row[8]);
1062
1063         vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()), out_low);
1064
1065         int32x4_t out_high = vmull_n_s16(vget_high_s16(data_s16.val[0]), _conv_row[0]);
1066         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_l3), _conv_row[1]);
1067         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_l2), _conv_row[2]);
1068         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_l1), _conv_row[3]);
1069         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_m), _conv_row[4]);
1070         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_r1), _conv_row[5]);
1071         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_r2), _conv_row[6]);
1072         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16_r3), _conv_row[7]);
1073         out_high           = vmlal_n_s16(out_high, vget_high_s16(data_s16.val[1]), _conv_row[8]);
1074
1075         vst1q_s32(reinterpret_cast<int32_t *>(output.ptr()) + 4, out_high);
1076     },
1077     input, output);
1078 }
1079 } // namespace arm_compute
1080 #endif
1081
1082 template class arm_compute::NESeparableConvolutionHorKernel<5>;
1083 template class arm_compute::NESeparableConvolutionHorKernel<7>;
1084 template class arm_compute::NESeparableConvolutionHorKernel<9>;
1085
1086 template <unsigned int matrix_size>
1087 NESeparableConvolutionVertKernel<matrix_size>::NESeparableConvolutionVertKernel()
1088     : _conv_col{ { 0 } }, _scale(0)
1089 {
1090 }
1091
1092 template <unsigned int matrix_size>
1093 BorderSize             NESeparableConvolutionVertKernel<matrix_size>::border_size() const
1094 {
1095     return BorderSize(matrix_size / 2, 0);
1096 }
1097
1098 template <unsigned int matrix_size>
1099 void NESeparableConvolutionVertKernel<matrix_size>::configure(const ITensor *input, ITensor *output, const int16_t *conv_col, uint32_t scale, bool border_undefined)
1100 {
1101     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U16, DataType::S16, DataType::S32);
1102     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16);
1103     ARM_COMPUTE_ERROR_ON(conv_col == nullptr);
1104     ARM_COMPUTE_ERROR_ON(scale == 0);
1105
1106     _input  = input;
1107     _output = output;
1108     std::copy_n(conv_col, _conv_col.size(), _conv_col.begin());
1109     _scale = scale;
1110
1111     // Configure kernel window
1112     constexpr unsigned int processed_elements(16);
1113     constexpr unsigned int read_elements(16);
1114     constexpr unsigned int written_elements(16);
1115
1116     Window                 win = calculate_max_window(*input->info(), Steps(processed_elements), border_undefined, border_size());
1117     AccessWindowHorizontal output_access(output->info(), 0, written_elements);
1118
1119     update_window_and_padding(win,
1120                               AccessWindowRectangle(input->info(), 0, -border_size().top, read_elements, matrix_size),
1121                               output_access);
1122
1123     output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
1124
1125     INEKernel::configure(win);
1126 }
1127
1128 template <unsigned int matrix_size>
1129 void NESeparableConvolutionVertKernel<matrix_size>::run(const Window &window)
1130 {
1131     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1132     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1133
1134     switch(_input->info()->data_type())
1135     {
1136         case DataType::U16:
1137             switch(_output->info()->data_type())
1138             {
1139                 case DataType::U8:
1140                     convolution_u16<uint8_t>(window);
1141                     break;
1142                 case DataType::S16:
1143                     convolution_u16<int16_t>(window);
1144                     break;
1145                 default:
1146                     ARM_COMPUTE_ERROR("Not supported");
1147             }
1148             break;
1149         case DataType::S16:
1150             switch(_output->info()->data_type())
1151             {
1152                 case DataType::U8:
1153                     convolution_s16<uint8_t>(window);
1154                     break;
1155                 case DataType::S16:
1156                     convolution_s16<int16_t>(window);
1157                     break;
1158                 default:
1159                     ARM_COMPUTE_ERROR("Not supported");
1160             }
1161             break;
1162         case DataType::S32:
1163             switch(_output->info()->data_type())
1164             {
1165                 case DataType::U8:
1166                     convolution_s32<uint8_t>(window);
1167                     break;
1168                 case DataType::S16:
1169                     convolution_s32<int16_t>(window);
1170                     break;
1171                 default:
1172                     ARM_COMPUTE_ERROR("Not supported");
1173             }
1174             break;
1175         default:
1176             ARM_COMPUTE_ERROR("Unsupported intermediate data type!");
1177             break;
1178     }
1179 }
1180
1181 template <unsigned int matrix_size>
1182 template <typename OutputType>
1183 void NESeparableConvolutionVertKernel<matrix_size>::convolution_u16(const Window &win)
1184 {
1185     static_assert(sizeof(OutputType) == sizeof(uint8_t) || sizeof(OutputType) == sizeof(int16_t), "The output buffer can only be u8 or s16");
1186
1187     Window win_in(win);
1188     win_in.set_dimension_step(Window::DimX, 8);
1189
1190     Iterator in(_input, win_in);
1191     Iterator out(_output, win);
1192
1193     std::array<unsigned char *, matrix_size> input_ptrs{ {} };
1194     const float32x4_t oneoverscale = vdupq_n_f32(1.0f / _scale);
1195     const int         k_half       = matrix_size / 2;
1196
1197     // Set row pointers
1198     for(int i = -k_half; i <= k_half; ++i)
1199     {
1200         input_ptrs[k_half + i] = _input->ptr_to_element(Coordinates(0, i));
1201     }
1202
1203     execute_window_loop(win, [&](const Coordinates & id)
1204     {
1205         uint16x8_t out0 = vdupq_n_u16(0);
1206         uint16x8_t out1 = vdupq_n_u16(0);
1207
1208         // First half
1209         for(unsigned int r = 0; r < matrix_size; ++r)
1210         {
1211             const uint16x8_t data = vld1q_u16(reinterpret_cast<const uint16_t *>(input_ptrs[r] + in.offset()));
1212             out0                  = vmlaq_n_u16(out0, data, _conv_col[r]);
1213         }
1214
1215         in.increment(Window::DimX);
1216
1217         // Second half
1218         for(unsigned int r = 0; r < matrix_size; ++r)
1219         {
1220             const uint16x8_t data = vld1q_u16(reinterpret_cast<const uint16_t *>(input_ptrs[r] + in.offset()));
1221             out1                  = vmlaq_n_u16(out1, data, _conv_col[r]);
1222         }
1223
1224         //scale the result if needed
1225         if(_scale != 1)
1226         {
1227             float32x4_t out0_f32_high = vcvtq_f32_u32(vmovl_u16(vget_high_u16(out0)));
1228             float32x4_t out0_f32_low  = vcvtq_f32_u32(vmovl_u16(vget_low_u16(out0)));
1229             out0_f32_high             = vmulq_f32(out0_f32_high, oneoverscale);
1230             out0_f32_low              = vmulq_f32(out0_f32_low, oneoverscale);
1231             store_results(vcvtq_u32_f32(out0_f32_low), vcvtq_u32_f32(out0_f32_high), reinterpret_cast<OutputType *>(out.ptr()));
1232
1233             float32x4_t out1_f32_high = vcvtq_f32_u32(vmovl_u16(vget_high_u16(out1)));
1234             float32x4_t out1_f32_low  = vcvtq_f32_u32(vmovl_u16(vget_low_u16(out1)));
1235             out1_f32_high             = vmulq_f32(out1_f32_high, oneoverscale);
1236             out1_f32_low              = vmulq_f32(out1_f32_low, oneoverscale);
1237             store_results(vcvtq_u32_f32(out1_f32_low), vcvtq_u32_f32(out1_f32_high), reinterpret_cast<OutputType *>(out.ptr()) + 8);
1238         }
1239         else
1240         {
1241             store_results(out0, out1, reinterpret_cast<OutputType *>(out.ptr()));
1242         }
1243     },
1244     in, out);
1245 }
1246
1247 template <unsigned int matrix_size>
1248 template <typename OutputType>
1249 void NESeparableConvolutionVertKernel<matrix_size>::convolution_s16(const Window &win)
1250 {
1251     static_assert(sizeof(OutputType) == sizeof(uint8_t) || sizeof(OutputType) == sizeof(int16_t), "The output buffer can only be u8 or s16");
1252
1253     Window win_in(win);
1254     win_in.set_dimension_step(Window::DimX, 8);
1255
1256     Iterator in(_input, win_in);
1257     Iterator out(_output, win);
1258
1259     std::array<unsigned char *, matrix_size> input_ptrs{ {} };
1260     const float32x4_t oneoverscale = vdupq_n_f32(1.0f / _scale);
1261     const int         k_half       = matrix_size / 2;
1262
1263     // Set row pointers
1264     for(int i = -k_half; i <= k_half; ++i)
1265     {
1266         input_ptrs[k_half + i] = _input->ptr_to_element(Coordinates(0, i));
1267     }
1268
1269     execute_window_loop(win, [&](const Coordinates & id)
1270     {
1271         int16x8_t out0 = vdupq_n_s16(0);
1272         int16x8_t out1 = vdupq_n_s16(0);
1273
1274         // First half
1275         for(unsigned int r = 0; r < matrix_size; ++r)
1276         {
1277             const int16x8_t data = vld1q_s16(reinterpret_cast<const int16_t *>(input_ptrs[r] + in.offset()));
1278             out0                 = vmlaq_n_s16(out0, data, _conv_col[r]);
1279         }
1280
1281         in.increment(Window::DimX);
1282
1283         // Second half
1284         for(unsigned int r = 0; r < matrix_size; ++r)
1285         {
1286             const int16x8_t data = vld1q_s16(reinterpret_cast<const int16_t *>(input_ptrs[r] + in.offset()));
1287             out1                 = vmlaq_n_s16(out1, data, _conv_col[r]);
1288         }
1289
1290         //scale the result if needed
1291         if(_scale != 1)
1292         {
1293             float32x4_t out0_f32_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(out0)));
1294             float32x4_t out0_f32_low  = vcvtq_f32_s32(vmovl_s16(vget_low_s16(out0)));
1295             out0_f32_high             = vmulq_f32(out0_f32_high, oneoverscale);
1296             out0_f32_low              = vmulq_f32(out0_f32_low, oneoverscale);
1297             store_results(vcvtq_s32_f32(out0_f32_low), vcvtq_s32_f32(out0_f32_high), reinterpret_cast<OutputType *>(out.ptr()));
1298
1299             float32x4_t out1_f32_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(out1)));
1300             float32x4_t out1_f32_low  = vcvtq_f32_s32(vmovl_s16(vget_low_s16(out1)));
1301             out1_f32_high             = vmulq_f32(out1_f32_high, oneoverscale);
1302             out1_f32_low              = vmulq_f32(out1_f32_low, oneoverscale);
1303             store_results(vcvtq_s32_f32(out1_f32_low), vcvtq_s32_f32(out1_f32_high), reinterpret_cast<OutputType *>(out.ptr()) + 8);
1304         }
1305         else
1306         {
1307             store_results(out0, out1, reinterpret_cast<OutputType *>(out.ptr()));
1308         }
1309     },
1310     in, out);
1311 }
1312
1313 template <unsigned int matrix_size>
1314 template <typename OutputType>
1315 void NESeparableConvolutionVertKernel<matrix_size>::convolution_s32(const Window &win)
1316 {
1317     static_assert(sizeof(OutputType) == sizeof(uint8_t) || sizeof(OutputType) == sizeof(int16_t), "The output buffer can only be u8 or s16");
1318
1319     Window win_in(win);
1320     win_in.set_dimension_step(Window::DimX, 8);
1321
1322     Iterator in(_input, win_in);
1323     Iterator out(_output, win);
1324
1325     std::array<unsigned char *, matrix_size> input_ptrs{ {} };
1326     const float32x4_t oneoverscale = vdupq_n_f32(1.0f / _scale);
1327     const int         k_half       = matrix_size / 2;
1328
1329     // Set row pointers
1330     for(int i = -k_half; i <= k_half; ++i)
1331     {
1332         input_ptrs[k_half + i] = _input->ptr_to_element(Coordinates(0, i));
1333     }
1334
1335     const int32x4_t zero = vdupq_n_s32(0);
1336
1337     execute_window_loop(win, [&](const Coordinates & id)
1338     {
1339         int32x4x2_t out0 =
1340         {
1341             {
1342                 zero,
1343                 zero
1344             }
1345         };
1346
1347         int32x4x2_t out1 =
1348         {
1349             {
1350                 zero,
1351                 zero
1352             }
1353         };
1354
1355         // First half
1356         for(unsigned int r = 0; r < matrix_size; ++r)
1357         {
1358             const int32x4x2_t data = vld2q_s32(reinterpret_cast<const int32_t *>(input_ptrs[r] + in.offset()));
1359             out0.val[0]            = vmlaq_n_s32(out0.val[0], data.val[0], _conv_col[r]);
1360             out0.val[1]            = vmlaq_n_s32(out0.val[1], data.val[1], _conv_col[r]);
1361         }
1362
1363         in.increment(Window::DimX);
1364
1365         // Second half
1366         for(unsigned int r = 0; r < matrix_size; ++r)
1367         {
1368             const int32x4x2_t data = vld2q_s32(reinterpret_cast<const int32_t *>(input_ptrs[r] + in.offset()));
1369             out1.val[0]            = vmlaq_n_s32(out1.val[0], data.val[0], _conv_col[r]);
1370             out1.val[1]            = vmlaq_n_s32(out1.val[1], data.val[1], _conv_col[r]);
1371         }
1372
1373         //scale the result if needed
1374         if(_scale != 1)
1375         {
1376             float32x4_t out0_f32_odd  = vcvtq_f32_s32(out0.val[0]);
1377             float32x4_t out0_f32_even = vcvtq_f32_s32(out0.val[1]);
1378             out0_f32_odd              = vmulq_f32(out0_f32_odd, oneoverscale);
1379             out0_f32_even             = vmulq_f32(out0_f32_even, oneoverscale);
1380             out0.val[0]               = vcvtq_s32_f32(out0_f32_odd);
1381             out0.val[1]               = vcvtq_s32_f32(out0_f32_even);
1382
1383             float32x4_t out1_f32_odd  = vcvtq_f32_s32(out1.val[0]);
1384             float32x4_t out1_f32_even = vcvtq_f32_s32(out1.val[1]);
1385             out1_f32_odd              = vmulq_f32(out1_f32_odd, oneoverscale);
1386             out1_f32_even             = vmulq_f32(out1_f32_even, oneoverscale);
1387             out1.val[0]               = vcvtq_s32_f32(out1_f32_odd);
1388             out1.val[1]               = vcvtq_s32_f32(out1_f32_even);
1389         }
1390
1391         const int32x4x2_t out0_s32 = vzipq_s32(out0.val[0], out0.val[1]);
1392         store_results(out0_s32.val[0], out0_s32.val[1], reinterpret_cast<OutputType *>(out.ptr()));
1393
1394         const int32x4x2_t out1_s32 = vzipq_s32(out1.val[0], out1.val[1]);
1395         store_results(out1_s32.val[0], out1_s32.val[1], reinterpret_cast<OutputType *>(out.ptr()) + 8);
1396     },
1397     in, out);
1398 }
1399
1400 template class arm_compute::NESeparableConvolutionVertKernel<5>;
1401 template class arm_compute::NESeparableConvolutionVertKernel<7>;
1402 template class arm_compute::NESeparableConvolutionVertKernel<9>;
1403
1404 /****************************************************************************************\
1405  *                                 Rectangle Convolution                                *
1406 \****************************************************************************************/
1407
1408 NEConvolutionRectangleKernel::NEConvolutionRectangleKernel()
1409     : _input(nullptr), _output(nullptr), _scale(0), _convolution(), _border_size(), _func_idx(0)
1410 {
1411 }
1412
1413 BorderSize NEConvolutionRectangleKernel::border_size() const
1414 {
1415     return _border_size;
1416 }
1417
1418 void NEConvolutionRectangleKernel::configure(const ITensor *input, ITensor *output, const int16_t *conv, uint32_t width, uint32_t height, uint32_t scale, bool border_undefined)
1419 {
1420     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
1421     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16);
1422     ARM_COMPUTE_ERROR_ON(nullptr == conv);
1423     ARM_COMPUTE_ERROR_ON(3 != width && 5 != width && 7 != width && 9 != width);
1424     ARM_COMPUTE_ERROR_ON(3 != height && 5 != height && 7 != height && 9 != height);
1425     ARM_COMPUTE_ERROR_ON(0 == scale);
1426
1427     _input       = input;
1428     _output      = output;
1429     _scale       = scale;
1430     _border_size = BorderSize(height / 2, width / 2);
1431
1432     // Setup the convolution matrix
1433     const uint32_t nr_elements = width * height;
1434     _convolution.resize(nr_elements);
1435     std::copy_n(conv, nr_elements, _convolution.begin());
1436
1437     // Set function index to help choose appropriate function in run()
1438     _func_idx = get_index(height) * 4 + get_index(width);
1439     ARM_COMPUTE_ERROR_ON(_func_idx > (_nr_supported_sizes * _nr_supported_sizes));
1440
1441     // Configure kernel window
1442     constexpr unsigned int processed_elements(8);
1443     constexpr unsigned int read_elements(16);
1444     constexpr unsigned int written_elements(8);
1445
1446     Window                 win           = calculate_max_window(*input->info(), Steps(processed_elements), border_undefined, _border_size);
1447     AccessWindowHorizontal output_access = AccessWindowHorizontal(output->info(), 0, written_elements);
1448
1449     update_window_and_padding(win,
1450                               AccessWindowRectangle(input->info(), -_border_size.left, -_border_size.top, read_elements, height),
1451                               output_access);
1452
1453     output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, _border_size);
1454
1455     INEKernel::configure(win);
1456 }
1457
1458 void NEConvolutionRectangleKernel::run(const Window &window)
1459 {
1460     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1461     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1462
1463     using ConvolutionRectangleFunction = void (NEConvolutionRectangleKernel::*)(const Window & window);
1464
1465     // uint8_t function table
1466     static const std::array<ConvolutionRectangleFunction, 16> func_table_u8 =
1467     {
1468         {
1469             &NEConvolutionRectangleKernel::convolution<uint8_t, 3, 3>,
1470             &NEConvolutionRectangleKernel::convolution<uint8_t, 3, 5>,
1471             &NEConvolutionRectangleKernel::convolution<uint8_t, 3, 7>,
1472             &NEConvolutionRectangleKernel::convolution<uint8_t, 3, 9>,
1473             &NEConvolutionRectangleKernel::convolution<uint8_t, 5, 3>,
1474             &NEConvolutionRectangleKernel::convolution<uint8_t, 5, 5>,
1475             &NEConvolutionRectangleKernel::convolution<uint8_t, 5, 7>,
1476             &NEConvolutionRectangleKernel::convolution<uint8_t, 5, 9>,
1477             &NEConvolutionRectangleKernel::convolution<uint8_t, 7, 3>,
1478             &NEConvolutionRectangleKernel::convolution<uint8_t, 7, 5>,
1479             &NEConvolutionRectangleKernel::convolution<uint8_t, 7, 7>,
1480             &NEConvolutionRectangleKernel::convolution<uint8_t, 7, 9>,
1481             &NEConvolutionRectangleKernel::convolution<uint8_t, 9, 3>,
1482             &NEConvolutionRectangleKernel::convolution<uint8_t, 9, 5>,
1483             &NEConvolutionRectangleKernel::convolution<uint8_t, 9, 7>,
1484             &NEConvolutionRectangleKernel::convolution<uint8_t, 9, 9>
1485         }
1486     };
1487     // int16_t function table
1488     static const std::array<ConvolutionRectangleFunction, 16> func_table_s16 =
1489     {
1490         {
1491             &NEConvolutionRectangleKernel::convolution<int16_t, 3, 3>,
1492             &NEConvolutionRectangleKernel::convolution<int16_t, 3, 5>,
1493             &NEConvolutionRectangleKernel::convolution<int16_t, 3, 7>,
1494             &NEConvolutionRectangleKernel::convolution<int16_t, 3, 9>,
1495             &NEConvolutionRectangleKernel::convolution<int16_t, 5, 3>,
1496             &NEConvolutionRectangleKernel::convolution<int16_t, 5, 5>,
1497             &NEConvolutionRectangleKernel::convolution<int16_t, 5, 7>,
1498             &NEConvolutionRectangleKernel::convolution<int16_t, 5, 9>,
1499             &NEConvolutionRectangleKernel::convolution<int16_t, 7, 3>,
1500             &NEConvolutionRectangleKernel::convolution<int16_t, 7, 5>,
1501             &NEConvolutionRectangleKernel::convolution<int16_t, 7, 7>,
1502             &NEConvolutionRectangleKernel::convolution<int16_t, 7, 9>,
1503             &NEConvolutionRectangleKernel::convolution<int16_t, 9, 3>,
1504             &NEConvolutionRectangleKernel::convolution<int16_t, 9, 5>,
1505             &NEConvolutionRectangleKernel::convolution<int16_t, 9, 7>,
1506             &NEConvolutionRectangleKernel::convolution<int16_t, 9, 9>
1507         }
1508     };
1509
1510     // Run appropriate function
1511     switch(_output->info()->format())
1512     {
1513         case Format::U8:
1514             ARM_COMPUTE_ERROR_ON(_func_idx >= func_table_u8.size());
1515             (this->*func_table_u8[_func_idx])(window);
1516             break;
1517         case Format::S16:
1518             ARM_COMPUTE_ERROR_ON(_func_idx >= func_table_s16.size());
1519             (this->*func_table_s16[_func_idx])(window);
1520             break;
1521         default:
1522             ARM_COMPUTE_ERROR("Not supported");
1523     }
1524 }
1525
1526 unsigned int NEConvolutionRectangleKernel::get_index(uint32_t val)
1527 {
1528     switch(val)
1529     {
1530         case 3:
1531             return 0;
1532         case 5:
1533             return 1;
1534         case 7:
1535             return 2;
1536         case 9:
1537             return 3;
1538         default:
1539             ARM_COMPUTE_ERROR("Not supported dimension size");
1540             return 0;
1541     }
1542 }
1543
1544 template <typename OutputType, unsigned int rows, unsigned int cols>
1545 void NEConvolutionRectangleKernel::convolution(const Window &win)
1546 {
1547     static_assert(sizeof(OutputType) == sizeof(uint8_t) || sizeof(OutputType) == sizeof(int16_t), "The output buffer can only be u8 or s16");
1548     ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
1549
1550     Iterator input(_input, win);
1551     Iterator output(_output, win);
1552
1553     std::array<unsigned char *, rows> input_ptrs{ {} };
1554     const int16_t    *conv       = _convolution.data();
1555     const float32x4_t scale_val  = vdupq_n_f32(1.0f / _scale);
1556     const int         k_row_half = rows / 2;
1557     const int         k_col_half = cols / 2;
1558
1559     // Set row pointers
1560     for(int i = -k_row_half; i <= k_row_half; ++i)
1561     {
1562         input_ptrs[k_row_half + i] = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-k_col_half, i));
1563     }
1564
1565     execute_window_loop(win, [&](const Coordinates & id)
1566     {
1567         int32x4_t out  = vdupq_n_s32(0);
1568         int32x4_t out2 = vdupq_n_s32(0);
1569
1570         // Perform appropriate convolution
1571         for(unsigned int r = 0; r < rows; ++r)
1572         {
1573             const uint8x16_t data = vld1q_u8(input_ptrs[r] + input.offset());
1574             if(3 == cols)
1575             {
1576                 convolve_row3x1(out, out2, data, conv + r * cols);
1577             }
1578             else if(5 == cols)
1579             {
1580                 convolve_row5x1(out, out2, data, conv + r * cols);
1581             }
1582             else if(7 == cols)
1583             {
1584                 convolve_row7x1(out, out2, data, conv + r * cols);
1585             }
1586             else if(9 == cols)
1587             {
1588                 convolve_row9x1(out, out2, data, conv + r * cols);
1589             }
1590             else
1591             {
1592                 ARM_COMPUTE_ERROR("Unsupported number of columns");
1593             }
1594         }
1595
1596         // Apply scale
1597         if(_scale != 1)
1598         {
1599             // Convert to F32, scale and convert back to S32
1600             out  = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(out), scale_val));
1601             out2 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(out2), scale_val));
1602         }
1603
1604         // Clamp and store as U8 or S16:
1605         store_results(out, out2, reinterpret_cast<OutputType *>(output.ptr()));
1606     },
1607     input, output);
1608 }