2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef LUCI_INTERPRETER_PAL_CONV2D_H
19 #define LUCI_INTERPRETER_PAL_CONV2D_H
21 #include <tensorflow/lite/kernels/internal/reference/conv.h>
22 #include <tensorflow/lite/kernels/internal/reference/integer_ops/conv.h>
24 namespace luci_interpreter_pal
26 static inline void Conv(const tflite::ConvParams ¶ms, const tflite::RuntimeShape &input_shape,
27 const float *input_data, const tflite::RuntimeShape &filter_shape,
28 const float *filter_data, const tflite::RuntimeShape &bias_shape,
29 const float *bias_data, const tflite::RuntimeShape &output_shape,
30 float *output_data, const tflite::RuntimeShape &scratchpad_shape,
31 float *scratchpad_data)
33 (void)scratchpad_shape;
34 (void)scratchpad_data;
35 tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
36 bias_shape, bias_data, output_shape, output_data,
37 tflite::RuntimeShape(), nullptr);
40 static inline void Conv(const tflite::ConvParams ¶ms, const tflite::RuntimeShape &input_shape,
41 const uint8_t *input_data, const tflite::RuntimeShape &filter_shape,
42 const uint8_t *filter_data, const tflite::RuntimeShape &bias_shape,
43 const int32_t *bias_data, const tflite::RuntimeShape &output_shape,
44 uint8_t *output_data, const tflite::RuntimeShape &scratchpad_shape,
45 uint8_t *scratchpad_data)
47 (void)scratchpad_shape;
48 (void)scratchpad_data;
49 tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
50 bias_shape, bias_data, output_shape, output_data, scratchpad_shape,
51 scratchpad_data, nullptr);
55 ConvPerChannel(const tflite::ConvParams ¶ms, const int32_t *mult, const int32_t *shifts,
56 const tflite::RuntimeShape &input_shape, const int8_t *input_data,
57 const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
58 const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
59 const tflite::RuntimeShape &output_shape, int8_t *output_data,
60 const tflite::RuntimeShape &scratchpad_shape, int8_t *scratchpad_data)
62 (void)scratchpad_shape;
63 (void)scratchpad_data;
64 tflite::reference_integer_ops::ConvPerChannel(params, mult, shifts, input_shape, input_data,
65 filter_shape, filter_data, bias_shape, bias_data,
66 output_shape, output_data);
69 static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
70 const luci_interpreter::DataType &input_data_type,
71 const tflite::ConvParams ¶ms,
72 const tflite::RuntimeShape &input_shape,
73 const tflite::RuntimeShape &filter_shape,
74 const tflite::RuntimeShape &output_shape)
76 (void)input_data_type;
84 } // namespace luci_interpreter_pal
86 #endif // LUCI_INTERPRETER_PAL_CONV2D_H