2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H
19 #define LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H
21 #include <tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h>
22 #include <tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h>
23 #include <tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h>
24 #include <arm_nnfunctions.h>
26 namespace luci_interpreter_pal
30 DepthwiseConvPerChannel(const tflite::DepthwiseParams ¶ms, const int32_t *output_multiplier,
31 const int32_t *output_shift, const tflite::RuntimeShape &input_shape,
32 const T *input_data, const tflite::RuntimeShape &filter_shape,
33 const T *filter_data, const tflite::RuntimeShape &bias_shape,
34 const int32_t *bias_data, const tflite::RuntimeShape &output_shape,
35 T *output_data, const tflite::RuntimeShape &scratchpad_shape,
39 // MARK: At this moment this operation is not supported
40 assert(false && "DepthwiseConvPerChannel NYI");
42 (void)output_multiplier;
53 (void)scratchpad_shape;
54 (void)scratchpad_data;
59 inline void DepthwiseConvPerChannel<int8_t>(
60 const tflite::DepthwiseParams ¶ms, const int32_t *output_multiplier,
61 const int32_t *output_shift, const tflite::RuntimeShape &input_shape, const int8_t *input_data,
62 const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
63 const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
64 const tflite::RuntimeShape &output_shape, int8_t *output_data,
65 const tflite::RuntimeShape &scratchpad_shape, int8_t *scratchpad_data)
69 cmsis_nn_dw_conv_params dw_conv_params;
70 dw_conv_params.dilation.h = params.dilation_height_factor;
71 dw_conv_params.dilation.w = params.dilation_width_factor;
72 assert(dw_conv_params.dilation.h == 1);
73 assert(dw_conv_params.dilation.w == 1);
75 dw_conv_params.input_offset = params.input_offset;
76 dw_conv_params.output_offset = params.output_offset;
77 dw_conv_params.stride.h = params.stride_height;
78 dw_conv_params.stride.w = params.stride_width;
79 dw_conv_params.padding.h = params.padding_values.height;
80 dw_conv_params.padding.w = params.padding_values.width;
82 dw_conv_params.activation.min = params.quantized_activation_min;
83 dw_conv_params.activation.max = params.quantized_activation_max;
84 dw_conv_params.ch_mult = params.depth_multiplier;
86 cmsis_nn_per_channel_quant_params quant_params;
87 int32_t output_multiplier = params.output_multiplier;
88 int32_t output_shift = params.output_shift;
90 quant_params.multiplier = &output_multiplier;
91 quant_params.shift = &output_shift;
93 assert(dw_conv_params.activation.min <= dw_conv_params.activation.max);
94 const int batch_size = tflite::MatchingDim(input_shape, 0, output_shape, 0);
95 const int output_depth = tflite::MatchingDim(filter_shape, 3, output_shape, 3);
98 assert(bias_shape.FlatSize() == output_depth);
101 cmsis_nn_dims input_dims;
102 input_dims.n = batch_size;
103 input_dims.h = input_shape.Dims(1);
104 input_dims.w = input_shape.Dims(2);
105 input_dims.c = input_shape.Dims(3);
107 cmsis_nn_dims filter_dims;
108 filter_dims.n = filter_shape.Dims(0);
109 filter_dims.h = filter_shape.Dims(1);
110 filter_dims.w = filter_shape.Dims(2);
111 filter_dims.c = output_depth;
113 cmsis_nn_dims bias_dims;
117 bias_dims.c = output_depth;
119 cmsis_nn_dims output_dims;
120 output_dims.n = batch_size;
121 output_dims.h = output_shape.Dims(1);
122 output_dims.w = output_shape.Dims(2);
123 output_dims.c = output_depth;
125 cmsis_nn_context ctx;
126 ctx.buf = scratchpad_data;
127 ctx.size = scratchpad_shape.Dims(0);
129 auto res = arm_depthwise_conv_wrapper_s8(&ctx, &dw_conv_params, &quant_params, &input_dims,
130 input_data, &filter_dims, filter_data, &bias_dims,
131 bias_data, &output_dims, output_data);
132 assert(res == ARM_MATH_SUCCESS);
136 tflite::reference_integer_ops::DepthwiseConvPerChannel(
137 params, output_multiplier, output_shift, input_shape, input_data, filter_shape, filter_data,
138 bias_shape, bias_data, output_shape, output_data);
142 static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
143 const tflite::DepthwiseParams ¶ms,
144 const luci_interpreter::DataType &input_data_type,
145 const tflite::RuntimeShape &input_shape,
146 const tflite::RuntimeShape &filter_shape,
147 const tflite::RuntimeShape &output_shape)
149 cmsis_nn_dw_conv_params dw_conv_params;
150 dw_conv_params.dilation.h = params.dilation_height_factor;
151 dw_conv_params.dilation.w = params.dilation_width_factor;
153 if (input_data_type == luci_interpreter::DataType::S8 && dw_conv_params.dilation.h == 1 &&
154 dw_conv_params.dilation.w == 1)
156 const int batch_size = tflite::MatchingDim(input_shape, 0, output_shape, 0);
157 const int output_depth = tflite::MatchingDim(filter_shape, 3, output_shape, 3);
159 cmsis_nn_dims input_dims;
160 input_dims.n = batch_size;
161 input_dims.h = input_shape.Dims(1);
162 input_dims.w = input_shape.Dims(2);
163 input_dims.c = input_shape.Dims(3);
165 cmsis_nn_dims filter_dims;
166 filter_dims.n = filter_shape.Dims(0);
167 filter_dims.h = filter_shape.Dims(1);
168 filter_dims.w = filter_shape.Dims(2);
169 filter_dims.c = output_depth;
171 cmsis_nn_dims output_dims;
172 output_dims.n = batch_size;
173 output_dims.h = output_shape.Dims(1);
174 output_dims.w = output_shape.Dims(2);
175 output_dims.c = output_depth;
177 const int32_t buf_size = arm_depthwise_conv_wrapper_s8_get_buffer_size(
178 &dw_conv_params, &input_dims, &filter_dims, &output_dims);
180 auto data_type_size = static_cast<int32_t>(luci_interpreter::getDataTypeSize(input_data_type));
182 luci_interpreter::Shape scratchpad_shape{buf_size * data_type_size};
183 scratchpad->resize(scratchpad_shape);
187 scratchpad->set_allocatable(false);
191 } // namespace luci_interpreter_pal
193 #endif // LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H