Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / common / PALUnidirectionalSequenceLSTMCommon.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_COMMON_H
19 #define LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_COMMON_H
20
21 #include "kernels/UnidirectionalSequenceLSTM.h"
22 #include "PALTanh.h"
23 #include "PALLogistic.h"
24 #include "PALFullyConnected.h"
25 #include "PALMul.h"
26 #include "PALUtils.h"
27
28 namespace luci_interpreter_pal
29 {
30 namespace lstm_internal
31 {
32 namespace
33 {
34 // Possible fused activation functions.
35 typedef enum
36 {
37   kTfLiteActNone = 0,
38   kTfLiteActRelu,
39   kTfLiteActReluN1To1, // min(max(-1, x), 1)
40   kTfLiteActRelu6,     // min(max(0, x), 6)
41   kTfLiteActTanh,
42   kTfLiteActSignBit,
43   kTfLiteActSigmoid,
44 } FusedActivation;
45
46 } // namespace
47
48 #ifndef DIS_QUANT
49
50 template <typename InputType, typename OutputType>
51 void mulElementwise(int size, const ArithmeticParams *params, const InputType *input1_data,
52                     const InputType *input2_data, OutputType *output_data)
53 {
54   for (int i = 0; i < size; ++i)
55   {
56     const int32_t input1_val = params->input1_offset + input1_data[i];
57     const int32_t input2_val = params->input2_offset + input2_data[i];
58     const int32_t unclamped_result =
59       params->output_offset + multiplyByQuantizedMultiplier(input1_val * input2_val,
60                                                             params->output_multiplier,
61                                                             params->output_shift);
62     const int32_t clamped_output =
63       std::min(params->quantized_activation_max,
64                std::max(params->quantized_activation_min, unclamped_result));
65     output_data[i] = static_cast<OutputType>(clamped_output);
66   }
67 }
68
69 // Input and output have the same shape in LSTM
70 void mul(const luci_interpreter::RuntimeShape &shape, const ArithmeticParams *params,
71          const int16_t *input1_data, const int16_t *input2_data, int8_t *output_data)
72 {
73   return mulElementwise<int16_t, int8_t>(shape.flatSize(), params, input1_data, input2_data,
74                                          output_data);
75 }
76
77 // Input and output have the same shape in LSTM
78 void mul(const luci_interpreter::RuntimeShape &shape, const ArithmeticParams *params,
79          const int16_t *input1_data, const int16_t *input2_data, int16_t *output_data)
80 {
81   return mulElementwise(shape.flatSize(), params, input1_data, input2_data, output_data);
82 }
83
84 void addElementWise(const int16_t *input_1, const int16_t *input_2, int n_batch, int n_input,
85                     int16_t *output)
86 {
87   for (int batch = 0; batch < n_batch; ++batch)
88   {
89     for (int i = 0; i < n_input; ++i)
90     {
91       const int index = batch * n_input + i;
92       int32_t sum = input_1[index] + input_2[index];
93       const int32_t sum_clamped =
94         std::min(static_cast<int32_t>(std::numeric_limits<int16_t>::max()),
95                  std::max(static_cast<int32_t>(std::numeric_limits<int16_t>::min()), sum));
96       output[index] = static_cast<int16_t>(sum_clamped);
97     }
98   }
99 }
100
101 void tanh(int32_t cell_state_scale_power, const luci_interpreter::RuntimeShape &input_data_shape,
102           int16_t *input_data, const luci_interpreter::RuntimeShape &output_data_shape,
103           int16_t *output_data)
104 {
105   int32_t tanh_input_left_shift = (15 + cell_state_scale_power) - 3;
106   int32_t input_multiplier = 0;
107   if (tanh_input_left_shift < 0) /* handling negative shift value */
108   {
109     tanh_input_left_shift = -tanh_input_left_shift;
110     input_multiplier = 3;
111   }
112   const int flat_size = input_data_shape.flatSize();
113   luci_interpreter_pal::Tanh(input_multiplier, tanh_input_left_shift, flat_size, input_data,
114                              output_data);
115 }
116
117 void sigmoid(const luci_interpreter::RuntimeShape &data_shape, int16_t *data)
118 {
119   luci_interpreter_pal::Logistic(0, 0, data_shape.flatSize(), data, data);
120 }
121
122 void clipping(const int v_size, const luci_interpreter::lstm::CellStateInfo *cell_state_info,
123               int16_t *vector)
124 {
125   for (int i = 0; i < v_size; i++)
126   {
127     vector[i] = std::max(std::min(cell_state_info->quantized_cell_clip, vector[i]),
128                          static_cast<int16_t>(-cell_state_info->quantized_cell_clip));
129   }
130 }
131 #endif // DIS_QUANT
132
133 #ifndef DIS_FLOAT
134 // Input and output have the same shape in LSTM
135 void mul(const luci_interpreter::RuntimeShape &shape, const ArithmeticParams *params,
136          const float *input1_data, const float *input2_data, float *output_data)
137 {
138   const int flat_size = shape.flatSize();
139   return luci_interpreter_pal::Mul(*params, flat_size, input1_data, input2_data, output_data);
140 }
141
142 void addElementWise(const float *input_1, const float *input_2, int n_batch, int n_input,
143                     float *output)
144 {
145   for (int batch = 0; batch < n_batch; ++batch)
146   {
147     for (int i = 0; i < n_input; ++i)
148     {
149       const int index = batch * n_input + i;
150       output[index] = input_1[index] + input_2[index];
151     }
152   }
153 }
154
155 void tanh(int32_t, const luci_interpreter::RuntimeShape &input_data_shape, float *input_data,
156           const luci_interpreter::RuntimeShape &output_data_shape, float *output_data)
157 {
158   const int flat_size = input_data_shape.flatSize();
159   luci_interpreter_pal::Tanh(flat_size, input_data, output_data);
160 }
161
162 void sigmoid(const luci_interpreter::RuntimeShape &data_shape, float *data)
163 {
164   const int flat_size = data_shape.flatSize();
165   luci_interpreter_pal::Logistic(flat_size, data, data);
166 }
167
168 void clipping(const int v_size, const luci_interpreter::lstm::CellStateInfo *cell_state_info,
169               float *vector)
170 {
171   for (int i = 0; i < v_size; i++)
172   {
173     vector[i] =
174       std::max(std::min(cell_state_info->cell_clip, vector[i]), -cell_state_info->cell_clip);
175   }
176 }
177 #endif // DIS_FLOAT
178
179 // Size information about the LSTM kernel, which is deduced from tensors stored
180 // in the flat buffer file.
181 struct LstmSizeInfo
182 {
183   bool time_major;
184   int32_t batch_size;
185   int32_t time_steps;
186   int32_t input_dimension;
187   int32_t state_dimension;
188 };
189
190 class LstmStepManager
191 {
192 public:
193   LstmStepManager() = delete;
194   // Does not take any ownership, and all pointers must refer to valid objects
195   // that outlive the one constructed.
196   explicit LstmStepManager(const LstmSizeInfo &size_info) : size_info_(size_info) {}
197
198   void updateTime()
199   {
200     current_time_ += 1;
201     // default as one batch per inference
202     int input_step = size_info_.input_dimension;
203     int output_step = size_info_.state_dimension;
204     // time major: batch inference
205     if (size_info_.time_major)
206     {
207       input_step = input_step * size_info_.batch_size;
208       output_step = output_step * size_info_.batch_size;
209     }
210
211     input_offset_ += input_step;
212     output_offset_ += output_step;
213   }
214
215   void updateBatch()
216   {
217     current_batch_ += 1;
218     // batch inference for time major: no action needed
219     if (size_info_.time_major)
220     {
221       return;
222     }
223     // otherwise: singe batch inference, go to the next batch
224     hidden_state_offset_ += size_info_.state_dimension;
225     cell_state_offset_ += size_info_.state_dimension;
226   }
227
228   void resetTime() { current_time_ = 0; }
229
230   luci_interpreter::RuntimeShape inputShape() const
231   {
232     int batch_size = 1;
233     if (size_info_.time_major)
234     {
235       batch_size = size_info_.batch_size;
236     }
237     const int dims[2] = {batch_size, size_info_.input_dimension};
238     const int32_t *dims_data = reinterpret_cast<const int32_t *>(dims);
239     return luci_interpreter::RuntimeShape(2, dims_data);
240   }
241
242   luci_interpreter::RuntimeShape stateShape() const
243   {
244     int batch_size = 1;
245     if (size_info_.time_major)
246     {
247       batch_size = size_info_.batch_size;
248     }
249     const int dims[2] = {batch_size, size_info_.state_dimension};
250     const int32_t *dims_data = reinterpret_cast<const int32_t *>(dims);
251     return luci_interpreter::RuntimeShape(2, dims_data);
252   }
253
254   int inputOffset() const { return input_offset_; }
255
256   int outputOffset() const { return output_offset_; }
257
258   int hiddenStateOffset() const { return hidden_state_offset_; }
259
260   int cellStateOffset() const { return cell_state_offset_; }
261
262 private:
263   int32_t current_time_ = 0;
264   int32_t current_batch_ = 0;
265   int32_t input_offset_ = 0;
266   int32_t output_offset_ = 0;
267   int32_t hidden_state_offset_ = 0;
268   int32_t cell_state_offset_ = 0;
269
270   const LstmSizeInfo &size_info_;
271 };
272
273 // Calculates a single LSTM gate.
274 // Implements the following formula:
275 //   gate = activate(FC(input) + FC(recurrent))
276 // Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
277 template <typename ActivationType, typename WeightType, typename CellType, typename BiasType>
278 void calculateLstmGate(const LstmStepManager *step_info,
279                        const luci_interpreter::lstm::GateParameters *gate_params,
280                        // Input FC
281                        ActivationType *input_data, const circle::Tensor *input_weight,
282                        const circle::Tensor *input_bias,
283                        // Recurrent FC
284                        ActivationType *recurrent_data, const circle::Tensor *recurrent_weight,
285                        const circle::Tensor *recurrent_bias,
286                        // Output
287                        CellType *gate_output,
288                        // Scratch arrays
289                        CellType *fc_output_buffer, const FusedActivation activation,
290                        luci_interpreter::BaseRuntimeGraph *runtime_graph)
291 {
292   // Input FC
293   const auto gate_output_shape = step_info->stateShape();
294   {
295     FullyConnectedParams op_params{};
296     op_params.input_offset = gate_params->input_fc_params.input_offset;
297     op_params.weights_offset = gate_params->input_fc_params.weights_offset;
298     op_params.output_offset = gate_params->input_fc_params.output_offset;
299     op_params.output_multiplier = gate_params->input_fc_params.output_multiplier;
300     op_params.output_shift = gate_params->input_fc_params.output_shift;
301     op_params.quantized_activation_min = gate_params->input_fc_params.quantized_activation_min;
302     op_params.quantized_activation_max = gate_params->input_fc_params.quantized_activation_max;
303     op_params.float_activation_max = gate_params->input_fc_params.float_activation_max;
304     op_params.float_activation_min = gate_params->input_fc_params.float_activation_min;
305
306     int32_t input_weight_shape[luci_interpreter::kMaxSmallSize];
307     luci_interpreter::kernels::getTensorDims(input_weight, runtime_graph, input_weight_shape);
308
309     FullyConnected(op_params, step_info->inputShape().dimsData(),
310                    input_data + step_info->inputOffset(), input_weight_shape,
311                    luci_interpreter::kernels::getTensorData<WeightType>(
312                      runtime_graph->getConstDataByTensor(input_weight)),
313                    luci_interpreter::kernels::getTensorData<BiasType>(
314                      runtime_graph->getConstDataByTensor(input_bias)),
315                    gate_output_shape.dimsData(), gate_output);
316   }
317
318   // Recurrent FC
319   {
320     FullyConnectedParams op_params{};
321     op_params.input_offset = gate_params->recurrent_fc_params.input_offset;
322     op_params.weights_offset = gate_params->recurrent_fc_params.weights_offset;
323     op_params.output_offset = gate_params->recurrent_fc_params.output_offset;
324     op_params.output_multiplier = gate_params->recurrent_fc_params.output_multiplier;
325     op_params.output_shift = gate_params->recurrent_fc_params.output_shift;
326     op_params.quantized_activation_min = gate_params->recurrent_fc_params.quantized_activation_min;
327     op_params.quantized_activation_max = gate_params->recurrent_fc_params.quantized_activation_max;
328     op_params.float_activation_max = gate_params->recurrent_fc_params.float_activation_max;
329     op_params.float_activation_min = gate_params->recurrent_fc_params.float_activation_min;
330
331     int32_t recurrent_weight_shape[luci_interpreter::kMaxSmallSize];
332     luci_interpreter::kernels::getTensorDims(recurrent_weight, runtime_graph,
333                                              recurrent_weight_shape);
334
335     FullyConnected(op_params, step_info->stateShape().dimsData(),
336                    recurrent_data + step_info->hiddenStateOffset(), recurrent_weight_shape,
337                    luci_interpreter::kernels::getTensorData<WeightType>(
338                      runtime_graph->getConstDataByTensor(recurrent_weight)),
339                    luci_interpreter::kernels::getTensorData<BiasType>(
340                      runtime_graph->getConstDataByTensor(recurrent_bias)),
341                    gate_output_shape.dimsData(), fc_output_buffer);
342
343     addElementWise(gate_output, fc_output_buffer, /*n_batch=*/gate_output_shape.dimsData()[0],
344                    /*n_state=*/gate_output_shape.dimsData()[1], gate_output);
345
346     switch (activation)
347     {
348       case FusedActivation::kTfLiteActSigmoid:
349         sigmoid(gate_output_shape, gate_output);
350         break;
351       case FusedActivation::kTfLiteActTanh:
352       {
353         // Set the scale power to -12 to avoid shift
354         tanh(/*cell_state_scale_power=*/-12, gate_output_shape, gate_output, gate_output_shape,
355              gate_output);
356       }
357       break;
358       default:
359         // Only Sigmoid or Tanh is used.
360         assert(false && "Only Sigmoid or Tanh is used");
361     }
362   }
363 }
364
365 // Update the hidden state of the LSTM kernel using the following formula:
366 // updated_hidden_state = Tanh(updated_cell_state) * output_gate_output, * means
367 // element wise multiplication
368 template <typename CellType, typename ActivationType>
369 void updateLstmHidden(const LstmStepManager *step_info, CellType *cell_state_data_base,
370                       ActivationType *hidden_state_data, const CellType *output_gate_output,
371                       const ArithmeticParams *mul_params, int32_t cell_state_scale_power,
372                       CellType *buffer)
373 {
374   auto cell_state_shape = step_info->stateShape();
375   CellType *cell_state_data = cell_state_data_base + step_info->cellStateOffset();
376   // Tanh(cell_state)
377   tanh(cell_state_scale_power, cell_state_shape, cell_state_data, cell_state_shape, buffer);
378   // Update the hidden state
379   mul(cell_state_shape, mul_params, buffer, output_gate_output,
380       hidden_state_data + step_info->hiddenStateOffset());
381 }
382
383 // Update the cell state using the output from the forget gate, input gate, and
384 // cell gate Formula: updated_cell_state = forget_gate_output*cell_state +
385 // input_gate_output * cell_gate_output, where * denotes element wise
386 // multiplication
387 template <typename CellType>
388 void updateLstmCell(const LstmStepManager *step_info, CellType *cell_state_data,
389                     // Gate outputs
390                     CellType *forget_gate_output, const CellType *input_gate_output,
391                     const CellType *cell_gate_output,
392                     // Mul parameters
393                     const ArithmeticParams &forget_cell_mul_params,
394                     const ArithmeticParams &input_mul_params,
395                     const luci_interpreter::lstm::CellStateInfo *cell_state_info, CellType *buffer)
396 {
397   auto cell_state_shape = step_info->stateShape();
398   // Forget Gate x Cell State
399   mul(cell_state_shape, &forget_cell_mul_params, forget_gate_output,
400       cell_state_data + step_info->cellStateOffset(),
401       cell_state_data + step_info->cellStateOffset());
402   // Input Gate x Cell Gate
403   mul(cell_state_shape, &input_mul_params, input_gate_output, cell_gate_output, buffer);
404
405   // Update the cell state
406   addElementWise(cell_state_data + step_info->cellStateOffset(), buffer,
407                  /*n_batch=*/cell_state_shape.dimsData()[0],
408                  /*n_state=*/cell_state_shape.dimsData()[1],
409                  cell_state_data + step_info->cellStateOffset());
410
411   if (cell_state_info->cell_clip > 0)
412   {
413     clipping(cell_state_shape.flatSize(), cell_state_info,
414              cell_state_data + step_info->cellStateOffset());
415   }
416 }
417
418 template <typename ActivationType, typename WeightType, typename CellType, typename BiasType>
419 void lstmStep(luci_interpreter::lstm::LSTMStruct *lstm_struct,
420               luci_interpreter::lstm::LSTMParameters *lstm_params, LstmStepManager *step_info,
421               luci_interpreter::lstm::CellStateInfo *cell_state_info,
422               ActivationType *output_state_data, CellType *cell_state_data, CellType *scratch0,
423               CellType *scratch1, CellType *scratch2, CellType *scratch3,
424               luci_interpreter::BaseRuntimeGraph *runtime_graph)
425 {
426   /*Step1: Calculate gate outputs to prepare cell state update*/
427   CellType *gate_internal_buffer = scratch3;
428   CellType *forget_gate_output = scratch0;
429
430   auto input_data = luci_interpreter::kernels::getTensorData<ActivationType>(
431     runtime_graph->getDataByTensor(lstm_struct->input()));
432
433   calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
434     step_info, &lstm_params->forget_gate_parameters,
435     // Input FC
436     input_data, lstm_struct->input_to_forget_weights(), lstm_struct->forget_gate_bias(),
437     // Recurrent FC
438     output_state_data, lstm_struct->recurrent_to_forget_weights(), nullptr,
439     // Output
440     forget_gate_output, gate_internal_buffer, FusedActivation::kTfLiteActSigmoid, runtime_graph);
441
442   // Input Gate calculation;
443   CellType *input_gate_output = scratch1;
444   calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
445     step_info, &lstm_params->input_gate_parameters,
446     // Input FC
447     input_data, lstm_struct->input_to_input_weights(), lstm_struct->input_gate_bias(),
448     // Recurrent FC
449     output_state_data, lstm_struct->recurrent_to_input_weights(),
450     /*recurrent_bias*/ nullptr,
451     // Output
452     input_gate_output,
453     // Scratch arrays
454     gate_internal_buffer, FusedActivation::kTfLiteActSigmoid, runtime_graph);
455
456   // Cell Gate calculation
457   CellType *cell_gate_output = scratch2;
458   calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
459     step_info, &lstm_params->cell_gate_parameters,
460     // Input FC
461     input_data, lstm_struct->input_to_cell_weights(), lstm_struct->cell_gate_bias(),
462     // Recurrent FC
463     output_state_data, lstm_struct->recurrent_to_cell_weights(),
464     /*recurrent_bias*/ nullptr,
465     // Output
466     cell_gate_output,
467     // Scratch arrays
468     gate_internal_buffer, FusedActivation::kTfLiteActTanh, runtime_graph);
469
470   /*Step2: update the cell state */
471   {
472     // const InterGateParameters& inter_gate_params = op_data.inter_gate_parameters;
473     CellType *updated_input_buffer = scratch1; // reuse buffer
474
475     updateLstmCell<CellType>(
476       step_info, cell_state_data, forget_gate_output, input_gate_output, cell_gate_output,
477       lstm_params->inter_gate_parameters.forget_cell_mul_params,
478       lstm_params->inter_gate_parameters.input_mul_params, cell_state_info, updated_input_buffer);
479   }
480
481   {
482     /*Step3: update the hidden state */
483     CellType *output_gate_output = scratch1; // reuse buffer
484     calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
485       step_info, &lstm_params->output_gate_parameters,
486       // Input FC
487       input_data, lstm_struct->input_to_output_weights(), lstm_struct->output_gate_bias(),
488       // Recurrent FC
489       output_state_data, lstm_struct->recurrent_to_output_weights(), nullptr,
490       // Output
491       output_gate_output,
492       // Scratch arrays
493       gate_internal_buffer, FusedActivation::kTfLiteActSigmoid, runtime_graph);
494     CellType *tanh_activated_cell_buffer = scratch0; // reuse buffer
495     updateLstmHidden<CellType, ActivationType>(
496       step_info, cell_state_data, output_state_data, output_gate_output,
497       &lstm_params->inter_gate_parameters.output_mul_params,
498       cell_state_info->cell_state_scale_power, tanh_activated_cell_buffer);
499
500     ActivationType *output_ptr = luci_interpreter::kernels::getTensorData<ActivationType>(
501       runtime_graph->getDataByTensor(lstm_struct->output()));
502     std::memcpy(output_ptr + step_info->outputOffset(),
503                 output_state_data + step_info->hiddenStateOffset(),
504                 step_info->stateShape().flatSize() * sizeof(ActivationType));
505   }
506 }
507
508 } // namespace lstm_internal
509
510 // Evaluate the LSTM kernel with (potential) multi-steps and multi-batch input
511 template <typename ActivationType, typename WeightType, typename CellType, typename BiasType>
512 void evalLSTM(luci_interpreter::lstm::LSTMStruct *lstm_struct,
513               luci_interpreter::lstm::LSTMParameters *lstm_params,
514               luci_interpreter::lstm::CellStateInfo *cell_state_info,
515               ActivationType *output_state_data, CellType *cell_state_data, CellType *scratch0,
516               CellType *scratch1, CellType *scratch2, CellType *scratch3,
517               luci_interpreter::BaseRuntimeGraph *runtime_graph)
518 {
519   lstm_internal::LstmSizeInfo size_info;
520
521   size_info.time_major = lstm_struct->options->time_major();
522   size_info.batch_size = size_info.time_major
523                            ? luci_interpreter::Tensor::dim(lstm_struct->input(), 1)
524                            : luci_interpreter::Tensor::dim(lstm_struct->input(), 0);
525   size_info.time_steps = size_info.time_major
526                            ? luci_interpreter::Tensor::dim(lstm_struct->input(), 0)
527                            : luci_interpreter::Tensor::dim(lstm_struct->input(), 1);
528   size_info.input_dimension = luci_interpreter::Tensor::dim(lstm_struct->input(), 2);
529   size_info.state_dimension = luci_interpreter::Tensor::dim(lstm_struct->output_state(), 1);
530
531   lstm_internal::LstmStepManager step_info(size_info);
532
533   // time is the first dimention, enable batch computation
534   if (size_info.time_major)
535   {
536     for (int t = 0; t < size_info.time_steps; t++)
537     {
538       lstm_internal::lstmStep<ActivationType, WeightType, CellType, BiasType>(
539         lstm_struct, lstm_params, &step_info, cell_state_info, output_state_data, cell_state_data,
540         scratch0, scratch1, scratch2, scratch3, runtime_graph);
541       // prepare for the next time step
542       step_info.updateTime();
543     }
544   }
545   else
546   {
547     // batch first, unable to size the input data. single batch inference
548     for (int b = 0; b < size_info.batch_size; b++)
549     {
550       for (int t = 0; t < size_info.time_steps; t++)
551       {
552         lstm_internal::lstmStep<ActivationType, WeightType, CellType, BiasType>(
553           lstm_struct, lstm_params, &step_info, cell_state_info, output_state_data, cell_state_data,
554           scratch0, scratch1, scratch2, scratch3, runtime_graph);
555         // prepare for the next time step
556         step_info.updateTime();
557       }
558       // prepare for the next batch
559       step_info.updateBatch();
560       step_info.resetTime();
561     }
562   }
563 }
564
565 } // namespace luci_interpreter_pal
566
567 #endif // LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_COMMON_H