2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "kernels/PRelu.h"
19 #include "kernels/BinaryOpCommon.h"
20 #include "kernels/Utils.h"
22 #include <tensorflow/lite/kernels/internal/reference/binary_function.h>
23 #include <tensorflow/lite/kernels/internal/reference/prelu.h>
27 namespace luci_interpreter
33 PRelu::PRelu(const Tensor *input, const Tensor *alpha, Tensor *output)
34 : Kernel({input, alpha}, {output})
40 // Destructor declared to delete vector of alpha quantized data properly
43 void PRelu::configure()
45 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
46 LUCI_INTERPRETER_CHECK(alpha()->element_type() == output()->element_type());
47 LUCI_INTERPRETER_CHECK(input()->scales().size() <= 1);
48 LUCI_INTERPRETER_CHECK(output()->scales().size() <= 1);
50 if (input()->element_type() == DataType::U8)
52 LUCI_INTERPRETER_CHECK(alpha()->scales().size() <= 1); // remove when CWQ kernel arrives
53 _alpha_multipliers.resize(1);
54 double alpha_multiplier = input()->scale() * alpha()->scale() / output()->scale();
55 quantizeMultiplier(alpha_multiplier, &_alpha_multipliers[0].multiplier,
56 &_alpha_multipliers[0].shift);
57 double identity_multiplier = input()->scale() / output()->scale();
58 quantizeMultiplier(identity_multiplier, &_output_multiplier_identity, &_output_shift_identity);
60 else if (input()->element_type() == DataType::S16)
62 // Common check for correctness of quant params
63 LUCI_INTERPRETER_CHECK(input()->zero_point() == 0 && output()->zero_point() == 0);
64 for (size_t channel = 0; channel < alpha()->zero_points().size(); ++channel)
66 LUCI_INTERPRETER_CHECK(alpha()->zero_points()[channel] == 0);
68 // PRelu specific checks for CWQ
69 LUCI_INTERPRETER_CHECK(alpha()->quantized_dimension() == alpha()->shape().num_dims() - 1);
70 LUCI_INTERPRETER_CHECK(static_cast<int32_t>(alpha()->scales().size()) ==
71 alpha()->shape().dim(alpha()->quantized_dimension()));
72 LUCI_INTERPRETER_CHECK(alpha()->shape().num_elements() ==
73 input()->shape().dim(input()->shape().num_dims() - 1));
75 // all dimension of alpha except last one should be size 1
76 for (int dim = 0; dim < alpha()->shape().num_dims() - 1; ++dim)
78 LUCI_INTERPRETER_CHECK(alpha()->shape().dim(dim) == 1);
81 std::vector<double> real_multipliers =
82 getQuantizedConvolutionMultiplers(input()->scale(), alpha()->scales(), output()->scale());
84 _alpha_multipliers = quantizeMultipliers(real_multipliers);
86 double identity_multiplier = input()->scale() / output()->scale();
87 quantizeMultiplier(identity_multiplier, &_output_multiplier_identity, &_output_shift_identity);
89 output()->resize(calculateShapeForBroadcast(input()->shape(), alpha()->shape()));
92 void PRelu::execute() const
94 switch (input()->element_type())
96 case DataType::FLOAT32:
106 throw std::runtime_error("Unsupported type.");
110 void PRelu::evalFloat() const
112 const auto input_data = getTensorData<float>(input());
113 const auto alpha_data = getTensorData<float>(alpha());
114 const auto size = getTensorShape(input()).FlatSize();
115 auto output_data = getTensorData<float>(output());
117 auto PReluFunc = [](float input, float alpha) { return input >= 0.0 ? input : input * alpha; };
119 if (input()->shape() != alpha()->shape())
121 tflite::reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
122 getTensorShape(input()), getTensorData<float>(input()), getTensorShape(alpha()),
123 getTensorData<float>(alpha()), getTensorShape(output()), getTensorData<float>(output()),
128 for (auto i = decltype(size){0}; i < size; ++i)
130 if (input_data[i] >= 0)
131 output_data[i] = input_data[i];
133 output_data[i] = input_data[i] * alpha_data[i];
138 void PRelu::evalQuantized() const
140 tflite::PreluParams op_params{};
142 op_params.input_offset = -input()->zero_point(); // Note the '-'.
143 op_params.alpha_offset = -alpha()->zero_point(); // Note the '-'.
144 op_params.output_offset = output()->zero_point();
145 op_params.output_shift_1 = _output_shift_identity;
146 op_params.output_multiplier_1 = _output_multiplier_identity;
147 op_params.output_shift_2 = _alpha_multipliers[0].shift;
148 op_params.output_multiplier_2 = _alpha_multipliers[0].multiplier;
150 if (input()->shape() != alpha()->shape())
152 tflite::reference_ops::BroadcastPrelu4DSlow(
153 op_params, getTensorShape(input()), getTensorData<uint8_t>(input()), getTensorShape(alpha()),
154 getTensorData<uint8_t>(alpha()), getTensorShape(output()), getTensorData<uint8_t>(output()));
158 tflite::reference_ops::Prelu<uint8_t>(
159 op_params, getTensorShape(input()), getTensorData<uint8_t>(input()), getTensorShape(alpha()),
160 getTensorData<uint8_t>(alpha()), getTensorShape(output()), getTensorData<uint8_t>(output()));
164 static inline int16_t evalElemS16PRelu(int16_t input_val, int16_t alpha_val,
165 const ChannelQuantMultipliers &identity_mult,
166 const ChannelQuantMultipliers &alpha_mult)
168 constexpr int32_t quantized_min = std::numeric_limits<int16_t>::min();
169 constexpr int32_t quantized_max = std::numeric_limits<int16_t>::max();
171 const int32_t output_val =
173 ? tflite::MultiplyByQuantizedMultiplier(static_cast<int32_t>(input_val),
174 identity_mult.multiplier, identity_mult.shift)
175 : tflite::MultiplyByQuantizedMultiplier(static_cast<int32_t>(input_val * alpha_val),
176 alpha_mult.multiplier, alpha_mult.shift);
177 const int32_t clamped_output = std::min(quantized_max, std::max(quantized_min, output_val));
178 return clamped_output;
181 void PRelu::evalQuantizedS16() const
183 // Note that this kernel assumes alpha is CWQ
184 tflite::RuntimeShape input_shape = getTensorShape(input());
185 const int16_t *input_data = input()->data<int16_t>();
186 const int16_t *alpha_data = alpha()->data<int16_t>();
187 int16_t *output_data = output()->data<int16_t>();
189 const ChannelQuantMultipliers pos_mult{_output_shift_identity, _output_multiplier_identity};
191 const int last_dim = input()->shape().num_dims() - 1;
193 int32_t outer_dims_size = 1;
194 for (int i = 0; i < last_dim; ++i)
195 outer_dims_size *= input_shape.Dims(i);
196 int32_t quant_dim_size = input_shape.Dims(last_dim);
198 for (int32_t outer_dims = 0; outer_dims < outer_dims_size; ++outer_dims)
199 for (int32_t quant_channel = 0; quant_channel < quant_dim_size; ++quant_channel)
201 const ChannelQuantMultipliers &neg_mult = _alpha_multipliers[quant_channel];
202 size_t offset = static_cast<size_t>(outer_dims) * static_cast<size_t>(quant_dim_size);
203 offset += quant_channel;
205 output_data[offset] =
206 evalElemS16PRelu(input_data[offset], alpha_data[quant_channel], pos_mult, neg_mult);
210 } // namespace kernels
211 } // namespace luci_interpreter