1c2658679912602faa52ef19aafe90fcb8ab664b
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / linux / PALConv2d.h
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_PAL_CONV2D_H
19 #define LUCI_INTERPRETER_PAL_CONV2D_H
20
21 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
22 #include <tensorflow/lite/kernels/internal/reference/integer_ops/conv.h>
23
24 namespace luci_interpreter_pal
25 {
26 static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeShape &input_shape,
27                         const float *input_data, const tflite::RuntimeShape &filter_shape,
28                         const float *filter_data, const tflite::RuntimeShape &bias_shape,
29                         const float *bias_data, const tflite::RuntimeShape &output_shape,
30                         float *output_data, const tflite::RuntimeShape &scratchpad_shape,
31                         float *scratchpad_data)
32 {
33   (void)scratchpad_shape;
34   if (scratchpad_data)
35   {
36     const int32_t batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
37     const int32_t input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
38     const int32_t output_height = output_shape.Dims(1);
39     const int32_t output_width = output_shape.Dims(2);
40     const int32_t filter_height = filter_shape.Dims(1);
41     const int32_t filter_width = filter_shape.Dims(2);
42     tflite::RuntimeShape im2col_shape{batches, output_height, output_width,
43                                       input_depth * filter_height * filter_width};
44
45     tflite::optimized_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
46                                 bias_shape, bias_data, output_shape, output_data, im2col_shape,
47                                 scratchpad_data);
48   }
49   else
50     tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
51                                 bias_shape, bias_data, output_shape, output_data,
52                                 tflite::RuntimeShape(), nullptr);
53 }
54
55 static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeShape &input_shape,
56                         const uint8 *input_data, const tflite::RuntimeShape &filter_shape,
57                         const uint8 *filter_data, const tflite::RuntimeShape &bias_shape,
58                         const int32 *bias_data, const tflite::RuntimeShape &output_shape,
59                         uint8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
60                         uint8 *scratchpad_data)
61 {
62   // TODO This should only be done once (although it takes only a few microseconds).
63   //  Also, the user should be able to adjust the number of threads.
64   auto gemmlowp_context = std::make_unique<gemmlowp::GemmContext>();
65   gemmlowp_context->set_max_num_threads(static_cast<int>(std::thread::hardware_concurrency()));
66
67   tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
68                               bias_shape, bias_data, output_shape, output_data, scratchpad_shape,
69                               scratchpad_data, gemmlowp_context.get());
70 }
71
72 static inline void ConvPerChannel(const tflite::ConvParams &params, const int32_t *mult,
73                                   const int32_t *shifts, const tflite::RuntimeShape &input_shape,
74                                   const int8 *input_data, const tflite::RuntimeShape &filter_shape,
75                                   const int8 *filter_data, const tflite::RuntimeShape &bias_shape,
76                                   const int32 *bias_data, const tflite::RuntimeShape &output_shape,
77                                   int8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
78                                   int8 *scratchpad_data)
79 {
80   (void)scratchpad_shape;
81   (void)scratchpad_data;
82   // TODO enable optimized version
83   tflite::reference_integer_ops::ConvPerChannel(params, mult, shifts, input_shape, input_data,
84                                                 filter_shape, filter_data, bias_shape, bias_data,
85                                                 output_shape, output_data);
86 }
87
88 static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
89                                          const luci_interpreter::DataType &input_data_type,
90                                          const tflite::ConvParams &params,
91                                          const tflite::RuntimeShape &input_shape,
92                                          const tflite::RuntimeShape &filter_shape,
93                                          const tflite::RuntimeShape &output_shape)
94 {
95   const int32_t filter_height = filter_shape.Dims(1);
96   const int32_t filter_width = filter_shape.Dims(2);
97
98   // Allocate tensor for scratchpad, if needed.
99   // The checks here should be aligned with the actual implementation.
100   const bool need_dilated_scratchpad =
101     params.dilation_height_factor != 1 || params.dilation_width_factor != 1;
102   const bool need_non_dilated_scratchpad = params.stride_height != 1 || params.stride_width != 1 ||
103                                            filter_height != 1 || filter_width != 1;
104   auto _need_scratchpad = input_data_type != luci_interpreter::DataType::S16 &&
105                           (need_dilated_scratchpad || need_non_dilated_scratchpad);
106
107   if (_need_scratchpad)
108   {
109     const int32_t batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
110     const int32_t input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
111     const int32_t output_height = output_shape.Dims(1);
112     const int32_t output_width = output_shape.Dims(2);
113
114     auto data_type_size = static_cast<int32_t>(luci_interpreter::getDataTypeSize(input_data_type));
115     int32_t scratchpad_size = batches * output_width * output_height * input_depth * filter_height *
116                               filter_width * data_type_size;
117     luci_interpreter::Shape scratchpad_shape{scratchpad_size};
118     scratchpad->resize(scratchpad_shape);
119   }
120   else
121   {
122     scratchpad->set_allocatable(false);
123   }
124 }
125
126 } // namespace luci_interpreter_pal
127
128 #endif // LUCI_INTERPRETER_PAL_CONV2D_H