2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef LUCI_INTERPRETER_PAL_CONV2D_H
18 #define LUCI_INTERPRETER_PAL_CONV2D_H
20 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
21 #include <tensorflow/lite/kernels/internal/reference/integer_ops/conv.h>
23 namespace luci_interpreter_pal
25 static inline void Conv(const tflite::ConvParams ¶ms, const tflite::RuntimeShape &input_shape,
26 const float *input_data, const tflite::RuntimeShape &filter_shape,
27 const float *filter_data, const tflite::RuntimeShape &bias_shape,
28 const float *bias_data, const tflite::RuntimeShape &output_shape,
29 float *output_data, const tflite::RuntimeShape &im2col_shape,
34 tflite::optimized_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
35 bias_shape, bias_data, output_shape, output_data, im2col_shape,
39 tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
40 bias_shape, bias_data, output_shape, output_data,
41 tflite::RuntimeShape(), nullptr);
44 static inline void Conv(const tflite::ConvParams ¶ms, const tflite::RuntimeShape &input_shape,
45 const uint8 *input_data, const tflite::RuntimeShape &filter_shape,
46 const uint8 *filter_data, const tflite::RuntimeShape &bias_shape,
47 const int32 *bias_data, const tflite::RuntimeShape &output_shape,
48 uint8 *output_data, const tflite::RuntimeShape &im2col_shape,
51 // TODO This should only be done once (although it takes only a few microseconds).
52 // Also, the user should be able to adjust the number of threads.
53 auto gemmlowp_context = std::make_unique<gemmlowp::GemmContext>();
54 gemmlowp_context->set_max_num_threads(static_cast<int>(std::thread::hardware_concurrency()));
56 tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
57 bias_shape, bias_data, output_shape, output_data, im2col_shape,
58 im2col_data, gemmlowp_context.get());
61 static inline void ConvPerChannel(const tflite::ConvParams ¶ms, const int32_t *mult,
62 const int32_t *shifts, const tflite::RuntimeShape &input_shape,
63 const int8 *input_data, const tflite::RuntimeShape &filter_shape,
64 const int8 *filter_data, const tflite::RuntimeShape &bias_shape,
65 const int32 *bias_data, const tflite::RuntimeShape &output_shape,
66 int8 *output_data, const tflite::RuntimeShape &im2col_shape,
71 // TODO enable optimized version
72 tflite::reference_integer_ops::ConvPerChannel(params, mult, shifts, input_shape, input_data,
73 filter_shape, filter_data, bias_shape, bias_data,
74 output_shape, output_data);
77 } // namespace luci_interpreter_pal
79 #endif // LUCI_INTERPRETER_PAL_CONV2D_H