2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef LUCI_INTERPRETER_PAL_SVDF_H
19 #define LUCI_INTERPRETER_PAL_SVDF_H
21 #include <tensorflow/lite/kernels/internal/reference/svdf.h>
23 namespace luci_interpreter_pal
26 IntegerSVDF(const TfLiteSVDFParams ¶ms, const tflite::RuntimeShape &input_shape,
27 const int8_t *input_data, const tflite::RuntimeShape &weight_feature_shape,
28 const int8_t *weight_feature_data, const tflite::RuntimeShape &weight_time_shape,
29 const int16_t *weight_time_data, const tflite::RuntimeShape &bias_shape,
30 const int32_t *bias_data, int16_t *activation_state_data,
31 const tflite::RuntimeShape &output_shape, int8_t *output_data, int32_t *scratchpad_data,
32 int32_t *output_temp_data, int32_t scale_1_a, int scale_1_b, int32_t scale_2_a,
33 int scale_2_b, int32_t input_zp, int32_t output_zp)
35 tflite::reference_ops::EvalIntegerSVDF(¶ms, input_shape, input_data, weight_feature_shape,
36 weight_feature_data, weight_time_shape, weight_time_data,
37 bias_shape, bias_data, activation_state_data, output_shape,
38 output_data, scratchpad_data, output_temp_data, scale_1_a,
39 scale_1_b, scale_2_a, scale_2_b, input_zp, output_zp);
42 FloatSVDF(const TfLiteSVDFParams ¶ms, const tflite::RuntimeShape &input_shape,
43 const float *input_data, const tflite::RuntimeShape &weight_feature_shape,
44 const float *weight_feature_data, const tflite::RuntimeShape &weight_time_shape,
45 const float *weight_time_data, const tflite::RuntimeShape &bias_shape,
46 const float *bias_data, float *scratchpad_data, float *activation_state_data,
47 const tflite::RuntimeShape &output_shape, float *output_data)
49 tflite::reference_ops::EvalFloatSVDF(¶ms, input_shape, input_data, weight_feature_shape,
50 weight_feature_data, weight_time_shape, weight_time_data,
51 bias_shape, bias_data, scratchpad_data,
52 activation_state_data, output_shape, output_data);
55 static inline void SetupScratchpadTensor(
56 const luci_interpreter::DataType &input_data_type,
57 const luci_interpreter::DataType &weight_feature_data_type,
58 luci_interpreter::Tensor *scratchpad_1, luci_interpreter::Tensor *scratchpad_2,
59 luci_interpreter::Tensor *scratchpad_3, luci_interpreter::Tensor *scratchpad_4,
60 luci_interpreter::Tensor *scratchpad_5, luci_interpreter::Tensor *scratchpad_6,
61 const luci_interpreter::Shape input_shape, const luci_interpreter::Shape weight_time_shape,
62 const int32_t batch_size, const int32_t num_filters, const int32_t num_units)
65 if (input_data_type == luci_interpreter::DataType::FLOAT32 &&
66 (weight_feature_data_type == luci_interpreter::DataType::S8 ||
67 weight_feature_data_type == luci_interpreter::DataType::U8))
70 (void)weight_time_shape;
76 assert(false && "Hybrid type is not currently supported for linux platform");
79 // Resize scratchpad_1 tensor
80 scratchpad_1->resize({batch_size, num_filters});
82 if (input_data_type == luci_interpreter::DataType::S8)
84 // Resize scratchpad_2 for full_integer op
85 scratchpad_2->resize({batch_size, num_units});
89 } // namespace luci_interpreter_pal
91 #endif // LUCI_INTERPRETER_PAL_SVDF_H