Imported Upstream version 1.22.1
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / linux / PALSVDF.h
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_PAL_SVDF_H
19 #define LUCI_INTERPRETER_PAL_SVDF_H
20
21 #include <tensorflow/lite/kernels/internal/reference/svdf.h>
22
23 namespace luci_interpreter_pal
24 {
25 static inline void
26 IntegerSVDF(const TfLiteSVDFParams &params, const tflite::RuntimeShape &input_shape,
27             const int8_t *input_data, const tflite::RuntimeShape &weight_feature_shape,
28             const int8_t *weight_feature_data, const tflite::RuntimeShape &weight_time_shape,
29             const int16_t *weight_time_data, const tflite::RuntimeShape &bias_shape,
30             const int32_t *bias_data, int16_t *activation_state_data,
31             const tflite::RuntimeShape &output_shape, int8_t *output_data, int32_t *scratchpad_data,
32             int32_t *output_temp_data, int32_t scale_1_a, int scale_1_b, int32_t scale_2_a,
33             int scale_2_b, int32_t input_zp, int32_t output_zp)
34 {
35   tflite::reference_ops::EvalIntegerSVDF(&params, input_shape, input_data, weight_feature_shape,
36                                          weight_feature_data, weight_time_shape, weight_time_data,
37                                          bias_shape, bias_data, activation_state_data, output_shape,
38                                          output_data, scratchpad_data, output_temp_data, scale_1_a,
39                                          scale_1_b, scale_2_a, scale_2_b, input_zp, output_zp);
40 }
41 static inline void
42 FloatSVDF(const TfLiteSVDFParams &params, const tflite::RuntimeShape &input_shape,
43           const float *input_data, const tflite::RuntimeShape &weight_feature_shape,
44           const float *weight_feature_data, const tflite::RuntimeShape &weight_time_shape,
45           const float *weight_time_data, const tflite::RuntimeShape &bias_shape,
46           const float *bias_data, float *scratchpad_data, float *activation_state_data,
47           const tflite::RuntimeShape &output_shape, float *output_data)
48 {
49   tflite::reference_ops::EvalFloatSVDF(&params, input_shape, input_data, weight_feature_shape,
50                                        weight_feature_data, weight_time_shape, weight_time_data,
51                                        bias_shape, bias_data, scratchpad_data,
52                                        activation_state_data, output_shape, output_data);
53 }
54
55 static inline void SetupScratchpadTensor(
56   const luci_interpreter::DataType &input_data_type,
57   const luci_interpreter::DataType &weight_feature_data_type,
58   luci_interpreter::Tensor *scratchpad_1, luci_interpreter::Tensor *scratchpad_2,
59   luci_interpreter::Tensor *scratchpad_3, luci_interpreter::Tensor *scratchpad_4,
60   luci_interpreter::Tensor *scratchpad_5, luci_interpreter::Tensor *scratchpad_6,
61   const luci_interpreter::Shape input_shape, const luci_interpreter::Shape weight_time_shape,
62   const int32_t batch_size, const int32_t num_filters, const int32_t num_units)
63 {
64
65   if (input_data_type == luci_interpreter::DataType::FLOAT32 &&
66       (weight_feature_data_type == luci_interpreter::DataType::S8 ||
67        weight_feature_data_type == luci_interpreter::DataType::U8))
68   {
69     (void)input_shape;
70     (void)weight_time_shape;
71     (void)scratchpad_3;
72     (void)scratchpad_4;
73     (void)scratchpad_5;
74     (void)scratchpad_6;
75
76     assert(false && "Hybrid type is not currently supported for linux platform");
77   }
78
79   // Resize scratchpad_1 tensor
80   scratchpad_1->resize({batch_size, num_filters});
81
82   if (input_data_type == luci_interpreter::DataType::S8)
83   {
84     // Resize scratchpad_2 for full_integer op
85     scratchpad_2->resize({batch_size, num_units});
86   }
87 }
88
89 } // namespace luci_interpreter_pal
90
91 #endif // LUCI_INTERPRETER_PAL_SVDF_H