ae1f9e78e906d1a678be5475d477a1a65f79139e
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / optimized / OptimizedUtils.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_OPTIMIZED_OPTIMIZED_UTILS_H__
19 #define __NNFW_CKER_OPTIMIZED_OPTIMIZED_UTILS_H__
20
21 #include "cker/Types.h"
22 #include "cker/Shape.h"
23
24 #include <stdexcept>
25
26 namespace nnfw
27 {
28 namespace cker
29 {
30 namespace optimized
31 {
32
33 template <typename T>
34 inline void ExtractPatchIntoBufferColumn(const Shape &input_shape, int w, int h, int b, int kheight,
35                                          int kwidth, int stride_width, int stride_height,
36                                          int pad_width, int pad_height, int in_width, int in_height,
37                                          int in_depth, int single_buffer_length, int buffer_id,
38                                          const T *in_data, T *conv_buffer_data, uint8_t zero_byte)
39 {
40   assert(input_shape.DimensionsCount() == 4);
41   // This chunk of code reshapes all the inputs corresponding to
42   // output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
43   const int kwidth_times_indepth = kwidth * in_depth;
44   const int inwidth_times_indepth = in_width * in_depth;
45   const int ih_ungated_start = h * stride_height - pad_height;
46   const int ih_ungated_end = (ih_ungated_start + kheight);
47   const int ih_end = std::min(ih_ungated_end, in_height);
48   const int iw_ungated_start = w * stride_width - pad_width;
49   const int iw_ungated_end = (iw_ungated_start + kwidth);
50   const int iw_end = std::min(iw_ungated_end, in_width);
51   // If the patch is off the edge of the input image, skip writing those rows
52   // and columns from the patch into the output array.
53   const int h_offset = std::max(0, -ih_ungated_start);
54   const int w_offset = std::max(0, -iw_ungated_start);
55   const int ih_start = std::max(0, ih_ungated_start);
56   const int iw_start = std::max(0, iw_ungated_start);
57   const int single_row_num = std::min(kwidth - w_offset, in_width - iw_start) * in_depth;
58   const int output_row_offset = (buffer_id * single_buffer_length);
59   int out_offset = output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
60   int in_offset = Offset(input_shape, b, ih_start, iw_start, 0);
61
62   // Express all of the calculations as padding around the input patch.
63   const int top_padding = h_offset;
64   const int bottom_padding = (ih_ungated_end - ih_end);
65   const int left_padding = w_offset;
66   const int right_padding = (iw_ungated_end - iw_end);
67   assert(single_row_num == ((kwidth - (left_padding + right_padding)) * in_depth));
68
69   // Write out zeroes to the elements representing the top rows of the input
70   // patch that are off the edge of the input image.
71   if (top_padding > 0)
72   {
73     const int top_row_elements = (top_padding * kwidth * in_depth);
74     memset(conv_buffer_data + output_row_offset, zero_byte, (top_row_elements * sizeof(T)));
75   }
76
77   // If the patch is on the interior of the input image horizontally, just copy
78   // over the rows sequentially, otherwise add zero padding at the start or end.
79   if ((left_padding == 0) && (right_padding == 0))
80   {
81     for (int ih = ih_start; ih < ih_end; ++ih)
82     {
83       memcpy(conv_buffer_data + out_offset, in_data + in_offset, single_row_num * sizeof(T));
84       out_offset += kwidth_times_indepth;
85       in_offset += inwidth_times_indepth;
86     }
87   }
88   else
89   {
90     for (int ih = ih_start; ih < ih_end; ++ih)
91     {
92       if (left_padding > 0)
93       {
94         const int left_start = (out_offset - (left_padding * in_depth));
95         memset(conv_buffer_data + left_start, zero_byte, (left_padding * in_depth * sizeof(T)));
96       }
97       memcpy(conv_buffer_data + out_offset, in_data + in_offset, single_row_num * sizeof(T));
98       if (right_padding > 0)
99       {
100         const int right_start = (out_offset + single_row_num);
101         memset(conv_buffer_data + right_start, zero_byte, (right_padding * in_depth * sizeof(T)));
102       }
103       out_offset += kwidth_times_indepth;
104       in_offset += inwidth_times_indepth;
105     }
106   }
107
108   // If the bottom of the patch falls off the input image, pad the values
109   // representing those input rows with zeroes.
110   if (bottom_padding > 0)
111   {
112     const int bottom_row_elements = (bottom_padding * kwidth * in_depth);
113     const int bottom_start =
114         output_row_offset + ((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
115     memset(conv_buffer_data + bottom_start, zero_byte, (bottom_row_elements * sizeof(T)));
116   }
117 }
118
119 // Supports per-batch zero_byte for per-batch asymmetric quantized inputs.
120 template <typename T>
121 void DilatedIm2col(const ConvParams &params, const Shape &input_shape, const T *input_data,
122                    const Shape &filter_shape, const Shape &output_shape, T *im2col_data,
123                    const int32_t *zero_bytes, const int zero_bytes_len)
124 {
125   const int stride_width = params.stride_width;
126   const int stride_height = params.stride_height;
127   const int dilation_width_factor = params.dilation_width_factor;
128   const int dilation_height_factor = params.dilation_height_factor;
129   const int pad_width = params.padding_values.width;
130   const int pad_height = params.padding_values.height;
131   assert(input_shape.DimensionsCount() == 4);
132   assert(filter_shape.DimensionsCount() == 4);
133   assert(output_shape.DimensionsCount() == 4);
134
135   // For dilated convolution, the input pixels are not contiguous therefore we
136   // can't use the same optimizations as Im2Col(). Though note this code would
137   // work fine for the non-dilated case too (though likely a bit slower).
138   assert(dilation_width_factor != 1 || dilation_height_factor != 1);
139   assert(im2col_data);
140   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
141   const int input_height = input_shape.Dims(1);
142   const int input_width = input_shape.Dims(2);
143   const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
144   const int filter_height = filter_shape.Dims(1);
145   const int filter_width = filter_shape.Dims(2);
146   const int output_height = output_shape.Dims(1);
147   const int output_width = output_shape.Dims(2);
148   MatchingDim(output_shape, 3, filter_shape, 0);
149
150   // Construct the MxN sized im2col matrix.
151   // The rows M, are sub-ordered B x H x W
152   const Shape row_shape({1, batches, output_height, output_width});
153   // The columns, N, are sub-ordered Kh x Kw x Din
154   const Shape col_shape({1, filter_height, filter_width, input_depth});
155   // Use dimensions M and N to construct dims for indexing directly into im2col
156   const Shape im2col_shape({1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
157
158   // Loop through the output rows (B x H x W)
159   for (int batch = 0; batch < batches; ++batch)
160   {
161     const T zero_byte =
162         zero_bytes_len > 1 ? static_cast<T>(zero_bytes[batch]) : static_cast<T>(zero_bytes[0]);
163     for (int out_y = 0; out_y < output_height; ++out_y)
164     {
165       for (int out_x = 0; out_x < output_width; ++out_x)
166       {
167         // Each im2col row is an output pixel. Arrange the input data in this
168         // row in an order we can conveniently multiply with the filter data.
169         int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
170         const int in_x_origin = (out_x * stride_width) - pad_width;
171         const int in_y_origin = (out_y * stride_height) - pad_height;
172         // Loop through all the pixels of the filter (Kh x Kw)
173         for (int filter_y = 0; filter_y < filter_height; ++filter_y)
174         {
175           const int in_y = in_y_origin + dilation_height_factor * filter_y;
176           if ((in_y >= 0) && (in_y < input_height))
177           {
178             // Filter row is within the input data.
179             // Loop through all the filter pixels in this row.
180             for (int filter_x = 0; filter_x < filter_width; ++filter_x)
181             {
182               const int in_x = in_x_origin + dilation_width_factor * filter_x;
183               int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
184               T *dst = im2col_data + Offset(im2col_shape, 0, 0, row_offset, col_offset);
185               if ((in_x >= 0) && (in_x < input_width))
186               {
187                 // Filter pixel is within the input, copy the input data.
188                 T const *src = input_data + Offset(input_shape, batch, in_y, in_x, 0);
189                 memcpy(dst, src, input_depth * sizeof(T));
190               }
191               else
192               {
193                 // Filter pixel is outside the input, zero it out.
194                 memset(dst, zero_byte, input_depth * sizeof(T));
195               }
196             }
197           }
198           else
199           {
200             // Filter row is outside the input, zero out the entire filter row.
201             int col_offset = Offset(col_shape, 0, filter_y, 0, 0);
202             T *dst = im2col_data + Offset(im2col_shape, 0, 0, row_offset, col_offset);
203             memset(dst, zero_byte, filter_width * input_depth * sizeof(T));
204           }
205         }
206       }
207     }
208   }
209 }
210
211 template <typename T>
212 void DilatedIm2col(const ConvParams &params, uint8_t zero_byte, const Shape &input_shape,
213                    const T *input_data, const Shape &filter_shape, const Shape &output_shape,
214                    T *im2col_data)
215 {
216   const int32_t zero_point = static_cast<int32_t>(zero_byte);
217   DilatedIm2col<T>(params, input_shape, input_data, filter_shape, output_shape, im2col_data,
218                    &zero_point, 1);
219 }
220
221 template <typename T>
222 void Im2col(const ConvParams &params, int kheight, int kwidth, uint8_t zero_byte,
223             const Shape &input_shape, const T *input_data, const Shape &output_shape,
224             T *output_data)
225 {
226   const int stride_width = params.stride_width;
227   const int stride_height = params.stride_height;
228   const int pad_width = params.padding_values.width;
229   const int pad_height = params.padding_values.height;
230   assert(input_shape.DimensionsCount() == 4);
231   assert(output_shape.DimensionsCount() == 4);
232
233   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
234   const int input_depth = input_shape.Dims(3);
235   const int input_width = input_shape.Dims(2);
236   const int input_height = input_shape.Dims(1);
237   const int output_depth = output_shape.Dims(3);
238   const int output_width = output_shape.Dims(2);
239   const int output_height = output_shape.Dims(1);
240
241   int buffer_id = 0;
242   // Loop over the output nodes.
243   for (int b = 0; b < batches; ++b)
244   {
245     for (int h = 0; h < output_height; ++h)
246     {
247       for (int w = 0; w < output_width; ++w)
248       {
249         ExtractPatchIntoBufferColumn(input_shape, w, h, b, kheight, kwidth, stride_width,
250                                      stride_height, pad_width, pad_height, input_width,
251                                      input_height, input_depth, output_depth, buffer_id, input_data,
252                                      output_data, zero_byte);
253         ++buffer_id;
254       }
255     }
256   }
257 }
258
259 } // namespace optimized
260 } // namespace cker
261 } // namespace nnfw
262
263 #endif // __NNFW_CKER_OPTIMIZED_OPTIMIZED_UTILS_H__