2 * Copyright (c) 2017 ARM Limited.
4 * SPDX-License-Identifier: MIT
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 #include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
26 #include "arm_compute/core/AccessWindowStatic.h"
27 #include "arm_compute/core/Error.h"
28 #include "arm_compute/core/Helpers.h"
29 #include "arm_compute/core/ITensor.h"
30 #include "arm_compute/core/NEON/NEFixedPoint.h"
31 #include "arm_compute/core/NEON/NEMath.h"
32 #include "arm_compute/core/TensorInfo.h"
33 #include "arm_compute/core/Utils.h"
34 #include "arm_compute/core/Validate.h"
35 #include "arm_compute/core/Window.h"
41 using namespace arm_compute;
45 Status validate_arguments_logits_1d_max(const ITensorInfo *input, const ITensorInfo *output)
47 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
48 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
50 // Checks performed when output is configured
51 if(output->total_size() != 0)
53 // Softmax across the x dimension
54 TensorShape output_shape{ input->tensor_shape() };
55 output_shape.set(0, 1);
57 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
58 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
59 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
65 std::pair<Status, Window> validate_and_configure_window_logits_1d_max(ITensorInfo *input, ITensorInfo *output)
67 // Configure kernel window
68 constexpr unsigned int num_elems_written_per_row = 1;
69 const int input_width = input->valid_region().shape.x();
71 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
72 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
73 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
74 bool window_changed = false;
76 if(output->total_size() != 0)
78 AccessWindowHorizontal output_access(output, 0, num_elems_written_per_row, 1.f / input_width);
79 window_changed = update_window_and_padding(win, input_access, output_access);
80 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
84 window_changed = update_window_and_padding(win, input_access);
87 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
88 return std::make_pair(err, win);
91 Status validate_arguments_logits_1d_shift_exp_sum(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum, float beta)
93 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, max, sum, output);
94 ARM_COMPUTE_RETURN_ERROR_ON((beta != 1.0f) && is_data_type_fixed_point(input->data_type()));
95 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
97 // Checks performed when output is configured
98 if(output->total_size() != 0)
100 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
101 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
102 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
105 // Checks performed when sum is configured
106 if(sum->total_size() != 0)
108 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, max, sum);
109 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(max, sum);
110 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, max, sum);
116 std::pair<Status, Window> validate_and_configure_window_logits_1d_shift_exp_sum(ITensorInfo *input, ITensorInfo *max, ITensorInfo *output, ITensorInfo *sum)
118 unsigned int num_elems_processed_per_iteration = input->valid_region().shape.x();
120 // Configure kernel window
121 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
122 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
123 AccessWindowHorizontal max_access(max, 0, 1);
124 AccessWindowHorizontal sum_access(sum, 0, 1);
125 bool window_changed = false;
127 if(output->total_size() != 0)
129 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
130 window_changed = update_window_and_padding(win, input_access, max_access, output_access, sum_access);
131 output_access.set_valid_region(win, input->valid_region());
135 window_changed = update_window_and_padding(win, input_access, max_access, sum_access);
138 sum_access.set_valid_region(win, ValidRegion(Coordinates(), sum->tensor_shape()));
140 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
141 return std::make_pair(err, win);
144 Status validate_arguments_logits_1d_norm(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output)
146 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, sum, output);
147 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::S32, DataType::F16, DataType::F32);
148 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum);
149 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, sum);
151 // Checks performed when output is configured
152 if(output->total_size() != 0)
154 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
155 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
156 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
162 std::pair<Status, Window> validate_and_configure_window_logits_1d_norm(ITensorInfo *input, ITensorInfo *sum, ITensorInfo *output)
164 // Configure kernel window
165 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
166 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
168 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
169 AccessWindowStatic sum_access(sum, 0, 0, 1, sum->dimension(1));
170 bool window_changed = false;
172 if(output->total_size() != 0)
174 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
176 window_changed = update_window_and_padding(win, input_access, sum_access, output_access);
178 output_access.set_valid_region(win, input->valid_region());
182 window_changed = update_window_and_padding(win, input_access, sum_access);
184 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
185 return std::make_pair(err, win);
188 void logits_1d_max_qs8(const ITensor *in, ITensor *out, const Window &window)
190 Window in_slice = window.first_slice_window_1D();
192 Window window_max(window);
193 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
194 Window max_slice = window_max.first_slice_window_1D();
198 Iterator input(in, in_slice);
199 Iterator output(out, max_slice);
201 qint8x16_t vec_max = vdupq_n_s8(std::numeric_limits<qint8_t>::lowest());
203 execute_window_loop(in_slice, [&](const Coordinates & id)
205 const auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
206 const qint8x16_t current_value = vld1q_qs8(in_ptr);
207 vec_max = vmaxq_qs8(vec_max, current_value);
211 qint8x8_t carry_max = vpmax_qs8(vget_high_s8(vec_max), vget_low_s8(vec_max));
212 carry_max = vpmax_qs8(carry_max, carry_max);
213 carry_max = vpmax_qs8(carry_max, carry_max);
214 carry_max = vpmax_qs8(carry_max, carry_max);
216 *(reinterpret_cast<qint8_t *>(output.ptr())) = vget_lane_s8(carry_max, 0);
218 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
220 void logits_1d_max_qs16(const ITensor *in, ITensor *out, const Window &window)
222 Window in_slice = window.first_slice_window_1D();
224 Window window_max(window);
225 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
226 Window max_slice = window_max.first_slice_window_1D();
230 Iterator input(in, in_slice);
231 Iterator output(out, max_slice);
233 qint16x8_t vec_max = vdupq_n_qs16(std::numeric_limits<qint16_t>::lowest());
235 execute_window_loop(in_slice, [&](const Coordinates & id)
237 const auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
238 const qint16x8_t current_value = vld1q_qs16(in_ptr);
239 vec_max = vmaxq_qs16(vec_max, current_value);
243 qint16x4_t carry_max = vpmax_qs16(vget_high_qs16(vec_max), vget_low_qs16(vec_max));
244 carry_max = vpmax_qs16(carry_max, carry_max);
245 carry_max = vpmax_qs16(carry_max, carry_max);
247 *(reinterpret_cast<qint16_t *>(output.ptr())) = vget_lane_s16(carry_max, 0);
249 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
252 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
253 void logits_1d_max_f16(const ITensor *in, ITensor *out, const Window &window)
255 Window in_slice = window.first_slice_window_1D();
257 Window window_max(window);
258 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
259 Window max_slice = window_max.first_slice_window_1D();
263 Iterator input(in, in_slice);
264 Iterator output(out, max_slice);
266 float16x8_t vec_max = vdupq_n_f16(std::numeric_limits<float16_t>::lowest());
268 execute_window_loop(in_slice, [&](const Coordinates & id)
270 const auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
271 const float16x8_t current_value = vld1q_f16(in_ptr);
272 vec_max = vmaxq_f16(vec_max, current_value);
276 float16x4_t carry_max = vpmax_f16(vget_high_f16(vec_max), vget_low_f16(vec_max));
277 carry_max = vpmax_f16(carry_max, carry_max);
278 carry_max = vpmax_f16(carry_max, carry_max);
280 *(reinterpret_cast<float16_t *>(output.ptr())) = vget_lane_f16(carry_max, 0);
282 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
284 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
286 void logits_1d_max_f32(const ITensor *in, ITensor *out, const Window &window)
288 Window in_slice = window.first_slice_window_1D();
290 Window window_max(window);
291 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
292 Window max_slice = window_max.first_slice_window_1D();
296 Iterator input(in, in_slice);
297 Iterator output(out, max_slice);
299 float32x4_t vec_max = vdupq_n_f32(-FLT_MAX);
301 execute_window_loop(in_slice, [&](const Coordinates & id)
303 const auto in_ptr = reinterpret_cast<const float *>(input.ptr());
304 const float32x4_t current_value = vld1q_f32(in_ptr);
305 vec_max = vmaxq_f32(vec_max, current_value);
309 float32x2_t carry_max = vpmax_f32(vget_high_f32(vec_max), vget_low_f32(vec_max));
310 carry_max = vpmax_f32(carry_max, carry_max);
312 *(reinterpret_cast<float *>(output.ptr())) = vget_lane_f32(carry_max, 0);
314 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
318 NELogits1DMaxKernel::NELogits1DMaxKernel()
319 : _func(nullptr), _border_size()
323 BorderSize NELogits1DMaxKernel::border_size() const
328 void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
330 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
332 // Softmax across the x dimension
333 TensorShape output_shape{ input->info()->tensor_shape() };
334 output_shape.set(0, 1);
336 // Output auto initialization if not yet initialized
337 auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
339 // Perform validation step
340 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(input->info(), output->info()));
342 const int input_width = input->info()->valid_region().shape.x();
343 unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
345 switch(input->info()->data_type())
348 _func = &logits_1d_max_qs8;
351 _func = &logits_1d_max_qs16;
354 _func = &logits_1d_max_f32;
357 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
358 _func = &logits_1d_max_f16;
360 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
362 ARM_COMPUTE_ERROR("Unsupported data type.");
367 _border_size = BorderSize(0, num_elems_processed_per_iteration - (input_width % num_elems_processed_per_iteration), 0, 0);
369 // Configure kernel window
370 auto win_config = validate_and_configure_window_logits_1d_max(input->info(), output->info());
371 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
372 INEKernel::configure(win_config.second);
375 Status NELogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
377 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(input, output));
378 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_max(input->clone().get(), output->clone().get()).first);
383 void NELogits1DMaxKernel::run(const Window &window, const ThreadInfo &info)
385 ARM_COMPUTE_UNUSED(info);
386 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
387 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
388 ARM_COMPUTE_ERROR_ON(_func == nullptr);
390 (*_func)(_input, _output, window);
395 void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
397 ARM_COMPUTE_UNUSED(beta);
399 Window window_max(window);
400 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
402 Window max_slice = window_max.first_slice_window_1D();
403 Window in_slice = window.first_slice_window_1D();
405 constexpr int step = 8;
406 const int long_steps = in->info()->valid_region().shape.x() / step;
407 const int small_steps = in->info()->valid_region().shape.x() % step;
408 const int fixed_point_position = in->info()->fixed_point_position();
412 Iterator input(in, in_slice);
413 Iterator exp(out, in_slice);
414 Iterator _max(max, max_slice);
415 Iterator _sum(sum, max_slice);
418 auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
419 auto exp_ptr = reinterpret_cast<qint8_t *>(exp.ptr());
422 qint16x8_t vec_sum_value = vdupq_n_qs16(0);
425 const auto max_ptr = reinterpret_cast<const qint8_t *>(_max.ptr());
426 const qint8x8_t vec_max = vdup_n_qs8(*max_ptr);
429 for(int i = 0; i < long_steps; ++i)
431 qint8x8_t vec_elements = vld1_qs8(in_ptr);
432 vec_elements = vqsub_qs8(vec_elements, vec_max);
433 vec_elements = vqexp_qs8(vec_elements, fixed_point_position);
435 vst1_qs8(exp_ptr, vec_elements);
436 vec_sum_value = vqaddq_qs16(vec_sum_value, vmovl_s8(vec_elements));
442 const qint16x4_t sum_red = vqadd_qs16(vget_low_s16(vec_sum_value), vget_high_s16(vec_sum_value));
443 const qint16_t sum0 = sqadd_qs16(vget_lane_s16(sum_red, 0), vget_lane_s16(sum_red, 1));
444 const qint16_t sum1 = sqadd_qs16(vget_lane_s16(sum_red, 2), vget_lane_s16(sum_red, 3));
445 qint16_t sum = sqadd_qs16(sum0, sum1);
447 // Run remaining elements
448 for(int i = 0; i < small_steps; ++i)
450 qint8_t element = sqexp_qs8(sqsub_qs8(in_ptr[i], *max_ptr), fixed_point_position);
451 exp_ptr[i] = element;
452 sum = sqadd_qs16(sum, element);
455 *(reinterpret_cast<qint8_t *>(_sum.ptr())) = sqmovn_qs16(sum);
457 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
459 void logits_1d_shift_exp_sum_qs16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
461 ARM_COMPUTE_UNUSED(beta);
463 Window window_max(window);
464 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
466 Window max_slice = window_max.first_slice_window_1D();
467 Window in_slice = window.first_slice_window_1D();
469 constexpr int step = 4;
470 const int long_steps = in->info()->valid_region().shape.x() / step;
471 const int small_steps = in->info()->valid_region().shape.x() % step;
472 const int fixed_point_position = in->info()->fixed_point_position();
476 Iterator input(in, in_slice);
477 Iterator exp(out, in_slice);
478 Iterator _max(max, max_slice);
479 Iterator _sum(sum, max_slice);
482 auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
483 auto exp_ptr = reinterpret_cast<qint16_t *>(exp.ptr());
486 qint32x4_t vec_sum_value = vdupq_n_qs32(0);
489 const auto max_ptr = reinterpret_cast<const qint16_t *>(_max.ptr());
490 const qint16x4_t vec_max = vdup_n_qs16(*max_ptr);
493 for(int i = 0; i < long_steps; ++i)
495 qint16x4_t vec_elements = vld1_qs16(in_ptr);
496 vec_elements = vqsub_qs16(vec_elements, vec_max);
497 vec_elements = vqexp_qs16(vec_elements, fixed_point_position);
499 vst1_qs16(exp_ptr, vec_elements);
500 vec_sum_value = vqaddq_qs32(vec_sum_value, vmovl_s16(vec_elements));
506 qint32x2_t carry_addition = vqadd_qs32(vget_high_s32(vec_sum_value), vget_low_s32(vec_sum_value));
507 qint32_t sum = vget_lane_s32(carry_addition, 0) + vget_lane_s32(carry_addition, 1);
509 // Run remaining elements
510 for(int i = 0; i < small_steps; ++i)
512 qint16_t element = sqexp_qs16(sqsub_qs16(in_ptr[i], *max_ptr), fixed_point_position);
513 exp_ptr[i] = element;
514 sum = sqadd_qs32(sum, element);
517 *(reinterpret_cast<qint16_t *>(_sum.ptr())) = sqmovn_qs32(sum);
519 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
522 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
523 void logits_1d_shift_exp_sum_f16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
525 Window window_max(window);
526 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
528 Window max_slice = window_max.first_slice_window_1D();
529 Window in_slice = window.first_slice_window_1D();
531 constexpr int step = 8;
532 const int long_steps = in->info()->valid_region().shape.x() / step;
533 const int small_steps = in->info()->valid_region().shape.x() % step;
537 Iterator input(in, in_slice);
538 Iterator exp(out, in_slice);
539 Iterator _max(max, max_slice);
540 Iterator _sum(sum, max_slice);
543 auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
544 auto exp_ptr = reinterpret_cast<float16_t *>(exp.ptr());
547 float16x8_t vec_sum_value = vdupq_n_f16(0);
550 const auto max_ptr = reinterpret_cast<const float16_t *>(_max.ptr());
551 const float16x8_t vec_max = vdupq_n_f16(*max_ptr);
554 for(int i = 0; i < long_steps; ++i)
556 float16x8_t vec_elements = vld1q_f16(in_ptr);
557 vec_elements = vsubq_f16(vec_elements, vec_max);
558 vec_elements = vmulq_n_f16(vec_elements, beta);
559 vec_elements = vexpq_f16(vec_elements);
561 vst1q_f16(exp_ptr, vec_elements);
562 vec_sum_value = vaddq_f16(vec_sum_value, vec_elements);
568 const float16x4_t sum_red = vadd_f16(vget_low_f16(vec_sum_value), vget_high_f16(vec_sum_value));
569 const float16x4_t carry_addition = vpadd_f16(sum_red, sum_red);
570 float16_t sum = vget_lane_f16(carry_addition, 0) + vget_lane_f16(carry_addition, 1);
572 // Run remaining elements
573 for(int i = 0; i < small_steps; ++i)
575 const float16_t element = std::exp(static_cast<float>(in_ptr[i] - *max_ptr) * beta);
576 exp_ptr[i] = element;
579 *(reinterpret_cast<float16_t *>(_sum.ptr())) = sum;
581 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
583 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
585 void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
587 Window window_max(window);
588 window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
590 Window max_slice = window_max.first_slice_window_1D();
591 Window in_slice = window.first_slice_window_1D();
593 constexpr int step = 4;
594 const int long_steps = in->info()->valid_region().shape.x() / step;
595 const int small_steps = in->info()->valid_region().shape.x() % step;
599 Iterator input(in, in_slice);
600 Iterator exp(out, in_slice);
601 Iterator _max(max, max_slice);
602 Iterator _sum(sum, max_slice);
605 auto in_ptr = reinterpret_cast<const float *>(input.ptr());
606 auto exp_ptr = reinterpret_cast<float *>(exp.ptr());
609 float32x4_t vec_sum_value = vdupq_n_f32(0.0f);
612 const auto max_ptr = reinterpret_cast<const float *>(_max.ptr());
613 const float32x4_t vec_max = vdupq_n_f32(*max_ptr);
616 for(int i = 0; i < long_steps; ++i)
618 float32x4_t vec_elements = vld1q_f32(in_ptr);
619 vec_elements = vsubq_f32(vec_elements, vec_max);
620 vec_elements = vmulq_n_f32(vec_elements, beta);
621 vec_elements = vexpq_f32(vec_elements);
623 vst1q_f32(exp_ptr, vec_elements);
624 vec_sum_value = vaddq_f32(vec_elements, vec_sum_value);
631 float32x2_t carry_addition = vpadd_f32(vget_high_f32(vec_sum_value), vget_low_f32(vec_sum_value));
632 carry_addition = vpadd_f32(carry_addition, carry_addition);
633 float sum = vget_lane_f32(carry_addition, 0);
635 // Run remaining elements
636 for(int i = 0; i < small_steps; ++i)
638 float element = std::exp((in_ptr[i] - *max_ptr) * beta);
639 exp_ptr[i] = element;
643 *(reinterpret_cast<float *>(_sum.ptr())) = sum;
645 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
649 NELogits1DShiftExpSumKernel::NELogits1DShiftExpSumKernel()
650 : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _sum(nullptr), _beta(1.0f)
654 void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum, float beta)
656 ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, sum, output);
658 // Output auto initialization if not yet initialized
659 auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
660 auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
662 // Perform validation step
663 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_shift_exp_sum(input->info(), max->info(), output->info(), sum->info(), beta));
665 switch(input->info()->data_type())
668 _func = &logits_1d_shift_exp_sum_qs8;
671 _func = &logits_1d_shift_exp_sum_qs16;
674 _func = &logits_1d_shift_exp_sum_f32;
677 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
678 _func = &logits_1d_shift_exp_sum_f16;
680 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
682 ARM_COMPUTE_ERROR("Unsupported data type.");
692 // Configure kernel window
693 auto win_config = validate_and_configure_window_logits_1d_shift_exp_sum(input->info(), max->info(), output->info(), sum->info());
694 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
695 INEKernel::configure(win_config.second);
698 Status NELogits1DShiftExpSumKernel::validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum, float beta)
700 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_shift_exp_sum(input, max, output, sum, beta));
701 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_shift_exp_sum(input->clone().get(), max->clone().get(), output->clone().get(), sum->clone().get()).first);
706 void NELogits1DShiftExpSumKernel::run(const Window &window, const ThreadInfo &info)
708 ARM_COMPUTE_UNUSED(info);
709 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
710 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
711 ARM_COMPUTE_ERROR_ON(_func == nullptr);
713 (*_func)(_input, _max, _output, _sum, window, _beta);
718 void logits_1d_norm_qs8(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
720 Window window_sum(window);
721 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
722 Window sum_slice = window_sum.first_slice_window_1D();
723 Window in_slice = window.first_slice_window_1D();
725 const int fixed_point_position = in->info()->fixed_point_position();
729 Iterator input(in, in_slice);
730 Iterator _sum(sum, sum_slice);
731 Iterator output(out, in_slice);
733 const int8_t sum_value = *reinterpret_cast<const qint8_t *>(_sum.ptr());
734 const qint8x16_t vec_sum_inversed = vqrecipq_qs8(vdupq_n_qs8(sum_value), fixed_point_position);
736 execute_window_loop(in_slice, [&](const Coordinates & id)
738 const auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
739 const auto out_ptr = reinterpret_cast<qint8_t *>(output.ptr());
741 const qint8x16_t vec_in = vld1q_qs8(in_ptr);
742 const qint8x16_t normalized_value = vqmulq_qs8(vec_in, vec_sum_inversed, fixed_point_position);
744 vst1q_qs8(out_ptr, normalized_value);
748 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
750 void logits_1d_norm_qs16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
752 Window window_sum(window);
753 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
754 Window sum_slice = window_sum.first_slice_window_1D();
755 Window in_slice = window.first_slice_window_1D();
757 const int fixed_point_position = in->info()->fixed_point_position();
761 Iterator input(in, in_slice);
762 Iterator _sum(sum, sum_slice);
763 Iterator output(out, in_slice);
765 const int16_t sum_value = *reinterpret_cast<const qint16_t *>(_sum.ptr());
766 const qint16x8_t vec_sum_inversed = vqrecipq_qs16(vdupq_n_qs16(sum_value), fixed_point_position);
768 execute_window_loop(in_slice, [&](const Coordinates & id)
770 const auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
771 const auto out_ptr = reinterpret_cast<qint16_t *>(output.ptr());
773 const qint16x8_t vec_in = vld1q_qs16(in_ptr);
774 const qint16x8_t normalized_value = vqmulq_qs16(vec_in, vec_sum_inversed, fixed_point_position);
776 vst1q_qs16(out_ptr, normalized_value);
780 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
782 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
783 void logits_1d_norm_f16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
785 Window window_sum(window);
786 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
787 Window sum_slice = window_sum.first_slice_window_1D();
788 Window in_slice = window.first_slice_window_1D();
792 Iterator input(in, in_slice);
793 Iterator _sum(sum, sum_slice);
794 Iterator output(out, in_slice);
796 const float16_t sum_value = *reinterpret_cast<const qint16_t *>(_sum.ptr());
797 const float16x8_t vec_sum_inversed = vdupq_n_f16(1.0f / sum_value);
799 execute_window_loop(in_slice, [&](const Coordinates & id)
801 const auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
802 const auto out_ptr = reinterpret_cast<float16_t *>(output.ptr());
804 const float16x8_t vec_in = vld1q_f16(in_ptr);
805 const float16x8_t normalized_value = vmulq_f16(vec_in, vec_sum_inversed);
807 vst1q_f16(out_ptr, normalized_value);
811 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
813 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
815 void logits_1d_norm_f32(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
817 Window window_sum(window);
818 window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
819 Window sum_slice = window_sum.first_slice_window_1D();
820 Window in_slice = window.first_slice_window_1D();
824 Iterator input(in, in_slice);
825 Iterator _sum(sum, sum_slice);
826 Iterator output(out, in_slice);
828 const float sum_value = *reinterpret_cast<const float *>(_sum.ptr());
829 const float32x4_t vec_sum_inversed = vdupq_n_f32(1.0f / sum_value);
831 execute_window_loop(in_slice, [&](const Coordinates & id)
833 const auto in_ptr = reinterpret_cast<const float *>(input.ptr());
834 const auto out_ptr = reinterpret_cast<float *>(output.ptr());
836 const float32x4_t vec_in = vld1q_f32(in_ptr);
837 const float32x4_t normalized_value = vmulq_f32(vec_in, vec_sum_inversed);
839 vst1q_f32(out_ptr, normalized_value);
843 while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
847 NELogits1DNormKernel::NELogits1DNormKernel()
848 : _func(nullptr), _input(nullptr), _sum(nullptr), _output(nullptr)
852 void NELogits1DNormKernel::configure(const ITensor *input, const ITensor *sum, ITensor *output)
854 ARM_COMPUTE_ERROR_ON_NULLPTR(input, sum, output);
856 // Output auto initialization if not yet initialized
857 auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
859 // Perform validation step
860 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_norm(input->info(), sum->info(), output->info()));
866 switch(input->info()->data_type())
869 _func = &logits_1d_norm_qs8;
872 _func = &logits_1d_norm_qs16;
875 _func = &logits_1d_norm_f32;
878 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
879 _func = &logits_1d_norm_f16;
881 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
883 ARM_COMPUTE_ERROR("Unsupported data type.");
887 // Configure kernel window
888 auto win_config = validate_and_configure_window_logits_1d_norm(input->info(), sum->info(), output->info());
889 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
890 INEKernel::configure(win_config.second);
893 Status NELogits1DNormKernel::validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output)
895 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_norm(input, sum, output));
896 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_norm(input->clone().get(), sum->clone().get(), output->clone().get()).first);
901 void NELogits1DNormKernel::run(const Window &window, const ThreadInfo &info)
903 ARM_COMPUTE_UNUSED(info);
904 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
905 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
906 ARM_COMPUTE_ERROR_ON(_func == nullptr);
908 (*_func)(_input, _sum, _output, window);