5f5e4dd8ed16179d3a4be654ba7a8313753cf2f9
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / cmsisnn / PALDepthwiseConv2d.h
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H
19 #define LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H
20
21 #include <tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h>
22 #include <tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h>
23 #include <tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h>
24 #include <arm_nnfunctions.h>
25
26 namespace luci_interpreter_pal
27 {
28 template <typename T>
29 static inline void
30 DepthwiseConvPerChannel(const tflite::DepthwiseParams &params, const int32_t *output_multiplier,
31                         const int32_t *output_shift, const tflite::RuntimeShape &input_shape,
32                         const T *input_data, const tflite::RuntimeShape &filter_shape,
33                         const T *filter_data, const tflite::RuntimeShape &bias_shape,
34                         const int32_t *bias_data, const tflite::RuntimeShape &output_shape,
35                         T *output_data, const tflite::RuntimeShape &scratchpad_shape,
36                         T *scratchpad_data)
37 {
38   {
39     // MARK: At this moment this operation is not supported
40     assert(false && "DepthwiseConvPerChannel NYI");
41     (void)params;
42     (void)output_multiplier;
43     (void)output_shift;
44     (void)input_shape;
45     (void)output_data;
46     (void)input_data;
47     (void)filter_shape;
48     (void)filter_data;
49     (void)bias_shape;
50     (void)bias_data;
51     (void)output_shape;
52     (void)output_data;
53     (void)scratchpad_shape;
54     (void)scratchpad_data;
55   }
56 }
57
58 template <>
59 inline void DepthwiseConvPerChannel<int8_t>(
60   const tflite::DepthwiseParams &params, const int32_t *output_multiplier,
61   const int32_t *output_shift, const tflite::RuntimeShape &input_shape, const int8_t *input_data,
62   const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
63   const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
64   const tflite::RuntimeShape &output_shape, int8_t *output_data,
65   const tflite::RuntimeShape &scratchpad_shape, int8_t *scratchpad_data)
66 {
67   if (scratchpad_data)
68   {
69     cmsis_nn_dw_conv_params dw_conv_params;
70     dw_conv_params.dilation.h = params.dilation_height_factor;
71     dw_conv_params.dilation.w = params.dilation_width_factor;
72     assert(dw_conv_params.dilation.h == 1);
73     assert(dw_conv_params.dilation.w == 1);
74
75     dw_conv_params.input_offset = params.input_offset;
76     dw_conv_params.output_offset = params.output_offset;
77     dw_conv_params.stride.h = params.stride_height;
78     dw_conv_params.stride.w = params.stride_width;
79     dw_conv_params.padding.h = params.padding_values.height;
80     dw_conv_params.padding.w = params.padding_values.width;
81
82     dw_conv_params.activation.min = params.quantized_activation_min;
83     dw_conv_params.activation.max = params.quantized_activation_max;
84     dw_conv_params.ch_mult = params.depth_multiplier;
85
86     cmsis_nn_per_channel_quant_params quant_params;
87     int32_t output_multiplier = params.output_multiplier;
88     int32_t output_shift = params.output_shift;
89
90     quant_params.multiplier = &output_multiplier;
91     quant_params.shift = &output_shift;
92
93     assert(dw_conv_params.activation.min <= dw_conv_params.activation.max);
94     const int batch_size = tflite::MatchingDim(input_shape, 0, output_shape, 0);
95     const int output_depth = tflite::MatchingDim(filter_shape, 3, output_shape, 3);
96     if (bias_data)
97     {
98       assert(bias_shape.FlatSize() == output_depth);
99     }
100
101     cmsis_nn_dims input_dims;
102     input_dims.n = batch_size;
103     input_dims.h = input_shape.Dims(1);
104     input_dims.w = input_shape.Dims(2);
105     input_dims.c = input_shape.Dims(3);
106
107     cmsis_nn_dims filter_dims;
108     filter_dims.n = filter_shape.Dims(0);
109     filter_dims.h = filter_shape.Dims(1);
110     filter_dims.w = filter_shape.Dims(2);
111     filter_dims.c = output_depth;
112
113     cmsis_nn_dims bias_dims;
114     bias_dims.n = 1;
115     bias_dims.h = 1;
116     bias_dims.w = 1;
117     bias_dims.c = output_depth;
118
119     cmsis_nn_dims output_dims;
120     output_dims.n = batch_size;
121     output_dims.h = output_shape.Dims(1);
122     output_dims.w = output_shape.Dims(2);
123     output_dims.c = output_depth;
124
125     cmsis_nn_context ctx;
126     ctx.buf = scratchpad_data;
127     ctx.size = scratchpad_shape.Dims(0);
128
129     auto res = arm_depthwise_conv_wrapper_s8(&ctx, &dw_conv_params, &quant_params, &input_dims,
130                                              input_data, &filter_dims, filter_data, &bias_dims,
131                                              bias_data, &output_dims, output_data);
132     assert(res == ARM_MATH_SUCCESS);
133   }
134   else
135   {
136     tflite::reference_integer_ops::DepthwiseConvPerChannel(
137       params, output_multiplier, output_shift, input_shape, input_data, filter_shape, filter_data,
138       bias_shape, bias_data, output_shape, output_data);
139   }
140 }
141
142 static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
143                                          const tflite::DepthwiseParams &params,
144                                          const luci_interpreter::DataType &input_data_type,
145                                          const tflite::RuntimeShape &input_shape,
146                                          const tflite::RuntimeShape &filter_shape,
147                                          const tflite::RuntimeShape &output_shape)
148 {
149   cmsis_nn_dw_conv_params dw_conv_params;
150   dw_conv_params.dilation.h = params.dilation_height_factor;
151   dw_conv_params.dilation.w = params.dilation_width_factor;
152
153   if (input_data_type == luci_interpreter::DataType::S8 && dw_conv_params.dilation.h == 1 &&
154       dw_conv_params.dilation.w == 1)
155   {
156     const int batch_size = tflite::MatchingDim(input_shape, 0, output_shape, 0);
157     const int output_depth = tflite::MatchingDim(filter_shape, 3, output_shape, 3);
158
159     cmsis_nn_dims input_dims;
160     input_dims.n = batch_size;
161     input_dims.h = input_shape.Dims(1);
162     input_dims.w = input_shape.Dims(2);
163     input_dims.c = input_shape.Dims(3);
164
165     cmsis_nn_dims filter_dims;
166     filter_dims.n = filter_shape.Dims(0);
167     filter_dims.h = filter_shape.Dims(1);
168     filter_dims.w = filter_shape.Dims(2);
169     filter_dims.c = output_depth;
170
171     cmsis_nn_dims output_dims;
172     output_dims.n = batch_size;
173     output_dims.h = output_shape.Dims(1);
174     output_dims.w = output_shape.Dims(2);
175     output_dims.c = output_depth;
176
177     const int32_t buf_size = arm_depthwise_conv_wrapper_s8_get_buffer_size(
178       &dw_conv_params, &input_dims, &filter_dims, &output_dims);
179
180     auto data_type_size = static_cast<int32_t>(luci_interpreter::getDataTypeSize(input_data_type));
181
182     luci_interpreter::Shape scratchpad_shape{buf_size * data_type_size};
183     scratchpad->resize(scratchpad_shape);
184   }
185   else
186   {
187     scratchpad->set_allocatable(false);
188   }
189 }
190
191 } // namespace luci_interpreter_pal
192
193 #endif // LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H