2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_H
19 #define LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_H
21 #include "kernels/UnidirectionalSequenceLSTM.h"
22 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
23 #include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
24 #include "tensorflow/lite/kernels/internal/reference/logistic.h"
25 #include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
26 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
28 namespace luci_interpreter_pal
30 namespace lstm_internal
34 // Possible fused activation functions.
39 kTfLiteActReluN1To1, // min(max(-1, x), 1)
40 kTfLiteActRelu6, // min(max(0, x), 6)
44 } TfLiteFusedActivation;
49 inline T activationFunctionWithMinMax(T x, T output_activation_min, T output_activation_max)
53 return min(max(x, output_activation_min), output_activation_max);
57 inline void mul(const luci_interpreter::lstm::ArithmeticParams *params,
58 const tflite::RuntimeShape &input1_shape, const T *input1_data,
59 const tflite::RuntimeShape &input2_shape, const T *input2_data,
60 const tflite::RuntimeShape &output_shape, T *output_data)
62 T output_activation_min = params->quantized_activation_min;
63 T output_activation_max = params->quantized_activation_max;
65 const int flat_size = input1_shape.FlatSize();
66 for (int i = 0; i < flat_size; ++i)
68 output_data[i] = activationFunctionWithMinMax(input1_data[i] * input2_data[i],
69 output_activation_min, output_activation_max);
74 inline int32_t multiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift)
76 using gemmlowp::RoundingDivideByPOT;
77 using gemmlowp::SaturatingRoundingDoublingHighMul;
78 int left_shift = shift > 0 ? shift : 0;
79 int right_shift = shift > 0 ? 0 : -shift;
80 return RoundingDivideByPOT(
81 SaturatingRoundingDoublingHighMul(x * (1 << left_shift), quantized_multiplier), right_shift);
84 template <typename InputType, typename WeightType, typename OutputType, typename BiasType>
85 void fullyConnectedInteger(const tflite::FullyConnectedParams ¶ms,
86 const tflite::RuntimeShape &input_shape, const InputType *input_data,
87 const tflite::RuntimeShape &filter_shape, const WeightType *filter_data,
88 const tflite::RuntimeShape &bias_shape, const BiasType *bias_data,
89 const tflite::RuntimeShape &output_shape, OutputType *output_data)
91 const int32_t input_offset = params.input_offset;
92 const int32_t filter_offset = params.weights_offset;
93 const int32_t output_offset = params.output_offset;
94 const int32_t output_multiplier = params.output_multiplier;
95 const int output_shift = params.output_shift;
96 const int32_t output_activation_min = params.quantized_activation_min;
97 const int32_t output_activation_max = params.quantized_activation_max;
98 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
99 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
101 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
102 const int filter_dim_count = filter_shape.DimensionsCount();
103 const int output_dim_count = output_shape.DimensionsCount();
104 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
105 const int output_depth = output_shape.Dims(output_dim_count - 1);
106 TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2));
107 const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
108 for (int b = 0; b < batches; ++b)
110 for (int out_c = 0; out_c < output_depth; ++out_c)
113 for (int d = 0; d < accum_depth; ++d)
115 int32_t input_val = input_data[b * accum_depth + d];
116 int32_t filter_val = filter_data[out_c * accum_depth + d];
117 acc += (filter_val + filter_offset) * (input_val + input_offset);
121 acc += bias_data[out_c];
123 int32_t acc_scaled = multiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
124 acc_scaled += output_offset;
125 acc_scaled = std::max(acc_scaled, output_activation_min);
126 acc_scaled = std::min(acc_scaled, output_activation_max);
127 output_data[out_c + output_depth * b] = static_cast<OutputType>(acc_scaled);
132 void fullyConnected(const tflite::FullyConnectedParams ¶ms,
133 const tflite::RuntimeShape &input_shape, const int8_t *input_data,
134 const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
135 const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
136 const tflite::RuntimeShape &output_shape, int16_t *output_data)
138 return fullyConnectedInteger(params, input_shape, input_data, filter_shape, filter_data,
139 bias_shape, bias_data, output_shape, output_data);
142 void fullyConnected(const tflite::FullyConnectedParams ¶ms,
143 const tflite::RuntimeShape &input_shape, const int16_t *input_data,
144 const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
145 const tflite::RuntimeShape &bias_shape, const int64_t *bias_data,
146 const tflite::RuntimeShape &output_shape, int16_t *output_data)
148 return fullyConnectedInteger(params, input_shape, input_data, filter_shape, filter_data,
149 bias_shape, bias_data, output_shape, output_data);
152 template <typename InputType, typename OutputType>
153 void mulElementwise(int size, const luci_interpreter::lstm::ArithmeticParams *params,
154 const InputType *input1_data, const InputType *input2_data,
155 OutputType *output_data)
157 for (int i = 0; i < size; ++i)
159 const int32_t input1_val = params->input1_offset + input1_data[i];
160 const int32_t input2_val = params->input2_offset + input2_data[i];
161 const int32_t unclamped_result =
162 params->output_offset + multiplyByQuantizedMultiplier(input1_val * input2_val,
163 params->output_multiplier,
164 params->output_shift);
165 const int32_t clamped_output =
166 std::min(params->quantized_activation_max,
167 std::max(params->quantized_activation_min, unclamped_result));
168 output_data[i] = static_cast<OutputType>(clamped_output);
172 // Input and output have the same shape in LSTM
173 void mul(const tflite::RuntimeShape &shape, const luci_interpreter::lstm::ArithmeticParams *params,
174 const int16_t *input1_data, const int16_t *input2_data, int8_t *output_data)
176 return mulElementwise<int16_t, int8_t>(shape.FlatSize(), params, input1_data, input2_data,
180 // Input and output have the same shape in LSTM
181 void mul(const tflite::RuntimeShape &shape, const luci_interpreter::lstm::ArithmeticParams *params,
182 const int16_t *input1_data, const int16_t *input2_data, int16_t *output_data)
184 return mulElementwise(shape.FlatSize(), params, input1_data, input2_data, output_data);
187 void addElementWise(const int16_t *input_1, const int16_t *input_2, int n_batch, int n_input,
190 for (int batch = 0; batch < n_batch; ++batch)
192 for (int i = 0; i < n_input; ++i)
194 const int index = batch * n_input + i;
195 int32_t sum = input_1[index] + input_2[index];
196 const int32_t sum_clamped =
197 std::min(static_cast<int32_t>(std::numeric_limits<int16_t>::max()),
198 std::max(static_cast<int32_t>(std::numeric_limits<int16_t>::min()), sum));
199 output[index] = static_cast<int16_t>(sum_clamped);
204 void tanh(int32_t cell_state_scale_power, const tflite::RuntimeShape &input_data_shape,
205 int16_t *input_data, const tflite::RuntimeShape &output_data_shape, int16_t *output_data)
207 int32_t tanh_input_left_shift = (15 + cell_state_scale_power) - 3;
208 int32_t input_multiplier = 0;
209 if (tanh_input_left_shift < 0) /* handling negative shift value */
211 tanh_input_left_shift = -tanh_input_left_shift;
212 input_multiplier = 3;
214 tflite::reference_integer_ops::Tanh(input_multiplier, tanh_input_left_shift, input_data_shape,
215 input_data, output_data_shape, output_data);
218 void sigmoid(const tflite::RuntimeShape &data_shape, int16_t *data)
220 tflite::reference_integer_ops::Logistic(0 /*data->input_multiplier*/,
221 0 /*data->input_left_shift */,
222 data_shape.FlatSize() /*NumElements(input->dims)*/,
223 data /* tflite::micro::GetTensorData<int16_t>(input) */,
224 data /*tflite::micro::GetTensorData<int16_t>(output) */);
227 void clipping(const int v_size, const luci_interpreter::lstm::CellStateInfo *cell_state_info,
230 for (int i = 0; i < v_size; i++)
232 vector[i] = std::max(std::min(cell_state_info->quantized_cell_clip, vector[i]),
233 static_cast<int16_t>(-cell_state_info->quantized_cell_clip));
239 void fullyConnected(const tflite::FullyConnectedParams ¶ms,
240 const tflite::RuntimeShape &input_shape, const float *input_data,
241 const tflite::RuntimeShape &filter_shape, const float *filter_data,
242 const tflite::RuntimeShape &bias_shape, const float *bias_data,
243 const tflite::RuntimeShape &output_shape, float *output_data)
245 return tflite::reference_ops::FullyConnected(params, input_shape, input_data, filter_shape,
246 filter_data, bias_shape, bias_data, output_shape,
250 // Input and output have the same shape in LSTM
251 void mul(const tflite::RuntimeShape &shape, const luci_interpreter::lstm::ArithmeticParams *params,
252 const float *input1_data, const float *input2_data, float *output_data)
254 return mul(params, shape, input1_data, shape, input2_data, shape, output_data);
257 void addElementWise(const float *input_1, const float *input_2, int n_batch, int n_input,
260 for (int batch = 0; batch < n_batch; ++batch)
262 for (int i = 0; i < n_input; ++i)
264 const int index = batch * n_input + i;
265 output[index] = input_1[index] + input_2[index];
270 void tanh(int32_t cell_state_scale_power, const tflite::RuntimeShape &input_data_shape,
271 float *input_data, const tflite::RuntimeShape &output_data_shape, float *output_data)
273 tflite::reference_ops::Tanh(input_data_shape, input_data, output_data_shape, output_data);
276 void sigmoid(const tflite::RuntimeShape &data_shape, float *data)
278 tflite::reference_ops::Logistic(data_shape, data, data_shape, data);
281 void clipping(const int v_size, const luci_interpreter::lstm::CellStateInfo *cell_state_info,
284 for (int i = 0; i < v_size; i++)
287 std::max(std::min(cell_state_info->cell_clip, vector[i]), -cell_state_info->cell_clip);
292 // Size information about the LSTM kernel, which is deduced from tensors stored
293 // in the flat buffer file.
299 int32_t input_dimension;
300 int32_t state_dimension;
303 class LstmStepManager
306 LstmStepManager() = delete;
307 // Does not take any ownership, and all pointers must refer to valid objects
308 // that outlive the one constructed.
309 explicit LstmStepManager(const LstmSizeInfo &size_info) : size_info_(size_info) {}
314 // default as one batch per inference
315 int input_step = size_info_.input_dimension;
316 int output_step = size_info_.state_dimension;
317 // time major: batch inference
318 if (size_info_.time_major)
320 input_step = input_step * size_info_.batch_size;
321 output_step = output_step * size_info_.batch_size;
324 input_offset_ += input_step;
325 output_offset_ += output_step;
331 TFLITE_DCHECK_LE(current_batch_, size_info_.batch_size);
332 // batch inference for time major: no action needed
333 if (size_info_.time_major)
337 // otherwise: singe batch inference, go to the next batch
338 hidden_state_offset_ += size_info_.state_dimension;
339 cell_state_offset_ += size_info_.state_dimension;
342 void resetTime() { current_time_ = 0; }
344 tflite::RuntimeShape inputShape() const
347 if (size_info_.time_major)
349 batch_size = size_info_.batch_size;
351 const int dims[2] = {batch_size, size_info_.input_dimension};
352 const int32_t *dims_data = reinterpret_cast<const int32_t *>(dims);
353 return tflite::RuntimeShape(2, dims_data);
356 tflite::RuntimeShape stateShape() const
359 if (size_info_.time_major)
361 batch_size = size_info_.batch_size;
363 const int dims[2] = {batch_size, size_info_.state_dimension};
364 const int32_t *dims_data = reinterpret_cast<const int32_t *>(dims);
365 return tflite::RuntimeShape(2, dims_data);
368 int inputOffset() const { return input_offset_; }
370 int outputOffset() const { return output_offset_; }
372 int hiddenStateOffset() const { return hidden_state_offset_; }
374 int cellStateOffset() const { return cell_state_offset_; }
377 int32_t current_time_ = 0;
378 int32_t current_batch_ = 0;
379 int32_t input_offset_ = 0;
380 int32_t output_offset_ = 0;
381 int32_t hidden_state_offset_ = 0;
382 int32_t cell_state_offset_ = 0;
384 const LstmSizeInfo &size_info_;
387 // Calculates a single LSTM gate.
388 // Implements the following formula:
389 // gate = activate(FC(input) + FC(recurrent))
390 // Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
391 template <typename ActivationType, typename WeightType, typename CellType, typename BiasType>
392 void calculateLstmGate(const LstmStepManager *step_info,
393 const luci_interpreter::lstm::GateParameters *gate_params,
395 ActivationType *input_data, const circle::Tensor *input_weight,
396 const circle::Tensor *input_bias,
398 ActivationType *recurrent_data, const circle::Tensor *recurrent_weight,
399 const circle::Tensor *recurrent_bias,
401 CellType *gate_output,
403 CellType *fc_output_buffer, const TfLiteFusedActivation activation,
404 luci_interpreter::BaseRuntimeGraph *runtime_graph)
407 const auto gate_output_shape = step_info->stateShape();
409 tflite::FullyConnectedParams op_params{};
410 op_params.input_offset = gate_params->input_fc_params.input_offset;
411 op_params.weights_offset = gate_params->input_fc_params.weights_offset;
412 op_params.output_offset = gate_params->input_fc_params.output_offset;
413 op_params.output_multiplier = gate_params->input_fc_params.output_multiplier;
414 op_params.output_shift = gate_params->input_fc_params.output_shift;
415 op_params.quantized_activation_min = gate_params->input_fc_params.quantized_activation_min;
416 op_params.quantized_activation_max = gate_params->input_fc_params.quantized_activation_max;
417 op_params.float_activation_max = gate_params->input_fc_params.float_activation_max;
418 op_params.float_activation_min = gate_params->input_fc_params.float_activation_min;
420 fullyConnected(op_params, step_info->inputShape(), input_data + step_info->inputOffset(),
421 luci_interpreter::kernels::getTensorShape(input_weight),
422 luci_interpreter::kernels::getTensorData<WeightType>(
423 runtime_graph->getConstDataByTensor(input_weight)),
424 luci_interpreter::kernels::getTensorShape(input_bias),
425 luci_interpreter::kernels::getTensorData<BiasType>(
426 runtime_graph->getConstDataByTensor(input_bias)),
427 gate_output_shape, gate_output);
432 tflite::FullyConnectedParams op_params{};
433 op_params.input_offset = gate_params->recurrent_fc_params.input_offset;
434 op_params.weights_offset = gate_params->recurrent_fc_params.weights_offset;
435 op_params.output_offset = gate_params->recurrent_fc_params.output_offset;
436 op_params.output_multiplier = gate_params->recurrent_fc_params.output_multiplier;
437 op_params.output_shift = gate_params->recurrent_fc_params.output_shift;
438 op_params.quantized_activation_min = gate_params->recurrent_fc_params.quantized_activation_min;
439 op_params.quantized_activation_max = gate_params->recurrent_fc_params.quantized_activation_max;
440 op_params.float_activation_max = gate_params->recurrent_fc_params.float_activation_max;
441 op_params.float_activation_min = gate_params->recurrent_fc_params.float_activation_min;
443 fullyConnected(op_params, step_info->stateShape(),
444 recurrent_data + step_info->hiddenStateOffset(),
445 luci_interpreter::kernels::getTensorShape(recurrent_weight),
446 luci_interpreter::kernels::getTensorData<WeightType>(
447 runtime_graph->getConstDataByTensor(recurrent_weight)),
448 luci_interpreter::kernels::getTensorShape(recurrent_bias),
449 luci_interpreter::kernels::getTensorData<BiasType>(
450 runtime_graph->getConstDataByTensor(recurrent_bias)),
451 gate_output_shape, fc_output_buffer);
453 addElementWise(gate_output, fc_output_buffer, /*n_batch=*/gate_output_shape.DimsData()[0],
454 /*n_state=*/gate_output_shape.DimsData()[1], gate_output);
458 case TfLiteFusedActivation::kTfLiteActSigmoid:
459 sigmoid(gate_output_shape, gate_output);
461 case TfLiteFusedActivation::kTfLiteActTanh:
463 // Set the scale power to -12 to avoid shift
464 tanh(/*cell_state_scale_power=*/-12, gate_output_shape, gate_output, gate_output_shape,
469 // Only Sigmoid or Tanh is used.
470 assert(false && "Only Sigmoid or Tanh is used");
475 // Update the hidden state of the LSTM kernel using the following formula:
476 // updated_hidden_state = Tanh(updated_cell_state) * output_gate_output, * means
477 // element wise multiplication
478 template <typename CellType, typename ActivationType>
479 void updateLstmHidden(const LstmStepManager *step_info, CellType *cell_state_data_base,
480 ActivationType *hidden_state_data, const CellType *output_gate_output,
481 const luci_interpreter::lstm::ArithmeticParams *mul_params,
482 int32_t cell_state_scale_power, CellType *buffer)
484 auto cell_state_shape = step_info->stateShape();
485 CellType *cell_state_data = cell_state_data_base + step_info->cellStateOffset();
487 tanh(cell_state_scale_power, cell_state_shape, cell_state_data, cell_state_shape, buffer);
488 // Update the hidden state
489 mul(cell_state_shape, mul_params, buffer, output_gate_output,
490 hidden_state_data + step_info->hiddenStateOffset());
493 // Update the cell state using the output from the forget gate, input gate, and
494 // cell gate Formula: updated_cell_state = forget_gate_output*cell_state +
495 // input_gate_output * cell_gate_output, where * denotes element wise
497 template <typename CellType>
498 void updateLstmCell(const LstmStepManager *step_info, CellType *cell_state_data,
500 CellType *forget_gate_output, const CellType *input_gate_output,
501 const CellType *cell_gate_output,
503 const luci_interpreter::lstm::ArithmeticParams &forget_cell_mul_params,
504 const luci_interpreter::lstm::ArithmeticParams &input_mul_params,
505 const luci_interpreter::lstm::CellStateInfo *cell_state_info, CellType *buffer)
507 auto cell_state_shape = step_info->stateShape();
508 // Forget Gate x Cell State
509 mul(cell_state_shape, &forget_cell_mul_params, forget_gate_output,
510 cell_state_data + step_info->cellStateOffset(),
511 cell_state_data + step_info->cellStateOffset());
512 // Input Gate x Cell Gate
513 mul(cell_state_shape, &input_mul_params, input_gate_output, cell_gate_output, buffer);
515 // Update the cell state
516 addElementWise(cell_state_data + step_info->cellStateOffset(), buffer,
517 /*n_batch=*/cell_state_shape.DimsData()[0],
518 /*n_state=*/cell_state_shape.DimsData()[1],
519 cell_state_data + step_info->cellStateOffset());
521 if (cell_state_info->cell_clip > 0)
523 clipping(cell_state_shape.FlatSize(), cell_state_info,
524 cell_state_data + step_info->cellStateOffset());
528 template <typename ActivationType, typename WeightType, typename CellType, typename BiasType>
529 void lstmStep(luci_interpreter::lstm::LSTMStruct *lstm_struct,
530 luci_interpreter::lstm::LSTMParameters *lstm_params, LstmStepManager *step_info,
531 luci_interpreter::lstm::CellStateInfo *cell_state_info,
532 ActivationType *output_state_data, CellType *cell_state_data, CellType *scratch0,
533 CellType *scratch1, CellType *scratch2, CellType *scratch3,
534 luci_interpreter::BaseRuntimeGraph *runtime_graph)
536 /*Step1: Calculate gate outputs to prepare cell state update*/
537 CellType *gate_internal_buffer = scratch3;
538 CellType *forget_gate_output = scratch0;
540 auto input_data = luci_interpreter::kernels::getTensorData<ActivationType>(
541 runtime_graph->getDataByTensor(lstm_struct->input()));
543 calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
544 step_info, &lstm_params->forget_gate_parameters,
546 input_data, lstm_struct->input_to_forget_weights(), lstm_struct->forget_gate_bias(),
548 output_state_data, lstm_struct->recurrent_to_forget_weights(), nullptr,
550 forget_gate_output, gate_internal_buffer, TfLiteFusedActivation::kTfLiteActSigmoid,
553 // Input Gate calculation;
554 CellType *input_gate_output = scratch1;
555 calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
556 step_info, &lstm_params->input_gate_parameters,
558 input_data, lstm_struct->input_to_input_weights(), lstm_struct->input_gate_bias(),
560 output_state_data, lstm_struct->recurrent_to_input_weights(),
561 /*recurrent_bias*/ nullptr,
565 gate_internal_buffer, TfLiteFusedActivation::kTfLiteActSigmoid, runtime_graph);
567 // Cell Gate calculation
568 CellType *cell_gate_output = scratch2;
569 calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
570 step_info, &lstm_params->cell_gate_parameters,
572 input_data, lstm_struct->input_to_cell_weights(), lstm_struct->cell_gate_bias(),
574 output_state_data, lstm_struct->recurrent_to_cell_weights(),
575 /*recurrent_bias*/ nullptr,
579 gate_internal_buffer, TfLiteFusedActivation::kTfLiteActTanh, runtime_graph);
581 /*Step2: update the cell state */
583 // const InterGateParameters& inter_gate_params = op_data.inter_gate_parameters;
584 CellType *updated_input_buffer = scratch1; // reuse buffer
586 updateLstmCell<CellType>(
587 step_info, cell_state_data, forget_gate_output, input_gate_output, cell_gate_output,
588 lstm_params->inter_gate_parameters.forget_cell_mul_params,
589 lstm_params->inter_gate_parameters.input_mul_params, cell_state_info, updated_input_buffer);
593 /*Step3: update the hidden state */
594 CellType *output_gate_output = scratch1; // reuse buffer
595 calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
596 step_info, &lstm_params->output_gate_parameters,
598 input_data, lstm_struct->input_to_output_weights(), lstm_struct->output_gate_bias(),
600 output_state_data, lstm_struct->recurrent_to_output_weights(), nullptr,
604 gate_internal_buffer, TfLiteFusedActivation::kTfLiteActSigmoid, runtime_graph);
605 CellType *tanh_activated_cell_buffer = scratch0; // reuse buffer
606 updateLstmHidden<CellType, ActivationType>(
607 step_info, cell_state_data, output_state_data, output_gate_output,
608 &lstm_params->inter_gate_parameters.output_mul_params,
609 cell_state_info->cell_state_scale_power, tanh_activated_cell_buffer);
611 ActivationType *output_ptr = luci_interpreter::kernels::getTensorData<ActivationType>(
612 runtime_graph->getDataByTensor(lstm_struct->output()));
613 std::memcpy(output_ptr + step_info->outputOffset(),
614 output_state_data + step_info->hiddenStateOffset(),
615 step_info->stateShape().FlatSize() * sizeof(ActivationType));
619 } // namespace lstm_internal
621 // Evaluate the LSTM kernel with (potential) multi-steps and multi-batch input
622 template <typename ActivationType, typename WeightType, typename CellType, typename BiasType>
623 void evalLSTM(luci_interpreter::lstm::LSTMStruct *lstm_struct,
624 luci_interpreter::lstm::LSTMParameters *lstm_params,
625 luci_interpreter::lstm::CellStateInfo *cell_state_info,
626 ActivationType *output_state_data, CellType *cell_state_data, CellType *scratch0,
627 CellType *scratch1, CellType *scratch2, CellType *scratch3,
628 luci_interpreter::BaseRuntimeGraph *runtime_graph)
630 lstm_internal::LstmSizeInfo size_info;
632 size_info.time_major = lstm_struct->options->time_major();
633 size_info.batch_size = size_info.time_major
634 ? luci_interpreter::Tensor::dim(lstm_struct->input(), 1)
635 : luci_interpreter::Tensor::dim(lstm_struct->input(), 0);
636 size_info.time_steps = size_info.time_major
637 ? luci_interpreter::Tensor::dim(lstm_struct->input(), 0)
638 : luci_interpreter::Tensor::dim(lstm_struct->input(), 1);
639 size_info.input_dimension = luci_interpreter::Tensor::dim(lstm_struct->input(), 2);
640 size_info.state_dimension = luci_interpreter::Tensor::dim(lstm_struct->output_state(), 1);
642 lstm_internal::LstmStepManager step_info(size_info);
644 // time is the first dimention, enable batch computation
645 if (size_info.time_major)
647 for (int t = 0; t < size_info.time_steps; t++)
649 lstm_internal::lstmStep<ActivationType, WeightType, CellType, BiasType>(
650 lstm_struct, lstm_params, &step_info, cell_state_info, output_state_data, cell_state_data,
651 scratch0, scratch1, scratch2, scratch3, runtime_graph);
652 // prepare for the next time step
653 step_info.updateTime();
658 // batch first, unable to size the input data. single batch inference
659 for (int b = 0; b < size_info.batch_size; b++)
661 for (int t = 0; t < size_info.time_steps; t++)
663 lstm_internal::lstmStep<ActivationType, WeightType, CellType, BiasType>(
664 lstm_struct, lstm_params, &step_info, cell_state_info, output_state_data, cell_state_data,
665 scratch0, scratch1, scratch2, scratch3, runtime_graph);
666 // prepare for the next time step
667 step_info.updateTime();
669 // prepare for the next batch
670 step_info.updateBatch();
671 step_info.resetTime();
676 } // namespace luci_interpreter_pal
678 #endif // LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_H