2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "BinaryArithmeticLayer.h"
19 #include <cker/operation/BinaryArithmeticOps.h>
33 template <nnfw::cker::BinaryArithmeticOpType arithmetic_type, typename T> struct Eval
35 nnfw::cker::Shape _lhs_shape;
36 nnfw::cker::Shape _rhs_shape;
37 nnfw::cker::Shape _output_shape;
38 nnfw::cker::BinaryArithmeticOpParam _op_params;
41 Eval(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output,
42 nnfw::cker::BinaryArithmeticOpParam op_params)
43 : _op_params(std::move(op_params)), _need_broadcast(false)
45 if (!output->is_dynamic())
46 updateCache(lhs, rhs, output);
49 void updateCache(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output)
51 _lhs_shape.ReplaceWith(getShape(lhs));
52 _rhs_shape.ReplaceWith(getShape(rhs));
53 _output_shape.ReplaceWith(getShape(output));
54 _need_broadcast = nnfw::cker::ProcessBroadcastShapes(_lhs_shape, _rhs_shape, &_op_params);
57 void operator()(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output)
59 // Assume dynamic tensors never become static and static ones never change shape since
61 if (output->is_dynamic())
62 updateCache(lhs, rhs, output);
64 assert(_lhs_shape == getShape(lhs) && _rhs_shape == getShape(rhs) &&
65 _output_shape == getShape(output));
66 auto lhs_buffer = getBuffer<T>(lhs);
67 auto rhs_buffer = getBuffer<T>(rhs);
68 auto output_buffer = getBuffer<T>(output);
71 nnfw::cker::BroadcastBinaryArithmeticOp<arithmetic_type>(
72 _op_params, _lhs_shape, lhs_buffer, _rhs_shape, rhs_buffer, _output_shape, output_buffer);
76 nnfw::cker::BinaryArithmeticOp<arithmetic_type>(
77 _op_params, _lhs_shape, lhs_buffer, _rhs_shape, rhs_buffer, _output_shape, output_buffer);
82 template <nnfw::cker::BinaryArithmeticOpType arithmetic_type>
83 std::function<void(const IPortableTensor *, const IPortableTensor *, IPortableTensor *)>
84 generateKernelGeneric(const IPortableTensor *lhs, const IPortableTensor *rhs,
85 IPortableTensor *output, const ir::Activation activation,
86 nnfw::cker::BinaryArithmeticOpParam &op_params)
88 switch (lhs->data_type())
90 case OperandType::FLOAT32:
92 float output_activation_min = 0, output_activation_max = 0;
93 CalculateActivationRange(activation, &output_activation_min, &output_activation_max);
94 op_params.float_activation_max = output_activation_max;
95 op_params.float_activation_min = output_activation_min;
96 return Eval<arithmetic_type, float>(lhs, rhs, output, op_params);
99 case OperandType::INT32:
101 int32_t output_activation_min = 0, output_activation_max = 0;
102 CalculateActivationRange(activation, &output_activation_min, &output_activation_max);
103 op_params.quantized_activation_max = output_activation_max;
104 op_params.quantized_activation_min = output_activation_min;
105 return Eval<arithmetic_type, int32_t>(lhs, rhs, output, op_params);
109 throw std::runtime_error{"BinaryArithmetic(generic): Unsupported data type"};
113 void setAddOrSubQuant8Params(const IPortableTensor *lhs, const IPortableTensor *rhs,
114 IPortableTensor *output, ir::Activation activation,
115 nnfw::cker::BinaryArithmeticOpParam *params)
117 int32_t output_activation_min, output_activation_max;
118 CalculateActivationRangeQuantized(activation, output, &output_activation_min,
119 &output_activation_max);
120 nnfw::cker::BinaryArithmeticOpParam &op_params = *params;
121 op_params.quantized_activation_max = output_activation_max;
122 op_params.quantized_activation_min = output_activation_min;
123 // Parameters for scaled quantized computation
124 op_params.left_shift = 20;
125 // Zero-points of input and output tensors
126 op_params.input1_offset = -lhs->data_zero_point();
127 op_params.input2_offset = -rhs->data_zero_point();
128 op_params.output_offset = output->data_zero_point();
130 // Compute normalized scale for _lhs and _rhs values,
131 // and represent in 32-bit fixed point
132 const double norm_max_scale = 2 * std::max(lhs->data_scale(), rhs->data_scale());
133 const double real_lhs_scale = lhs->data_scale() / norm_max_scale;
134 const double real_rhs_scale = rhs->data_scale() / norm_max_scale;
135 // output scale is used to normalize final result, so we invert the scale here
136 const double real_output_scale =
137 norm_max_scale / (output->data_scale() * (1 << op_params.left_shift));
139 // Represent the scales as fixed int32_t multipliers, and int32_t shifts
140 QuantizeMultiplier(real_lhs_scale, &op_params.input1_multiplier, &op_params.input1_shift);
141 QuantizeMultiplier(real_rhs_scale, &op_params.input2_multiplier, &op_params.input2_shift);
142 QuantizeMultiplier(real_output_scale, &op_params.output_multiplier, &op_params.output_shift);
145 void setMulQuant8Params(const IPortableTensor *lhs, const IPortableTensor *rhs,
146 IPortableTensor *output, ir::Activation activation,
147 nnfw::cker::BinaryArithmeticOpParam *params)
149 int32_t output_activation_min, output_activation_max;
150 CalculateActivationRangeQuantized(activation, output, &output_activation_min,
151 &output_activation_max);
152 nnfw::cker::BinaryArithmeticOpParam &op_params = *params;
154 op_params.quantized_activation_max = output_activation_max;
155 op_params.quantized_activation_min = output_activation_min;
156 op_params.input1_offset = -lhs->data_zero_point();
157 op_params.input2_offset = -rhs->data_zero_point();
158 op_params.output_offset = output->data_zero_point();
160 double real_multiplier = lhs->data_scale() * rhs->data_scale() / output->data_scale();
161 QuantizeMultiplier(real_multiplier, &op_params.output_multiplier, &op_params.output_shift);
166 void BinaryArithmeticLayer::configure(const IPortableTensor *lhs, const IPortableTensor *rhs,
167 IPortableTensor *output, const ir::Activation activation,
168 const ArithmeticType arithmetic_type)
170 assert(lhs != nullptr);
171 assert(rhs != nullptr);
172 assert(output != nullptr);
178 nnfw::cker::BinaryArithmeticOpParam op_params;
179 switch (arithmetic_type)
181 case ArithmeticType::kAdd:
182 if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
184 setAddOrSubQuant8Params(_lhs, _rhs, _output, activation, &op_params);
186 Eval<nnfw::cker::BinaryArithmeticOpType::ADD, uint8_t>(_lhs, _rhs, _output, op_params);
188 else if (_lhs->data_type() == OperandType::QUANT_INT8_ASYMM)
190 setAddOrSubQuant8Params(_lhs, _rhs, _output, activation, &op_params);
192 Eval<nnfw::cker::BinaryArithmeticOpType::ADD, int8_t>(_lhs, _rhs, _output, op_params);
197 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::ADD>(
198 _lhs, _rhs, _output, activation, op_params);
201 case ArithmeticType::kSub:
202 if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
204 setAddOrSubQuant8Params(_lhs, _rhs, _output, activation, &op_params);
205 op_params.input2_multiplier *= -1;
207 Eval<nnfw::cker::BinaryArithmeticOpType::SUB, uint8_t>(_lhs, _rhs, _output, op_params);
209 else if (_lhs->data_type() == OperandType::QUANT_INT8_ASYMM)
211 setAddOrSubQuant8Params(_lhs, _rhs, _output, activation, &op_params);
212 op_params.input2_multiplier *= -1;
214 Eval<nnfw::cker::BinaryArithmeticOpType::SUB, int8_t>(_lhs, _rhs, _output, op_params);
219 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::SUB>(
220 _lhs, _rhs, _output, activation, op_params);
223 case ArithmeticType::kMul:
224 if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
226 nnfw::cker::BinaryArithmeticOpParam op_params;
227 setMulQuant8Params(_lhs, _rhs, _output, activation, &op_params);
229 Eval<nnfw::cker::BinaryArithmeticOpType::MUL, uint8_t>(_lhs, _rhs, _output, op_params);
231 else if (_lhs->data_type() == OperandType::QUANT_INT8_ASYMM)
233 nnfw::cker::BinaryArithmeticOpParam op_params;
234 setMulQuant8Params(_lhs, _rhs, _output, activation, &op_params);
236 Eval<nnfw::cker::BinaryArithmeticOpType::MUL, int8_t>(_lhs, _rhs, _output, op_params);
240 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::MUL>(
241 _lhs, _rhs, _output, activation, op_params);
244 case ArithmeticType::kDiv:
245 if (_lhs->data_type() == OperandType::QUANT_UINT8_ASYMM)
247 throw std::runtime_error{
248 "BinaryArithmetic(Div): Div operation does not support quantization"};
250 else if (_lhs->data_type() == OperandType::INT32)
252 throw std::runtime_error{"BinaryArithmetic(Div): Unsupported data type"};
256 _kernel = generateKernelGeneric<nnfw::cker::BinaryArithmeticOpType::DIV>(
257 _lhs, _rhs, _output, activation, op_params);
261 throw std::runtime_error{"BinaryArithmetic: Unsupported BinaryArithmetic type"};
265 void BinaryArithmeticLayer::run() { _kernel(_lhs, _rhs, _output); }
269 } // namespace backend