Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / pal / linux / PALConv2d.h
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef LUCI_INTERPRETER_PAL_CONV2D_H
18 #define LUCI_INTERPRETER_PAL_CONV2D_H
19
20 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
21 #include <tensorflow/lite/kernels/internal/reference/integer_ops/conv.h>
22
23 namespace luci_interpreter_pal
24 {
25 static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeShape &input_shape,
26                         const float *input_data, const tflite::RuntimeShape &filter_shape,
27                         const float *filter_data, const tflite::RuntimeShape &bias_shape,
28                         const float *bias_data, const tflite::RuntimeShape &output_shape,
29                         float *output_data, const tflite::RuntimeShape &im2col_shape,
30                         float *im2col_data)
31 {
32   if (im2col_data)
33   {
34     tflite::optimized_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
35                                 bias_shape, bias_data, output_shape, output_data, im2col_shape,
36                                 im2col_data);
37   }
38   else
39     tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
40                                 bias_shape, bias_data, output_shape, output_data,
41                                 tflite::RuntimeShape(), nullptr);
42 }
43
44 static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeShape &input_shape,
45                         const uint8 *input_data, const tflite::RuntimeShape &filter_shape,
46                         const uint8 *filter_data, const tflite::RuntimeShape &bias_shape,
47                         const int32 *bias_data, const tflite::RuntimeShape &output_shape,
48                         uint8 *output_data, const tflite::RuntimeShape &im2col_shape,
49                         uint8 *im2col_data)
50 {
51   // TODO This should only be done once (although it takes only a few microseconds).
52   //  Also, the user should be able to adjust the number of threads.
53   auto gemmlowp_context = std::make_unique<gemmlowp::GemmContext>();
54   gemmlowp_context->set_max_num_threads(static_cast<int>(std::thread::hardware_concurrency()));
55
56   tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
57                               bias_shape, bias_data, output_shape, output_data, im2col_shape,
58                               im2col_data, gemmlowp_context.get());
59 }
60
61 static inline void ConvPerChannel(const tflite::ConvParams &params, const int32_t *mult,
62                                   const int32_t *shifts, const tflite::RuntimeShape &input_shape,
63                                   const int8 *input_data, const tflite::RuntimeShape &filter_shape,
64                                   const int8 *filter_data, const tflite::RuntimeShape &bias_shape,
65                                   const int32 *bias_data, const tflite::RuntimeShape &output_shape,
66                                   int8 *output_data, const tflite::RuntimeShape &im2col_shape,
67                                   int8 *im2col_data)
68 {
69   (void)im2col_shape;
70   (void)im2col_data;
71   // TODO enable optimized version
72   tflite::reference_integer_ops::ConvPerChannel(params, mult, shifts, input_shape, input_data,
73                                                 filter_shape, filter_data, bias_shape, bias_data,
74                                                 output_shape, output_data);
75 }
76
77 } // namespace luci_interpreter_pal
78
79 #endif // LUCI_INTERPRETER_PAL_CONV2D_H