Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compute / ruy / include / ruy / Utils.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_RUY_UTILS_H__
19 #define __NNFW_RUY_UTILS_H__
20
21 #include "Types.h"
22 #include "Shape.h"
23
24 #include <stdexcept>
25
26 namespace nnfw
27 {
28 namespace ruy
29 {
30 template <typename T>
31 inline void ExtractPatchIntoBufferColumn(const Shape &input_shape, int w, int h, int b, int kheight,
32                                          int kwidth, int stride_width, int stride_height,
33                                          int pad_width, int pad_height, int in_width, int in_height,
34                                          int in_depth, int single_buffer_length, int buffer_id,
35                                          const T *in_data, T *conv_buffer_data, uint8_t zero_byte)
36 {
37   assert(input_shape.DimensionsCount() == 4);
38   // This chunk of code reshapes all the inputs corresponding to
39   // output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
40   const int kwidth_times_indepth = kwidth * in_depth;
41   const int inwidth_times_indepth = in_width * in_depth;
42   const int ih_ungated_start = h * stride_height - pad_height;
43   const int ih_ungated_end = (ih_ungated_start + kheight);
44   const int ih_end = std::min(ih_ungated_end, in_height);
45   const int iw_ungated_start = w * stride_width - pad_width;
46   const int iw_ungated_end = (iw_ungated_start + kwidth);
47   const int iw_end = std::min(iw_ungated_end, in_width);
48   // If the patch is off the edge of the input image, skip writing those rows
49   // and columns from the patch into the output array.
50   const int h_offset = std::max(0, -ih_ungated_start);
51   const int w_offset = std::max(0, -iw_ungated_start);
52   const int ih_start = std::max(0, ih_ungated_start);
53   const int iw_start = std::max(0, iw_ungated_start);
54   const int single_row_num = std::min(kwidth - w_offset, in_width - iw_start) * in_depth;
55   const int output_row_offset = (buffer_id * single_buffer_length);
56   int out_offset = output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
57   int in_offset = Offset(input_shape, b, ih_start, iw_start, 0);
58
59   // Express all of the calculations as padding around the input patch.
60   const int top_padding = h_offset;
61   const int bottom_padding = (ih_ungated_end - ih_end);
62   const int left_padding = w_offset;
63   const int right_padding = (iw_ungated_end - iw_end);
64   assert(single_row_num == ((kwidth - (left_padding + right_padding)) * in_depth));
65
66   // Write out zeroes to the elements representing the top rows of the input
67   // patch that are off the edge of the input image.
68   if (top_padding > 0)
69   {
70     const int top_row_elements = (top_padding * kwidth * in_depth);
71     memset(conv_buffer_data + output_row_offset, zero_byte, (top_row_elements * sizeof(T)));
72   }
73
74   // If the patch is on the interior of the input image horizontally, just copy
75   // over the rows sequentially, otherwise add zero padding at the start or end.
76   if ((left_padding == 0) && (right_padding == 0))
77   {
78     for (int ih = ih_start; ih < ih_end; ++ih)
79     {
80       memcpy(conv_buffer_data + out_offset, in_data + in_offset, single_row_num * sizeof(T));
81       out_offset += kwidth_times_indepth;
82       in_offset += inwidth_times_indepth;
83     }
84   }
85   else
86   {
87     for (int ih = ih_start; ih < ih_end; ++ih)
88     {
89       if (left_padding > 0)
90       {
91         const int left_start = (out_offset - (left_padding * in_depth));
92         memset(conv_buffer_data + left_start, zero_byte, (left_padding * in_depth * sizeof(T)));
93       }
94       memcpy(conv_buffer_data + out_offset, in_data + in_offset, single_row_num * sizeof(T));
95       if (right_padding > 0)
96       {
97         const int right_start = (out_offset + single_row_num);
98         memset(conv_buffer_data + right_start, zero_byte, (right_padding * in_depth * sizeof(T)));
99       }
100       out_offset += kwidth_times_indepth;
101       in_offset += inwidth_times_indepth;
102     }
103   }
104
105   // If the bottom of the patch falls off the input image, pad the values
106   // representing those input rows with zeroes.
107   if (bottom_padding > 0)
108   {
109     const int bottom_row_elements = (bottom_padding * kwidth * in_depth);
110     const int bottom_start =
111       output_row_offset + ((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
112     memset(conv_buffer_data + bottom_start, zero_byte, (bottom_row_elements * sizeof(T)));
113   }
114 }
115
116 // Supports per-batch zero_byte for per-batch asymmetric quantized inputs.
117 template <typename T>
118 void DilatedIm2col(const ConvParams &params, const Shape &input_shape, const T *input_data,
119                    const Shape &filter_shape, const Shape &output_shape, T *im2col_data,
120                    const int32_t *zero_bytes, const int zero_bytes_len)
121 {
122   const int stride_width = params.stride_width;
123   const int stride_height = params.stride_height;
124   const int dilation_width_factor = params.dilation_width_factor;
125   const int dilation_height_factor = params.dilation_height_factor;
126   const int pad_width = params.padding_values.width;
127   const int pad_height = params.padding_values.height;
128   assert(input_shape.DimensionsCount() == 4);
129   assert(filter_shape.DimensionsCount() == 4);
130   assert(output_shape.DimensionsCount() == 4);
131
132   // For dilated convolution, the input pixels are not contiguous therefore we
133   // can't use the same optimizations as Im2Col(). Though note this code would
134   // work fine for the non-dilated case too (though likely a bit slower).
135   assert(dilation_width_factor != 1 || dilation_height_factor != 1);
136   assert(im2col_data);
137   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
138   const int input_height = input_shape.Dims(1);
139   const int input_width = input_shape.Dims(2);
140   const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
141   const int filter_height = filter_shape.Dims(1);
142   const int filter_width = filter_shape.Dims(2);
143   const int output_height = output_shape.Dims(1);
144   const int output_width = output_shape.Dims(2);
145   MatchingDim(output_shape, 3, filter_shape, 0);
146
147   // Construct the MxN sized im2col matrix.
148   // The rows M, are sub-ordered B x H x W
149   const Shape row_shape({1, batches, output_height, output_width});
150   // The columns, N, are sub-ordered Kh x Kw x Din
151   const Shape col_shape({1, filter_height, filter_width, input_depth});
152   // Use dimensions M and N to construct dims for indexing directly into im2col
153   const Shape im2col_shape({1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
154
155   // Loop through the output rows (B x H x W)
156   for (int batch = 0; batch < batches; ++batch)
157   {
158     const T zero_byte =
159       zero_bytes_len > 1 ? static_cast<T>(zero_bytes[batch]) : static_cast<T>(zero_bytes[0]);
160     for (int out_y = 0; out_y < output_height; ++out_y)
161     {
162       for (int out_x = 0; out_x < output_width; ++out_x)
163       {
164         // Each im2col row is an output pixel. Arrange the input data in this
165         // row in an order we can conveniently multiply with the filter data.
166         int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
167         const int in_x_origin = (out_x * stride_width) - pad_width;
168         const int in_y_origin = (out_y * stride_height) - pad_height;
169         // Loop through all the pixels of the filter (Kh x Kw)
170         for (int filter_y = 0; filter_y < filter_height; ++filter_y)
171         {
172           const int in_y = in_y_origin + dilation_height_factor * filter_y;
173           if ((in_y >= 0) && (in_y < input_height))
174           {
175             // Filter row is within the input data.
176             // Loop through all the filter pixels in this row.
177             for (int filter_x = 0; filter_x < filter_width; ++filter_x)
178             {
179               const int in_x = in_x_origin + dilation_width_factor * filter_x;
180               int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
181               T *dst = im2col_data + Offset(im2col_shape, 0, 0, row_offset, col_offset);
182               if ((in_x >= 0) && (in_x < input_width))
183               {
184                 // Filter pixel is within the input, copy the input data.
185                 T const *src = input_data + Offset(input_shape, batch, in_y, in_x, 0);
186                 memcpy(dst, src, input_depth * sizeof(T));
187               }
188               else
189               {
190                 // Filter pixel is outside the input, zero it out.
191                 memset(dst, zero_byte, input_depth * sizeof(T));
192               }
193             }
194           }
195           else
196           {
197             // Filter row is outside the input, zero out the entire filter row.
198             int col_offset = Offset(col_shape, 0, filter_y, 0, 0);
199             T *dst = im2col_data + Offset(im2col_shape, 0, 0, row_offset, col_offset);
200             memset(dst, zero_byte, filter_width * input_depth * sizeof(T));
201           }
202         }
203       }
204     }
205   }
206 }
207
208 template <typename T>
209 void DilatedIm2col(const ConvParams &params, uint8_t zero_byte, const Shape &input_shape,
210                    const T *input_data, const Shape &filter_shape, const Shape &output_shape,
211                    T *im2col_data)
212 {
213   const int32_t zero_point = static_cast<int32_t>(zero_byte);
214   DilatedIm2col<T>(params, input_shape, input_data, filter_shape, output_shape, im2col_data,
215                    &zero_point, 1);
216 }
217
218 template <typename T>
219 void Im2col(const ConvParams &params, int kheight, int kwidth, uint8_t zero_byte,
220             const Shape &input_shape, const T *input_data, const Shape &output_shape,
221             T *output_data)
222 {
223   const int stride_width = params.stride_width;
224   const int stride_height = params.stride_height;
225   const int pad_width = params.padding_values.width;
226   const int pad_height = params.padding_values.height;
227   assert(input_shape.DimensionsCount() == 4);
228   assert(output_shape.DimensionsCount() == 4);
229
230   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
231   const int input_depth = input_shape.Dims(3);
232   const int input_width = input_shape.Dims(2);
233   const int input_height = input_shape.Dims(1);
234   const int output_depth = output_shape.Dims(3);
235   const int output_width = output_shape.Dims(2);
236   const int output_height = output_shape.Dims(1);
237
238   int buffer_id = 0;
239   // Loop over the output nodes.
240   for (int b = 0; b < batches; ++b)
241   {
242     for (int h = 0; h < output_height; ++h)
243     {
244       for (int w = 0; w < output_width; ++w)
245       {
246         ExtractPatchIntoBufferColumn(input_shape, w, h, b, kheight, kwidth, stride_width,
247                                      stride_height, pad_width, pad_height, input_width,
248                                      input_height, input_depth, output_depth, buffer_id, input_data,
249                                      output_data, zero_byte);
250         ++buffer_id;
251       }
252     }
253   }
254 }
255
256 } // namespace ruy
257 } // namespace nnfw
258
259 #endif // __NNFW_RUY_UTILS_H__