arm_compute v17.04
[platform/upstream/armcl.git] / src / core / NEON / kernels / NENonLinearFilterKernel.cpp
1 /*
2  * Copyright (c) 2016, 2017 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/NEON/kernels/NENonLinearFilterKernel.h"
25
26 #include "arm_compute/core/Coordinates.h"
27 #include "arm_compute/core/Error.h"
28 #include "arm_compute/core/Helpers.h"
29 #include "arm_compute/core/ITensor.h"
30 #include "arm_compute/core/TensorInfo.h"
31 #include "arm_compute/core/Validate.h"
32
33 #include <algorithm>
34 #include <arm_neon.h>
35 #include <array>
36 #include <tuple>
37 #include <utility>
38
39 using namespace arm_compute;
40
41 namespace
42 {
43 const uint8x16_t zero_u8 = vdupq_n_u8(0);
44
45 template <size_t columns>
46 inline uint8x8_t min_row(uint8x16_t row_data)
47 {
48     uint8x8_t min = vget_low_u8(row_data);
49
50     for(size_t c = 1; c < columns; ++c)
51     {
52         row_data = vextq_u8(row_data, zero_u8, 1);
53         min      = vmin_u8(min, vget_low_u8(row_data));
54     }
55
56     return min;
57 }
58
59 template <size_t columns>
60 inline uint8x8_t max_row(uint8x16_t row_data)
61 {
62     uint8x8_t max = vget_low_u8(row_data);
63
64     for(size_t c = 1; c < columns; ++c)
65     {
66         row_data = vextq_u8(row_data, zero_u8, 1);
67         max      = vmax_u8(max, vget_low_u8(row_data));
68     }
69
70     return max;
71 }
72
73 inline void sort(uint8x8_t &a, uint8x8_t &b)
74 {
75     const uint8x8_t min = vmin_u8(a, b);
76     const uint8x8_t max = vmax_u8(a, b);
77     a                   = min;
78     b                   = max;
79 }
80
81 // Sorting networks below were generated using http://pages.ripco.net/~jgamble/nw.html
82 // Calculations that do not affect the median were removed.
83 inline void sort5(uint8x8_t &p0, uint8x8_t &p1, uint8x8_t &p2, uint8x8_t &p3, uint8x8_t &p4)
84 {
85     sort(p0, p1);
86     sort(p2, p3);
87     sort(p0, p2);
88     sort(p1, p3);
89     sort(p1, p2);
90     sort(p0, p4);
91     sort(p1, p4);
92     sort(p2, p4);
93 }
94
95 inline void sort9(uint8x8_t &p0, uint8x8_t &p1, uint8x8_t &p2,
96                   uint8x8_t &p3, uint8x8_t &p4, uint8x8_t &p5,
97                   uint8x8_t &p6, uint8x8_t &p7, uint8x8_t &p8)
98 {
99     sort(p1, p2);
100     sort(p4, p5);
101     sort(p7, p8);
102     sort(p0, p1);
103     sort(p3, p4);
104     sort(p6, p7);
105     sort(p1, p2);
106     sort(p4, p5);
107     sort(p7, p8);
108     sort(p0, p3);
109     sort(p5, p8);
110     sort(p4, p7);
111     sort(p3, p6);
112     sort(p1, p4);
113     sort(p2, p5);
114     sort(p4, p7);
115     sort(p4, p2);
116     sort(p6, p4);
117     sort(p4, p2);
118 }
119
120 inline void sort21(uint8x8_t p[21])
121 {
122     sort(p[0], p[1]);
123     sort(p[2], p[3]);
124     sort(p[4], p[5]);
125     sort(p[6], p[7]);
126     sort(p[8], p[9]);
127     sort(p[10], p[11]);
128     sort(p[12], p[13]);
129     sort(p[14], p[15]);
130     sort(p[16], p[17]);
131     sort(p[18], p[19]);
132     sort(p[0], p[2]);
133     sort(p[1], p[3]);
134     sort(p[4], p[6]);
135     sort(p[5], p[7]);
136     sort(p[8], p[10]);
137     sort(p[9], p[11]);
138     sort(p[12], p[14]);
139     sort(p[13], p[15]);
140     sort(p[16], p[18]);
141     sort(p[17], p[19]);
142     sort(p[1], p[2]);
143     sort(p[5], p[6]);
144     sort(p[0], p[4]);
145     sort(p[3], p[7]);
146     sort(p[9], p[10]);
147     sort(p[13], p[14]);
148     sort(p[8], p[12]);
149     sort(p[11], p[15]);
150     sort(p[17], p[18]);
151     sort(p[16], p[20]);
152     sort(p[1], p[5]);
153     sort(p[2], p[6]);
154     sort(p[9], p[13]);
155     sort(p[10], p[14]);
156     sort(p[0], p[8]);
157     sort(p[7], p[15]);
158     sort(p[17], p[20]);
159     sort(p[1], p[4]);
160     sort(p[3], p[6]);
161     sort(p[9], p[12]);
162     sort(p[11], p[14]);
163     sort(p[18], p[20]);
164     sort(p[0], p[16]);
165     sort(p[2], p[4]);
166     sort(p[3], p[5]);
167     sort(p[10], p[12]);
168     sort(p[11], p[13]);
169     sort(p[1], p[9]);
170     sort(p[6], p[14]);
171     sort(p[19], p[20]);
172     sort(p[3], p[4]);
173     sort(p[11], p[12]);
174     sort(p[1], p[8]);
175     sort(p[2], p[10]);
176     sort(p[5], p[13]);
177     sort(p[7], p[14]);
178     sort(p[3], p[11]);
179     sort(p[2], p[8]);
180     sort(p[4], p[12]);
181     sort(p[7], p[13]);
182     sort(p[1], p[17]);
183     sort(p[3], p[10]);
184     sort(p[5], p[12]);
185     sort(p[1], p[16]);
186     sort(p[2], p[18]);
187     sort(p[3], p[9]);
188     sort(p[6], p[12]);
189     sort(p[2], p[16]);
190     sort(p[3], p[8]);
191     sort(p[7], p[12]);
192     sort(p[5], p[9]);
193     sort(p[6], p[10]);
194     sort(p[4], p[8]);
195     sort(p[7], p[11]);
196     sort(p[3], p[19]);
197     sort(p[5], p[8]);
198     sort(p[7], p[10]);
199     sort(p[3], p[18]);
200     sort(p[4], p[20]);
201     sort(p[6], p[8]);
202     sort(p[7], p[9]);
203     sort(p[3], p[17]);
204     sort(p[5], p[20]);
205     sort(p[7], p[8]);
206     sort(p[3], p[16]);
207     sort(p[6], p[20]);
208     sort(p[5], p[17]);
209     sort(p[7], p[20]);
210     sort(p[4], p[16]);
211     sort(p[6], p[18]);
212     sort(p[5], p[16]);
213     sort(p[7], p[19]);
214     sort(p[7], p[18]);
215     sort(p[6], p[16]);
216     sort(p[7], p[17]);
217     sort(p[10], p[18]);
218     sort(p[7], p[16]);
219     sort(p[9], p[17]);
220     sort(p[8], p[16]);
221     sort(p[9], p[16]);
222     sort(p[10], p[16]);
223 }
224
225 inline void sort25(uint8x8_t p[25])
226 {
227     sort(p[1], p[2]);
228     sort(p[0], p[1]);
229     sort(p[1], p[2]);
230     sort(p[4], p[5]);
231     sort(p[3], p[4]);
232     sort(p[4], p[5]);
233     sort(p[0], p[3]);
234     sort(p[2], p[5]);
235     sort(p[2], p[3]);
236     sort(p[1], p[4]);
237     sort(p[1], p[2]);
238     sort(p[3], p[4]);
239     sort(p[7], p[8]);
240     sort(p[6], p[7]);
241     sort(p[7], p[8]);
242     sort(p[10], p[11]);
243     sort(p[9], p[10]);
244     sort(p[10], p[11]);
245     sort(p[6], p[9]);
246     sort(p[8], p[11]);
247     sort(p[8], p[9]);
248     sort(p[7], p[10]);
249     sort(p[7], p[8]);
250     sort(p[9], p[10]);
251     sort(p[0], p[6]);
252     sort(p[4], p[10]);
253     sort(p[4], p[6]);
254     sort(p[2], p[8]);
255     sort(p[2], p[4]);
256     sort(p[6], p[8]);
257     sort(p[1], p[7]);
258     sort(p[5], p[11]);
259     sort(p[5], p[7]);
260     sort(p[3], p[9]);
261     sort(p[3], p[5]);
262     sort(p[7], p[9]);
263     sort(p[1], p[2]);
264     sort(p[3], p[4]);
265     sort(p[5], p[6]);
266     sort(p[7], p[8]);
267     sort(p[9], p[10]);
268     sort(p[13], p[14]);
269     sort(p[12], p[13]);
270     sort(p[13], p[14]);
271     sort(p[16], p[17]);
272     sort(p[15], p[16]);
273     sort(p[16], p[17]);
274     sort(p[12], p[15]);
275     sort(p[14], p[17]);
276     sort(p[14], p[15]);
277     sort(p[13], p[16]);
278     sort(p[13], p[14]);
279     sort(p[15], p[16]);
280     sort(p[19], p[20]);
281     sort(p[18], p[19]);
282     sort(p[19], p[20]);
283     sort(p[21], p[22]);
284     sort(p[23], p[24]);
285     sort(p[21], p[23]);
286     sort(p[22], p[24]);
287     sort(p[22], p[23]);
288     sort(p[18], p[21]);
289     sort(p[20], p[23]);
290     sort(p[20], p[21]);
291     sort(p[19], p[22]);
292     sort(p[22], p[24]);
293     sort(p[19], p[20]);
294     sort(p[21], p[22]);
295     sort(p[23], p[24]);
296     sort(p[12], p[18]);
297     sort(p[16], p[22]);
298     sort(p[16], p[18]);
299     sort(p[14], p[20]);
300     sort(p[20], p[24]);
301     sort(p[14], p[16]);
302     sort(p[18], p[20]);
303     sort(p[22], p[24]);
304     sort(p[13], p[19]);
305     sort(p[17], p[23]);
306     sort(p[17], p[19]);
307     sort(p[15], p[21]);
308     sort(p[15], p[17]);
309     sort(p[19], p[21]);
310     sort(p[13], p[14]);
311     sort(p[15], p[16]);
312     sort(p[17], p[18]);
313     sort(p[19], p[20]);
314     sort(p[21], p[22]);
315     sort(p[23], p[24]);
316     sort(p[0], p[12]);
317     sort(p[8], p[20]);
318     sort(p[8], p[12]);
319     sort(p[4], p[16]);
320     sort(p[16], p[24]);
321     sort(p[12], p[16]);
322     sort(p[2], p[14]);
323     sort(p[10], p[22]);
324     sort(p[10], p[14]);
325     sort(p[6], p[18]);
326     sort(p[6], p[10]);
327     sort(p[10], p[12]);
328     sort(p[1], p[13]);
329     sort(p[9], p[21]);
330     sort(p[9], p[13]);
331     sort(p[5], p[17]);
332     sort(p[13], p[17]);
333     sort(p[3], p[15]);
334     sort(p[11], p[23]);
335     sort(p[11], p[15]);
336     sort(p[7], p[19]);
337     sort(p[7], p[11]);
338     sort(p[11], p[13]);
339     sort(p[11], p[12]);
340 }
341 } // namespace
342
343 NENonLinearFilterKernel::NENonLinearFilterKernel()
344     : _border_width(0), _input(nullptr), _output(nullptr), _mask(nullptr), _pattern(MatrixPattern::BOX), _function(NonLinearFilterFunction::MIN), _func_idx(0), _border_size()
345 {
346 }
347
348 BorderSize NENonLinearFilterKernel::border_size() const
349 {
350     return _border_size;
351 }
352
353 void NENonLinearFilterKernel::configure(const ITensor *input, ITensor *output, NonLinearFilterFunction function, unsigned int mask_size, MatrixPattern pattern, const uint8_t *mask,
354                                         bool border_undefined)
355 {
356     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
357     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
358     ARM_COMPUTE_ERROR_ON(3 != mask_size && 5 != mask_size);
359     ARM_COMPUTE_ERROR_ON(MatrixPattern::OTHER == pattern && nullptr == mask);
360
361     // Set class variables
362     _border_size = BorderSize(mask_size / 2);
363     _input       = input;
364     _output      = output;
365     _mask        = mask;
366     _pattern     = pattern;
367     _function    = function;
368
369     // Configure kernel window
370     const unsigned int     num_elems_processed_per_iteration = (MatrixPattern::OTHER == pattern) ? 1 : 8;
371     constexpr unsigned int num_elems_read_per_iteration      = 16;
372
373     Window                 win = calculate_max_window(*input->info(), num_elems_processed_per_iteration, border_undefined, border_size());
374     AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
375     update_window_and_padding(win,
376                               AccessWindowRectangle(input->info(), -border_size().left, -border_size().top, num_elems_read_per_iteration, mask_size),
377                               output_access);
378     output_access.set_valid_region(win, input->info()->valid_region(), border_undefined, border_size());
379
380     INEKernel::configure(win);
381
382     // Define function index
383     _func_idx = (3 == mask_size) ? 0 : 1;
384
385     if(MatrixPattern::OTHER != pattern)
386     {
387         _func_idx = (_func_idx) * 3 + static_cast<unsigned int>(function);
388     }
389 }
390
391 void NENonLinearFilterKernel::fill_mask(uint8_t *mask, int cols, int rows, MatrixPattern pattern)
392 {
393     unsigned int v = 0;
394
395     for(int r = 0; r < rows; ++r)
396     {
397         for(int c = 0; c < cols; ++c, ++v)
398         {
399             uint8_t val = 0;
400
401             switch(pattern)
402             {
403                 case MatrixPattern::BOX:
404                     val = 255;
405                     break;
406                 case MatrixPattern::CROSS:
407                     val = ((r == (rows / 2)) || (c == (cols / 2))) ? 255 : 0;
408                     break;
409                 case MatrixPattern::DISK:
410                     val = (((r - rows / 2.0f + 0.5f) * (r - rows / 2.0f + 0.5f)) / ((rows / 2.0f) * (rows / 2.0f)) + ((c - cols / 2.0f + 0.5f) * (c - cols / 2.0f + 0.5f)) / ((cols / 2.0f) *
411                             (cols / 2.0f))) <= 1.0f ? 255 : 0;
412                     break;
413                 default:
414                     return;
415             }
416
417             mask[v] = val;
418         }
419     }
420 }
421
422 #ifndef DOXYGEN_SKIP_THIS /* Doxygen gets confused by the templates and can't match the implementation to the declaration */
423 namespace arm_compute
424 {
425 template <>
426 void NENonLinearFilterKernel::median_filter_box<3, 3>(const Window &win)
427 {
428     Iterator input(_input, win);
429     Iterator output(_output, win);
430
431     const auto input_top_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-1, -1)));
432     const auto input_mid_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-1, 0)));
433     const auto input_bot_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-1, 1)));
434
435     execute_window_loop(win, [&](const Coordinates & id)
436     {
437         const uint8x16_t top_data = vld1q_u8(input_top_ptr + input.offset());
438         const uint8x16_t mid_data = vld1q_u8(input_mid_ptr + input.offset());
439         const uint8x16_t bot_data = vld1q_u8(input_bot_ptr + input.offset());
440
441         uint8x8_t p0 = vget_low_u8(top_data);
442         uint8x8_t p1 = vext_u8(vget_low_u8(top_data), vget_high_u8(top_data), 1);
443         uint8x8_t p2 = vext_u8(vget_low_u8(top_data), vget_high_u8(top_data), 2);
444         uint8x8_t p3 = vget_low_u8(mid_data);
445         uint8x8_t p4 = vext_u8(vget_low_u8(mid_data), vget_high_u8(mid_data), 1);
446         uint8x8_t p5 = vext_u8(vget_low_u8(mid_data), vget_high_u8(mid_data), 2);
447         uint8x8_t p6 = vget_low_u8(bot_data);
448         uint8x8_t p7 = vext_u8(vget_low_u8(bot_data), vget_high_u8(bot_data), 1);
449         uint8x8_t p8 = vext_u8(vget_low_u8(bot_data), vget_high_u8(bot_data), 2);
450
451         sort9(p0, p1, p2, p3, p4, p5, p6, p7, p8);
452
453         vst1_u8(output.ptr(), p4);
454     },
455     input, output);
456 }
457 template <>
458 void NENonLinearFilterKernel::median_filter_box<5, 5>(const Window &win)
459 {
460     Iterator input(_input, win);
461     Iterator output(_output, win);
462
463     const auto input_top2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, -2)));
464     const auto input_top_ptr  = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, -1)));
465     const auto input_mid_ptr  = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 0)));
466     const auto input_bot_ptr  = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 1)));
467     const auto input_bot2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 2)));
468
469     execute_window_loop(win, [&](const Coordinates & id)
470     {
471         const uint8x16_t top2_data = vld1q_u8(input_top2_ptr + input.offset());
472         const uint8x16_t top_data  = vld1q_u8(input_top_ptr + input.offset());
473         const uint8x16_t mid_data  = vld1q_u8(input_mid_ptr + input.offset());
474         const uint8x16_t bot_data  = vld1q_u8(input_bot_ptr + input.offset());
475         const uint8x16_t bot2_data = vld1q_u8(input_bot2_ptr + input.offset());
476
477         const uint8x8_t d[] =
478         {
479             vget_low_u8(top2_data),
480             vget_high_u8(top2_data),
481             vget_low_u8(top_data),
482             vget_high_u8(top_data),
483             vget_low_u8(mid_data),
484             vget_high_u8(mid_data),
485             vget_low_u8(bot_data),
486             vget_high_u8(bot_data),
487             vget_low_u8(bot2_data),
488             vget_high_u8(bot2_data)
489         };
490
491         uint8x8_t p[25];
492         for(unsigned int i = 0; i < 5; ++i)
493         {
494             const unsigned int idx_d = i * 2;
495             const unsigned int idx_p = i * 5;
496
497             p[idx_p]     = d[idx_d];
498             p[idx_p + 1] = vext_u8(d[idx_d], d[idx_d + 1], 1);
499             p[idx_p + 2] = vext_u8(d[idx_d], d[idx_d + 1], 2);
500             p[idx_p + 3] = vext_u8(d[idx_d], d[idx_d + 1], 3);
501             p[idx_p + 4] = vext_u8(d[idx_d], d[idx_d + 1], 4);
502         }
503
504         sort25(p);
505
506         vst1_u8(output.ptr(), p[12]);
507     },
508     input, output);
509 }
510 } // namespace arm_compute
511 #endif
512
513 template <int mask_w, int mask_h>
514 void NENonLinearFilterKernel::min_filter_box(const Window &win)
515 {
516     static_assert(mask_w > 0, "Mask size must not be 0");
517     static_assert(mask_h > 0, "Mask size must not be 0");
518
519     Iterator input(_input, win);
520     Iterator output(_output, win);
521
522     const int k_row_half = mask_h / 2;
523     const int k_col_half = mask_w / 2;
524
525     // Set row pointers
526     std::array<const unsigned char *, mask_h> input_ptrs{ {} };
527     for(int i = -k_row_half; i <= k_row_half; ++i)
528     {
529         input_ptrs[k_row_half + i] = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-k_col_half, i));
530     }
531
532     execute_window_loop(win, [&](const Coordinates & id)
533     {
534         // Get min of rows
535         uint8x16_t rows_min = vld1q_u8(input_ptrs[0] + input.offset());
536
537         for(unsigned int r = 1; r < mask_h; ++r)
538         {
539             const uint8x16_t data = vld1q_u8(input_ptrs[r] + input.offset());
540             rows_min              = vminq_u8(rows_min, data);
541         }
542
543         const uint8x8_t out = min_row<mask_w>(rows_min);
544
545         // Store result as U8
546         vst1_u8(output.ptr(), out);
547     },
548     input, output);
549 }
550
551 template <int mask_w, int mask_h>
552 void NENonLinearFilterKernel::max_filter_box(const Window &win)
553 {
554     static_assert(mask_w > 0, "Mask size must not be 0");
555     static_assert(mask_h > 0, "Mask size must not be 0");
556     ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
557
558     Iterator input(_input, win);
559     Iterator output(_output, win);
560
561     const int k_row_half = mask_h / 2;
562     const int k_col_half = mask_w / 2;
563
564     // Set row pointers
565     std::array<const unsigned char *, mask_h> input_ptrs{ {} };
566     for(int i = -k_row_half; i <= k_row_half; ++i)
567     {
568         input_ptrs[k_row_half + i] = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-k_col_half, i));
569     }
570
571     execute_window_loop(win, [&](const Coordinates & id)
572     {
573         uint8x16_t rows_max = vld1q_u8(input_ptrs[0] + input.offset());
574
575         // Get max of rows
576         for(unsigned int r = 1; r < mask_h; ++r)
577         {
578             const uint8x16_t data = vld1q_u8(input_ptrs[r] + input.offset());
579             rows_max              = vmaxq_u8(rows_max, data);
580         }
581
582         // Get max of columns
583         const uint8x8_t out = max_row<mask_w>(rows_max);
584
585         // Store result as U8
586         vst1_u8(output.ptr(), out);
587     },
588     input, output);
589 }
590
591 #ifndef DOXYGEN_SKIP_THIS /* Doxygen gets confused by the templates and can't match the implementation to the declaration */
592 namespace arm_compute
593 {
594 template <>
595 void NENonLinearFilterKernel::median_filter_cross<3, 3>(const Window &win)
596 {
597     Iterator input(_input, win);
598     Iterator output(_output, win);
599
600     const auto input_top_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(0, -1)));
601     const auto input_mid_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-1, 0)));
602     const auto input_bot_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(0, 1)));
603
604     execute_window_loop(win, [&](const Coordinates & id)
605     {
606         const uint8x8_t  top_data = vld1_u8(input_top_ptr + input.offset());
607         const uint8x16_t mid_data = vld1q_u8(input_mid_ptr + input.offset());
608         const uint8x8_t  bot_data = vld1_u8(input_bot_ptr + input.offset());
609
610         uint8x8_t p0 = top_data;
611         uint8x8_t p1 = vget_low_u8(mid_data);
612         uint8x8_t p2 = vext_u8(vget_low_u8(mid_data), vget_high_u8(mid_data), 1);
613         uint8x8_t p3 = vext_u8(vget_low_u8(mid_data), vget_high_u8(mid_data), 2);
614         uint8x8_t p4 = bot_data;
615
616         sort5(p0, p1, p2, p3, p4);
617
618         vst1_u8(output.ptr(), p2);
619     },
620     input, output);
621 }
622
623 template <>
624 void NENonLinearFilterKernel::median_filter_cross<5, 5>(const Window &win)
625 {
626     Iterator input(_input, win);
627     Iterator output(_output, win);
628
629     const auto input_top2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(0, -2)));
630     const auto input_top_ptr  = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(0, -1)));
631     const auto input_mid_ptr  = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 0)));
632     const auto input_bot_ptr  = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(0, 1)));
633     const auto input_bot2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(0, 2)));
634
635     execute_window_loop(win, [&](const Coordinates & id)
636     {
637         const uint8x8_t  top2_data = vld1_u8(input_top2_ptr + input.offset());
638         const uint8x8_t  top_data  = vld1_u8(input_top_ptr + input.offset());
639         const uint8x16_t mid_data  = vld1q_u8(input_mid_ptr + input.offset());
640         const uint8x8_t  bot_data  = vld1_u8(input_bot_ptr + input.offset());
641         const uint8x8_t  bot2_data = vld1_u8(input_bot2_ptr + input.offset());
642
643         uint8x8_t p0 = top2_data;
644         uint8x8_t p1 = top_data;
645         uint8x8_t p2 = vget_low_u8(mid_data);
646         uint8x8_t p3 = vext_u8(vget_low_u8(mid_data), vget_high_u8(mid_data), 1);
647         uint8x8_t p4 = vext_u8(vget_low_u8(mid_data), vget_high_u8(mid_data), 2);
648         uint8x8_t p5 = vext_u8(vget_low_u8(mid_data), vget_high_u8(mid_data), 3);
649         uint8x8_t p6 = vext_u8(vget_low_u8(mid_data), vget_high_u8(mid_data), 4);
650         uint8x8_t p7 = bot_data;
651         uint8x8_t p8 = bot2_data;
652
653         sort9(p0, p1, p2, p3, p4, p5, p6, p7, p8);
654
655         vst1_u8(output.ptr(), p4);
656     },
657     input, output);
658 }
659 } // namespace arm_compute
660 #endif
661
662 template <int mask_w, int mask_h>
663 void NENonLinearFilterKernel::min_filter_cross(const Window &win)
664 {
665     static_assert(mask_w > 0, "Mask size must not be 0");
666     static_assert(mask_h > 0, "Mask size must not be 0");
667     ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
668
669     Iterator input(_input, win);
670     Iterator output(_output, win);
671
672     const int k_row_half = mask_h / 2;
673     const int k_col_half = mask_w / 2;
674
675     const unsigned char *mid_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-k_col_half, 0));
676
677     // Set row pointers
678     std::array<const unsigned char *, mask_h> input_ptrs{ {} };
679     for(int i = -k_row_half; i <= k_row_half; ++i)
680     {
681         input_ptrs[k_row_half + i] = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(0, i));
682     }
683
684     execute_window_loop(win, [&](const Coordinates & id)
685     {
686         uint8x8_t rows_min = vld1_u8(input_ptrs[0] + input.offset());
687
688         // Get min of rows
689         for(unsigned int r = 1; r < mask_h; ++r)
690         {
691             const uint8x8_t data = vld1_u8(input_ptrs[r] + input.offset());
692             rows_min             = vmin_u8(rows_min, data);
693         }
694
695         // Get min of middle row
696         const uint8x16_t data = vld1q_u8(mid_ptr + input.offset());
697         uint8x8_t        out  = min_row<mask_w>(data);
698
699         // Get final min
700         out = vmin_u8(out, rows_min);
701
702         // Store result as U8
703         vst1_u8(output.ptr(), out);
704     },
705     input, output);
706 }
707
708 template <int mask_w, int mask_h>
709 void NENonLinearFilterKernel::max_filter_cross(const Window &win)
710 {
711     static_assert(mask_w > 0, "Mask size must not be 0");
712     static_assert(mask_h > 0, "Mask size must not be 0");
713     ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
714
715     Iterator input(_input, win);
716     Iterator output(_output, win);
717
718     const int k_row_half = mask_h / 2;
719     const int k_col_half = mask_w / 2;
720
721     const unsigned char *mid_ptr = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-k_col_half, 0));
722
723     // Set row pointers
724     std::array<unsigned char *, mask_h> input_ptrs{ {} };
725     for(int i = -k_row_half; i <= k_row_half; ++i)
726     {
727         input_ptrs[k_row_half + i] = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(0, i));
728     }
729
730     execute_window_loop(win, [&](const Coordinates & id)
731     {
732         uint8x8_t rows_max = vld1_u8(input_ptrs[0] + input.offset());
733
734         // Get max of rows
735         for(unsigned int r = 1; r < mask_h; ++r)
736         {
737             const uint8x8_t data = vld1_u8(input_ptrs[r] + input.offset());
738             rows_max             = vmax_u8(rows_max, data);
739         }
740
741         // Get max of middle row
742         const uint8x16_t data = vld1q_u8(mid_ptr + input.offset());
743         uint8x8_t        out  = max_row<mask_w>(data);
744
745         // Get final max
746         out = vmax_u8(out, rows_max);
747
748         // Store result as U8
749         vst1_u8(output.ptr(), out);
750     },
751     input, output);
752 }
753
754 #ifndef DOXYGEN_SKIP_THIS /* Doxygen gets confused by the templates and can't match the implementation to the declaration */
755 namespace arm_compute
756 {
757 template <>
758 void NENonLinearFilterKernel::median_filter_disk<5, 5>(const Window &win)
759 {
760     Iterator input(_input, win);
761     Iterator output(_output, win);
762
763     const auto input_top2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-1, -2)));
764     const auto input_top_ptr  = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, -1)));
765     const auto input_mid_ptr  = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 0)));
766     const auto input_bot_ptr  = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 1)));
767     const auto input_bot2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-1, 2)));
768
769     execute_window_loop(win, [&](const Coordinates & id)
770     {
771         const uint8x16_t top2_data = vld1q_u8(input_top2_ptr + input.offset());
772         const uint8x16_t top_data  = vld1q_u8(input_top_ptr + input.offset());
773         const uint8x16_t mid_data  = vld1q_u8(input_mid_ptr + input.offset());
774         const uint8x16_t bot_data  = vld1q_u8(input_bot_ptr + input.offset());
775         const uint8x16_t bot2_data = vld1q_u8(input_bot2_ptr + input.offset());
776
777         uint8x8_t d[] =
778         {
779             vget_low_u8(top2_data),
780             vget_high_u8(top2_data),
781             vget_low_u8(top_data),
782             vget_high_u8(top_data),
783             vget_low_u8(mid_data),
784             vget_high_u8(mid_data),
785             vget_low_u8(bot_data),
786             vget_high_u8(bot_data),
787             vget_low_u8(bot2_data),
788             vget_high_u8(bot2_data)
789         };
790
791         uint8x8_t p[21];
792         p[0]  = d[0];
793         p[1]  = vext_u8(d[0], d[1], 1);
794         p[2]  = vext_u8(d[0], d[1], 2);
795         p[18] = d[8];
796         p[19] = vext_u8(d[8], d[9], 1);
797         p[20] = vext_u8(d[8], d[9], 2);
798
799         for(unsigned int i = 0; i < 3; ++i)
800         {
801             const unsigned int idx_d = 2 + i * 2;
802             const unsigned int idx_p = 3 + i * 5;
803
804             p[idx_p]     = d[idx_d];
805             p[idx_p + 1] = vext_u8(d[idx_d], d[idx_d + 1], 1);
806             p[idx_p + 2] = vext_u8(d[idx_d], d[idx_d + 1], 2);
807             p[idx_p + 3] = vext_u8(d[idx_d], d[idx_d + 1], 3);
808             p[idx_p + 4] = vext_u8(d[idx_d], d[idx_d + 1], 4);
809         }
810
811         sort21(p);
812
813         vst1_u8(output.ptr(), p[10]);
814     },
815     input, output);
816 }
817
818 template <>
819 void NENonLinearFilterKernel::min_filter_disk<5, 5>(const Window &win)
820 {
821     Iterator input(_input, win);
822     Iterator output(_output, win);
823
824     const auto input_top2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-1, -2)));
825     const auto input_top_ptr  = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, -1)));
826     const auto input_mid_ptr  = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 0)));
827     const auto input_bot_ptr  = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 1)));
828     const auto input_bot2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-1, 2)));
829
830     execute_window_loop(win, [&](const Coordinates & id)
831     {
832         const uint8x16_t top2_data = vld1q_u8(input_top2_ptr + input.offset());
833         const uint8x16_t top_data  = vld1q_u8(input_top_ptr + input.offset());
834         const uint8x16_t mid_data  = vld1q_u8(input_mid_ptr + input.offset());
835         const uint8x16_t bot_data  = vld1q_u8(input_bot_ptr + input.offset());
836         const uint8x16_t bot2_data = vld1q_u8(input_bot2_ptr + input.offset());
837
838         const uint8x16_t rows_min_3 = vminq_u8(top2_data, bot2_data);
839         uint8x16_t       rows_min_5 = vminq_u8(top_data, bot_data);
840         rows_min_5                  = vminq_u8(rows_min_5, mid_data);
841
842         const uint8x8_t out_3 = min_row<3>(rows_min_3);
843         const uint8x8_t out_5 = min_row<5>(rows_min_5);
844
845         vst1_u8(output.ptr(), vmin_u8(out_3, out_5));
846     },
847     input, output);
848 }
849
850 template <>
851 void NENonLinearFilterKernel::max_filter_disk<5, 5>(const Window &win)
852 {
853     Iterator input(_input, win);
854     Iterator output(_output, win);
855
856     const auto input_top2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-1, -2)));
857     const auto input_top_ptr  = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, -1)));
858     const auto input_mid_ptr  = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 0)));
859     const auto input_bot_ptr  = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-2, 1)));
860     const auto input_bot2_ptr = static_cast<const unsigned char *>(_input->ptr_to_element(Coordinates(-1, 2)));
861
862     execute_window_loop(win, [&](const Coordinates & id)
863     {
864         const uint8x16_t top2_data = vld1q_u8(input_top2_ptr + input.offset());
865         const uint8x16_t top_data  = vld1q_u8(input_top_ptr + input.offset());
866         const uint8x16_t mid_data  = vld1q_u8(input_mid_ptr + input.offset());
867         const uint8x16_t bot_data  = vld1q_u8(input_bot_ptr + input.offset());
868         const uint8x16_t bot2_data = vld1q_u8(input_bot2_ptr + input.offset());
869
870         const uint8x16_t rows_max_3 = vmaxq_u8(top2_data, bot2_data);
871         uint8x16_t       rows_max_5 = vmaxq_u8(top_data, bot_data);
872         rows_max_5                  = vmaxq_u8(rows_max_5, mid_data);
873
874         const uint8x8_t out_3 = max_row<3>(rows_max_3);
875         const uint8x8_t out_5 = max_row<5>(rows_max_5);
876
877         vst1_u8(output.ptr(), vmax_u8(out_3, out_5));
878     },
879     input, output);
880 }
881 } // namespace arm_compute
882 #endif
883
884 template <int mask_w, int mask_h>
885 void NENonLinearFilterKernel::non_linear_filter_generic(const Window &win)
886 {
887     Iterator input(_input, win);
888     Iterator output(_output, win);
889     ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
890
891     const int     k_row_half = mask_h / 2;
892     const int     k_col_half = mask_w / 2;
893     constexpr int mask_size  = mask_w * mask_h;
894
895     // Set row pointers
896     std::array<unsigned char *, mask_h> input_ptrs{ {} };
897     for(int i = -k_row_half; i <= k_row_half; ++i)
898     {
899         input_ptrs[k_row_half + i] = _input->buffer() + _input->info()->offset_element_in_bytes(Coordinates(-k_col_half, i));
900     }
901
902     execute_window_loop(win, [&](const Coordinates & id)
903     {
904         std::array<uint8_t, mask_size> vals{ {} };
905
906         size_t v = 0;
907         size_t m = 0;
908
909         for(unsigned int r = 0; r < mask_h; ++r)
910         {
911             const auto in_ptr = static_cast<const uint8_t *>(input_ptrs[r] + input.offset());
912
913             for(unsigned int c = 0; c < mask_w; ++c, ++m)
914             {
915                 if(_mask[m] == 255)
916                 {
917                     vals[v] = in_ptr[c];
918                     ++v;
919                 }
920             }
921         }
922
923         // Only do something if there is at least one non-zero element in the
924         // mask
925         if(v > 0)
926         {
927             std::sort(vals.begin(), vals.begin() + v);
928
929             switch(_function)
930             {
931                 case NonLinearFilterFunction::MIN:
932                     *output.ptr() = vals[0];
933                     break;
934                 case NonLinearFilterFunction::MAX:
935                     *output.ptr() = vals[v - 1];
936                     break;
937                 case NonLinearFilterFunction::MEDIAN:
938                     *output.ptr() = vals[v / 2];
939                     break;
940                 default:
941                     break;
942             }
943         }
944     },
945     input, output);
946 }
947
948 void NENonLinearFilterKernel::run(const Window &window)
949 {
950     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
951     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
952
953     using NonLinearFilterFunction = void (NENonLinearFilterKernel::*)(const Window & window);
954
955     // Function table for BOX pattern
956     static const std::array<NonLinearFilterFunction, 6> func_table_box =
957     {
958         {
959             &NENonLinearFilterKernel::median_filter_box<3, 3>,
960             &NENonLinearFilterKernel::min_filter_box<3, 3>,
961             &NENonLinearFilterKernel::max_filter_box<3, 3>,
962             &NENonLinearFilterKernel::median_filter_box<5, 5>,
963             &NENonLinearFilterKernel::min_filter_box<5, 5>,
964             &NENonLinearFilterKernel::max_filter_box<5, 5>,
965         }
966     };
967
968     // Function table for CROSS pattern
969     static const std::array<NonLinearFilterFunction, 6> func_table_cross =
970     {
971         {
972             &NENonLinearFilterKernel::median_filter_cross<3, 3>,
973             &NENonLinearFilterKernel::min_filter_cross<3, 3>,
974             &NENonLinearFilterKernel::max_filter_cross<3, 3>,
975             &NENonLinearFilterKernel::median_filter_cross<5, 5>,
976             &NENonLinearFilterKernel::min_filter_cross<5, 5>,
977             &NENonLinearFilterKernel::max_filter_cross<5, 5>,
978         }
979     };
980
981     // Function table for DISK pattern
982     static const std::array<NonLinearFilterFunction, 6> func_table_disk =
983     {
984         {
985             &NENonLinearFilterKernel::median_filter_box<3, 3>,
986             &NENonLinearFilterKernel::min_filter_box<3, 3>,
987             &NENonLinearFilterKernel::max_filter_box<3, 3>,
988             &NENonLinearFilterKernel::median_filter_disk<5, 5>,
989             &NENonLinearFilterKernel::min_filter_disk<5, 5>,
990             &NENonLinearFilterKernel::max_filter_disk<5, 5>,
991         }
992     };
993
994     // Function table for OTHER pattern
995     static const std::array<NonLinearFilterFunction, 2> func_table_generic =
996     {
997         {
998             &NENonLinearFilterKernel::non_linear_filter_generic<3, 3>,
999             &NENonLinearFilterKernel::non_linear_filter_generic<5, 5>,
1000         }
1001     };
1002
1003     switch(_pattern)
1004     {
1005         case MatrixPattern::BOX:
1006             ARM_COMPUTE_ERROR_ON(_func_idx >= func_table_box.size());
1007             (this->*func_table_box[_func_idx])(window);
1008             break;
1009         case MatrixPattern::CROSS:
1010             ARM_COMPUTE_ERROR_ON(_func_idx >= func_table_cross.size());
1011             (this->*func_table_cross[_func_idx])(window);
1012             break;
1013         case MatrixPattern::DISK:
1014             ARM_COMPUTE_ERROR_ON(_func_idx >= func_table_disk.size());
1015             (this->*func_table_disk[_func_idx])(window);
1016             break;
1017         case MatrixPattern::OTHER:
1018         default:
1019             ARM_COMPUTE_ERROR_ON(_func_idx >= func_table_generic.size());
1020             (this->*func_table_generic[_func_idx])(window);
1021             break;
1022     }
1023 }