Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / FullyConnected.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "kernels/FullyConnected.h"
18
19 #include "kernels/Utils.h"
20
21 #include <tensorflow/lite/kernels/internal/reference/fully_connected.h>
22 #include <tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h>
23
24 #include <stdexcept>
25
26 namespace luci_interpreter
27 {
28
29 namespace kernels
30 {
31
32 FullyConnected::FullyConnected(const Tensor *input, const Tensor *weights, const Tensor *bias,
33                                Tensor *output, const FullyConnectedParams &params)
34   : KernelWithParams<FullyConnectedParams>({input, weights, bias}, {output}, params)
35 {
36 }
37
38 void FullyConnected::configure()
39 {
40   if (weights()->element_type() == DataType::U8)
41   {
42     LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::U8);
43     LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::U8);
44     LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::S32)
45   }
46   else if (weights()->element_type() == DataType::FLOAT32)
47   {
48     LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::FLOAT32);
49     LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
50     LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::FLOAT32)
51   }
52   else if (weights()->element_type() == DataType::S8)
53   {
54     LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::S8);
55     LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::S8);
56     LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::S32)
57   }
58   else
59   {
60     throw std::runtime_error("Unsupported type.");
61   }
62
63   const Shape &input_shape = input()->shape();
64   const Shape &weights_shape = weights()->shape();
65
66   LUCI_INTERPRETER_CHECK(weights_shape.num_dims() == 2);
67   LUCI_INTERPRETER_CHECK(bias() == nullptr ||
68                          bias()->shape().num_elements() == weights_shape.dim(0));
69
70   LUCI_INTERPRETER_CHECK(input_shape.num_elements() % weights_shape.dim(1) == 0);
71   const int32_t batch_size = input_shape.num_elements() / weights_shape.dim(1);
72   const int32_t num_units = weights_shape.dim(0);
73
74   if (bias())
75     LUCI_INTERPRETER_CHECK(bias()->shape().num_elements() == weights()->shape().dim(0));
76
77   output()->resize({batch_size, num_units});
78 }
79
80 void FullyConnected::execute() const
81 {
82   switch (input()->element_type())
83   {
84     case DataType::U8:
85       evalQuantized();
86       break;
87     case DataType::S8:
88       evalQuantizedS8();
89       break;
90     case DataType::FLOAT32:
91       evalFloat();
92       break;
93     default:
94       throw std::runtime_error("Unsupported type.");
95   }
96 }
97
98 void FullyConnected::evalFloat() const
99 {
100   float activation_min{};
101   float activation_max{};
102   calculateActivationRange(_params.activation, &activation_min, &activation_max);
103
104   tflite::FullyConnectedParams params{};
105   params.float_activation_min = activation_min;
106   params.float_activation_max = activation_max;
107   params.weights_format = tflite::FullyConnectedWeightsFormat::kDefault;
108
109   tflite::reference_ops::FullyConnected(
110     params, getTensorShape(input()), getTensorData<float>(input()), getTensorShape(weights()),
111     getTensorData<float>(weights()), getTensorShape(bias()), getTensorData<float>(bias()),
112     getTensorShape(output()), getTensorData<float>(output()));
113 }
114
115 void FullyConnected::evalQuantized() const
116 {
117   double real_multiplier = 0.0;
118   int output_shift;
119   int32_t output_activation_min;
120   int32_t output_activation_max;
121   int32_t output_multiplier;
122   real_multiplier =
123     getQuantizedConvolutionMultipler(input()->scale(), weights()->scale(), output()->scale());
124   quantizeMultiplier(real_multiplier, &output_multiplier, &output_shift);
125   calculateActivationRangeQuantized(params().activation, output(), &output_activation_min,
126                                     &output_activation_max);
127
128   int32_t input_offset = -input()->zero_point();
129   int32_t filter_offset = -weights()->zero_point();
130   int32_t output_offset = output()->zero_point();
131
132   tflite::FullyConnectedParams op_params{};
133   op_params.input_offset = input_offset;
134   op_params.weights_offset = filter_offset;
135   op_params.output_offset = output_offset;
136   op_params.output_multiplier = output_multiplier;
137   op_params.output_shift = output_shift;
138   op_params.quantized_activation_min = output_activation_min;
139   op_params.quantized_activation_max = output_activation_max;
140   op_params.lhs_cacheable = false;
141   op_params.rhs_cacheable = false;
142   tflite::reference_ops::FullyConnected(
143     op_params, getTensorShape(input()), getTensorData<uint8_t>(input()), getTensorShape(weights()),
144     getTensorData<uint8_t>(weights()), getTensorShape(bias()), getTensorData<int32_t>(bias()),
145     getTensorShape(output()), getTensorData<uint8_t>(output()));
146 }
147
148 void FullyConnected::evalQuantizedS8() const
149 {
150   double real_multiplier = 0.0;
151   int output_shift;
152   int32_t output_activation_min;
153   int32_t output_activation_max;
154   int32_t output_multiplier;
155   real_multiplier =
156     getQuantizedConvolutionMultipler(input()->scale(), weights()->scale(), output()->scale());
157   quantizeMultiplier(real_multiplier, &output_multiplier, &output_shift);
158   calculateActivationRangeQuantized(params().activation, output(), &output_activation_min,
159                                     &output_activation_max);
160
161   int32_t input_offset = -input()->zero_point();
162   int32_t filter_offset = -weights()->zero_point();
163   int32_t output_offset = output()->zero_point();
164
165   tflite::FullyConnectedParams op_params{};
166   op_params.input_offset = input_offset;
167   op_params.weights_offset = filter_offset;
168   op_params.output_offset = output_offset;
169   op_params.output_multiplier = output_multiplier;
170   op_params.output_shift = output_shift;
171   op_params.quantized_activation_min = output_activation_min;
172   op_params.quantized_activation_max = output_activation_max;
173   op_params.lhs_cacheable = false;
174   op_params.rhs_cacheable = false;
175   tflite::reference_integer_ops::FullyConnected(
176     op_params, getTensorShape(input()), getTensorData<int8_t>(input()), getTensorShape(weights()),
177     getTensorData<int8_t>(weights()), getTensorShape(bias()), getTensorData<int32_t>(bias()),
178     getTensorShape(output()), getTensorData<int8_t>(output()));
179 }
180
181 } // namespace kernels
182 } // namespace luci_interpreter