2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/Mul.h"
20 #include "kernels/BinaryOpCommon.h"
21 #include "kernels/Utils.h"
25 #include <tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h>
29 namespace luci_interpreter
34 Mul::Mul(const Tensor *input1, const Tensor *input2, Tensor *output, const MulParams ¶ms)
35 : KernelWithParams<MulParams>({input1, input2}, {output}, params)
41 LUCI_INTERPRETER_CHECK(input1()->element_type() == input2()->element_type());
42 LUCI_INTERPRETER_CHECK(output()->element_type() == input1()->element_type());
43 if (input1()->element_type() == DataType::S16)
45 LUCI_INTERPRETER_CHECK(input1()->zero_point() == 0 && input2()->zero_point() == 0 &&
46 output()->zero_point() == 0);
49 output()->resize(calculateShapeForBroadcast(input1()->shape(), input2()->shape()));
52 void Mul::execute() const
54 switch (input1()->element_type())
56 case DataType::FLOAT32:
63 throw std::runtime_error("Unsupported type.");
67 void Mul::evalFloat() const
69 float activation_min{};
70 float activation_max{};
71 calculateActivationRange(_params.activation, &activation_min, &activation_max);
73 tflite::ArithmeticParams params{};
74 params.float_activation_min = activation_min;
75 params.float_activation_max = activation_max;
77 const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
78 getTensorShape(input1()), getTensorShape(input2()), ¶ms);
82 luci_interpreter_pal::BroadcastMul4DSlow(
83 params, getTensorShape(input1()), getTensorData<float>(input1()), getTensorShape(input2()),
84 getTensorData<float>(input2()), getTensorShape(output()), getTensorData<float>(output()));
88 luci_interpreter_pal::Mul(params, getTensorShape(input1()), getTensorData<float>(input1()),
89 getTensorShape(input2()), getTensorData<float>(input2()),
90 getTensorShape(output()), getTensorData<float>(output()));
94 void Mul::evalQuantizedS16() const
96 const auto input1_scale = static_cast<double>(input1()->scale());
97 const auto input2_scale = static_cast<double>(input2()->scale());
98 const auto output_scale = static_cast<double>(output()->scale());
100 const double real_multiplier = input1_scale * input2_scale / output_scale;
102 int32_t output_multiplier;
104 quantizeMultiplier(real_multiplier, &output_multiplier, &output_shift);
106 int32_t activation_min{};
107 int32_t activation_max{};
108 calculateActivationRangeQuantized(_params.activation, output(), &activation_min, &activation_max);
110 auto fn = [output_multiplier, output_shift, activation_min, activation_max](int16_t input1_val,
111 int16_t input2_val) {
112 int32_t output = static_cast<int32_t>(input1_val) * static_cast<int32_t>(input2_val);
113 output = tflite::MultiplyByQuantizedMultiplier(output, output_multiplier, output_shift);
114 output = std::max(output, activation_min);
115 output = std::min(output, activation_max);
116 return static_cast<int16_t>(output);
119 BinaryOpBroadcastSlow(getTensorShape(input1()), getTensorData<int16_t>(input1()),
120 getTensorShape(input2()), getTensorData<int16_t>(input2()),
121 getTensorShape(output()), getTensorData<int16_t>(output()), fn);
124 } // namespace kernels
125 } // namespace luci_interpreter