2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_OPTIMIZED_OPTIMIZED_UTILS_H__
19 #define __NNFW_CKER_OPTIMIZED_OPTIMIZED_UTILS_H__
21 #include "cker/Types.h"
22 #include "cker/Shape.h"
34 inline void ExtractPatchIntoBufferColumn(const Shape &input_shape, int w, int h, int b, int kheight,
35 int kwidth, int stride_width, int stride_height,
36 int pad_width, int pad_height, int in_width, int in_height,
37 int in_depth, int single_buffer_length, int buffer_id,
38 const T *in_data, T *conv_buffer_data, uint8_t zero_byte)
40 assert(input_shape.DimensionsCount() == 4);
41 // This chunk of code reshapes all the inputs corresponding to
42 // output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
43 const int kwidth_times_indepth = kwidth * in_depth;
44 const int inwidth_times_indepth = in_width * in_depth;
45 const int ih_ungated_start = h * stride_height - pad_height;
46 const int ih_ungated_end = (ih_ungated_start + kheight);
47 const int ih_end = std::min(ih_ungated_end, in_height);
48 const int iw_ungated_start = w * stride_width - pad_width;
49 const int iw_ungated_end = (iw_ungated_start + kwidth);
50 const int iw_end = std::min(iw_ungated_end, in_width);
51 // If the patch is off the edge of the input image, skip writing those rows
52 // and columns from the patch into the output array.
53 const int h_offset = std::max(0, -ih_ungated_start);
54 const int w_offset = std::max(0, -iw_ungated_start);
55 const int ih_start = std::max(0, ih_ungated_start);
56 const int iw_start = std::max(0, iw_ungated_start);
57 const int single_row_num = std::min(kwidth - w_offset, in_width - iw_start) * in_depth;
58 const int output_row_offset = (buffer_id * single_buffer_length);
59 int out_offset = output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
60 int in_offset = Offset(input_shape, b, ih_start, iw_start, 0);
62 // Express all of the calculations as padding around the input patch.
63 const int top_padding = h_offset;
64 const int bottom_padding = (ih_ungated_end - ih_end);
65 const int left_padding = w_offset;
66 const int right_padding = (iw_ungated_end - iw_end);
67 assert(single_row_num == ((kwidth - (left_padding + right_padding)) * in_depth));
69 // Write out zeroes to the elements representing the top rows of the input
70 // patch that are off the edge of the input image.
73 const int top_row_elements = (top_padding * kwidth * in_depth);
74 memset(conv_buffer_data + output_row_offset, zero_byte, (top_row_elements * sizeof(T)));
77 // If the patch is on the interior of the input image horizontally, just copy
78 // over the rows sequentially, otherwise add zero padding at the start or end.
79 if ((left_padding == 0) && (right_padding == 0))
81 for (int ih = ih_start; ih < ih_end; ++ih)
83 memcpy(conv_buffer_data + out_offset, in_data + in_offset, single_row_num * sizeof(T));
84 out_offset += kwidth_times_indepth;
85 in_offset += inwidth_times_indepth;
90 for (int ih = ih_start; ih < ih_end; ++ih)
94 const int left_start = (out_offset - (left_padding * in_depth));
95 memset(conv_buffer_data + left_start, zero_byte, (left_padding * in_depth * sizeof(T)));
97 memcpy(conv_buffer_data + out_offset, in_data + in_offset, single_row_num * sizeof(T));
98 if (right_padding > 0)
100 const int right_start = (out_offset + single_row_num);
101 memset(conv_buffer_data + right_start, zero_byte, (right_padding * in_depth * sizeof(T)));
103 out_offset += kwidth_times_indepth;
104 in_offset += inwidth_times_indepth;
108 // If the bottom of the patch falls off the input image, pad the values
109 // representing those input rows with zeroes.
110 if (bottom_padding > 0)
112 const int bottom_row_elements = (bottom_padding * kwidth * in_depth);
113 const int bottom_start =
114 output_row_offset + ((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
115 memset(conv_buffer_data + bottom_start, zero_byte, (bottom_row_elements * sizeof(T)));
119 template <typename T>
120 void DilatedIm2col(const ConvParams ¶ms, uint8_t zero_byte, const Shape &input_shape,
121 const T *input_data, const Shape &filter_shape, const Shape &output_shape,
131 throw std::runtime_error{"NYI: cker DilatedIm2col"};
134 template <typename T>
135 void Im2col(const ConvParams ¶ms, int kheight, int kwidth, uint8_t zero_byte,
136 const Shape &input_shape, const T *input_data, const Shape &output_shape,
139 const int stride_width = params.stride_width;
140 const int stride_height = params.stride_height;
141 const int pad_width = params.padding_values.width;
142 const int pad_height = params.padding_values.height;
143 assert(input_shape.DimensionsCount() == 4);
144 assert(output_shape.DimensionsCount() == 4);
146 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
147 const int input_depth = input_shape.Dims(3);
148 const int input_width = input_shape.Dims(2);
149 const int input_height = input_shape.Dims(1);
150 const int output_depth = output_shape.Dims(3);
151 const int output_width = output_shape.Dims(2);
152 const int output_height = output_shape.Dims(1);
155 // Loop over the output nodes.
156 for (int b = 0; b < batches; ++b)
158 for (int h = 0; h < output_height; ++h)
160 for (int w = 0; w < output_width; ++w)
162 ExtractPatchIntoBufferColumn(input_shape, w, h, b, kheight, kwidth, stride_width,
163 stride_height, pad_width, pad_height, input_width,
164 input_height, input_depth, output_depth, buffer_id, input_data,
165 output_data, zero_byte);
172 } // namespace optimized
176 #endif // __NNFW_CKER_OPTIMIZED_OPTIMIZED_UTILS_H__