2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef LUCI_INTERPRETER_PAL_SVDF_H
19 #define LUCI_INTERPRETER_PAL_SVDF_H
21 #include <tensorflow/lite/kernels/internal/reference/svdf.h>
23 namespace luci_interpreter_pal
26 IntegerSVDF(const TfLiteSVDFParams ¶ms, const tflite::RuntimeShape &input_shape,
27 const int8_t *input_data, const tflite::RuntimeShape &weight_feature_shape,
28 const int8_t *weight_feature_data, const tflite::RuntimeShape &weight_time_shape,
29 const int16_t *weight_time_data, const tflite::RuntimeShape &bias_shape,
30 const int32_t *bias_data, int16_t *activation_state_data,
31 const tflite::RuntimeShape &output_shape, int8_t *output_data, int32_t *scratchpad_data,
32 int32_t *output_temp_data, int32_t scale_1_a, int scale_1_b, int32_t scale_2_a,
33 int scale_2_b, int32_t input_zp, int32_t output_zp)
35 const int n_rank = params.rank;
36 const int n_batch = input_shape.Dims(0);
37 const int n_input = input_shape.Dims(1);
38 const int n_filter = weight_feature_shape.Dims(0);
39 const int n_unit = n_filter / n_rank;
40 const int n_memory = weight_time_shape.Dims(1);
42 // Left shift the activation_state.
44 int16_t *new_state_start = activation_state_data;
45 const int16_t *old_state_start = activation_state_data + 1;
46 const int16_t *old_state_end = activation_state_data + n_batch * n_filter * n_memory;
47 while (old_state_start != old_state_end)
49 *new_state_start++ = *old_state_start++;
53 // Note: no need to clear the latest activation, matmul is not accumulative.
57 const int32_t output_max = std::numeric_limits<int16_t>::max();
58 const int32_t output_min = std::numeric_limits<int16_t>::min();
59 int16_t *result_in_batch = activation_state_data + (n_memory - 1);
60 for (int b = 0; b < n_batch; b++)
62 const int8_t *matrix_ptr = weight_feature_data;
63 for (int r = 0; r < n_filter; r++)
66 const int8_t *vector_in_batch = input_data + b * n_input;
67 for (int c = 0; c < n_input; c++)
69 dot_prod += *matrix_ptr++ * (*vector_in_batch++ - input_zp);
71 dot_prod = tflite::MultiplyByQuantizedMultiplier(dot_prod, scale_1_a, scale_1_b);
72 dot_prod = std::min(std::max(output_min, dot_prod), output_max);
73 // This assumes state is symmetrically quantized. Otherwise last bit of
74 // state should be initialized to its zero point and accumulate the
76 // Equivalent as the following:
77 // result_in_batch = zero point, which happens to be zero.
78 // result_in_batch += dot_prod_56.
79 *result_in_batch = dot_prod;
80 result_in_batch += n_memory;
87 for (int b = 0; b < n_batch; ++b)
89 int32_t *scratch_ptr_batch = scratchpad_data + b * n_filter;
91 // Perform batched vector dot product:
92 const int16_t *vector1_ptr = weight_time_data;
93 const int16_t *vector2_ptr = activation_state_data + b * n_memory * n_filter;
95 for (int i = 0; i < n_filter; i++)
97 *scratch_ptr_batch = 0;
98 for (int j = 0; j < n_memory; j++)
100 *scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
107 // Reduce, add bias, rescale, activation.
112 // Vector batch assign:
113 for (int i = 0; i < n_batch; ++i)
115 int32_t *output_ptr = output_temp_data + i * n_unit;
116 const int32_t *bias_ptr = bias_data;
117 for (int j = 0; j < n_unit; ++j)
119 *output_ptr++ = *bias_ptr++;
125 int32_t *output_ptr = output_temp_data;
126 for (int i = 0; i < n_batch * n_unit; ++i)
133 for (int b = 0; b < n_batch; ++b)
135 int32_t *output_temp_ptr = output_temp_data + b * n_unit;
136 int32_t *scratch_ptr_batch = scratchpad_data + b * n_filter;
138 // Reduction sum vector
139 for (int i = 0; i < n_unit; ++i)
141 for (int j = 0; j < n_rank; ++j)
143 output_temp_ptr[i] += *scratch_ptr_batch++;
149 const int32_t output_max = std::numeric_limits<int8_t>::max();
150 const int32_t output_min = std::numeric_limits<int8_t>::min();
151 for (int i = 0; i < n_batch * n_unit; ++i)
153 int32_t x1 = output_temp_data[i];
154 int32_t x2 = tflite::MultiplyByQuantizedMultiplier(x1, scale_2_a, scale_2_b);
155 int32_t x3 = x2 + output_zp;
156 int32_t x4 = std::min(std::max(output_min, x3), output_max);
157 output_data[i] = static_cast<int8_t>(x4);
162 FloatSVDF(const TfLiteSVDFParams ¶ms, const tflite::RuntimeShape &input_shape,
163 const float *input_data, const tflite::RuntimeShape &weight_feature_shape,
164 const float *weight_feature_data, const tflite::RuntimeShape &weight_time_shape,
165 const float *weight_time_data, const tflite::RuntimeShape &bias_shape,
166 const float *bias_data, float *scratchpad_data, float *activation_state_data,
167 const tflite::RuntimeShape &output_shape, float *output_data)
169 const int32_t rank = params.rank;
170 const int32_t batch_size = input_shape.Dims(0);
171 const int32_t input_size = input_shape.Dims(1);
172 const int32_t num_filters = weight_feature_shape.Dims(0);
173 const int32_t num_units = num_filters / rank;
174 const int32_t memory_size = weight_time_shape.Dims(1);
176 // Left shift the activation_state.
178 float *new_state_start = activation_state_data;
179 const float *old_state_start = activation_state_data + 1;
180 const float *old_state_end = activation_state_data + batch_size * num_filters * memory_size;
181 while (old_state_start != old_state_end)
183 *new_state_start++ = *old_state_start++;
187 // Note: no need to clear the latest activation, matmul is not accumulative.
189 // Compute conv1d(inputs, weights_feature).
190 // The activation_state's rightmost column is used to save current cycle
191 // activation. This is achieved by starting at state_ptr[memory_size - 1] and
192 // having the stride equal to memory_size.
194 // Perform batched matrix vector multiply operation:
196 const float *matrix = weight_feature_data;
197 const float *vector = input_data;
198 float *result = &activation_state_data[memory_size - 1];
199 float *result_in_batch = result;
200 for (int i = 0; i < batch_size; ++i)
202 const float *matrix_ptr = matrix;
203 for (int j = 0; j < num_filters; ++j)
205 float dot_prod = 0.0f;
206 const float *vector_in_batch = vector + i * input_size;
207 for (int k = 0; k < input_size; ++k)
209 dot_prod += *matrix_ptr++ * *vector_in_batch++;
211 *result_in_batch = dot_prod;
212 result_in_batch += memory_size;
217 tflite::reference_ops::ApplyTimeWeightsBiasAndActivation(
218 batch_size, memory_size, num_filters, num_units, rank, weight_time_data, bias_data,
219 params.activation, activation_state_data, scratchpad_data, output_data);
222 static inline void SetupScratchpadTensor(
223 const luci_interpreter::DataType &input_data_type,
224 const luci_interpreter::DataType &weight_feature_data_type,
225 luci_interpreter::Tensor *scratchpad_1, luci_interpreter::Tensor *scratchpad_2,
226 luci_interpreter::Tensor *scratchpad_3, luci_interpreter::Tensor *scratchpad_4,
227 luci_interpreter::Tensor *scratchpad_5, luci_interpreter::Tensor *scratchpad_6,
228 const luci_interpreter::Shape input_shape, const luci_interpreter::Shape weight_time_shape,
229 const int32_t batch_size, const int32_t num_filters, const int32_t num_units)
232 if (input_data_type == luci_interpreter::DataType::FLOAT32 &&
233 (weight_feature_data_type == luci_interpreter::DataType::S8 ||
234 weight_feature_data_type == luci_interpreter::DataType::U8))
237 (void)weight_time_shape;
243 assert(false && "Hybrid type is not currently supported for mcu platform");
246 // Resize scratchpad_1 tensor
247 scratchpad_1->resize({batch_size, num_filters});
249 if (input_data_type == luci_interpreter::DataType::S8)
251 // Resize scratchpad_2 for full_integer op
252 scratchpad_2->resize({batch_size, num_units});
256 } // namespace luci_interpreter_pal
258 #endif // LUCI_INTERPRETER_PAL_SVDF_H