5d4b17d9840fe7bcb8f0f9d9e323a0d6aaf41d35
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Dequantize.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/Dequantize.h"
19 #include "kernels/Utils.h"
20 #include "PALDequantize.h"
21
22 namespace luci_interpreter
23 {
24 namespace kernels
25 {
26
27 Dequantize::Dequantize(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}
28
29 void Dequantize::configure()
30 {
31   LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::S8 ||
32                          input()->element_type() == DataType::U8 ||
33                          input()->element_type() == DataType::S16);
34
35   LUCI_INTERPRETER_CHECK(input()->scales().size() == 1);
36
37   if (input()->element_type() == DataType::S16)
38     LUCI_INTERPRETER_CHECK(input()->zero_point() == 0);
39
40   LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
41
42   // TODO: enable it only if kernel with dynamic shapes
43   output()->resize(input()->shape());
44 }
45
46 void Dequantize::execute() const
47 {
48   tflite::DequantizationParams op_params;
49   op_params.zero_point = input()->zero_point();
50   op_params.scale = input()->scale();
51
52   switch (input()->element_type())
53   {
54     case DataType::U8:
55     {
56       luci_interpreter_pal::Dequantize(op_params, getTensorShape(input()),
57                                        getTensorData<uint8_t>(input()), getTensorShape(output()),
58                                        getTensorData<float>(output()));
59       break;
60     }
61     case DataType::S8:
62     {
63       luci_interpreter_pal::Dequantize(op_params, getTensorShape(input()),
64                                        getTensorData<int8_t>(input()), getTensorShape(output()),
65                                        getTensorData<float>(output()));
66       break;
67     }
68     case DataType::S16:
69     {
70       luci_interpreter_pal::Dequantize(op_params, getTensorShape(input()),
71                                        getTensorData<int16_t>(input()), getTensorShape(output()),
72                                        getTensorData<float>(output()));
73       break;
74     }
75     default:
76       assert(false && "Unsupported type.");
77   }
78 }
79
80 } // namespace kernels
81 } // namespace luci_interpreter