Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / cmsisnn / PALConv2d.h
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef LUCI_INTERPRETER_PAL_CONV2D_H
18 #define LUCI_INTERPRETER_PAL_CONV2D_H
19
20 #include <tensorflow/lite/kernels/internal/reference/conv.h>
21 #include <tensorflow/lite/kernels/internal/reference/integer_ops/conv.h>
22 #include <arm_nn_types.h>
23 #include <arm_nnfunctions.h>
24
25 namespace luci_interpreter_pal
26 {
27 static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeShape &input_shape,
28                         const float *input_data, const tflite::RuntimeShape &filter_shape,
29                         const float *filter_data, const tflite::RuntimeShape &bias_shape,
30                         const float *bias_data, const tflite::RuntimeShape &output_shape,
31                         float *output_data, const tflite::RuntimeShape &scratchpad_shape,
32                         float *scratchpad_data)
33 {
34   (void)scratchpad_shape;
35   (void)scratchpad_data;
36   tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
37                               bias_shape, bias_data, output_shape, output_data,
38                               tflite::RuntimeShape(), nullptr);
39 }
40
41 static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeShape &input_shape,
42                         const uint8 *input_data, const tflite::RuntimeShape &filter_shape,
43                         const uint8 *filter_data, const tflite::RuntimeShape &bias_shape,
44                         const int32 *bias_data, const tflite::RuntimeShape &output_shape,
45                         uint8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
46                         uint8 *scratchpad_data)
47 {
48   (void)scratchpad_shape;
49   (void)scratchpad_data;
50   tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
51                               bias_shape, bias_data, output_shape, output_data, scratchpad_shape,
52                               scratchpad_data, nullptr);
53 }
54
55 static inline void ConvPerChannel(const tflite::ConvParams &params, const int32_t *mult,
56                                   const int32_t *shifts, const tflite::RuntimeShape &input_shape,
57                                   const int8 *input_data, const tflite::RuntimeShape &filter_shape,
58                                   const int8 *filter_data, const tflite::RuntimeShape &bias_shape,
59                                   const int32 *bias_data, const tflite::RuntimeShape &output_shape,
60                                   int8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
61                                   int8 *scratchpad_data)
62 {
63   if (scratchpad_data)
64   {
65     cmsis_nn_conv_params conv_params;
66     conv_params.dilation.h = params.dilation_height_factor;
67     conv_params.dilation.w = params.dilation_width_factor;
68
69     assert(conv_params.dilation.h == 1);
70     assert(conv_params.dilation.w == 1);
71
72     conv_params.input_offset = params.input_offset;
73     conv_params.output_offset = params.output_offset;
74     conv_params.stride.h = params.stride_height;
75     conv_params.stride.w = params.stride_width;
76     conv_params.padding.h = params.padding_values.height;
77     conv_params.padding.w = params.padding_values.width;
78     conv_params.activation.min = params.quantized_activation_min;
79     conv_params.activation.max = params.quantized_activation_max;
80
81     cmsis_nn_per_channel_quant_params quant_params;
82     quant_params.multiplier = const_cast<int32_t *>(mult);
83     quant_params.shift = const_cast<int32_t *>(shifts);
84
85     assert(conv_params.activation.min <= conv_params.activation.max);
86     assert(input_shape.DimensionsCount() == 4);
87     assert(filter_shape.DimensionsCount() == 4);
88     assert(output_shape.DimensionsCount() == 4);
89     const int batch_size = tflite::MatchingDim(input_shape, 0, output_shape, 0);
90     const int input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
91     const int output_depth = tflite::MatchingDim(filter_shape, 0, output_shape, 3);
92     if (bias_data)
93     {
94       assert(bias_shape.FlatSize() == output_depth);
95     }
96
97     cmsis_nn_dims input_dims;
98     input_dims.n = batch_size;
99     input_dims.h = input_shape.Dims(1);
100     input_dims.w = input_shape.Dims(2);
101     input_dims.c = input_depth;
102
103     cmsis_nn_dims filter_dims;
104     filter_dims.n = output_depth;
105     filter_dims.h = filter_shape.Dims(1);
106     filter_dims.w = filter_shape.Dims(2);
107     filter_dims.c = input_depth;
108
109     cmsis_nn_dims bias_dims;
110     bias_dims.n = 1;
111     bias_dims.h = 1;
112     bias_dims.w = 1;
113     bias_dims.c = output_depth;
114
115     cmsis_nn_dims output_dims;
116     output_dims.n = batch_size;
117     output_dims.h = output_shape.Dims(1);
118     output_dims.w = output_shape.Dims(2);
119     output_dims.c = output_depth;
120
121     cmsis_nn_context ctx;
122     ctx.buf = scratchpad_data;
123     ctx.size = scratchpad_shape.Dims(0);
124
125     auto res = arm_convolve_wrapper_s8(&ctx, &conv_params, &quant_params, &input_dims, input_data,
126                                        &filter_dims, filter_data, &bias_dims, bias_data,
127                                        &output_dims, output_data);
128     assert(res == ARM_MATH_SUCCESS);
129   }
130   else
131   {
132     tflite::reference_integer_ops::ConvPerChannel(params, mult, shifts, input_shape, input_data,
133                                                   filter_shape, filter_data, bias_shape, bias_data,
134                                                   output_shape, output_data);
135   }
136 }
137
138 static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
139                                          const luci_interpreter::DataType &input_data_type,
140                                          const tflite::ConvParams &params,
141                                          const tflite::RuntimeShape &input_shape,
142                                          const tflite::RuntimeShape &filter_shape,
143                                          const tflite::RuntimeShape &output_shape)
144 {
145   cmsis_nn_conv_params conv_params;
146   conv_params.dilation.h = params.dilation_height_factor;
147   conv_params.dilation.w = params.dilation_width_factor;
148
149   if (input_data_type == luci_interpreter::DataType::S8 && conv_params.dilation.h == 1 &&
150       conv_params.dilation.w == 1)
151   {
152     const int32_t batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
153     const int32_t input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
154     const int32_t output_depth = tflite::MatchingDim(filter_shape, 0, output_shape, 3);
155     const int32_t filter_height = filter_shape.Dims(1);
156     const int32_t filter_width = filter_shape.Dims(2);
157     const int32_t output_height = output_shape.Dims(1);
158     const int32_t output_width = output_shape.Dims(2);
159
160     conv_params.input_offset = params.input_offset;
161     conv_params.output_offset = params.output_offset;
162     conv_params.stride.h = params.stride_height;
163     conv_params.stride.w = params.stride_width;
164     conv_params.padding.h = params.padding_values.height;
165     conv_params.padding.w = params.padding_values.width;
166
167     cmsis_nn_dims input_dims;
168     input_dims.n = batches;
169     input_dims.h = input_shape.Dims(1);
170     input_dims.w = input_shape.Dims(2);
171     input_dims.c = input_depth;
172
173     cmsis_nn_dims filter_dims;
174     filter_dims.n = output_depth;
175     filter_dims.h = filter_height;
176     filter_dims.w = filter_width;
177     filter_dims.c = input_depth;
178
179     cmsis_nn_dims output_dims;
180     output_dims.n = batches;
181     output_dims.h = output_height;
182     output_dims.w = output_width;
183     output_dims.c = output_depth;
184
185     const int32_t buf_size = arm_convolve_wrapper_s8_get_buffer_size(&conv_params, &input_dims,
186                                                                      &filter_dims, &output_dims);
187
188     luci_interpreter::Shape scratchpad_shape{buf_size};
189     scratchpad->resize(scratchpad_shape);
190   }
191   else
192   {
193     scratchpad->set_allocatable(false);
194   }
195 }
196
197 } // namespace luci_interpreter_pal
198
199 #endif // LUCI_INTERPRETER_PAL_CONV2D_H