2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "kernels/FullyConnected.h"
19 #include "kernels/Utils.h"
21 #include <tensorflow/lite/kernels/internal/reference/fully_connected.h>
22 #include <tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h>
26 namespace luci_interpreter
32 FullyConnected::FullyConnected(const Tensor *input, const Tensor *weights, const Tensor *bias,
33 Tensor *output, const FullyConnectedParams ¶ms)
34 : KernelWithParams<FullyConnectedParams>({input, weights, bias}, {output}, params)
38 void FullyConnected::configure()
40 if (weights()->element_type() == DataType::U8)
42 LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::U8);
43 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::U8);
44 LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::S32)
46 else if (weights()->element_type() == DataType::FLOAT32)
48 LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::FLOAT32);
49 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
50 LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::FLOAT32)
52 else if (weights()->element_type() == DataType::S8)
54 LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::S8);
55 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::S8);
56 LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::S32)
60 throw std::runtime_error("Unsupported type.");
63 const Shape &input_shape = input()->shape();
64 const Shape &weights_shape = weights()->shape();
66 LUCI_INTERPRETER_CHECK(weights_shape.num_dims() == 2);
67 LUCI_INTERPRETER_CHECK(bias() == nullptr ||
68 bias()->shape().num_elements() == weights_shape.dim(0));
70 LUCI_INTERPRETER_CHECK(input_shape.num_elements() % weights_shape.dim(1) == 0);
71 const int32_t batch_size = input_shape.num_elements() / weights_shape.dim(1);
72 const int32_t num_units = weights_shape.dim(0);
75 LUCI_INTERPRETER_CHECK(bias()->shape().num_elements() == weights()->shape().dim(0));
77 output()->resize({batch_size, num_units});
80 void FullyConnected::execute() const
82 switch (input()->element_type())
90 case DataType::FLOAT32:
94 throw std::runtime_error("Unsupported type.");
98 void FullyConnected::evalFloat() const
100 float activation_min{};
101 float activation_max{};
102 calculateActivationRange(_params.activation, &activation_min, &activation_max);
104 tflite::FullyConnectedParams params{};
105 params.float_activation_min = activation_min;
106 params.float_activation_max = activation_max;
107 params.weights_format = tflite::FullyConnectedWeightsFormat::kDefault;
109 tflite::reference_ops::FullyConnected(
110 params, getTensorShape(input()), getTensorData<float>(input()), getTensorShape(weights()),
111 getTensorData<float>(weights()), getTensorShape(bias()), getTensorData<float>(bias()),
112 getTensorShape(output()), getTensorData<float>(output()));
115 void FullyConnected::evalQuantized() const
117 double real_multiplier = 0.0;
119 int32_t output_activation_min;
120 int32_t output_activation_max;
121 int32_t output_multiplier;
123 getQuantizedConvolutionMultipler(input()->scale(), weights()->scale(), output()->scale());
124 quantizeMultiplier(real_multiplier, &output_multiplier, &output_shift);
125 calculateActivationRangeQuantized(params().activation, output(), &output_activation_min,
126 &output_activation_max);
128 int32_t input_offset = -input()->zero_point();
129 int32_t filter_offset = -weights()->zero_point();
130 int32_t output_offset = output()->zero_point();
132 tflite::FullyConnectedParams op_params{};
133 op_params.input_offset = input_offset;
134 op_params.weights_offset = filter_offset;
135 op_params.output_offset = output_offset;
136 op_params.output_multiplier = output_multiplier;
137 op_params.output_shift = output_shift;
138 op_params.quantized_activation_min = output_activation_min;
139 op_params.quantized_activation_max = output_activation_max;
140 op_params.lhs_cacheable = false;
141 op_params.rhs_cacheable = false;
142 tflite::reference_ops::FullyConnected(
143 op_params, getTensorShape(input()), getTensorData<uint8_t>(input()), getTensorShape(weights()),
144 getTensorData<uint8_t>(weights()), getTensorShape(bias()), getTensorData<int32_t>(bias()),
145 getTensorShape(output()), getTensorData<uint8_t>(output()));
148 void FullyConnected::evalQuantizedS8() const
150 double real_multiplier = 0.0;
152 int32_t output_activation_min;
153 int32_t output_activation_max;
154 int32_t output_multiplier;
156 getQuantizedConvolutionMultipler(input()->scale(), weights()->scale(), output()->scale());
157 quantizeMultiplier(real_multiplier, &output_multiplier, &output_shift);
158 calculateActivationRangeQuantized(params().activation, output(), &output_activation_min,
159 &output_activation_max);
161 int32_t input_offset = -input()->zero_point();
162 int32_t filter_offset = -weights()->zero_point();
163 int32_t output_offset = output()->zero_point();
165 tflite::FullyConnectedParams op_params{};
166 op_params.input_offset = input_offset;
167 op_params.weights_offset = filter_offset;
168 op_params.output_offset = output_offset;
169 op_params.output_multiplier = output_multiplier;
170 op_params.output_shift = output_shift;
171 op_params.quantized_activation_min = output_activation_min;
172 op_params.quantized_activation_max = output_activation_max;
173 op_params.lhs_cacheable = false;
174 op_params.rhs_cacheable = false;
175 tflite::reference_integer_ops::FullyConnected(
176 op_params, getTensorShape(input()), getTensorData<int8_t>(input()), getTensorShape(weights()),
177 getTensorData<int8_t>(weights()), getTensorShape(bias()), getTensorData<int32_t>(bias()),
178 getTensorShape(output()), getTensorData<int8_t>(output()));
181 } // namespace kernels
182 } // namespace luci_interpreter