6c832f7b10b46ac91bca071fa57130bd5053b45c
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Quantize.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/Quantize.h"
19 #include "kernels/Utils.h"
20 #include "PALQuantize.h"
21
22 namespace luci_interpreter
23 {
24 namespace kernels
25 {
26
27 namespace
28 {
29
30 template <typename input_dtype> void call_requantize(const Tensor *input, Tensor *output)
31 {
32   int32_t multiplier;
33   int shift;
34
35   const double effective_output_scale = input->scale() / output->scale();
36   quantizeMultiplier(effective_output_scale, &multiplier, &shift);
37
38   const auto input_shape = getTensorShape(input);
39   const auto output_shape = getTensorShape(output);
40   const auto size = tflite::MatchingFlatSize(input_shape, output_shape);
41
42   const auto input_data = getTensorData<input_dtype>(input);
43
44   switch (output->element_type())
45   {
46     case DataType::S8:
47       luci_interpreter_pal::Requantize(input_data, size, multiplier, shift, input->zero_point(),
48                                        output->zero_point(), getTensorData<int8_t>(output));
49       break;
50     case DataType::U8:
51       luci_interpreter_pal::Requantize(input_data, size, multiplier, shift, input->zero_point(),
52                                        output->zero_point(), getTensorData<uint8_t>(output));
53       break;
54     case DataType::S16:
55       luci_interpreter_pal::Requantize(input_data, size, multiplier, shift, input->zero_point(),
56                                        output->zero_point(), getTensorData<int16_t>(output));
57       break;
58     default:
59       assert(false && "Unsupported quantized type, yet!");
60   }
61 }
62
63 } // namespace
64
65 Quantize::Quantize(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}
66
67 void Quantize::configure()
68 {
69
70   if (input()->element_type() == DataType::S16)
71     LUCI_INTERPRETER_CHECK(input()->zero_point() == 0);
72
73   switch (input()->element_type())
74   {
75     case DataType::FLOAT32:
76     {
77       LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::U8 ||
78                              output()->element_type() == DataType::S8 ||
79                              output()->element_type() == DataType::S16);
80       break;
81     }
82     case DataType::S16:
83     case DataType::S8:
84     case DataType::U8:
85     {
86       LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::S8 ||
87                              output()->element_type() == DataType::U8 ||
88                              output()->element_type() == DataType::S16);
89       if (output()->element_type() == DataType::S16)
90       {
91         LUCI_INTERPRETER_CHECK(output()->zero_point() == 0);
92       }
93       break;
94     }
95     default:
96       assert(false && "Unsupported type");
97   }
98   // TODO: enable it only if kernel with dynamic shapes
99   output()->resize(input()->shape());
100 }
101
102 void Quantize::execute() const
103 {
104   switch (input()->element_type())
105   {
106     case DataType::FLOAT32:
107     {
108       tflite::QuantizationParams op_params;
109       op_params.zero_point = output()->zero_point();
110       op_params.scale = output()->scale();
111       const auto input_data = getTensorData<float>(input());
112
113       switch (output()->element_type())
114       {
115         case DataType::S8:
116         {
117           luci_interpreter_pal::Quantize(op_params, getTensorShape(input()), input_data,
118                                          getTensorShape(output()), getTensorData<int8_t>(output()));
119           break;
120         }
121         case DataType::U8:
122         {
123           luci_interpreter_pal::Quantize(op_params, getTensorShape(input()), input_data,
124                                          getTensorShape(output()),
125                                          getTensorData<uint8_t>(output()));
126           break;
127         }
128         case DataType::S16:
129         {
130           luci_interpreter_pal::Quantize(op_params, getTensorShape(input()), input_data,
131                                          getTensorShape(output()),
132                                          getTensorData<int16_t>(output()));
133           break;
134         }
135         default:
136           assert(false && "Unsupported type.");
137       }
138       break;
139     }
140     case DataType::S16:
141     {
142       call_requantize<int16_t>(input(), output());
143       break;
144     }
145     case DataType::S8:
146     {
147       call_requantize<int8_t>(input(), output());
148       break;
149     }
150     case DataType::U8:
151     {
152       call_requantize<uint8_t>(input(), output());
153       break;
154     }
155     default:
156       assert(false && "Unsupported type.");
157   }
158 }
159
160 } // namespace kernels
161 } // namespace luci_interpreter