Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / Conv.h
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_CONV_H__
19 #define __NNFW_CKER_CONV_H__
20
21 #include "cker/Types.h"
22 #include "cker/Shape.h"
23 #include "cker/Utils.h"
24 #include "cker/operation/reference/Conv.h"
25 #include "cker/operation/optimized/Conv.h"
26 #include <iostream>
27 #include <vector>
28
29 namespace nnfw
30 {
31 namespace cker
32 {
33
34 namespace
35 {
36 // Naive implementation of transpose for floats. Could be optimized to be more
37 // cache friendly, but for now it's a one-time cost on first run, and we would
38 // prefer to remove the need to do this at all eventually.
39 inline void TransposeFloatTensor(const float *input_data, const nnfw::cker::Shape &output_shape,
40                                  float *output_data)
41 {
42   const int rows = output_shape.Dims(1);
43   const int cols = output_shape.Dims(0);
44   for (int i = 0; i < rows; ++i)
45   {
46     for (int j = 0; j < cols; ++j)
47     {
48       const float in_value = input_data[i * cols + j];
49       output_data[j * rows + i] = in_value;
50     }
51   }
52 }
53 } // namespace
54
55 class Conv
56 {
57 public:
58   Conv() : _modified_filter_data(), _im2col_shape(4), _need_im2col(false), _prepared(false) {}
59
60   void prepare(const Shape &filter_shape, const float *filter_data, PaddingType padding_type,
61                bool &is_replaced_weights, uint32_t dilationWidthFactor,
62                uint32_t dilationHeightFactor)
63   {
64     if (!_prepared)
65     {
66       if (usableMultiThreaded(padding_type, dilationWidthFactor, dilationHeightFactor))
67       {
68         transposeFilter(filter_shape, filter_data, is_replaced_weights);
69       }
70       _prepared = true;
71     }
72   }
73
74   void prepareQuant(const Shape &input_shape, const Shape &kernel_shape, const Shape &output_shape,
75                     uint32_t stride_width, uint32_t stride_height)
76   {
77     if (!_prepared)
78     {
79       IsRequiredIm2col(input_shape, kernel_shape, output_shape, stride_width, stride_height);
80       _prepared = true;
81     }
82   }
83
84   void operator()(const ConvParams &params, const Shape &input_shape, const float *input_data,
85                   const Shape &filter_shape, const float *filter_data, const Shape &bias_shape,
86                   const float *bias_data, const Shape &output_shape, float *output_data)
87   {
88     if (usableMultiThreaded(params.padding_type, params.dilation_width_factor,
89                             params.dilation_height_factor))
90     {
91       bool transposed_in_execution = false;
92       if (!_prepared)
93       {
94         // This means that filter is not constant
95         // TODO Apply optimized kernel if multithreaded kernel is slower than optimized kernel by
96         // transposing filter data
97         transposeFilter(filter_shape, filter_data, transposed_in_execution);
98       }
99       multithreaded::Conv(params, input_shape, input_data, filter_shape, &_modified_filter_data[0],
100                           bias_shape, bias_data, output_shape, output_data);
101     }
102     else
103     {
104       // TODO Support optimized kernel
105       reference::Conv(params, input_shape, input_data, filter_shape, filter_data, bias_shape,
106                       bias_data, output_shape, output_data);
107     }
108   }
109
110   void operator()(const ConvParams &params, const Shape &input_shape, const uint8_t *input_data,
111                   const Shape &filter_shape, const uint8_t *filter_data, const Shape &bias_shape,
112                   const int32_t *bias_data, const Shape &output_shape, uint8_t *output_data)
113   {
114     if (!_prepared)
115     {
116       // This means that input or output are dynamic or filter is not constant
117       IsRequiredIm2col(input_shape, filter_shape, output_shape, params.stride_width,
118                        params.stride_height);
119     }
120
121     int im2col_size = _need_im2col ? _im2col_shape.FlatSize() : 1;
122
123     // Use heap if size is larger than 8MB
124     if (im2col_size > 8 * 1024 * 1024)
125     {
126       std::unique_ptr<uint8_t[]> im2col_data = std::make_unique<uint8_t[]>(im2col_size);
127       optimized::Conv(params, input_shape, input_data, filter_shape, filter_data, bias_shape,
128                       bias_data, output_shape, output_data, _im2col_shape, im2col_data.get());
129     }
130     else
131     {
132       uint8_t im2col_data[im2col_size];
133       optimized::Conv(params, input_shape, input_data, filter_shape, filter_data, bias_shape,
134                       bias_data, output_shape, output_data, _im2col_shape, im2col_data);
135     }
136   }
137
138 private:
139   bool usableMultiThreaded(PaddingType padding_type, uint32_t dilation_width_factor,
140                            int32_t dilation_height_factor)
141   {
142     return padding_type != PaddingType::kNone && std::thread::hardware_concurrency() > 1 &&
143            dilation_width_factor == 1 && dilation_height_factor == 1;
144   }
145
146   void transposeFilter(const Shape &filter_shape, const float *filter_data,
147                        bool &is_replaced_weights)
148   {
149     const auto output_depth = filter_shape.Dims(0);
150     const Shape hwcn_filter_shape{filter_shape.FlatSize() / output_depth, output_depth};
151     _modified_filter_data.resize(hwcn_filter_shape.FlatSize());
152     TransposeFloatTensor(filter_data, hwcn_filter_shape, &_modified_filter_data[0]);
153     is_replaced_weights = true;
154   }
155
156   void IsRequiredIm2col(const Shape &input_shape, const Shape &kernel_shape,
157                         const Shape &output_shape, uint32_t stride_width, uint32_t stride_height)
158   {
159     _need_im2col = stride_width != 1 || stride_height != 1 || kernel_shape.Dims(1) != 1 ||
160                    kernel_shape.Dims(2) != 1;
161     if (_need_im2col)
162     {
163       _im2col_shape.SetDim(0, output_shape.Dims(0));
164       _im2col_shape.SetDim(1, output_shape.Dims(1));
165       _im2col_shape.SetDim(2, output_shape.Dims(2));
166       _im2col_shape.SetDim(3, input_shape.Dims(3) * kernel_shape.Dims(1) * kernel_shape.Dims(2));
167     }
168   }
169
170 private:
171   std::vector<float> _modified_filter_data;
172   Shape _im2col_shape;
173   bool _need_im2col;
174   bool _prepared;
175 };
176 } // namespace cker
177 } // namespace nnfw
178
179 #endif // __NNFW_CKER_CONCATENATION_H_