Imported Upstream version 1.22.1
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / mcu / PALSVDF.h
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_PAL_SVDF_H
19 #define LUCI_INTERPRETER_PAL_SVDF_H
20
21 #include <tensorflow/lite/kernels/internal/reference/svdf.h>
22
23 namespace luci_interpreter_pal
24 {
25 static inline void
26 IntegerSVDF(const TfLiteSVDFParams &params, const tflite::RuntimeShape &input_shape,
27             const int8_t *input_data, const tflite::RuntimeShape &weight_feature_shape,
28             const int8_t *weight_feature_data, const tflite::RuntimeShape &weight_time_shape,
29             const int16_t *weight_time_data, const tflite::RuntimeShape &bias_shape,
30             const int32_t *bias_data, int16_t *activation_state_data,
31             const tflite::RuntimeShape &output_shape, int8_t *output_data, int32_t *scratchpad_data,
32             int32_t *output_temp_data, int32_t scale_1_a, int scale_1_b, int32_t scale_2_a,
33             int scale_2_b, int32_t input_zp, int32_t output_zp)
34 {
35   const int n_rank = params.rank;
36   const int n_batch = input_shape.Dims(0);
37   const int n_input = input_shape.Dims(1);
38   const int n_filter = weight_feature_shape.Dims(0);
39   const int n_unit = n_filter / n_rank;
40   const int n_memory = weight_time_shape.Dims(1);
41
42   // Left shift the activation_state.
43   {
44     int16_t *new_state_start = activation_state_data;
45     const int16_t *old_state_start = activation_state_data + 1;
46     const int16_t *old_state_end = activation_state_data + n_batch * n_filter * n_memory;
47     while (old_state_start != old_state_end)
48     {
49       *new_state_start++ = *old_state_start++;
50     }
51   }
52
53   // Note: no need to clear the latest activation, matmul is not accumulative.
54
55   // Feature matmul.
56   {
57     const int32_t output_max = std::numeric_limits<int16_t>::max();
58     const int32_t output_min = std::numeric_limits<int16_t>::min();
59     int16_t *result_in_batch = activation_state_data + (n_memory - 1);
60     for (int b = 0; b < n_batch; b++)
61     {
62       const int8_t *matrix_ptr = weight_feature_data;
63       for (int r = 0; r < n_filter; r++)
64       {
65         int32_t dot_prod = 0;
66         const int8_t *vector_in_batch = input_data + b * n_input;
67         for (int c = 0; c < n_input; c++)
68         {
69           dot_prod += *matrix_ptr++ * (*vector_in_batch++ - input_zp);
70         }
71         dot_prod = tflite::MultiplyByQuantizedMultiplier(dot_prod, scale_1_a, scale_1_b);
72         dot_prod = std::min(std::max(output_min, dot_prod), output_max);
73         // This assumes state is symmetrically quantized. Otherwise last bit of
74         // state should be initialized to its zero point and accumulate the
75         // dot_prod.
76         // Equivalent as the following:
77         //     result_in_batch = zero point, which happens to be zero.
78         //     result_in_batch += dot_prod_56.
79         *result_in_batch = dot_prod;
80         result_in_batch += n_memory;
81       }
82     }
83   }
84
85   // Time.
86   {
87     for (int b = 0; b < n_batch; ++b)
88     {
89       int32_t *scratch_ptr_batch = scratchpad_data + b * n_filter;
90
91       // Perform batched vector dot product:
92       const int16_t *vector1_ptr = weight_time_data;
93       const int16_t *vector2_ptr = activation_state_data + b * n_memory * n_filter;
94
95       for (int i = 0; i < n_filter; i++)
96       {
97         *scratch_ptr_batch = 0;
98         for (int j = 0; j < n_memory; j++)
99         {
100           *scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++;
101         }
102         scratch_ptr_batch++;
103       }
104     }
105   }
106
107   // Reduce, add bias, rescale, activation.
108   {
109     // Add bias.
110     if (bias_data)
111     {
112       // Vector batch assign:
113       for (int i = 0; i < n_batch; ++i)
114       {
115         int32_t *output_ptr = output_temp_data + i * n_unit;
116         const int32_t *bias_ptr = bias_data;
117         for (int j = 0; j < n_unit; ++j)
118         {
119           *output_ptr++ = *bias_ptr++;
120         }
121       }
122     }
123     else
124     {
125       int32_t *output_ptr = output_temp_data;
126       for (int i = 0; i < n_batch * n_unit; ++i)
127       {
128         *output_ptr++ = 0;
129       }
130     }
131
132     // Reduce.
133     for (int b = 0; b < n_batch; ++b)
134     {
135       int32_t *output_temp_ptr = output_temp_data + b * n_unit;
136       int32_t *scratch_ptr_batch = scratchpad_data + b * n_filter;
137
138       // Reduction sum vector
139       for (int i = 0; i < n_unit; ++i)
140       {
141         for (int j = 0; j < n_rank; ++j)
142         {
143           output_temp_ptr[i] += *scratch_ptr_batch++;
144         }
145       }
146     }
147
148     // Rescale.
149     const int32_t output_max = std::numeric_limits<int8_t>::max();
150     const int32_t output_min = std::numeric_limits<int8_t>::min();
151     for (int i = 0; i < n_batch * n_unit; ++i)
152     {
153       int32_t x1 = output_temp_data[i];
154       int32_t x2 = tflite::MultiplyByQuantizedMultiplier(x1, scale_2_a, scale_2_b);
155       int32_t x3 = x2 + output_zp;
156       int32_t x4 = std::min(std::max(output_min, x3), output_max);
157       output_data[i] = static_cast<int8_t>(x4);
158     }
159   }
160 }
161 static inline void
162 FloatSVDF(const TfLiteSVDFParams &params, const tflite::RuntimeShape &input_shape,
163           const float *input_data, const tflite::RuntimeShape &weight_feature_shape,
164           const float *weight_feature_data, const tflite::RuntimeShape &weight_time_shape,
165           const float *weight_time_data, const tflite::RuntimeShape &bias_shape,
166           const float *bias_data, float *scratchpad_data, float *activation_state_data,
167           const tflite::RuntimeShape &output_shape, float *output_data)
168 {
169   const int32_t rank = params.rank;
170   const int32_t batch_size = input_shape.Dims(0);
171   const int32_t input_size = input_shape.Dims(1);
172   const int32_t num_filters = weight_feature_shape.Dims(0);
173   const int32_t num_units = num_filters / rank;
174   const int32_t memory_size = weight_time_shape.Dims(1);
175
176   // Left shift the activation_state.
177   {
178     float *new_state_start = activation_state_data;
179     const float *old_state_start = activation_state_data + 1;
180     const float *old_state_end = activation_state_data + batch_size * num_filters * memory_size;
181     while (old_state_start != old_state_end)
182     {
183       *new_state_start++ = *old_state_start++;
184     }
185   }
186
187   // Note: no need to clear the latest activation, matmul is not accumulative.
188
189   // Compute conv1d(inputs, weights_feature).
190   // The activation_state's rightmost column is used to save current cycle
191   // activation. This is achieved by starting at state_ptr[memory_size - 1] and
192   // having the stride equal to memory_size.
193
194   // Perform batched matrix vector multiply operation:
195   {
196     const float *matrix = weight_feature_data;
197     const float *vector = input_data;
198     float *result = &activation_state_data[memory_size - 1];
199     float *result_in_batch = result;
200     for (int i = 0; i < batch_size; ++i)
201     {
202       const float *matrix_ptr = matrix;
203       for (int j = 0; j < num_filters; ++j)
204       {
205         float dot_prod = 0.0f;
206         const float *vector_in_batch = vector + i * input_size;
207         for (int k = 0; k < input_size; ++k)
208         {
209           dot_prod += *matrix_ptr++ * *vector_in_batch++;
210         }
211         *result_in_batch = dot_prod;
212         result_in_batch += memory_size;
213       }
214     }
215   }
216
217   tflite::reference_ops::ApplyTimeWeightsBiasAndActivation(
218     batch_size, memory_size, num_filters, num_units, rank, weight_time_data, bias_data,
219     params.activation, activation_state_data, scratchpad_data, output_data);
220 }
221
222 static inline void SetupScratchpadTensor(
223   const luci_interpreter::DataType &input_data_type,
224   const luci_interpreter::DataType &weight_feature_data_type,
225   luci_interpreter::Tensor *scratchpad_1, luci_interpreter::Tensor *scratchpad_2,
226   luci_interpreter::Tensor *scratchpad_3, luci_interpreter::Tensor *scratchpad_4,
227   luci_interpreter::Tensor *scratchpad_5, luci_interpreter::Tensor *scratchpad_6,
228   const luci_interpreter::Shape input_shape, const luci_interpreter::Shape weight_time_shape,
229   const int32_t batch_size, const int32_t num_filters, const int32_t num_units)
230 {
231
232   if (input_data_type == luci_interpreter::DataType::FLOAT32 &&
233       (weight_feature_data_type == luci_interpreter::DataType::S8 ||
234        weight_feature_data_type == luci_interpreter::DataType::U8))
235   {
236     (void)input_shape;
237     (void)weight_time_shape;
238     (void)scratchpad_3;
239     (void)scratchpad_4;
240     (void)scratchpad_5;
241     (void)scratchpad_6;
242
243     assert(false && "Hybrid type is not currently supported for mcu platform");
244   }
245
246   // Resize scratchpad_1 tensor
247   scratchpad_1->resize({batch_size, num_filters});
248
249   if (input_data_type == luci_interpreter::DataType::S8)
250   {
251     // Resize scratchpad_2 for full_integer op
252     scratchpad_2->resize({batch_size, num_units});
253   }
254 }
255
256 } // namespace luci_interpreter_pal
257
258 #endif // LUCI_INTERPRETER_PAL_SVDF_H