2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/Quantize.h"
19 #include "kernels/Utils.h"
20 #include "PALQuantize.h"
22 namespace luci_interpreter
30 template <typename input_dtype> void call_requantize(const Tensor *input, Tensor *output)
35 const double effective_output_scale = input->scale() / output->scale();
36 quantizeMultiplier(effective_output_scale, &multiplier, &shift);
38 const auto input_shape = getTensorShape(input);
39 const auto output_shape = getTensorShape(output);
40 const auto size = tflite::MatchingFlatSize(input_shape, output_shape);
42 const auto input_data = getTensorData<input_dtype>(input);
44 switch (output->element_type())
47 luci_interpreter_pal::Requantize(input_data, size, multiplier, shift, input->zero_point(),
48 output->zero_point(), getTensorData<int8_t>(output));
51 luci_interpreter_pal::Requantize(input_data, size, multiplier, shift, input->zero_point(),
52 output->zero_point(), getTensorData<uint8_t>(output));
55 luci_interpreter_pal::Requantize(input_data, size, multiplier, shift, input->zero_point(),
56 output->zero_point(), getTensorData<int16_t>(output));
59 assert(false && "Unsupported quantized type, yet!");
65 Quantize::Quantize(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}
67 void Quantize::configure()
70 if (input()->element_type() == DataType::S16)
71 LUCI_INTERPRETER_CHECK(input()->zero_point() == 0);
73 switch (input()->element_type())
75 case DataType::FLOAT32:
77 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::U8 ||
78 output()->element_type() == DataType::S8 ||
79 output()->element_type() == DataType::S16);
86 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::S8 ||
87 output()->element_type() == DataType::U8 ||
88 output()->element_type() == DataType::S16);
89 if (output()->element_type() == DataType::S16)
91 LUCI_INTERPRETER_CHECK(output()->zero_point() == 0);
96 assert(false && "Unsupported type");
98 // TODO: enable it only if kernel with dynamic shapes
99 output()->resize(input()->shape());
102 void Quantize::execute() const
104 switch (input()->element_type())
106 case DataType::FLOAT32:
108 tflite::QuantizationParams op_params;
109 op_params.zero_point = output()->zero_point();
110 op_params.scale = output()->scale();
111 const auto input_data = getTensorData<float>(input());
113 switch (output()->element_type())
117 luci_interpreter_pal::Quantize(op_params, getTensorShape(input()), input_data,
118 getTensorShape(output()), getTensorData<int8_t>(output()));
123 luci_interpreter_pal::Quantize(op_params, getTensorShape(input()), input_data,
124 getTensorShape(output()),
125 getTensorData<uint8_t>(output()));
130 luci_interpreter_pal::Quantize(op_params, getTensorShape(input()), input_data,
131 getTensorShape(output()),
132 getTensorData<int16_t>(output()));
136 assert(false && "Unsupported type.");
142 call_requantize<int16_t>(input(), output());
147 call_requantize<int8_t>(input(), output());
152 call_requantize<uint8_t>(input(), output());
156 assert(false && "Unsupported type.");
160 } // namespace kernels
161 } // namespace luci_interpreter