1bf191bf8b72644187522d204ab603e7abac623d
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / Conv.h
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_CONV_H__
19 #define __NNFW_CKER_CONV_H__
20
21 #include "cker/Types.h"
22 #include "cker/Shape.h"
23 #include "cker/Utils.h"
24 #include "cker/operation/reference/Conv.h"
25 #include "cker/operation/optimized/Conv.h"
26 #include <vector>
27
28 namespace nnfw
29 {
30 namespace cker
31 {
32
33 namespace
34 {
35 // Naive implementation of transpose for floats. Could be optimized to be more
36 // cache friendly, but for now it's a one-time cost on first run, and we would
37 // prefer to remove the need to do this at all eventually.
38 inline void TransposeFloatTensor(const float *input_data, const nnfw::cker::Shape &output_shape,
39                                  float *output_data)
40 {
41   const int rows = output_shape.Dims(1);
42   const int cols = output_shape.Dims(0);
43   for (int i = 0; i < rows; ++i)
44   {
45     for (int j = 0; j < cols; ++j)
46     {
47       const float in_value = input_data[i * cols + j];
48       output_data[j * rows + i] = in_value;
49     }
50   }
51 }
52 } // namespace
53
54 class Conv
55 {
56 public:
57   Conv()
58       : _modified_filter_data(), _im2col_data(), _im2col_shape(4), _need_im2col(false),
59         _prepared(false)
60   {
61   }
62
63   void prepare(const Shape &filter_shape, const float *filter_data, PaddingType padding_type,
64                bool &is_replaced_weights)
65   {
66     if (!_prepared)
67     {
68       if (usableMultiThreaded(padding_type))
69       {
70         transposeFilter(filter_shape, filter_data, is_replaced_weights);
71       }
72       _prepared = true;
73     }
74   }
75
76   void prepareQuant(const Shape &input_shape, const Shape &kernel_shape, const Shape &output_shape,
77                     uint32_t stride_width, uint32_t stride_height)
78   {
79     if (!_prepared)
80     {
81       IsRequiredIm2col(input_shape, kernel_shape, output_shape, stride_width, stride_height);
82       _prepared = true;
83     }
84   }
85
86   void operator()(const ConvParams &params, const Shape &input_shape, const float *input_data,
87                   const Shape &filter_shape, const float *filter_data, const Shape &bias_shape,
88                   const float *bias_data, const Shape &output_shape, float *output_data)
89   {
90     if (usableMultiThreaded(params.padding_type))
91     {
92       bool transposed_in_execution = false;
93       if (!_prepared)
94       {
95         // This means that filter is not constant
96         // TODO Apply optimized kernel if multithreaded kernel is slower than optimized kernel by
97         // transposing filter data
98         transposeFilter(filter_shape, filter_data, transposed_in_execution);
99       }
100       multithreaded::Conv(params, input_shape, input_data, filter_shape, &_modified_filter_data[0],
101                           bias_shape, bias_data, output_shape, output_data);
102     }
103     else
104     {
105       // TODO Support optimized kernel
106       reference::Conv(params, input_shape, input_data, filter_shape, filter_data, bias_shape,
107                       bias_data, output_shape, output_data);
108     }
109   }
110
111   void operator()(const ConvParams &params, const Shape &input_shape, const uint8_t *input_data,
112                   const Shape &filter_shape, const uint8_t *filter_data, const Shape &bias_shape,
113                   const int32_t *bias_data, const Shape &output_shape, uint8_t *output_data)
114   {
115     if (!_prepared)
116     {
117       // This means that input or output are dynamic or filter is not constant
118       IsRequiredIm2col(input_shape, filter_shape, output_shape, params.stride_width,
119                        params.stride_height);
120     }
121
122     uint8_t *im2col_raw_data = _im2col_data.data();
123     optimized::Conv(params, input_shape, input_data, filter_shape, filter_data, bias_shape,
124                     bias_data, output_shape, output_data, _im2col_shape, im2col_raw_data);
125   }
126
127 private:
128   bool usableMultiThreaded(PaddingType padding_type)
129   {
130     return padding_type != PaddingType::kNone && std::thread::hardware_concurrency() > 1;
131   }
132
133   void transposeFilter(const Shape &filter_shape, const float *filter_data,
134                        bool &is_replaced_weights)
135   {
136     const auto output_depth = filter_shape.Dims(0);
137     const Shape hwcn_filter_shape{filter_shape.FlatSize() / output_depth, output_depth};
138     _modified_filter_data.resize(hwcn_filter_shape.FlatSize());
139     TransposeFloatTensor(filter_data, hwcn_filter_shape, &_modified_filter_data[0]);
140     is_replaced_weights = true;
141   }
142
143   void IsRequiredIm2col(const Shape &input_shape, const Shape &kernel_shape,
144                         const Shape &output_shape, uint32_t stride_width, uint32_t stride_height)
145   {
146     _need_im2col = stride_width != 1 || stride_height != 1 || kernel_shape.Dims(1) != 1 ||
147                    kernel_shape.Dims(2) != 1;
148     if (_need_im2col)
149     {
150       _im2col_shape.SetDim(0, output_shape.Dims(0));
151       _im2col_shape.SetDim(1, output_shape.Dims(1));
152       _im2col_shape.SetDim(2, output_shape.Dims(2));
153       _im2col_shape.SetDim(3, input_shape.Dims(3) * kernel_shape.Dims(1) * kernel_shape.Dims(2));
154       _im2col_data.resize(_im2col_shape.FlatSize());
155     }
156   }
157
158 private:
159   std::vector<float> _modified_filter_data;
160   std::vector<uint8_t> _im2col_data;
161   Shape _im2col_shape;
162   bool _need_im2col;
163   bool _prepared;
164 };
165 } // namespace cker
166 } // namespace nnfw
167
168 #endif // __NNFW_CKER_CONCATENATION_H_