fix warning hit with Android clang version 5.0.300080 (#348)
[platform/upstream/armcl.git] / src / core / NEON / kernels / NESoftmaxLayerKernel.cpp
1 /*
2  * Copyright (c) 2017 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
25
26 #include "arm_compute/core/AccessWindowStatic.h"
27 #include "arm_compute/core/Error.h"
28 #include "arm_compute/core/Helpers.h"
29 #include "arm_compute/core/ITensor.h"
30 #include "arm_compute/core/NEON/NEFixedPoint.h"
31 #include "arm_compute/core/NEON/NEMath.h"
32 #include "arm_compute/core/TensorInfo.h"
33 #include "arm_compute/core/Utils.h"
34 #include "arm_compute/core/Validate.h"
35 #include "arm_compute/core/Window.h"
36
37 #include <algorithm>
38 #include <arm_neon.h>
39 #include <cfloat>
40
41 using namespace arm_compute;
42
43 namespace
44 {
45 Status validate_arguments_logits_1d_max(const ITensorInfo *input, const ITensorInfo *output)
46 {
47     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
48     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
49
50     // Checks performed when output is configured
51     if(output->total_size() != 0)
52     {
53         // Softmax across the x dimension
54         TensorShape output_shape{ input->tensor_shape() };
55         output_shape.set(0, 1);
56
57         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
58         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
59         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
60     }
61
62     return Status{};
63 }
64
65 std::pair<Status, Window> validate_and_configure_window_logits_1d_max(ITensorInfo *input, ITensorInfo *output)
66 {
67     // Configure kernel window
68     constexpr unsigned int num_elems_written_per_row = 1;
69     const int              input_width               = input->valid_region().shape.x();
70
71     unsigned int           num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
72     Window                 win                               = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
73     AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
74     bool                   window_changed = false;
75
76     if(output->total_size() != 0)
77     {
78         AccessWindowHorizontal output_access(output, 0, num_elems_written_per_row, 1.f / input_width);
79         window_changed = update_window_and_padding(win, input_access, output_access);
80         output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
81     }
82     else
83     {
84         window_changed = update_window_and_padding(win, input_access);
85     }
86
87     Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
88     return std::make_pair(err, win);
89 }
90
91 Status validate_arguments_logits_1d_shift_exp_sum(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum, float beta)
92 {
93     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, max, sum, output);
94     ARM_COMPUTE_RETURN_ERROR_ON((beta != 1.0f) && is_data_type_fixed_point(input->data_type()));
95     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
96
97     // Checks performed when output is configured
98     if(output->total_size() != 0)
99     {
100         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
101         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
102         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
103     }
104
105     // Checks performed when sum is configured
106     if(sum->total_size() != 0)
107     {
108         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, max, sum);
109         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(max, sum);
110         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, max, sum);
111     }
112
113     return Status{};
114 }
115
116 std::pair<Status, Window> validate_and_configure_window_logits_1d_shift_exp_sum(ITensorInfo *input, ITensorInfo *max, ITensorInfo *output, ITensorInfo *sum)
117 {
118     unsigned int num_elems_processed_per_iteration = input->valid_region().shape.x();
119
120     // Configure kernel window
121     Window                 win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
122     AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
123     AccessWindowHorizontal max_access(max, 0, 1);
124     AccessWindowHorizontal sum_access(sum, 0, 1);
125     bool                   window_changed = false;
126
127     if(output->total_size() != 0)
128     {
129         AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
130         window_changed = update_window_and_padding(win, input_access, max_access, output_access, sum_access);
131         output_access.set_valid_region(win, input->valid_region());
132     }
133     else
134     {
135         window_changed = update_window_and_padding(win, input_access, max_access, sum_access);
136     }
137
138     sum_access.set_valid_region(win, ValidRegion(Coordinates(), sum->tensor_shape()));
139
140     Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
141     return std::make_pair(err, win);
142 }
143
144 Status validate_arguments_logits_1d_norm(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output)
145 {
146     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, sum, output);
147     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::S32, DataType::F16, DataType::F32);
148     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum);
149     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, sum);
150
151     // Checks performed when output is configured
152     if(output->total_size() != 0)
153     {
154         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
155         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
156         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
157     }
158
159     return Status{};
160 }
161
162 std::pair<Status, Window> validate_and_configure_window_logits_1d_norm(ITensorInfo *input, ITensorInfo *sum, ITensorInfo *output)
163 {
164     // Configure kernel window
165     unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
166     Window       win                               = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
167
168     AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
169     AccessWindowStatic     sum_access(sum, 0, 0, 1, sum->dimension(1));
170     bool                   window_changed = false;
171
172     if(output->total_size() != 0)
173     {
174         AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
175
176         window_changed = update_window_and_padding(win, input_access, sum_access, output_access);
177
178         output_access.set_valid_region(win, input->valid_region());
179     }
180     else
181     {
182         window_changed = update_window_and_padding(win, input_access, sum_access);
183     }
184     Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
185     return std::make_pair(err, win);
186 }
187
188 void logits_1d_max_qs8(const ITensor *in, ITensor *out, const Window &window)
189 {
190     Window in_slice = window.first_slice_window_1D();
191
192     Window window_max(window);
193     window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
194     Window max_slice = window_max.first_slice_window_1D();
195
196     do
197     {
198         Iterator input(in, in_slice);
199         Iterator output(out, max_slice);
200
201         qint8x16_t vec_max = vdupq_n_s8(std::numeric_limits<qint8_t>::lowest());
202
203         execute_window_loop(in_slice, [&](const Coordinates & id)
204         {
205             const auto       in_ptr        = reinterpret_cast<const qint8_t *>(input.ptr());
206             const qint8x16_t current_value = vld1q_qs8(in_ptr);
207             vec_max                        = vmaxq_qs8(vec_max, current_value);
208         },
209         input);
210
211         qint8x8_t carry_max = vpmax_qs8(vget_high_s8(vec_max), vget_low_s8(vec_max));
212         carry_max           = vpmax_qs8(carry_max, carry_max);
213         carry_max           = vpmax_qs8(carry_max, carry_max);
214         carry_max           = vpmax_qs8(carry_max, carry_max);
215
216         *(reinterpret_cast<qint8_t *>(output.ptr())) = vget_lane_s8(carry_max, 0);
217     }
218     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
219 }
220 void logits_1d_max_qs16(const ITensor *in, ITensor *out, const Window &window)
221 {
222     Window in_slice = window.first_slice_window_1D();
223
224     Window window_max(window);
225     window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
226     Window max_slice = window_max.first_slice_window_1D();
227
228     do
229     {
230         Iterator input(in, in_slice);
231         Iterator output(out, max_slice);
232
233         qint16x8_t vec_max = vdupq_n_qs16(std::numeric_limits<qint16_t>::lowest());
234
235         execute_window_loop(in_slice, [&](const Coordinates & id)
236         {
237             const auto       in_ptr        = reinterpret_cast<const qint16_t *>(input.ptr());
238             const qint16x8_t current_value = vld1q_qs16(in_ptr);
239             vec_max                        = vmaxq_qs16(vec_max, current_value);
240         },
241         input);
242
243         qint16x4_t carry_max = vpmax_qs16(vget_high_qs16(vec_max), vget_low_qs16(vec_max));
244         carry_max            = vpmax_qs16(carry_max, carry_max);
245         carry_max            = vpmax_qs16(carry_max, carry_max);
246
247         *(reinterpret_cast<qint16_t *>(output.ptr())) = vget_lane_s16(carry_max, 0);
248     }
249     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
250 }
251
252 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
253 void logits_1d_max_f16(const ITensor *in, ITensor *out, const Window &window)
254 {
255     Window in_slice = window.first_slice_window_1D();
256
257     Window window_max(window);
258     window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
259     Window max_slice = window_max.first_slice_window_1D();
260
261     do
262     {
263         Iterator input(in, in_slice);
264         Iterator output(out, max_slice);
265
266         float16x8_t vec_max = vdupq_n_f16(std::numeric_limits<float16_t>::lowest());
267
268         execute_window_loop(in_slice, [&](const Coordinates & id)
269         {
270             const auto        in_ptr        = reinterpret_cast<const float16_t *>(input.ptr());
271             const float16x8_t current_value = vld1q_f16(in_ptr);
272             vec_max                         = vmaxq_f16(vec_max, current_value);
273         },
274         input);
275
276         float16x4_t carry_max = vpmax_f16(vget_high_f16(vec_max), vget_low_f16(vec_max));
277         carry_max             = vpmax_f16(carry_max, carry_max);
278         carry_max             = vpmax_f16(carry_max, carry_max);
279
280         *(reinterpret_cast<float16_t *>(output.ptr())) = vget_lane_f16(carry_max, 0);
281     }
282     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
283 }
284 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
285
286 void logits_1d_max_f32(const ITensor *in, ITensor *out, const Window &window)
287 {
288     Window in_slice = window.first_slice_window_1D();
289
290     Window window_max(window);
291     window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
292     Window max_slice = window_max.first_slice_window_1D();
293
294     do
295     {
296         Iterator input(in, in_slice);
297         Iterator output(out, max_slice);
298
299         float32x4_t vec_max = vdupq_n_f32(-FLT_MAX);
300
301         execute_window_loop(in_slice, [&](const Coordinates & id)
302         {
303             const auto        in_ptr        = reinterpret_cast<const float *>(input.ptr());
304             const float32x4_t current_value = vld1q_f32(in_ptr);
305             vec_max                         = vmaxq_f32(vec_max, current_value);
306         },
307         input);
308
309         float32x2_t carry_max = vpmax_f32(vget_high_f32(vec_max), vget_low_f32(vec_max));
310         carry_max             = vpmax_f32(carry_max, carry_max);
311
312         *(reinterpret_cast<float *>(output.ptr())) = vget_lane_f32(carry_max, 0);
313     }
314     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
315 }
316 } // namespace
317
318 NELogits1DMaxKernel::NELogits1DMaxKernel()
319     : _func(nullptr), _border_size()
320 {
321 }
322
323 BorderSize NELogits1DMaxKernel::border_size() const
324 {
325     return _border_size;
326 }
327
328 void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
329 {
330     ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
331
332     // Softmax across the x dimension
333     TensorShape output_shape{ input->info()->tensor_shape() };
334     output_shape.set(0, 1);
335
336     // Output auto initialization if not yet initialized
337     auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
338
339     // Perform validation step
340     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(input->info(), output->info()));
341
342     const int    input_width                       = input->info()->valid_region().shape.x();
343     unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
344
345     switch(input->info()->data_type())
346     {
347         case DataType::QS8:
348             _func = &logits_1d_max_qs8;
349             break;
350         case DataType::QS16:
351             _func = &logits_1d_max_qs16;
352             break;
353         case DataType::F32:
354             _func = &logits_1d_max_f32;
355             break;
356         case DataType::F16:
357 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
358             _func = &logits_1d_max_f16;
359             break;
360 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
361         default:
362             ARM_COMPUTE_ERROR("Unsupported data type.");
363     }
364
365     _input       = input;
366     _output      = output;
367     _border_size = BorderSize(0, num_elems_processed_per_iteration - (input_width % num_elems_processed_per_iteration), 0, 0);
368
369     // Configure kernel window
370     auto win_config = validate_and_configure_window_logits_1d_max(input->info(), output->info());
371     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
372     INEKernel::configure(win_config.second);
373 }
374
375 Status NELogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
376 {
377     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(input, output));
378     ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_max(input->clone().get(), output->clone().get()).first);
379
380     return Status{};
381 }
382
383 void NELogits1DMaxKernel::run(const Window &window, const ThreadInfo &info)
384 {
385     ARM_COMPUTE_UNUSED(info);
386     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
387     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
388     ARM_COMPUTE_ERROR_ON(_func == nullptr);
389
390     (*_func)(_input, _output, window);
391 }
392
393 namespace
394 {
395 void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
396 {
397     ARM_COMPUTE_UNUSED(beta);
398
399     Window window_max(window);
400     window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
401
402     Window max_slice = window_max.first_slice_window_1D();
403     Window in_slice  = window.first_slice_window_1D();
404
405     constexpr int step                 = 8;
406     const int     long_steps           = in->info()->valid_region().shape.x() / step;
407     const int     small_steps          = in->info()->valid_region().shape.x() % step;
408     const int     fixed_point_position = in->info()->fixed_point_position();
409
410     do
411     {
412         Iterator input(in, in_slice);
413         Iterator exp(out, in_slice);
414         Iterator _max(max, max_slice);
415         Iterator _sum(sum, max_slice);
416
417         // Get pointers
418         auto in_ptr  = reinterpret_cast<const qint8_t *>(input.ptr());
419         auto exp_ptr = reinterpret_cast<qint8_t *>(exp.ptr());
420
421         // Init sum to zero
422         qint16x8_t vec_sum_value = vdupq_n_qs16(0);
423
424         // Get max value
425         const auto      max_ptr = reinterpret_cast<const qint8_t *>(_max.ptr());
426         const qint8x8_t vec_max = vdup_n_qs8(*max_ptr);
427
428         // Run neon loop
429         for(int i = 0; i < long_steps; ++i)
430         {
431             qint8x8_t vec_elements = vld1_qs8(in_ptr);
432             vec_elements           = vqsub_qs8(vec_elements, vec_max);
433             vec_elements           = vqexp_qs8(vec_elements, fixed_point_position);
434
435             vst1_qs8(exp_ptr, vec_elements);
436             vec_sum_value = vqaddq_qs16(vec_sum_value, vmovl_s8(vec_elements));
437
438             in_ptr += step;
439             exp_ptr += step;
440         }
441         // Reduce sum
442         const qint16x4_t sum_red = vqadd_qs16(vget_low_s16(vec_sum_value), vget_high_s16(vec_sum_value));
443         const qint16_t   sum0    = sqadd_qs16(vget_lane_s16(sum_red, 0), vget_lane_s16(sum_red, 1));
444         const qint16_t   sum1    = sqadd_qs16(vget_lane_s16(sum_red, 2), vget_lane_s16(sum_red, 3));
445         qint16_t         sum     = sqadd_qs16(sum0, sum1);
446
447         // Run remaining elements
448         for(int i = 0; i < small_steps; ++i)
449         {
450             qint8_t element = sqexp_qs8(sqsub_qs8(in_ptr[i], *max_ptr), fixed_point_position);
451             exp_ptr[i]      = element;
452             sum             = sqadd_qs16(sum, element);
453         }
454
455         *(reinterpret_cast<qint8_t *>(_sum.ptr())) = sqmovn_qs16(sum);
456     }
457     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
458 }
459 void logits_1d_shift_exp_sum_qs16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
460 {
461     ARM_COMPUTE_UNUSED(beta);
462
463     Window window_max(window);
464     window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
465
466     Window max_slice = window_max.first_slice_window_1D();
467     Window in_slice  = window.first_slice_window_1D();
468
469     constexpr int step                 = 4;
470     const int     long_steps           = in->info()->valid_region().shape.x() / step;
471     const int     small_steps          = in->info()->valid_region().shape.x() % step;
472     const int     fixed_point_position = in->info()->fixed_point_position();
473
474     do
475     {
476         Iterator input(in, in_slice);
477         Iterator exp(out, in_slice);
478         Iterator _max(max, max_slice);
479         Iterator _sum(sum, max_slice);
480
481         // Get pointers
482         auto in_ptr  = reinterpret_cast<const qint16_t *>(input.ptr());
483         auto exp_ptr = reinterpret_cast<qint16_t *>(exp.ptr());
484
485         // Init sum to zero
486         qint32x4_t vec_sum_value = vdupq_n_qs32(0);
487
488         // Get max value
489         const auto       max_ptr = reinterpret_cast<const qint16_t *>(_max.ptr());
490         const qint16x4_t vec_max = vdup_n_qs16(*max_ptr);
491
492         // Run neon loop
493         for(int i = 0; i < long_steps; ++i)
494         {
495             qint16x4_t vec_elements = vld1_qs16(in_ptr);
496             vec_elements            = vqsub_qs16(vec_elements, vec_max);
497             vec_elements            = vqexp_qs16(vec_elements, fixed_point_position);
498
499             vst1_qs16(exp_ptr, vec_elements);
500             vec_sum_value = vqaddq_qs32(vec_sum_value, vmovl_s16(vec_elements));
501
502             in_ptr += step;
503             exp_ptr += step;
504         }
505         // Reduce sum
506         qint32x2_t carry_addition = vqadd_qs32(vget_high_s32(vec_sum_value), vget_low_s32(vec_sum_value));
507         qint32_t   sum            = vget_lane_s32(carry_addition, 0) + vget_lane_s32(carry_addition, 1);
508
509         // Run remaining elements
510         for(int i = 0; i < small_steps; ++i)
511         {
512             qint16_t element = sqexp_qs16(sqsub_qs16(in_ptr[i], *max_ptr), fixed_point_position);
513             exp_ptr[i]       = element;
514             sum              = sqadd_qs32(sum, element);
515         }
516
517         *(reinterpret_cast<qint16_t *>(_sum.ptr())) = sqmovn_qs32(sum);
518     }
519     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
520 }
521
522 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
523 void logits_1d_shift_exp_sum_f16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
524 {
525     Window window_max(window);
526     window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
527
528     Window max_slice = window_max.first_slice_window_1D();
529     Window in_slice  = window.first_slice_window_1D();
530
531     constexpr int step        = 8;
532     const int     long_steps  = in->info()->valid_region().shape.x() / step;
533     const int     small_steps = in->info()->valid_region().shape.x() % step;
534
535     do
536     {
537         Iterator input(in, in_slice);
538         Iterator exp(out, in_slice);
539         Iterator _max(max, max_slice);
540         Iterator _sum(sum, max_slice);
541
542         // Get pointers
543         auto in_ptr  = reinterpret_cast<const float16_t *>(input.ptr());
544         auto exp_ptr = reinterpret_cast<float16_t *>(exp.ptr());
545
546         // Init sum to zero
547         float16x8_t vec_sum_value = vdupq_n_f16(0);
548
549         // Get max value
550         const auto        max_ptr = reinterpret_cast<const float16_t *>(_max.ptr());
551         const float16x8_t vec_max = vdupq_n_f16(*max_ptr);
552
553         // Run neon loop
554         for(int i = 0; i < long_steps; ++i)
555         {
556             float16x8_t vec_elements = vld1q_f16(in_ptr);
557             vec_elements             = vsubq_f16(vec_elements, vec_max);
558             vec_elements             = vmulq_n_f16(vec_elements, beta);
559             vec_elements             = vexpq_f16(vec_elements);
560
561             vst1q_f16(exp_ptr, vec_elements);
562             vec_sum_value = vaddq_f16(vec_sum_value, vec_elements);
563
564             in_ptr += step;
565             exp_ptr += step;
566         }
567         // Reduce sum
568         const float16x4_t sum_red        = vadd_f16(vget_low_f16(vec_sum_value), vget_high_f16(vec_sum_value));
569         const float16x4_t carry_addition = vpadd_f16(sum_red, sum_red);
570         float16_t         sum            = vget_lane_f16(carry_addition, 0) + vget_lane_f16(carry_addition, 1);
571
572         // Run remaining elements
573         for(int i = 0; i < small_steps; ++i)
574         {
575             const float16_t element = std::exp(static_cast<float>(in_ptr[i] - *max_ptr) * beta);
576             exp_ptr[i]              = element;
577             sum += element;
578         }
579         *(reinterpret_cast<float16_t *>(_sum.ptr())) = sum;
580     }
581     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
582 }
583 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
584
585 void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
586 {
587     Window window_max(window);
588     window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
589
590     Window max_slice = window_max.first_slice_window_1D();
591     Window in_slice  = window.first_slice_window_1D();
592
593     constexpr int step        = 4;
594     const int     long_steps  = in->info()->valid_region().shape.x() / step;
595     const int     small_steps = in->info()->valid_region().shape.x() % step;
596
597     do
598     {
599         Iterator input(in, in_slice);
600         Iterator exp(out, in_slice);
601         Iterator _max(max, max_slice);
602         Iterator _sum(sum, max_slice);
603
604         // Get pointers
605         auto in_ptr  = reinterpret_cast<const float *>(input.ptr());
606         auto exp_ptr = reinterpret_cast<float *>(exp.ptr());
607
608         // Init sum to zero
609         float32x4_t vec_sum_value = vdupq_n_f32(0.0f);
610
611         // Get max value
612         const auto        max_ptr = reinterpret_cast<const float *>(_max.ptr());
613         const float32x4_t vec_max = vdupq_n_f32(*max_ptr);
614
615         // Run neon loop
616         for(int i = 0; i < long_steps; ++i)
617         {
618             float32x4_t vec_elements = vld1q_f32(in_ptr);
619             vec_elements             = vsubq_f32(vec_elements, vec_max);
620             vec_elements             = vmulq_n_f32(vec_elements, beta);
621             vec_elements             = vexpq_f32(vec_elements);
622
623             vst1q_f32(exp_ptr, vec_elements);
624             vec_sum_value = vaddq_f32(vec_elements, vec_sum_value);
625
626             in_ptr += step;
627             exp_ptr += step;
628         }
629
630         // Reduce sum
631         float32x2_t carry_addition = vpadd_f32(vget_high_f32(vec_sum_value), vget_low_f32(vec_sum_value));
632         carry_addition             = vpadd_f32(carry_addition, carry_addition);
633         float sum                  = vget_lane_f32(carry_addition, 0);
634
635         // Run remaining elements
636         for(int i = 0; i < small_steps; ++i)
637         {
638             float element = std::exp((in_ptr[i] - *max_ptr) * beta);
639             exp_ptr[i]    = element;
640             sum += element;
641         }
642
643         *(reinterpret_cast<float *>(_sum.ptr())) = sum;
644     }
645     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
646 }
647 } //namespace
648
649 NELogits1DShiftExpSumKernel::NELogits1DShiftExpSumKernel()
650     : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _sum(nullptr), _beta(1.0f)
651 {
652 }
653
654 void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum, float beta)
655 {
656     ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, sum, output);
657
658     // Output auto initialization if not yet initialized
659     auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
660     auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
661
662     // Perform validation step
663     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_shift_exp_sum(input->info(), max->info(), output->info(), sum->info(), beta));
664
665     switch(input->info()->data_type())
666     {
667         case DataType::QS8:
668             _func = &logits_1d_shift_exp_sum_qs8;
669             break;
670         case DataType::QS16:
671             _func = &logits_1d_shift_exp_sum_qs16;
672             break;
673         case DataType::F32:
674             _func = &logits_1d_shift_exp_sum_f32;
675             break;
676         case DataType::F16:
677 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
678             _func = &logits_1d_shift_exp_sum_f16;
679             break;
680 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
681         default:
682             ARM_COMPUTE_ERROR("Unsupported data type.");
683             break;
684     }
685
686     _input  = input;
687     _max    = max;
688     _output = output;
689     _sum    = sum;
690     _beta   = beta;
691
692     // Configure kernel window
693     auto win_config = validate_and_configure_window_logits_1d_shift_exp_sum(input->info(), max->info(), output->info(), sum->info());
694     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
695     INEKernel::configure(win_config.second);
696 }
697
698 Status NELogits1DShiftExpSumKernel::validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum, float beta)
699 {
700     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_shift_exp_sum(input, max, output, sum, beta));
701     ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_shift_exp_sum(input->clone().get(), max->clone().get(), output->clone().get(), sum->clone().get()).first);
702
703     return Status{};
704 }
705
706 void NELogits1DShiftExpSumKernel::run(const Window &window, const ThreadInfo &info)
707 {
708     ARM_COMPUTE_UNUSED(info);
709     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
710     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
711     ARM_COMPUTE_ERROR_ON(_func == nullptr);
712
713     (*_func)(_input, _max, _output, _sum, window, _beta);
714 }
715
716 namespace
717 {
718 void logits_1d_norm_qs8(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
719 {
720     Window window_sum(window);
721     window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
722     Window sum_slice = window_sum.first_slice_window_1D();
723     Window in_slice  = window.first_slice_window_1D();
724
725     const int fixed_point_position = in->info()->fixed_point_position();
726
727     do
728     {
729         Iterator input(in, in_slice);
730         Iterator _sum(sum, sum_slice);
731         Iterator output(out, in_slice);
732
733         const int8_t     sum_value        = *reinterpret_cast<const qint8_t *>(_sum.ptr());
734         const qint8x16_t vec_sum_inversed = vqrecipq_qs8(vdupq_n_qs8(sum_value), fixed_point_position);
735
736         execute_window_loop(in_slice, [&](const Coordinates & id)
737         {
738             const auto in_ptr  = reinterpret_cast<const qint8_t *>(input.ptr());
739             const auto out_ptr = reinterpret_cast<qint8_t *>(output.ptr());
740
741             const qint8x16_t vec_in           = vld1q_qs8(in_ptr);
742             const qint8x16_t normalized_value = vqmulq_qs8(vec_in, vec_sum_inversed, fixed_point_position);
743
744             vst1q_qs8(out_ptr, normalized_value);
745         },
746         input, output);
747     }
748     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
749 }
750 void logits_1d_norm_qs16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
751 {
752     Window window_sum(window);
753     window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
754     Window sum_slice = window_sum.first_slice_window_1D();
755     Window in_slice  = window.first_slice_window_1D();
756
757     const int fixed_point_position = in->info()->fixed_point_position();
758
759     do
760     {
761         Iterator input(in, in_slice);
762         Iterator _sum(sum, sum_slice);
763         Iterator output(out, in_slice);
764
765         const int16_t    sum_value        = *reinterpret_cast<const qint16_t *>(_sum.ptr());
766         const qint16x8_t vec_sum_inversed = vqrecipq_qs16(vdupq_n_qs16(sum_value), fixed_point_position);
767
768         execute_window_loop(in_slice, [&](const Coordinates & id)
769         {
770             const auto in_ptr  = reinterpret_cast<const qint16_t *>(input.ptr());
771             const auto out_ptr = reinterpret_cast<qint16_t *>(output.ptr());
772
773             const qint16x8_t vec_in           = vld1q_qs16(in_ptr);
774             const qint16x8_t normalized_value = vqmulq_qs16(vec_in, vec_sum_inversed, fixed_point_position);
775
776             vst1q_qs16(out_ptr, normalized_value);
777         },
778         input, output);
779     }
780     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
781 }
782 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
783 void logits_1d_norm_f16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
784 {
785     Window window_sum(window);
786     window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
787     Window sum_slice = window_sum.first_slice_window_1D();
788     Window in_slice  = window.first_slice_window_1D();
789
790     do
791     {
792         Iterator input(in, in_slice);
793         Iterator _sum(sum, sum_slice);
794         Iterator output(out, in_slice);
795
796         const float16_t   sum_value        = *reinterpret_cast<const qint16_t *>(_sum.ptr());
797         const float16x8_t vec_sum_inversed = vdupq_n_f16(1.0f / sum_value);
798
799         execute_window_loop(in_slice, [&](const Coordinates & id)
800         {
801             const auto in_ptr  = reinterpret_cast<const float16_t *>(input.ptr());
802             const auto out_ptr = reinterpret_cast<float16_t *>(output.ptr());
803
804             const float16x8_t vec_in           = vld1q_f16(in_ptr);
805             const float16x8_t normalized_value = vmulq_f16(vec_in, vec_sum_inversed);
806
807             vst1q_f16(out_ptr, normalized_value);
808         },
809         input, output);
810     }
811     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
812 }
813 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
814
815 void logits_1d_norm_f32(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
816 {
817     Window window_sum(window);
818     window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
819     Window sum_slice = window_sum.first_slice_window_1D();
820     Window in_slice  = window.first_slice_window_1D();
821
822     do
823     {
824         Iterator input(in, in_slice);
825         Iterator _sum(sum, sum_slice);
826         Iterator output(out, in_slice);
827
828         const float       sum_value        = *reinterpret_cast<const float *>(_sum.ptr());
829         const float32x4_t vec_sum_inversed = vdupq_n_f32(1.0f / sum_value);
830
831         execute_window_loop(in_slice, [&](const Coordinates & id)
832         {
833             const auto in_ptr  = reinterpret_cast<const float *>(input.ptr());
834             const auto out_ptr = reinterpret_cast<float *>(output.ptr());
835
836             const float32x4_t vec_in           = vld1q_f32(in_ptr);
837             const float32x4_t normalized_value = vmulq_f32(vec_in, vec_sum_inversed);
838
839             vst1q_f32(out_ptr, normalized_value);
840         },
841         input, output);
842     }
843     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
844 }
845 } // namespace
846
847 NELogits1DNormKernel::NELogits1DNormKernel()
848     : _func(nullptr), _input(nullptr), _sum(nullptr), _output(nullptr)
849 {
850 }
851
852 void NELogits1DNormKernel::configure(const ITensor *input, const ITensor *sum, ITensor *output)
853 {
854     ARM_COMPUTE_ERROR_ON_NULLPTR(input, sum, output);
855
856     // Output auto initialization if not yet initialized
857     auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
858
859     // Perform validation step
860     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_norm(input->info(), sum->info(), output->info()));
861
862     _input  = input;
863     _sum    = sum;
864     _output = output;
865
866     switch(input->info()->data_type())
867     {
868         case DataType::QS8:
869             _func = &logits_1d_norm_qs8;
870             break;
871         case DataType::QS16:
872             _func = &logits_1d_norm_qs16;
873             break;
874         case DataType::F32:
875             _func = &logits_1d_norm_f32;
876             break;
877         case DataType::F16:
878 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
879             _func = &logits_1d_norm_f16;
880             break;
881 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
882         default:
883             ARM_COMPUTE_ERROR("Unsupported data type.");
884             break;
885     }
886
887     // Configure kernel window
888     auto win_config = validate_and_configure_window_logits_1d_norm(input->info(), sum->info(), output->info());
889     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
890     INEKernel::configure(win_config.second);
891 }
892
893 Status NELogits1DNormKernel::validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output)
894 {
895     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_norm(input, sum, output));
896     ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_norm(input->clone().get(), sum->clone().get(), output->clone().get()).first);
897
898     return Status{};
899 }
900
901 void NELogits1DNormKernel::run(const Window &window, const ThreadInfo &info)
902 {
903     ARM_COMPUTE_UNUSED(info);
904     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
905     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
906     ARM_COMPUTE_ERROR_ON(_func == nullptr);
907
908     (*_func)(_input, _sum, _output, window);
909 }