7ec64a76848a0ee143aa76a476e1f9d54e771a54
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / cmsisnn / PALConv2d.h
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_PAL_CONV2D_H
19 #define LUCI_INTERPRETER_PAL_CONV2D_H
20
21 #include <tensorflow/lite/kernels/internal/reference/conv.h>
22 #include <tensorflow/lite/kernels/internal/reference/integer_ops/conv.h>
23 #include <arm_nn_types.h>
24 #include <arm_nnfunctions.h>
25
26 namespace luci_interpreter_pal
27 {
28 static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeShape &input_shape,
29                         const float *input_data, const tflite::RuntimeShape &filter_shape,
30                         const float *filter_data, const tflite::RuntimeShape &bias_shape,
31                         const float *bias_data, const tflite::RuntimeShape &output_shape,
32                         float *output_data, const tflite::RuntimeShape &scratchpad_shape,
33                         float *scratchpad_data)
34 {
35   (void)scratchpad_shape;
36   (void)scratchpad_data;
37   tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
38                               bias_shape, bias_data, output_shape, output_data,
39                               tflite::RuntimeShape(), nullptr);
40 }
41
42 static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeShape &input_shape,
43                         const uint8 *input_data, const tflite::RuntimeShape &filter_shape,
44                         const uint8 *filter_data, const tflite::RuntimeShape &bias_shape,
45                         const int32 *bias_data, const tflite::RuntimeShape &output_shape,
46                         uint8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
47                         uint8 *scratchpad_data)
48 {
49   (void)scratchpad_shape;
50   (void)scratchpad_data;
51   tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
52                               bias_shape, bias_data, output_shape, output_data, scratchpad_shape,
53                               scratchpad_data, nullptr);
54 }
55
56 static inline void ConvPerChannel(const tflite::ConvParams &params, const int32_t *mult,
57                                   const int32_t *shifts, const tflite::RuntimeShape &input_shape,
58                                   const int8 *input_data, const tflite::RuntimeShape &filter_shape,
59                                   const int8 *filter_data, const tflite::RuntimeShape &bias_shape,
60                                   const int32 *bias_data, const tflite::RuntimeShape &output_shape,
61                                   int8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
62                                   int8 *scratchpad_data)
63 {
64   if (scratchpad_data)
65   {
66     cmsis_nn_conv_params conv_params;
67     conv_params.dilation.h = params.dilation_height_factor;
68     conv_params.dilation.w = params.dilation_width_factor;
69
70     assert(conv_params.dilation.h == 1);
71     assert(conv_params.dilation.w == 1);
72
73     conv_params.input_offset = params.input_offset;
74     conv_params.output_offset = params.output_offset;
75     conv_params.stride.h = params.stride_height;
76     conv_params.stride.w = params.stride_width;
77     conv_params.padding.h = params.padding_values.height;
78     conv_params.padding.w = params.padding_values.width;
79     conv_params.activation.min = params.quantized_activation_min;
80     conv_params.activation.max = params.quantized_activation_max;
81
82     cmsis_nn_per_channel_quant_params quant_params;
83     quant_params.multiplier = const_cast<int32_t *>(mult);
84     quant_params.shift = const_cast<int32_t *>(shifts);
85
86     assert(conv_params.activation.min <= conv_params.activation.max);
87     assert(input_shape.DimensionsCount() == 4);
88     assert(filter_shape.DimensionsCount() == 4);
89     assert(output_shape.DimensionsCount() == 4);
90     const int batch_size = tflite::MatchingDim(input_shape, 0, output_shape, 0);
91     const int input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
92     const int output_depth = tflite::MatchingDim(filter_shape, 0, output_shape, 3);
93     if (bias_data)
94     {
95       assert(bias_shape.FlatSize() == output_depth);
96     }
97
98     cmsis_nn_dims input_dims;
99     input_dims.n = batch_size;
100     input_dims.h = input_shape.Dims(1);
101     input_dims.w = input_shape.Dims(2);
102     input_dims.c = input_depth;
103
104     cmsis_nn_dims filter_dims;
105     filter_dims.n = output_depth;
106     filter_dims.h = filter_shape.Dims(1);
107     filter_dims.w = filter_shape.Dims(2);
108     filter_dims.c = input_depth;
109
110     cmsis_nn_dims bias_dims;
111     bias_dims.n = 1;
112     bias_dims.h = 1;
113     bias_dims.w = 1;
114     bias_dims.c = output_depth;
115
116     cmsis_nn_dims output_dims;
117     output_dims.n = batch_size;
118     output_dims.h = output_shape.Dims(1);
119     output_dims.w = output_shape.Dims(2);
120     output_dims.c = output_depth;
121
122     cmsis_nn_context ctx;
123     ctx.buf = scratchpad_data;
124     ctx.size = scratchpad_shape.Dims(0);
125
126     auto res = arm_convolve_wrapper_s8(&ctx, &conv_params, &quant_params, &input_dims, input_data,
127                                        &filter_dims, filter_data, &bias_dims, bias_data,
128                                        &output_dims, output_data);
129     assert(res == ARM_MATH_SUCCESS);
130   }
131   else
132   {
133     tflite::reference_integer_ops::ConvPerChannel(params, mult, shifts, input_shape, input_data,
134                                                   filter_shape, filter_data, bias_shape, bias_data,
135                                                   output_shape, output_data);
136   }
137 }
138
139 static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
140                                          const luci_interpreter::DataType &input_data_type,
141                                          const tflite::ConvParams &params,
142                                          const tflite::RuntimeShape &input_shape,
143                                          const tflite::RuntimeShape &filter_shape,
144                                          const tflite::RuntimeShape &output_shape)
145 {
146   cmsis_nn_conv_params conv_params;
147   conv_params.dilation.h = params.dilation_height_factor;
148   conv_params.dilation.w = params.dilation_width_factor;
149
150   if (input_data_type == luci_interpreter::DataType::S8 && conv_params.dilation.h == 1 &&
151       conv_params.dilation.w == 1)
152   {
153     const int32_t batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
154     const int32_t input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
155     const int32_t output_depth = tflite::MatchingDim(filter_shape, 0, output_shape, 3);
156     const int32_t filter_height = filter_shape.Dims(1);
157     const int32_t filter_width = filter_shape.Dims(2);
158     const int32_t output_height = output_shape.Dims(1);
159     const int32_t output_width = output_shape.Dims(2);
160
161     conv_params.input_offset = params.input_offset;
162     conv_params.output_offset = params.output_offset;
163     conv_params.stride.h = params.stride_height;
164     conv_params.stride.w = params.stride_width;
165     conv_params.padding.h = params.padding_values.height;
166     conv_params.padding.w = params.padding_values.width;
167
168     cmsis_nn_dims input_dims;
169     input_dims.n = batches;
170     input_dims.h = input_shape.Dims(1);
171     input_dims.w = input_shape.Dims(2);
172     input_dims.c = input_depth;
173
174     cmsis_nn_dims filter_dims;
175     filter_dims.n = output_depth;
176     filter_dims.h = filter_height;
177     filter_dims.w = filter_width;
178     filter_dims.c = input_depth;
179
180     cmsis_nn_dims output_dims;
181     output_dims.n = batches;
182     output_dims.h = output_height;
183     output_dims.w = output_width;
184     output_dims.c = output_depth;
185
186     const int32_t buf_size = arm_convolve_wrapper_s8_get_buffer_size(&conv_params, &input_dims,
187                                                                      &filter_dims, &output_dims);
188
189     luci_interpreter::Shape scratchpad_shape{buf_size};
190     scratchpad->resize(scratchpad_shape);
191   }
192   else
193   {
194     scratchpad->set_allocatable(false);
195   }
196 }
197
198 } // namespace luci_interpreter_pal
199
200 #endif // LUCI_INTERPRETER_PAL_CONV2D_H