2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "kernels/Conv2D.h"
19 #include "kernels/Utils.h"
21 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
26 namespace luci_interpreter
31 Conv2D::Conv2D(const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output,
32 const Conv2DParams ¶ms)
33 : KernelWithParams<Conv2DParams>({input, filter, bias}, {output}, params)
37 void Conv2D::configure()
39 // TensorFlow Lite (as of v2.2.0) supports the following combinations of types:
40 // | input filter bias output |
41 // ----+---------------------------+
42 // (1) | float float float float |
43 // (2) | float int8 float float | hybrid
44 // (3) | uint8 uint8 int32 uint8 | quantized
45 // (4) | int8 int8 int32 int8 | quantized per channel
47 // We only support (1) and (3) for now.
48 if (input()->element_type() == DataType::FLOAT32 && filter()->element_type() == DataType::FLOAT32)
50 assert(bias() == nullptr || bias()->element_type() == DataType::FLOAT32);
52 else if (input()->element_type() == DataType::U8 && filter()->element_type() == DataType::U8)
54 assert(bias() == nullptr || bias()->element_type() == DataType::S32);
58 throw std::runtime_error("Unsupported type.");
60 assert(output()->element_type() == input()->element_type());
62 const Shape &input_shape = input()->shape();
63 const Shape &filter_shape = filter()->shape();
64 assert(input_shape.num_dims() == 4 && filter_shape.num_dims() == 4);
66 const int32_t batches = input_shape.dim(0);
67 const int32_t input_height = input_shape.dim(1);
68 const int32_t input_width = input_shape.dim(2);
69 const int32_t output_depth = filter_shape.dim(0);
70 const int32_t filter_height = filter_shape.dim(1);
71 const int32_t filter_width = filter_shape.dim(2);
72 assert(filter_shape.dim(3) == input_shape.dim(3));
74 assert(bias() == nullptr ||
75 (bias()->shape().num_dims() == 1 && bias()->shape().dim(0) == output_depth));
77 const int32_t output_height =
78 computeOutputSize(_params.padding, input_height, filter_height, _params.stride_height,
79 _params.dilation_height_factor);
80 const int32_t output_width =
81 computeOutputSize(_params.padding, input_width, filter_width, _params.stride_width,
82 _params.dilation_width_factor);
84 _padding_height = computePadding(_params.stride_height, _params.dilation_height_factor,
85 input_height, filter_height, output_height);
86 _padding_width = computePadding(_params.stride_width, _params.dilation_width_factor, input_width,
87 filter_width, output_width);
89 output()->resize({batches, output_height, output_width, output_depth});
91 // Allocate tensor for Im2Col, if needed.
92 // The checks here should be aligned with the actual implementation.
93 const bool need_dilated_im2col =
94 _params.dilation_height_factor != 1 || _params.dilation_width_factor != 1;
95 const bool need_non_dilated_im2col = _params.stride_height != 1 || _params.stride_width != 1 ||
96 filter_height != 1 || filter_width != 1;
97 const bool need_im2col = need_dilated_im2col || need_non_dilated_im2col;
100 const int input_depth = input_shape.dim(3);
101 Shape im2col_shape{batches, output_height, output_width,
102 input_depth * filter_height * filter_width};
104 std::make_unique<Tensor>(input()->element_type(), im2col_shape, AffineQuantization{}, "");
108 void Conv2D::execute() const
110 switch (input()->element_type())
112 case DataType::FLOAT32:
113 if (filter()->element_type() == DataType::FLOAT32)
118 throw std::runtime_error("Unsupported type.");
123 throw std::runtime_error("Unsupported type.");
127 void Conv2D::evalFloat() const
129 float activation_min{};
130 float activation_max{};
131 calculateActivationRange(_params.activation, &activation_min, &activation_max);
133 tflite::ConvParams params{};
134 params.padding_values.height = _padding_height;
135 params.padding_values.width = _padding_width;
136 params.stride_height = _params.stride_height;
137 params.stride_width = _params.stride_width;
138 params.dilation_height_factor = _params.dilation_height_factor;
139 params.dilation_width_factor = _params.dilation_width_factor;
140 params.float_activation_min = activation_min;
141 params.float_activation_max = activation_max;
143 tflite::optimized_ops::Conv(params, getTensorShape(input()), getTensorData<float>(input()),
144 getTensorShape(filter()), getTensorData<float>(filter()),
145 getTensorShape(bias()), getTensorData<float>(bias()),
146 getTensorShape(output()), getTensorData<float>(output()),
147 getTensorShape(_im2col.get()), getTensorData<float>(_im2col.get()));
150 void Conv2D::evalQuantized() const
152 const auto input_scale = static_cast<double>(input()->scale());
153 const auto filter_scale = static_cast<double>(filter()->scale());
154 const auto output_scale = static_cast<double>(output()->scale());
156 const double real_multiplier = input_scale * filter_scale / output_scale;
157 int32_t output_multiplier{};
159 quantizeMultiplier(real_multiplier, &output_multiplier, &output_shift);
161 int32_t activation_min{};
162 int32_t activation_max{};
163 calculateActivationRangeQuantized(_params.activation, output(), &activation_min, &activation_max);
165 tflite::ConvParams params{};
166 params.padding_values.height = _padding_height;
167 params.padding_values.width = _padding_width;
168 params.stride_height = _params.stride_height;
169 params.stride_width = _params.stride_width;
170 params.dilation_height_factor = _params.dilation_height_factor;
171 params.dilation_width_factor = _params.dilation_width_factor;
172 // The kernel expects input and filter zero points to be negated.
173 params.input_offset = -input()->zero_point(); // Note the '-'.
174 params.weights_offset = -filter()->zero_point(); // Note the '-'.
175 params.output_offset = output()->zero_point();
176 params.output_multiplier = output_multiplier;
177 params.output_shift = output_shift;
178 params.quantized_activation_min = activation_min;
179 params.quantized_activation_max = activation_max;
181 // TODO This should only be done once (although it takes only a few microseconds).
182 // Also, the user should be able to adjust the number of threads.
183 auto gemmlowp_context = std::make_unique<gemmlowp::GemmContext>();
184 gemmlowp_context->set_max_num_threads(static_cast<int>(std::thread::hardware_concurrency()));
186 tflite::optimized_ops::Conv(
187 params, getTensorShape(input()), getTensorData<uint8_t>(input()), getTensorShape(filter()),
188 getTensorData<uint8_t>(filter()), getTensorShape(bias()), getTensorData<int32_t>(bias()),
189 getTensorShape(output()), getTensorData<uint8_t>(output()), getTensorShape(_im2col.get()),
190 getTensorData<uint8_t>(_im2col.get()), gemmlowp_context.get());
193 } // namespace kernels
194 } // namespace luci_interpreter