Imported Upstream version 1.19.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / pal / cmsisnn / PALSoftmax.h
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef LUCI_INTERPRETER_PAL_SOFTMAX_H
18 #define LUCI_INTERPRETER_PAL_SOFTMAX_H
19
20 #include <tensorflow/lite/kernels/internal/reference/softmax.h>
21 #include <arm_nnfunctions.h>
22
23 namespace luci_interpreter_pal
24 {
25 static inline void PopulateSoftmaxLookupTable(tflite::SoftmaxParams *data, float input_scale,
26                                               float beta)
27 {
28   // Do nothing for mcu
29   (void)data;
30   (void)input_scale;
31   (void)beta;
32 }
33
34 static inline void InitializeParams(tflite::SoftmaxParams *params, float input_scale, float beta)
35 {
36   int32 input_beta_multiplier;
37   int input_beta_left_shift;
38   static const int kScaledDiffIntegerBits = 5;
39   tflite::PreprocessSoftmaxScaling(beta, input_scale, kScaledDiffIntegerBits,
40                                    &input_beta_multiplier, &input_beta_left_shift);
41
42   params->input_multiplier = input_beta_multiplier;
43   params->input_left_shift = input_beta_left_shift;
44   params->diff_min =
45     -tflite::CalculateInputRadius(kScaledDiffIntegerBits, params->input_left_shift);
46 }
47
48 template <typename T>
49 static inline void Softmax(const tflite::SoftmaxParams &params,
50                            const tflite::RuntimeShape &input_shape, const T *input_data,
51                            const tflite::RuntimeShape &output_shape, T *output_data)
52 {
53   // MARK: At this moment this operation doesn't support on mcu
54   assert(false && "Softmax NYI");
55   (void)params;
56   (void)input_shape;
57   (void)input_data;
58   (void)output_shape;
59   (void)output_data;
60 }
61
62 template <>
63 inline void Softmax<int8_t>(const tflite::SoftmaxParams &params,
64                             const tflite::RuntimeShape &input_shape, const int8_t *input_data,
65                             const tflite::RuntimeShape &output_shape, int8_t *output_data)
66 {
67   const int trailing_dim = input_shape.DimensionsCount() - 1;
68   const int outer_size = tflite::MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
69   const int depth = tflite::MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
70   const int32_t mult = params.input_multiplier;
71   const int32_t shift = params.input_left_shift;
72   const int32_t diff_min = params.diff_min;
73
74   arm_softmax_s8(input_data, outer_size, depth, mult, shift, diff_min, output_data);
75 }
76 } // namespace luci_interpreter_pal
77
78 #endif // LUCI_INTERPRETER_PAL_SOFTMAX_H