1fb18624828196b78d5212e2f2abc6b00da161fb
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / cmsisnn / PALSoftmax.h
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_PAL_SOFTMAX_H
19 #define LUCI_INTERPRETER_PAL_SOFTMAX_H
20
21 #include <tensorflow/lite/kernels/internal/reference/softmax.h>
22 #include <arm_nnfunctions.h>
23
24 namespace luci_interpreter_pal
25 {
26 static inline void PopulateSoftmaxLookupTable(tflite::SoftmaxParams *data, float input_scale,
27                                               float beta)
28 {
29   // Do nothing for mcu
30   (void)data;
31   (void)input_scale;
32   (void)beta;
33 }
34
35 static inline void InitializeParams(tflite::SoftmaxParams *params, float input_scale, float beta)
36 {
37   int32 input_beta_multiplier;
38   int input_beta_left_shift;
39   static const int kScaledDiffIntegerBits = 5;
40   tflite::PreprocessSoftmaxScaling(beta, input_scale, kScaledDiffIntegerBits,
41                                    &input_beta_multiplier, &input_beta_left_shift);
42
43   params->input_multiplier = input_beta_multiplier;
44   params->input_left_shift = input_beta_left_shift;
45   params->diff_min =
46     -tflite::CalculateInputRadius(kScaledDiffIntegerBits, params->input_left_shift);
47 }
48
49 template <typename T>
50 static inline void Softmax(const tflite::SoftmaxParams &params,
51                            const tflite::RuntimeShape &input_shape, const T *input_data,
52                            const tflite::RuntimeShape &output_shape, T *output_data)
53 {
54   // MARK: At this moment this operation doesn't support on mcu
55   assert(false && "Softmax NYI");
56   (void)params;
57   (void)input_shape;
58   (void)input_data;
59   (void)output_shape;
60   (void)output_data;
61 }
62
63 template <>
64 inline void Softmax<int8_t>(const tflite::SoftmaxParams &params,
65                             const tflite::RuntimeShape &input_shape, const int8_t *input_data,
66                             const tflite::RuntimeShape &output_shape, int8_t *output_data)
67 {
68   const int trailing_dim = input_shape.DimensionsCount() - 1;
69   const int outer_size = tflite::MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
70   const int depth = tflite::MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
71   const int32_t mult = params.input_multiplier;
72   const int32_t shift = params.input_left_shift;
73   const int32_t diff_min = params.diff_min;
74
75   arm_softmax_s8(input_data, outer_size, depth, mult, shift, diff_min, output_data);
76 }
77 } // namespace luci_interpreter_pal
78
79 #endif // LUCI_INTERPRETER_PAL_SOFTMAX_H