2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef LUCI_INTERPRETER_PAL_SVDF_H
19 #define LUCI_INTERPRETER_PAL_SVDF_H
21 #include <arm_nn_types.h>
22 #include <arm_nnfunctions.h>
24 namespace luci_interpreter_pal
27 IntegerSVDF(const TfLiteSVDFParams ¶ms, const tflite::RuntimeShape &input_shape,
28 const int8_t *input_data, const tflite::RuntimeShape &weight_feature_shape,
29 const int8_t *weight_feature_data, const tflite::RuntimeShape &weight_time_shape,
30 const int16_t *weight_time_data, const tflite::RuntimeShape &bias_shape,
31 const int32_t *bias_data, int16_t *activation_state_data,
32 const tflite::RuntimeShape &output_shape, int8_t *output_data, int32_t *scratchpad_data,
33 int32_t *output_temp_data, int32_t scale_1_a, int scale_1_b, int32_t scale_2_a,
34 int scale_2_b, int32_t input_zp, int32_t output_zp)
36 const int32_t rank = params.rank;
37 const int32_t batch_size = input_shape.Dims(0);
38 const int32_t num_filters = weight_feature_shape.Dims(0);
39 const int32_t memory_size = weight_time_shape.Dims(1);
41 cmsis_nn_dims input_dims;
42 input_dims.n = input_shape.Dims(0);
43 input_dims.h = input_shape.Dims(1);
45 cmsis_nn_dims weights_feature_dims;
46 weights_feature_dims.n = weight_feature_shape.Dims(0);
47 weights_feature_dims.h = weight_feature_shape.Dims(1);
49 cmsis_nn_dims weights_time_dims;
50 weights_time_dims.n = weight_time_shape.Dims(0);
51 weights_time_dims.h = weight_time_shape.Dims(1);
53 cmsis_nn_dims bias_dims;
54 bias_dims.n = bias_shape.Dims(0);
56 cmsis_nn_dims state_dims;
57 state_dims.n = batch_size;
58 state_dims.h = memory_size * num_filters;
60 cmsis_nn_dims output_dims;
61 output_dims.n = output_shape.Dims(0);
62 output_dims.h = output_shape.Dims(1);
64 cmsis_nn_svdf_params svdf_params;
65 svdf_params.rank = params.rank;
66 svdf_params.input_offset = input_zp;
67 svdf_params.output_offset = output_zp;
69 svdf_params.input_activation.min = INT16_MIN;
70 svdf_params.input_activation.max = INT16_MAX;
72 svdf_params.output_activation.min = INT8_MIN;
73 svdf_params.output_activation.max = INT8_MAX;
75 cmsis_nn_per_tensor_quant_params in_quant_params;
76 in_quant_params.multiplier = scale_1_a;
77 in_quant_params.shift = scale_1_b;
79 cmsis_nn_per_tensor_quant_params out_quant_params;
80 out_quant_params.multiplier = scale_2_a;
81 out_quant_params.shift = scale_2_b;
83 cmsis_nn_context scratch_ctx;
84 scratch_ctx.buf = scratchpad_data;
86 cmsis_nn_context scratch_output_ctx;
87 scratch_output_ctx.buf = output_temp_data;
89 arm_svdf_s8(&scratch_ctx, &scratch_output_ctx, &svdf_params, &in_quant_params, &out_quant_params,
90 &input_dims, input_data, &state_dims, activation_state_data, &weights_feature_dims,
91 weight_feature_data, &weights_time_dims, weight_time_data, &bias_dims, bias_data,
92 &output_dims, output_data);
95 FloatSVDF(const TfLiteSVDFParams ¶ms, const tflite::RuntimeShape &input_shape,
96 const float *input_data, const tflite::RuntimeShape &weight_feature_shape,
97 const float *weight_feature_data, const tflite::RuntimeShape &weight_time_shape,
98 const float *weight_time_data, const tflite::RuntimeShape &bias_shape,
99 const float *bias_data, float *scratchpad_data, float *activation_state_data,
100 const tflite::RuntimeShape &output_shape, float *output_data)
102 const int32_t rank = params.rank;
103 const int32_t batch_size = input_shape.Dims(0);
104 const int32_t input_size = input_shape.Dims(1);
105 const int32_t num_filters = weight_feature_shape.Dims(0);
106 const int32_t num_units = num_filters / rank;
107 const int32_t memory_size = weight_time_shape.Dims(1);
109 // Left shift the activation_state.
111 float *new_state_start = activation_state_data;
112 const float *old_state_start = activation_state_data + 1;
113 const float *old_state_end = activation_state_data + batch_size * num_filters * memory_size;
114 while (old_state_start != old_state_end)
116 *new_state_start++ = *old_state_start++;
120 // Note: no need to clear the latest activation, matmul is not accumulative.
122 // Compute conv1d(inputs, weights_feature).
123 // The activation_state's rightmost column is used to save current cycle
124 // activation. This is achieved by starting at state_ptr[memory_size - 1] and
125 // having the stride equal to memory_size.
127 // Perform batched matrix vector multiply operation:
129 const float *matrix = weight_feature_data;
130 const float *vector = input_data;
131 float *result = &activation_state_data[memory_size - 1];
132 float *result_in_batch = result;
133 for (int i = 0; i < batch_size; ++i)
135 const float *matrix_ptr = matrix;
136 for (int j = 0; j < num_filters; ++j)
138 float dot_prod = 0.0f;
139 const float *vector_in_batch = vector + i * input_size;
140 for (int k = 0; k < input_size; ++k)
142 dot_prod += *matrix_ptr++ * *vector_in_batch++;
144 *result_in_batch = dot_prod;
145 result_in_batch += memory_size;
150 tflite::reference_ops::ApplyTimeWeightsBiasAndActivation(
151 batch_size, memory_size, num_filters, num_units, rank, weight_time_data, bias_data,
152 params.activation, activation_state_data, scratchpad_data, output_data);
155 static inline void SetupScratchpadTensor(
156 const luci_interpreter::DataType &input_data_type,
157 const luci_interpreter::DataType &weight_feature_data_type,
158 luci_interpreter::Tensor *scratchpad_1, luci_interpreter::Tensor *scratchpad_2,
159 luci_interpreter::Tensor *scratchpad_3, luci_interpreter::Tensor *scratchpad_4,
160 luci_interpreter::Tensor *scratchpad_5, luci_interpreter::Tensor *scratchpad_6,
161 const luci_interpreter::Shape input_shape, const luci_interpreter::Shape weight_time_shape,
162 const int32_t batch_size, const int32_t num_filters, const int32_t num_units)
164 if (input_data_type == luci_interpreter::DataType::FLOAT32 &&
165 (weight_feature_data_type == luci_interpreter::DataType::S8 ||
166 weight_feature_data_type == luci_interpreter::DataType::U8))
169 (void)weight_time_shape;
175 assert(false && "Hybrid type is not supported for cmsisnn");
178 // Resize scratchpad_1 tensor
179 scratchpad_1->resize({batch_size, num_filters});
181 if (input_data_type == luci_interpreter::DataType::S8)
183 // Resize scratchpad_2 for full_integer op
184 scratchpad_2->resize({batch_size, num_units});
188 } // namespace luci_interpreter_pal
190 #endif // LUCI_INTERPRETER_PAL_SVDF_H