60e6134ab9e46d8202a386666a0d3f3d8e57ea98
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / Conv2D.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "kernels/Conv2D.h"
18
19 #include "kernels/Utils.h"
20
21 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
22
23 #include <stdexcept>
24 #include <thread>
25
26 namespace luci_interpreter
27 {
28 namespace kernels
29 {
30
31 Conv2D::Conv2D(const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output,
32                const Conv2DParams &params)
33     : KernelWithParams<Conv2DParams>({input, filter, bias}, {output}, params)
34 {
35 }
36
37 void Conv2D::configure()
38 {
39   // TensorFlow Lite (as of v2.2.0) supports the following combinations of types:
40   //     | input filter bias  output |
41   // ----+---------------------------+
42   // (1) | float float  float float  |
43   // (2) | float int8   float float  | hybrid
44   // (3) | uint8 uint8  int32 uint8  | quantized
45   // (4) | int8  int8   int32 int8   | quantized per channel
46   //
47   // We only support (1) and (3) for now.
48   if (input()->element_type() == DataType::FLOAT32 && filter()->element_type() == DataType::FLOAT32)
49   {
50     assert(bias() == nullptr || bias()->element_type() == DataType::FLOAT32);
51   }
52   else if (input()->element_type() == DataType::U8 && filter()->element_type() == DataType::U8)
53   {
54     assert(bias() == nullptr || bias()->element_type() == DataType::S32);
55   }
56   else
57   {
58     throw std::runtime_error("Unsupported type.");
59   }
60   assert(output()->element_type() == input()->element_type());
61
62   const Shape &input_shape = input()->shape();
63   const Shape &filter_shape = filter()->shape();
64   assert(input_shape.num_dims() == 4 && filter_shape.num_dims() == 4);
65
66   const int32_t batches = input_shape.dim(0);
67   const int32_t input_height = input_shape.dim(1);
68   const int32_t input_width = input_shape.dim(2);
69   const int32_t output_depth = filter_shape.dim(0);
70   const int32_t filter_height = filter_shape.dim(1);
71   const int32_t filter_width = filter_shape.dim(2);
72   assert(filter_shape.dim(3) == input_shape.dim(3));
73
74   assert(bias() == nullptr ||
75          (bias()->shape().num_dims() == 1 && bias()->shape().dim(0) == output_depth));
76
77   const int32_t output_height =
78       computeOutputSize(_params.padding, input_height, filter_height, _params.stride_height,
79                         _params.dilation_height_factor);
80   const int32_t output_width =
81       computeOutputSize(_params.padding, input_width, filter_width, _params.stride_width,
82                         _params.dilation_width_factor);
83
84   _padding_height = computePadding(_params.stride_height, _params.dilation_height_factor,
85                                    input_height, filter_height, output_height);
86   _padding_width = computePadding(_params.stride_width, _params.dilation_width_factor, input_width,
87                                   filter_width, output_width);
88
89   output()->resize({batches, output_height, output_width, output_depth});
90
91   // Allocate tensor for Im2Col, if needed.
92   // The checks here should be aligned with the actual implementation.
93   const bool need_dilated_im2col =
94       _params.dilation_height_factor != 1 || _params.dilation_width_factor != 1;
95   const bool need_non_dilated_im2col = _params.stride_height != 1 || _params.stride_width != 1 ||
96                                        filter_height != 1 || filter_width != 1;
97   const bool need_im2col = need_dilated_im2col || need_non_dilated_im2col;
98   if (need_im2col)
99   {
100     const int input_depth = input_shape.dim(3);
101     Shape im2col_shape{batches, output_height, output_width,
102                        input_depth * filter_height * filter_width};
103     _im2col =
104         std::make_unique<Tensor>(input()->element_type(), im2col_shape, AffineQuantization{}, "");
105   }
106 }
107
108 void Conv2D::execute() const
109 {
110   switch (input()->element_type())
111   {
112     case DataType::FLOAT32:
113       if (filter()->element_type() == DataType::FLOAT32)
114       {
115         evalFloat();
116         break;
117       }
118       throw std::runtime_error("Unsupported type.");
119     case DataType::U8:
120       evalQuantized();
121       break;
122     default:
123       throw std::runtime_error("Unsupported type.");
124   }
125 }
126
127 void Conv2D::evalFloat() const
128 {
129   float activation_min{};
130   float activation_max{};
131   calculateActivationRange(_params.activation, &activation_min, &activation_max);
132
133   tflite::ConvParams params{};
134   params.padding_values.height = _padding_height;
135   params.padding_values.width = _padding_width;
136   params.stride_height = _params.stride_height;
137   params.stride_width = _params.stride_width;
138   params.dilation_height_factor = _params.dilation_height_factor;
139   params.dilation_width_factor = _params.dilation_width_factor;
140   params.float_activation_min = activation_min;
141   params.float_activation_max = activation_max;
142
143   tflite::optimized_ops::Conv(params, getTensorShape(input()), getTensorData<float>(input()),
144                               getTensorShape(filter()), getTensorData<float>(filter()),
145                               getTensorShape(bias()), getTensorData<float>(bias()),
146                               getTensorShape(output()), getTensorData<float>(output()),
147                               getTensorShape(_im2col.get()), getTensorData<float>(_im2col.get()));
148 }
149
150 void Conv2D::evalQuantized() const
151 {
152   const auto input_scale = static_cast<double>(input()->scale());
153   const auto filter_scale = static_cast<double>(filter()->scale());
154   const auto output_scale = static_cast<double>(output()->scale());
155
156   const double real_multiplier = input_scale * filter_scale / output_scale;
157   int32_t output_multiplier{};
158   int output_shift{};
159   quantizeMultiplier(real_multiplier, &output_multiplier, &output_shift);
160
161   int32_t activation_min{};
162   int32_t activation_max{};
163   calculateActivationRangeQuantized(_params.activation, output(), &activation_min, &activation_max);
164
165   tflite::ConvParams params{};
166   params.padding_values.height = _padding_height;
167   params.padding_values.width = _padding_width;
168   params.stride_height = _params.stride_height;
169   params.stride_width = _params.stride_width;
170   params.dilation_height_factor = _params.dilation_height_factor;
171   params.dilation_width_factor = _params.dilation_width_factor;
172   // The kernel expects input and filter zero points to be negated.
173   params.input_offset = -input()->zero_point();    // Note the '-'.
174   params.weights_offset = -filter()->zero_point(); // Note the '-'.
175   params.output_offset = output()->zero_point();
176   params.output_multiplier = output_multiplier;
177   params.output_shift = output_shift;
178   params.quantized_activation_min = activation_min;
179   params.quantized_activation_max = activation_max;
180
181   // TODO This should only be done once (although it takes only a few microseconds).
182   //  Also, the user should be able to adjust the number of threads.
183   auto gemmlowp_context = std::make_unique<gemmlowp::GemmContext>();
184   gemmlowp_context->set_max_num_threads(static_cast<int>(std::thread::hardware_concurrency()));
185
186   tflite::optimized_ops::Conv(
187       params, getTensorShape(input()), getTensorData<uint8_t>(input()), getTensorShape(filter()),
188       getTensorData<uint8_t>(filter()), getTensorShape(bias()), getTensorData<int32_t>(bias()),
189       getTensorShape(output()), getTensorData<uint8_t>(output()), getTensorShape(_im2col.get()),
190       getTensorData<uint8_t>(_im2col.get()), gemmlowp_context.get());
191 }
192
193 } // namespace kernels
194 } // namespace luci_interpreter