2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef LUCI_INTERPRETER_PAL_CONV2D_H
19 #define LUCI_INTERPRETER_PAL_CONV2D_H
21 #include <tensorflow/lite/kernels/internal/reference/conv.h>
22 #include <tensorflow/lite/kernels/internal/reference/integer_ops/conv.h>
23 #include <arm_nn_types.h>
24 #include <arm_nnfunctions.h>
26 namespace luci_interpreter_pal
28 static inline void Conv(const tflite::ConvParams ¶ms, const tflite::RuntimeShape &input_shape,
29 const float *input_data, const tflite::RuntimeShape &filter_shape,
30 const float *filter_data, const tflite::RuntimeShape &bias_shape,
31 const float *bias_data, const tflite::RuntimeShape &output_shape,
32 float *output_data, const tflite::RuntimeShape &scratchpad_shape,
33 float *scratchpad_data)
35 (void)scratchpad_shape;
36 (void)scratchpad_data;
37 tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
38 bias_shape, bias_data, output_shape, output_data,
39 tflite::RuntimeShape(), nullptr);
42 static inline void Conv(const tflite::ConvParams ¶ms, const tflite::RuntimeShape &input_shape,
43 const uint8 *input_data, const tflite::RuntimeShape &filter_shape,
44 const uint8 *filter_data, const tflite::RuntimeShape &bias_shape,
45 const int32 *bias_data, const tflite::RuntimeShape &output_shape,
46 uint8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
47 uint8 *scratchpad_data)
49 (void)scratchpad_shape;
50 (void)scratchpad_data;
51 tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
52 bias_shape, bias_data, output_shape, output_data, scratchpad_shape,
53 scratchpad_data, nullptr);
56 static inline void ConvPerChannel(const tflite::ConvParams ¶ms, const int32_t *mult,
57 const int32_t *shifts, const tflite::RuntimeShape &input_shape,
58 const int8 *input_data, const tflite::RuntimeShape &filter_shape,
59 const int8 *filter_data, const tflite::RuntimeShape &bias_shape,
60 const int32 *bias_data, const tflite::RuntimeShape &output_shape,
61 int8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
62 int8 *scratchpad_data)
66 cmsis_nn_conv_params conv_params;
67 conv_params.dilation.h = params.dilation_height_factor;
68 conv_params.dilation.w = params.dilation_width_factor;
70 assert(conv_params.dilation.h == 1);
71 assert(conv_params.dilation.w == 1);
73 conv_params.input_offset = params.input_offset;
74 conv_params.output_offset = params.output_offset;
75 conv_params.stride.h = params.stride_height;
76 conv_params.stride.w = params.stride_width;
77 conv_params.padding.h = params.padding_values.height;
78 conv_params.padding.w = params.padding_values.width;
79 conv_params.activation.min = params.quantized_activation_min;
80 conv_params.activation.max = params.quantized_activation_max;
82 cmsis_nn_per_channel_quant_params quant_params;
83 quant_params.multiplier = const_cast<int32_t *>(mult);
84 quant_params.shift = const_cast<int32_t *>(shifts);
86 assert(conv_params.activation.min <= conv_params.activation.max);
87 assert(input_shape.DimensionsCount() == 4);
88 assert(filter_shape.DimensionsCount() == 4);
89 assert(output_shape.DimensionsCount() == 4);
90 const int batch_size = tflite::MatchingDim(input_shape, 0, output_shape, 0);
91 const int input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
92 const int output_depth = tflite::MatchingDim(filter_shape, 0, output_shape, 3);
95 assert(bias_shape.FlatSize() == output_depth);
98 cmsis_nn_dims input_dims;
99 input_dims.n = batch_size;
100 input_dims.h = input_shape.Dims(1);
101 input_dims.w = input_shape.Dims(2);
102 input_dims.c = input_depth;
104 cmsis_nn_dims filter_dims;
105 filter_dims.n = output_depth;
106 filter_dims.h = filter_shape.Dims(1);
107 filter_dims.w = filter_shape.Dims(2);
108 filter_dims.c = input_depth;
110 cmsis_nn_dims bias_dims;
114 bias_dims.c = output_depth;
116 cmsis_nn_dims output_dims;
117 output_dims.n = batch_size;
118 output_dims.h = output_shape.Dims(1);
119 output_dims.w = output_shape.Dims(2);
120 output_dims.c = output_depth;
122 cmsis_nn_context ctx;
123 ctx.buf = scratchpad_data;
124 ctx.size = scratchpad_shape.Dims(0);
126 auto res = arm_convolve_wrapper_s8(&ctx, &conv_params, &quant_params, &input_dims, input_data,
127 &filter_dims, filter_data, &bias_dims, bias_data,
128 &output_dims, output_data);
129 assert(res == ARM_MATH_SUCCESS);
133 tflite::reference_integer_ops::ConvPerChannel(params, mult, shifts, input_shape, input_data,
134 filter_shape, filter_data, bias_shape, bias_data,
135 output_shape, output_data);
139 static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
140 const luci_interpreter::DataType &input_data_type,
141 const tflite::ConvParams ¶ms,
142 const tflite::RuntimeShape &input_shape,
143 const tflite::RuntimeShape &filter_shape,
144 const tflite::RuntimeShape &output_shape)
146 cmsis_nn_conv_params conv_params;
147 conv_params.dilation.h = params.dilation_height_factor;
148 conv_params.dilation.w = params.dilation_width_factor;
150 if (input_data_type == luci_interpreter::DataType::S8 && conv_params.dilation.h == 1 &&
151 conv_params.dilation.w == 1)
153 const int32_t batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
154 const int32_t input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
155 const int32_t output_depth = tflite::MatchingDim(filter_shape, 0, output_shape, 3);
156 const int32_t filter_height = filter_shape.Dims(1);
157 const int32_t filter_width = filter_shape.Dims(2);
158 const int32_t output_height = output_shape.Dims(1);
159 const int32_t output_width = output_shape.Dims(2);
161 conv_params.input_offset = params.input_offset;
162 conv_params.output_offset = params.output_offset;
163 conv_params.stride.h = params.stride_height;
164 conv_params.stride.w = params.stride_width;
165 conv_params.padding.h = params.padding_values.height;
166 conv_params.padding.w = params.padding_values.width;
168 cmsis_nn_dims input_dims;
169 input_dims.n = batches;
170 input_dims.h = input_shape.Dims(1);
171 input_dims.w = input_shape.Dims(2);
172 input_dims.c = input_depth;
174 cmsis_nn_dims filter_dims;
175 filter_dims.n = output_depth;
176 filter_dims.h = filter_height;
177 filter_dims.w = filter_width;
178 filter_dims.c = input_depth;
180 cmsis_nn_dims output_dims;
181 output_dims.n = batches;
182 output_dims.h = output_height;
183 output_dims.w = output_width;
184 output_dims.c = output_depth;
186 const int32_t buf_size = arm_convolve_wrapper_s8_get_buffer_size(&conv_params, &input_dims,
187 &filter_dims, &output_dims);
189 luci_interpreter::Shape scratchpad_shape{buf_size};
190 scratchpad->resize(scratchpad_shape);
194 scratchpad->set_allocatable(false);
198 } // namespace luci_interpreter_pal
200 #endif // LUCI_INTERPRETER_PAL_CONV2D_H