bb394619e1d4e3ef98867cd10bcbf93b00a74f0a
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / SoftMax.h
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_SOFTMAX_H__
19 #define __NNFW_CKER_SOFTMAX_H__
20
21 #include "cker/Shape.h"
22 #include "cker/Utils.h"
23 #include "cker/Types.h"
24 #include "cker/eigen/Utils.h"
25
26 #include <Eigen/Core>
27 #include <fixedpoint/fixedpoint.h>
28 #include <cmath>
29
30 namespace nnfw
31 {
32 namespace cker
33 {
34
35 inline void Softmax(const SoftmaxParams &params, const Shape &input_shape, const float *input_data,
36                     const Shape &output_shape, float *output_data)
37 {
38   // Validate whether if shapes of input and output are the same
39   MatchingFlatSize(input_shape, output_shape);
40
41   const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
42   auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
43   // Compute the exponential first, removing the max coefficient for numerical
44   // stability.
45   out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * params.beta;
46   // We are separating out the exp function so that exp can be vectorized.
47   out_mat = out_mat.array().exp();
48   // Normalize to get the activations.
49   Eigen::Array<float, 1, Eigen::Dynamic> scale = out_mat.array().colwise().sum().inverse();
50   out_mat.array().rowwise() *= scale;
51 }
52
53 inline void Softmax(const SoftmaxParams &params, const Shape &input_shape,
54                     const uint8_t *input_data, const Shape &output_shape, uint8_t *output_data)
55 {
56   const int32_t input_beta_multiplier = params.input_multiplier;
57   const int32_t input_beta_left_shift = params.input_left_shift;
58   const int diff_min = params.diff_min;
59   // The representation chosen for the input to the exp() function is Q5.26.
60   // We need to leave extra space since values that we skip might be as large as
61   // -32 before multiplying by input_beta_multiplier, and therefore as large as
62   // -16 afterwards.  Note that exp(-8) is definitely not insignificant to
63   // accumulation, but exp(-16) definitely is.
64   static const int kScaledDiffIntegerBits = 5;
65   static const int kAccumulationIntegerBits = 12;
66   using FixedPointScaledDiff = gemmlowp::FixedPoint<int32_t, kScaledDiffIntegerBits>;
67   using FixedPointAccum = gemmlowp::FixedPoint<int32_t, kAccumulationIntegerBits>;
68   using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
69
70   const int trailing_dim = input_shape.DimensionsCount() - 1;
71   const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
72   const int depth = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
73
74   for (int i = 0; i < outer_size; ++i)
75   {
76     uint8_t max_in_row = 0;
77     for (int c = 0; c < depth; ++c)
78     {
79       max_in_row = std::max(max_in_row, input_data[i * depth + c]);
80     }
81
82     FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
83     for (int c = 0; c < depth; ++c)
84     {
85       int32_t input_diff = static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
86       if (input_diff >= diff_min)
87       {
88         const int32_t input_diff_rescaled = MultiplyByQuantizedMultiplierGreaterThanOne(
89             input_diff, input_beta_multiplier, input_beta_left_shift);
90         const FixedPointScaledDiff scaled_diff_f8 =
91             FixedPointScaledDiff::FromRaw(input_diff_rescaled);
92         sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
93                                         exp_on_negative_values(scaled_diff_f8));
94       }
95     }
96
97     int32_t fixed_sum_of_exps = sum_of_exps.raw();
98     int headroom_plus_one = CountLeadingZeros(static_cast<uint32_t>(fixed_sum_of_exps));
99     // This is the number of bits to the left of the binary point above 1.0.
100     // Consider fixed_sum_of_exps=1.25.  In that case shifted_scale=0.8 and
101     // no later adjustment will be needed.
102     int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
103     int32_t shifted_sum_minus_one =
104         static_cast<int32_t>((static_cast<uint32_t>(fixed_sum_of_exps) << headroom_plus_one) -
105                              (static_cast<uint32_t>(1) << 31));
106
107     FixedPoint0 shifted_scale =
108         one_over_one_plus_x_for_x_in_0_1(FixedPoint0::FromRaw(shifted_sum_minus_one));
109
110     for (int c = 0; c < depth; ++c)
111     {
112       int32_t input_diff = static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
113       if (input_diff >= diff_min)
114       {
115         const int32_t input_diff_rescaled = MultiplyByQuantizedMultiplierGreaterThanOne(
116             input_diff, input_beta_multiplier, input_beta_left_shift);
117         const FixedPointScaledDiff scaled_diff_f8 =
118             FixedPointScaledDiff::FromRaw(input_diff_rescaled);
119
120         FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
121         int32_t unsat_output = gemmlowp::RoundingDivideByPOT((shifted_scale * exp_in_0).raw(),
122                                                              num_bits_over_unit + 31 - 8);
123
124         output_data[i * depth + c] = static_cast<uint8_t>(
125             std::max(std::min(unsat_output, static_cast<int32_t>(255)), static_cast<int32_t>(0)));
126       }
127       else
128       {
129         output_data[i * depth + c] = 0;
130       }
131     }
132   }
133 }
134
135 } // namespace cker
136 } // namespace nnfw
137
138 #endif // __NNFW_CKER_SOFTMAX_H__