287f30314ea625f22cc8fcde5757a8278cf1bcaf
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / mcu / PALUnidirectionalSequenceLSTM.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_H
19 #define LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_H
20
21 #include "kernels/UnidirectionalSequenceLSTM.h"
22 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
23 #include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
24 #include "tensorflow/lite/kernels/internal/reference/logistic.h"
25 #include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
26 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
27
28 namespace luci_interpreter_pal
29 {
30 namespace lstm_internal
31 {
32 namespace
33 {
34 // Possible fused activation functions.
35 typedef enum
36 {
37   kTfLiteActNone = 0,
38   kTfLiteActRelu,
39   kTfLiteActReluN1To1, // min(max(-1, x), 1)
40   kTfLiteActRelu6,     // min(max(0, x), 6)
41   kTfLiteActTanh,
42   kTfLiteActSignBit,
43   kTfLiteActSigmoid,
44 } TfLiteFusedActivation;
45
46 } // namespace
47
48 template <typename T>
49 inline T activationFunctionWithMinMax(T x, T output_activation_min, T output_activation_max)
50 {
51   using std::max;
52   using std::min;
53   return min(max(x, output_activation_min), output_activation_max);
54 }
55
56 template <typename T>
57 inline void mul(const luci_interpreter::lstm::ArithmeticParams *params,
58                 const tflite::RuntimeShape &input1_shape, const T *input1_data,
59                 const tflite::RuntimeShape &input2_shape, const T *input2_data,
60                 const tflite::RuntimeShape &output_shape, T *output_data)
61 {
62   T output_activation_min = params->quantized_activation_min;
63   T output_activation_max = params->quantized_activation_max;
64
65   const int flat_size = input1_shape.FlatSize();
66   for (int i = 0; i < flat_size; ++i)
67   {
68     output_data[i] = activationFunctionWithMinMax(input1_data[i] * input2_data[i],
69                                                   output_activation_min, output_activation_max);
70   }
71 }
72
73 #ifndef DIS_QUANT
74 inline int32_t multiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift)
75 {
76   using gemmlowp::RoundingDivideByPOT;
77   using gemmlowp::SaturatingRoundingDoublingHighMul;
78   int left_shift = shift > 0 ? shift : 0;
79   int right_shift = shift > 0 ? 0 : -shift;
80   return RoundingDivideByPOT(
81     SaturatingRoundingDoublingHighMul(x * (1 << left_shift), quantized_multiplier), right_shift);
82 }
83
84 template <typename InputType, typename WeightType, typename OutputType, typename BiasType>
85 void fullyConnectedInteger(const tflite::FullyConnectedParams &params,
86                            const tflite::RuntimeShape &input_shape, const InputType *input_data,
87                            const tflite::RuntimeShape &filter_shape, const WeightType *filter_data,
88                            const tflite::RuntimeShape &bias_shape, const BiasType *bias_data,
89                            const tflite::RuntimeShape &output_shape, OutputType *output_data)
90 {
91   const int32_t input_offset = params.input_offset;
92   const int32_t filter_offset = params.weights_offset;
93   const int32_t output_offset = params.output_offset;
94   const int32_t output_multiplier = params.output_multiplier;
95   const int output_shift = params.output_shift;
96   const int32_t output_activation_min = params.quantized_activation_min;
97   const int32_t output_activation_max = params.quantized_activation_max;
98   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
99   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
100
101   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
102   const int filter_dim_count = filter_shape.DimensionsCount();
103   const int output_dim_count = output_shape.DimensionsCount();
104   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
105   const int output_depth = output_shape.Dims(output_dim_count - 1);
106   TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2));
107   const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
108   for (int b = 0; b < batches; ++b)
109   {
110     for (int out_c = 0; out_c < output_depth; ++out_c)
111     {
112       BiasType acc = 0;
113       for (int d = 0; d < accum_depth; ++d)
114       {
115         int32_t input_val = input_data[b * accum_depth + d];
116         int32_t filter_val = filter_data[out_c * accum_depth + d];
117         acc += (filter_val + filter_offset) * (input_val + input_offset);
118       }
119       if (bias_data)
120       {
121         acc += bias_data[out_c];
122       }
123       int32_t acc_scaled = multiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
124       acc_scaled += output_offset;
125       acc_scaled = std::max(acc_scaled, output_activation_min);
126       acc_scaled = std::min(acc_scaled, output_activation_max);
127       output_data[out_c + output_depth * b] = static_cast<OutputType>(acc_scaled);
128     }
129   }
130 }
131
132 void fullyConnected(const tflite::FullyConnectedParams &params,
133                     const tflite::RuntimeShape &input_shape, const int8_t *input_data,
134                     const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
135                     const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
136                     const tflite::RuntimeShape &output_shape, int16_t *output_data)
137 {
138   return fullyConnectedInteger(params, input_shape, input_data, filter_shape, filter_data,
139                                bias_shape, bias_data, output_shape, output_data);
140 }
141
142 void fullyConnected(const tflite::FullyConnectedParams &params,
143                     const tflite::RuntimeShape &input_shape, const int16_t *input_data,
144                     const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
145                     const tflite::RuntimeShape &bias_shape, const int64_t *bias_data,
146                     const tflite::RuntimeShape &output_shape, int16_t *output_data)
147 {
148   return fullyConnectedInteger(params, input_shape, input_data, filter_shape, filter_data,
149                                bias_shape, bias_data, output_shape, output_data);
150 }
151
152 template <typename InputType, typename OutputType>
153 void mulElementwise(int size, const luci_interpreter::lstm::ArithmeticParams *params,
154                     const InputType *input1_data, const InputType *input2_data,
155                     OutputType *output_data)
156 {
157   for (int i = 0; i < size; ++i)
158   {
159     const int32_t input1_val = params->input1_offset + input1_data[i];
160     const int32_t input2_val = params->input2_offset + input2_data[i];
161     const int32_t unclamped_result =
162       params->output_offset + multiplyByQuantizedMultiplier(input1_val * input2_val,
163                                                             params->output_multiplier,
164                                                             params->output_shift);
165     const int32_t clamped_output =
166       std::min(params->quantized_activation_max,
167                std::max(params->quantized_activation_min, unclamped_result));
168     output_data[i] = static_cast<OutputType>(clamped_output);
169   }
170 }
171
172 // Input and output have the same shape in LSTM
173 void mul(const tflite::RuntimeShape &shape, const luci_interpreter::lstm::ArithmeticParams *params,
174          const int16_t *input1_data, const int16_t *input2_data, int8_t *output_data)
175 {
176   return mulElementwise<int16_t, int8_t>(shape.FlatSize(), params, input1_data, input2_data,
177                                          output_data);
178 }
179
180 // Input and output have the same shape in LSTM
181 void mul(const tflite::RuntimeShape &shape, const luci_interpreter::lstm::ArithmeticParams *params,
182          const int16_t *input1_data, const int16_t *input2_data, int16_t *output_data)
183 {
184   return mulElementwise(shape.FlatSize(), params, input1_data, input2_data, output_data);
185 }
186
187 void addElementWise(const int16_t *input_1, const int16_t *input_2, int n_batch, int n_input,
188                     int16_t *output)
189 {
190   for (int batch = 0; batch < n_batch; ++batch)
191   {
192     for (int i = 0; i < n_input; ++i)
193     {
194       const int index = batch * n_input + i;
195       int32_t sum = input_1[index] + input_2[index];
196       const int32_t sum_clamped =
197         std::min(static_cast<int32_t>(std::numeric_limits<int16_t>::max()),
198                  std::max(static_cast<int32_t>(std::numeric_limits<int16_t>::min()), sum));
199       output[index] = static_cast<int16_t>(sum_clamped);
200     }
201   }
202 }
203
204 void tanh(int32_t cell_state_scale_power, const tflite::RuntimeShape &input_data_shape,
205           int16_t *input_data, const tflite::RuntimeShape &output_data_shape, int16_t *output_data)
206 {
207   int32_t tanh_input_left_shift = (15 + cell_state_scale_power) - 3;
208   int32_t input_multiplier = 0;
209   if (tanh_input_left_shift < 0) /* handling negative shift value */
210   {
211     tanh_input_left_shift = -tanh_input_left_shift;
212     input_multiplier = 3;
213   }
214   tflite::reference_integer_ops::Tanh(input_multiplier, tanh_input_left_shift, input_data_shape,
215                                       input_data, output_data_shape, output_data);
216 }
217
218 void sigmoid(const tflite::RuntimeShape &data_shape, int16_t *data)
219 {
220   tflite::reference_integer_ops::Logistic(0 /*data->input_multiplier*/,
221                                           0 /*data->input_left_shift */,
222                                           data_shape.FlatSize() /*NumElements(input->dims)*/,
223                                           data /* tflite::micro::GetTensorData<int16_t>(input) */,
224                                           data /*tflite::micro::GetTensorData<int16_t>(output) */);
225 }
226
227 void clipping(const int v_size, const luci_interpreter::lstm::CellStateInfo *cell_state_info,
228               int16_t *vector)
229 {
230   for (int i = 0; i < v_size; i++)
231   {
232     vector[i] = std::max(std::min(cell_state_info->quantized_cell_clip, vector[i]),
233                          static_cast<int16_t>(-cell_state_info->quantized_cell_clip));
234   }
235 }
236 #endif // DIS_QUANT
237
238 #ifndef DIS_FLOAT
239 void fullyConnected(const tflite::FullyConnectedParams &params,
240                     const tflite::RuntimeShape &input_shape, const float *input_data,
241                     const tflite::RuntimeShape &filter_shape, const float *filter_data,
242                     const tflite::RuntimeShape &bias_shape, const float *bias_data,
243                     const tflite::RuntimeShape &output_shape, float *output_data)
244 {
245   return tflite::reference_ops::FullyConnected(params, input_shape, input_data, filter_shape,
246                                                filter_data, bias_shape, bias_data, output_shape,
247                                                output_data);
248 }
249
250 // Input and output have the same shape in LSTM
251 void mul(const tflite::RuntimeShape &shape, const luci_interpreter::lstm::ArithmeticParams *params,
252          const float *input1_data, const float *input2_data, float *output_data)
253 {
254   return mul(params, shape, input1_data, shape, input2_data, shape, output_data);
255 }
256
257 void addElementWise(const float *input_1, const float *input_2, int n_batch, int n_input,
258                     float *output)
259 {
260   for (int batch = 0; batch < n_batch; ++batch)
261   {
262     for (int i = 0; i < n_input; ++i)
263     {
264       const int index = batch * n_input + i;
265       output[index] = input_1[index] + input_2[index];
266     }
267   }
268 }
269
270 void tanh(int32_t cell_state_scale_power, const tflite::RuntimeShape &input_data_shape,
271           float *input_data, const tflite::RuntimeShape &output_data_shape, float *output_data)
272 {
273   tflite::reference_ops::Tanh(input_data_shape, input_data, output_data_shape, output_data);
274 }
275
276 void sigmoid(const tflite::RuntimeShape &data_shape, float *data)
277 {
278   tflite::reference_ops::Logistic(data_shape, data, data_shape, data);
279 }
280
281 void clipping(const int v_size, const luci_interpreter::lstm::CellStateInfo *cell_state_info,
282               float *vector)
283 {
284   for (int i = 0; i < v_size; i++)
285   {
286     vector[i] =
287       std::max(std::min(cell_state_info->cell_clip, vector[i]), -cell_state_info->cell_clip);
288   }
289 }
290 #endif // DIS_FLOAT
291
292 // Size information about the LSTM kernel, which is deduced from tensors stored
293 // in the flat buffer file.
294 struct LstmSizeInfo
295 {
296   bool time_major;
297   int32_t batch_size;
298   int32_t time_steps;
299   int32_t input_dimension;
300   int32_t state_dimension;
301 };
302
303 class LstmStepManager
304 {
305 public:
306   LstmStepManager() = delete;
307   // Does not take any ownership, and all pointers must refer to valid objects
308   // that outlive the one constructed.
309   explicit LstmStepManager(const LstmSizeInfo &size_info) : size_info_(size_info) {}
310
311   void updateTime()
312   {
313     current_time_ += 1;
314     // default as one batch per inference
315     int input_step = size_info_.input_dimension;
316     int output_step = size_info_.state_dimension;
317     // time major: batch inference
318     if (size_info_.time_major)
319     {
320       input_step = input_step * size_info_.batch_size;
321       output_step = output_step * size_info_.batch_size;
322     }
323
324     input_offset_ += input_step;
325     output_offset_ += output_step;
326   }
327
328   void updateBatch()
329   {
330     current_batch_ += 1;
331     TFLITE_DCHECK_LE(current_batch_, size_info_.batch_size);
332     // batch inference for time major: no action needed
333     if (size_info_.time_major)
334     {
335       return;
336     }
337     // otherwise: singe batch inference, go to the next batch
338     hidden_state_offset_ += size_info_.state_dimension;
339     cell_state_offset_ += size_info_.state_dimension;
340   }
341
342   void resetTime() { current_time_ = 0; }
343
344   tflite::RuntimeShape inputShape() const
345   {
346     int batch_size = 1;
347     if (size_info_.time_major)
348     {
349       batch_size = size_info_.batch_size;
350     }
351     const int dims[2] = {batch_size, size_info_.input_dimension};
352     const int32_t *dims_data = reinterpret_cast<const int32_t *>(dims);
353     return tflite::RuntimeShape(2, dims_data);
354   }
355
356   tflite::RuntimeShape stateShape() const
357   {
358     int batch_size = 1;
359     if (size_info_.time_major)
360     {
361       batch_size = size_info_.batch_size;
362     }
363     const int dims[2] = {batch_size, size_info_.state_dimension};
364     const int32_t *dims_data = reinterpret_cast<const int32_t *>(dims);
365     return tflite::RuntimeShape(2, dims_data);
366   }
367
368   int inputOffset() const { return input_offset_; }
369
370   int outputOffset() const { return output_offset_; }
371
372   int hiddenStateOffset() const { return hidden_state_offset_; }
373
374   int cellStateOffset() const { return cell_state_offset_; }
375
376 private:
377   int32_t current_time_ = 0;
378   int32_t current_batch_ = 0;
379   int32_t input_offset_ = 0;
380   int32_t output_offset_ = 0;
381   int32_t hidden_state_offset_ = 0;
382   int32_t cell_state_offset_ = 0;
383
384   const LstmSizeInfo &size_info_;
385 };
386
387 // Calculates a single LSTM gate.
388 // Implements the following formula:
389 //   gate = activate(FC(input) + FC(recurrent))
390 // Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
391 template <typename ActivationType, typename WeightType, typename CellType, typename BiasType>
392 void calculateLstmGate(const LstmStepManager *step_info,
393                        const luci_interpreter::lstm::GateParameters *gate_params,
394                        // Input FC
395                        ActivationType *input_data, const circle::Tensor *input_weight,
396                        const circle::Tensor *input_bias,
397                        // Recurrent FC
398                        ActivationType *recurrent_data, const circle::Tensor *recurrent_weight,
399                        const circle::Tensor *recurrent_bias,
400                        // Output
401                        CellType *gate_output,
402                        // Scratch arrays
403                        CellType *fc_output_buffer, const TfLiteFusedActivation activation,
404                        luci_interpreter::BaseRuntimeGraph *runtime_graph)
405 {
406   // Input FC
407   const auto gate_output_shape = step_info->stateShape();
408   {
409     tflite::FullyConnectedParams op_params{};
410     op_params.input_offset = gate_params->input_fc_params.input_offset;
411     op_params.weights_offset = gate_params->input_fc_params.weights_offset;
412     op_params.output_offset = gate_params->input_fc_params.output_offset;
413     op_params.output_multiplier = gate_params->input_fc_params.output_multiplier;
414     op_params.output_shift = gate_params->input_fc_params.output_shift;
415     op_params.quantized_activation_min = gate_params->input_fc_params.quantized_activation_min;
416     op_params.quantized_activation_max = gate_params->input_fc_params.quantized_activation_max;
417     op_params.float_activation_max = gate_params->input_fc_params.float_activation_max;
418     op_params.float_activation_min = gate_params->input_fc_params.float_activation_min;
419
420     fullyConnected(op_params, step_info->inputShape(), input_data + step_info->inputOffset(),
421                    luci_interpreter::kernels::getTensorShape(input_weight),
422                    luci_interpreter::kernels::getTensorData<WeightType>(
423                      runtime_graph->getConstDataByTensor(input_weight)),
424                    luci_interpreter::kernels::getTensorShape(input_bias),
425                    luci_interpreter::kernels::getTensorData<BiasType>(
426                      runtime_graph->getConstDataByTensor(input_bias)),
427                    gate_output_shape, gate_output);
428   }
429
430   // Recurrent FC
431   {
432     tflite::FullyConnectedParams op_params{};
433     op_params.input_offset = gate_params->recurrent_fc_params.input_offset;
434     op_params.weights_offset = gate_params->recurrent_fc_params.weights_offset;
435     op_params.output_offset = gate_params->recurrent_fc_params.output_offset;
436     op_params.output_multiplier = gate_params->recurrent_fc_params.output_multiplier;
437     op_params.output_shift = gate_params->recurrent_fc_params.output_shift;
438     op_params.quantized_activation_min = gate_params->recurrent_fc_params.quantized_activation_min;
439     op_params.quantized_activation_max = gate_params->recurrent_fc_params.quantized_activation_max;
440     op_params.float_activation_max = gate_params->recurrent_fc_params.float_activation_max;
441     op_params.float_activation_min = gate_params->recurrent_fc_params.float_activation_min;
442
443     fullyConnected(op_params, step_info->stateShape(),
444                    recurrent_data + step_info->hiddenStateOffset(),
445                    luci_interpreter::kernels::getTensorShape(recurrent_weight),
446                    luci_interpreter::kernels::getTensorData<WeightType>(
447                      runtime_graph->getConstDataByTensor(recurrent_weight)),
448                    luci_interpreter::kernels::getTensorShape(recurrent_bias),
449                    luci_interpreter::kernels::getTensorData<BiasType>(
450                      runtime_graph->getConstDataByTensor(recurrent_bias)),
451                    gate_output_shape, fc_output_buffer);
452
453     addElementWise(gate_output, fc_output_buffer, /*n_batch=*/gate_output_shape.DimsData()[0],
454                    /*n_state=*/gate_output_shape.DimsData()[1], gate_output);
455
456     switch (activation)
457     {
458       case TfLiteFusedActivation::kTfLiteActSigmoid:
459         sigmoid(gate_output_shape, gate_output);
460         break;
461       case TfLiteFusedActivation::kTfLiteActTanh:
462       {
463         // Set the scale power to -12 to avoid shift
464         tanh(/*cell_state_scale_power=*/-12, gate_output_shape, gate_output, gate_output_shape,
465              gate_output);
466       }
467       break;
468       default:
469         // Only Sigmoid or Tanh is used.
470         assert(false && "Only Sigmoid or Tanh is used");
471     }
472   }
473 }
474
475 // Update the hidden state of the LSTM kernel using the following formula:
476 // updated_hidden_state = Tanh(updated_cell_state) * output_gate_output, * means
477 // element wise multiplication
478 template <typename CellType, typename ActivationType>
479 void updateLstmHidden(const LstmStepManager *step_info, CellType *cell_state_data_base,
480                       ActivationType *hidden_state_data, const CellType *output_gate_output,
481                       const luci_interpreter::lstm::ArithmeticParams *mul_params,
482                       int32_t cell_state_scale_power, CellType *buffer)
483 {
484   auto cell_state_shape = step_info->stateShape();
485   CellType *cell_state_data = cell_state_data_base + step_info->cellStateOffset();
486   // Tanh(cell_state)
487   tanh(cell_state_scale_power, cell_state_shape, cell_state_data, cell_state_shape, buffer);
488   // Update the hidden state
489   mul(cell_state_shape, mul_params, buffer, output_gate_output,
490       hidden_state_data + step_info->hiddenStateOffset());
491 }
492
493 // Update the cell state using the output from the forget gate, input gate, and
494 // cell gate Formula: updated_cell_state = forget_gate_output*cell_state +
495 // input_gate_output * cell_gate_output, where * denotes element wise
496 // multiplication
497 template <typename CellType>
498 void updateLstmCell(const LstmStepManager *step_info, CellType *cell_state_data,
499                     // Gate outputs
500                     CellType *forget_gate_output, const CellType *input_gate_output,
501                     const CellType *cell_gate_output,
502                     // Mul parameters
503                     const luci_interpreter::lstm::ArithmeticParams &forget_cell_mul_params,
504                     const luci_interpreter::lstm::ArithmeticParams &input_mul_params,
505                     const luci_interpreter::lstm::CellStateInfo *cell_state_info, CellType *buffer)
506 {
507   auto cell_state_shape = step_info->stateShape();
508   // Forget Gate x Cell State
509   mul(cell_state_shape, &forget_cell_mul_params, forget_gate_output,
510       cell_state_data + step_info->cellStateOffset(),
511       cell_state_data + step_info->cellStateOffset());
512   // Input Gate x Cell Gate
513   mul(cell_state_shape, &input_mul_params, input_gate_output, cell_gate_output, buffer);
514
515   // Update the cell state
516   addElementWise(cell_state_data + step_info->cellStateOffset(), buffer,
517                  /*n_batch=*/cell_state_shape.DimsData()[0],
518                  /*n_state=*/cell_state_shape.DimsData()[1],
519                  cell_state_data + step_info->cellStateOffset());
520
521   if (cell_state_info->cell_clip > 0)
522   {
523     clipping(cell_state_shape.FlatSize(), cell_state_info,
524              cell_state_data + step_info->cellStateOffset());
525   }
526 }
527
528 template <typename ActivationType, typename WeightType, typename CellType, typename BiasType>
529 void lstmStep(luci_interpreter::lstm::LSTMStruct *lstm_struct,
530               luci_interpreter::lstm::LSTMParameters *lstm_params, LstmStepManager *step_info,
531               luci_interpreter::lstm::CellStateInfo *cell_state_info,
532               ActivationType *output_state_data, CellType *cell_state_data, CellType *scratch0,
533               CellType *scratch1, CellType *scratch2, CellType *scratch3,
534               luci_interpreter::BaseRuntimeGraph *runtime_graph)
535 {
536   /*Step1: Calculate gate outputs to prepare cell state update*/
537   CellType *gate_internal_buffer = scratch3;
538   CellType *forget_gate_output = scratch0;
539
540   auto input_data = luci_interpreter::kernels::getTensorData<ActivationType>(
541     runtime_graph->getDataByTensor(lstm_struct->input()));
542
543   calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
544     step_info, &lstm_params->forget_gate_parameters,
545     // Input FC
546     input_data, lstm_struct->input_to_forget_weights(), lstm_struct->forget_gate_bias(),
547     // Recurrent FC
548     output_state_data, lstm_struct->recurrent_to_forget_weights(), nullptr,
549     // Output
550     forget_gate_output, gate_internal_buffer, TfLiteFusedActivation::kTfLiteActSigmoid,
551     runtime_graph);
552
553   // Input Gate calculation;
554   CellType *input_gate_output = scratch1;
555   calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
556     step_info, &lstm_params->input_gate_parameters,
557     // Input FC
558     input_data, lstm_struct->input_to_input_weights(), lstm_struct->input_gate_bias(),
559     // Recurrent FC
560     output_state_data, lstm_struct->recurrent_to_input_weights(),
561     /*recurrent_bias*/ nullptr,
562     // Output
563     input_gate_output,
564     // Scratch arrays
565     gate_internal_buffer, TfLiteFusedActivation::kTfLiteActSigmoid, runtime_graph);
566
567   // Cell Gate calculation
568   CellType *cell_gate_output = scratch2;
569   calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
570     step_info, &lstm_params->cell_gate_parameters,
571     // Input FC
572     input_data, lstm_struct->input_to_cell_weights(), lstm_struct->cell_gate_bias(),
573     // Recurrent FC
574     output_state_data, lstm_struct->recurrent_to_cell_weights(),
575     /*recurrent_bias*/ nullptr,
576     // Output
577     cell_gate_output,
578     // Scratch arrays
579     gate_internal_buffer, TfLiteFusedActivation::kTfLiteActTanh, runtime_graph);
580
581   /*Step2: update the cell state */
582   {
583     // const InterGateParameters& inter_gate_params = op_data.inter_gate_parameters;
584     CellType *updated_input_buffer = scratch1; // reuse buffer
585
586     updateLstmCell<CellType>(
587       step_info, cell_state_data, forget_gate_output, input_gate_output, cell_gate_output,
588       lstm_params->inter_gate_parameters.forget_cell_mul_params,
589       lstm_params->inter_gate_parameters.input_mul_params, cell_state_info, updated_input_buffer);
590   }
591
592   {
593     /*Step3: update the hidden state */
594     CellType *output_gate_output = scratch1; // reuse buffer
595     calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
596       step_info, &lstm_params->output_gate_parameters,
597       // Input FC
598       input_data, lstm_struct->input_to_output_weights(), lstm_struct->output_gate_bias(),
599       // Recurrent FC
600       output_state_data, lstm_struct->recurrent_to_output_weights(), nullptr,
601       // Output
602       output_gate_output,
603       // Scratch arrays
604       gate_internal_buffer, TfLiteFusedActivation::kTfLiteActSigmoid, runtime_graph);
605     CellType *tanh_activated_cell_buffer = scratch0; // reuse buffer
606     updateLstmHidden<CellType, ActivationType>(
607       step_info, cell_state_data, output_state_data, output_gate_output,
608       &lstm_params->inter_gate_parameters.output_mul_params,
609       cell_state_info->cell_state_scale_power, tanh_activated_cell_buffer);
610
611     ActivationType *output_ptr = luci_interpreter::kernels::getTensorData<ActivationType>(
612       runtime_graph->getDataByTensor(lstm_struct->output()));
613     std::memcpy(output_ptr + step_info->outputOffset(),
614                 output_state_data + step_info->hiddenStateOffset(),
615                 step_info->stateShape().FlatSize() * sizeof(ActivationType));
616   }
617 }
618
619 } // namespace lstm_internal
620
621 // Evaluate the LSTM kernel with (potential) multi-steps and multi-batch input
622 template <typename ActivationType, typename WeightType, typename CellType, typename BiasType>
623 void evalLSTM(luci_interpreter::lstm::LSTMStruct *lstm_struct,
624               luci_interpreter::lstm::LSTMParameters *lstm_params,
625               luci_interpreter::lstm::CellStateInfo *cell_state_info,
626               ActivationType *output_state_data, CellType *cell_state_data, CellType *scratch0,
627               CellType *scratch1, CellType *scratch2, CellType *scratch3,
628               luci_interpreter::BaseRuntimeGraph *runtime_graph)
629 {
630   lstm_internal::LstmSizeInfo size_info;
631
632   size_info.time_major = lstm_struct->options->time_major();
633   size_info.batch_size = size_info.time_major
634                            ? luci_interpreter::Tensor::dim(lstm_struct->input(), 1)
635                            : luci_interpreter::Tensor::dim(lstm_struct->input(), 0);
636   size_info.time_steps = size_info.time_major
637                            ? luci_interpreter::Tensor::dim(lstm_struct->input(), 0)
638                            : luci_interpreter::Tensor::dim(lstm_struct->input(), 1);
639   size_info.input_dimension = luci_interpreter::Tensor::dim(lstm_struct->input(), 2);
640   size_info.state_dimension = luci_interpreter::Tensor::dim(lstm_struct->output_state(), 1);
641
642   lstm_internal::LstmStepManager step_info(size_info);
643
644   // time is the first dimention, enable batch computation
645   if (size_info.time_major)
646   {
647     for (int t = 0; t < size_info.time_steps; t++)
648     {
649       lstm_internal::lstmStep<ActivationType, WeightType, CellType, BiasType>(
650         lstm_struct, lstm_params, &step_info, cell_state_info, output_state_data, cell_state_data,
651         scratch0, scratch1, scratch2, scratch3, runtime_graph);
652       // prepare for the next time step
653       step_info.updateTime();
654     }
655   }
656   else
657   {
658     // batch first, unable to size the input data. single batch inference
659     for (int b = 0; b < size_info.batch_size; b++)
660     {
661       for (int t = 0; t < size_info.time_steps; t++)
662       {
663         lstm_internal::lstmStep<ActivationType, WeightType, CellType, BiasType>(
664           lstm_struct, lstm_params, &step_info, cell_state_info, output_state_data, cell_state_data,
665           scratch0, scratch1, scratch2, scratch3, runtime_graph);
666         // prepare for the next time step
667         step_info.updateTime();
668       }
669       // prepare for the next batch
670       step_info.updateBatch();
671       step_info.resetTime();
672     }
673   }
674 }
675
676 } // namespace luci_interpreter_pal
677
678 #endif // LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_H