c3c79be7d5099b21b9c97ed9669cd3ad00d606cd
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / PRelu.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/PRelu.h"
19
20 #include "kernels/BinaryOpCommon.h"
21 #include "kernels/Utils.h"
22
23 #include <tensorflow/lite/kernels/internal/reference/binary_function.h>
24 #include <tensorflow/lite/kernels/internal/reference/prelu.h>
25
26 namespace luci_interpreter
27 {
28
29 namespace kernels
30 {
31
32 PRelu::PRelu(const Tensor *input, const Tensor *alpha, Tensor *output)
33   : Kernel({input, alpha}, {output})
34 {
35 }
36
37 PRelu::~PRelu()
38 {
39   // Destructor declared to delete vector of alpha quantized data properly
40 }
41
42 void PRelu::configure()
43 {
44   LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
45   LUCI_INTERPRETER_CHECK(alpha()->element_type() == output()->element_type());
46   LUCI_INTERPRETER_CHECK(input()->scales().size() <= 1);
47   LUCI_INTERPRETER_CHECK(output()->scales().size() <= 1);
48
49   if (input()->element_type() == DataType::U8)
50   {
51     LUCI_INTERPRETER_CHECK(alpha()->scales().size() <= 1); // remove when CWQ kernel arrives
52     _alpha_multipliers.resize(1);
53     double alpha_multiplier = input()->scale() * alpha()->scale() / output()->scale();
54     quantizeMultiplier(alpha_multiplier, &_alpha_multipliers[0].multiplier,
55                        &_alpha_multipliers[0].shift);
56     double identity_multiplier = input()->scale() / output()->scale();
57     quantizeMultiplier(identity_multiplier, &_output_multiplier_identity, &_output_shift_identity);
58   }
59   else if (input()->element_type() == DataType::S16)
60   {
61     // Common check for correctness of quant params
62     LUCI_INTERPRETER_CHECK(input()->zero_point() == 0 && output()->zero_point() == 0);
63     for (size_t channel = 0; channel < alpha()->zero_points().size(); ++channel)
64     {
65       LUCI_INTERPRETER_CHECK(alpha()->zero_points()[channel] == 0);
66     }
67     // PRelu specific checks for CWQ
68     LUCI_INTERPRETER_CHECK(alpha()->quantized_dimension() == alpha()->shape().num_dims() - 1);
69     LUCI_INTERPRETER_CHECK(static_cast<int32_t>(alpha()->scales().size()) ==
70                            alpha()->shape().dim(alpha()->quantized_dimension()));
71     LUCI_INTERPRETER_CHECK(alpha()->shape().num_elements() ==
72                            input()->shape().dim(input()->shape().num_dims() - 1));
73
74     // all dimension of alpha except last one should be size 1
75     for (int dim = 0; dim < alpha()->shape().num_dims() - 1; ++dim)
76     {
77       LUCI_INTERPRETER_CHECK(alpha()->shape().dim(dim) == 1);
78     }
79
80     std::vector<double> real_multipliers =
81       getQuantizedConvolutionMultiplers(input()->scale(), alpha()->scales(), output()->scale());
82
83     _alpha_multipliers = quantizeMultipliers(real_multipliers);
84
85     double identity_multiplier = input()->scale() / output()->scale();
86     quantizeMultiplier(identity_multiplier, &_output_multiplier_identity, &_output_shift_identity);
87   }
88   // TODO: enable it only if kernel with dynamic shapes
89   output()->resize(calculateShapeForBroadcast(input()->shape(), alpha()->shape()));
90 }
91
92 void PRelu::execute() const
93 {
94   switch (input()->element_type())
95   {
96     case DataType::FLOAT32:
97       evalFloat();
98       break;
99     case DataType::U8:
100       evalQuantized();
101       break;
102     case DataType::S16:
103       evalQuantizedS16();
104       break;
105     default:
106       assert(false && "Unsupported type.");
107   }
108 }
109
110 void PRelu::evalFloat() const
111 {
112   const auto input_data = getTensorData<float>(input());
113   const auto alpha_data = getTensorData<float>(alpha());
114   const auto size = getTensorShape(input()).FlatSize();
115   auto output_data = getTensorData<float>(output());
116
117   auto PReluFunc = [](float input, float alpha) { return input >= 0.0 ? input : input * alpha; };
118
119   if (input()->shape() != alpha()->shape())
120   {
121     tflite::reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
122       getTensorShape(input()), getTensorData<float>(input()), getTensorShape(alpha()),
123       getTensorData<float>(alpha()), getTensorShape(output()), getTensorData<float>(output()),
124       PReluFunc);
125   }
126   else
127   {
128     for (auto i = decltype(size){0}; i < size; ++i)
129     {
130       if (input_data[i] >= 0)
131         output_data[i] = input_data[i];
132       else
133         output_data[i] = input_data[i] * alpha_data[i];
134     }
135   }
136 }
137
138 void PRelu::evalQuantized() const
139 {
140   tflite::PreluParams op_params{};
141
142   op_params.input_offset = -input()->zero_point(); // Note the '-'.
143   op_params.alpha_offset = -alpha()->zero_point(); // Note the '-'.
144   op_params.output_offset = output()->zero_point();
145   op_params.output_shift_1 = _output_shift_identity;
146   op_params.output_multiplier_1 = _output_multiplier_identity;
147   op_params.output_shift_2 = _alpha_multipliers[0].shift;
148   op_params.output_multiplier_2 = _alpha_multipliers[0].multiplier;
149
150   if (input()->shape() != alpha()->shape())
151   {
152     tflite::reference_ops::BroadcastPrelu4DSlow(
153       op_params, getTensorShape(input()), getTensorData<uint8_t>(input()), getTensorShape(alpha()),
154       getTensorData<uint8_t>(alpha()), getTensorShape(output()), getTensorData<uint8_t>(output()));
155   }
156   else
157   {
158     tflite::reference_ops::Prelu<uint8_t>(
159       op_params, getTensorShape(input()), getTensorData<uint8_t>(input()), getTensorShape(alpha()),
160       getTensorData<uint8_t>(alpha()), getTensorShape(output()), getTensorData<uint8_t>(output()));
161   }
162 }
163
164 static inline int16_t evalElemS16PRelu(int16_t input_val, int16_t alpha_val,
165                                        const ChannelQuantMultipliers &identity_mult,
166                                        const ChannelQuantMultipliers &alpha_mult)
167 {
168   constexpr int32_t quantized_min = std::numeric_limits<int16_t>::min();
169   constexpr int32_t quantized_max = std::numeric_limits<int16_t>::max();
170
171   const int32_t output_val =
172     input_val >= 0
173       ? tflite::MultiplyByQuantizedMultiplier(static_cast<int32_t>(input_val),
174                                               identity_mult.multiplier, identity_mult.shift)
175       : tflite::MultiplyByQuantizedMultiplier(static_cast<int32_t>(input_val * alpha_val),
176                                               alpha_mult.multiplier, alpha_mult.shift);
177   const int32_t clamped_output = std::min(quantized_max, std::max(quantized_min, output_val));
178   return clamped_output;
179 }
180
181 void PRelu::evalQuantizedS16() const
182 {
183   // Note that this kernel assumes alpha is CWQ
184   tflite::RuntimeShape input_shape = getTensorShape(input());
185   const int16_t *input_data = input()->data<int16_t>();
186   const int16_t *alpha_data = alpha()->data<int16_t>();
187   int16_t *output_data = output()->data<int16_t>();
188
189   const ChannelQuantMultipliers pos_mult{_output_shift_identity, _output_multiplier_identity};
190
191   const int last_dim = input()->shape().num_dims() - 1;
192
193   int32_t outer_dims_size = 1;
194   for (int i = 0; i < last_dim; ++i)
195     outer_dims_size *= input_shape.Dims(i);
196   int32_t quant_dim_size = input_shape.Dims(last_dim);
197
198   for (int32_t outer_dims = 0; outer_dims < outer_dims_size; ++outer_dims)
199     for (int32_t quant_channel = 0; quant_channel < quant_dim_size; ++quant_channel)
200     {
201       const ChannelQuantMultipliers &neg_mult = _alpha_multipliers[quant_channel];
202       size_t offset = static_cast<size_t>(outer_dims) * static_cast<size_t>(quant_dim_size);
203       offset += quant_channel;
204
205       output_data[offset] =
206         evalElemS16PRelu(input_data[offset], alpha_data[quant_channel], pos_mult, neg_mult);
207     }
208 }
209
210 } // namespace kernels
211 } // namespace luci_interpreter