2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_COMMON_H
19 #define LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_COMMON_H
21 #include "kernels/UnidirectionalSequenceLSTM.h"
23 #include "PALLogistic.h"
24 #include "PALFullyConnected.h"
28 namespace luci_interpreter_pal
30 namespace lstm_internal
34 // Possible fused activation functions.
39 kTfLiteActReluN1To1, // min(max(-1, x), 1)
40 kTfLiteActRelu6, // min(max(0, x), 6)
50 template <typename InputType, typename OutputType>
51 void mulElementwise(int size, const ArithmeticParams *params, const InputType *input1_data,
52 const InputType *input2_data, OutputType *output_data)
54 for (int i = 0; i < size; ++i)
56 const int32_t input1_val = params->input1_offset + input1_data[i];
57 const int32_t input2_val = params->input2_offset + input2_data[i];
58 const int32_t unclamped_result =
59 params->output_offset + multiplyByQuantizedMultiplier(input1_val * input2_val,
60 params->output_multiplier,
61 params->output_shift);
62 const int32_t clamped_output =
63 std::min(params->quantized_activation_max,
64 std::max(params->quantized_activation_min, unclamped_result));
65 output_data[i] = static_cast<OutputType>(clamped_output);
69 // Input and output have the same shape in LSTM
70 void mul(const luci_interpreter::RuntimeShape &shape, const ArithmeticParams *params,
71 const int16_t *input1_data, const int16_t *input2_data, int8_t *output_data)
73 return mulElementwise<int16_t, int8_t>(shape.flatSize(), params, input1_data, input2_data,
77 // Input and output have the same shape in LSTM
78 void mul(const luci_interpreter::RuntimeShape &shape, const ArithmeticParams *params,
79 const int16_t *input1_data, const int16_t *input2_data, int16_t *output_data)
81 return mulElementwise(shape.flatSize(), params, input1_data, input2_data, output_data);
84 void addElementWise(const int16_t *input_1, const int16_t *input_2, int n_batch, int n_input,
87 for (int batch = 0; batch < n_batch; ++batch)
89 for (int i = 0; i < n_input; ++i)
91 const int index = batch * n_input + i;
92 int32_t sum = input_1[index] + input_2[index];
93 const int32_t sum_clamped =
94 std::min(static_cast<int32_t>(std::numeric_limits<int16_t>::max()),
95 std::max(static_cast<int32_t>(std::numeric_limits<int16_t>::min()), sum));
96 output[index] = static_cast<int16_t>(sum_clamped);
101 void tanh(int32_t cell_state_scale_power, const luci_interpreter::RuntimeShape &input_data_shape,
102 int16_t *input_data, const luci_interpreter::RuntimeShape &output_data_shape,
103 int16_t *output_data)
105 int32_t tanh_input_left_shift = (15 + cell_state_scale_power) - 3;
106 int32_t input_multiplier = 0;
107 if (tanh_input_left_shift < 0) /* handling negative shift value */
109 tanh_input_left_shift = -tanh_input_left_shift;
110 input_multiplier = 3;
112 const int flat_size = input_data_shape.flatSize();
113 luci_interpreter_pal::Tanh(input_multiplier, tanh_input_left_shift, flat_size, input_data,
117 void sigmoid(const luci_interpreter::RuntimeShape &data_shape, int16_t *data)
119 luci_interpreter_pal::Logistic(0, 0, data_shape.flatSize(), data, data);
122 void clipping(const int v_size, const luci_interpreter::lstm::CellStateInfo *cell_state_info,
125 for (int i = 0; i < v_size; i++)
127 vector[i] = std::max(std::min(cell_state_info->quantized_cell_clip, vector[i]),
128 static_cast<int16_t>(-cell_state_info->quantized_cell_clip));
134 // Input and output have the same shape in LSTM
135 void mul(const luci_interpreter::RuntimeShape &shape, const ArithmeticParams *params,
136 const float *input1_data, const float *input2_data, float *output_data)
138 const int flat_size = shape.flatSize();
139 return luci_interpreter_pal::Mul(*params, flat_size, input1_data, input2_data, output_data);
142 void addElementWise(const float *input_1, const float *input_2, int n_batch, int n_input,
145 for (int batch = 0; batch < n_batch; ++batch)
147 for (int i = 0; i < n_input; ++i)
149 const int index = batch * n_input + i;
150 output[index] = input_1[index] + input_2[index];
155 void tanh(int32_t, const luci_interpreter::RuntimeShape &input_data_shape, float *input_data,
156 const luci_interpreter::RuntimeShape &output_data_shape, float *output_data)
158 const int flat_size = input_data_shape.flatSize();
159 luci_interpreter_pal::Tanh(flat_size, input_data, output_data);
162 void sigmoid(const luci_interpreter::RuntimeShape &data_shape, float *data)
164 const int flat_size = data_shape.flatSize();
165 luci_interpreter_pal::Logistic(flat_size, data, data);
168 void clipping(const int v_size, const luci_interpreter::lstm::CellStateInfo *cell_state_info,
171 for (int i = 0; i < v_size; i++)
174 std::max(std::min(cell_state_info->cell_clip, vector[i]), -cell_state_info->cell_clip);
179 // Size information about the LSTM kernel, which is deduced from tensors stored
180 // in the flat buffer file.
186 int32_t input_dimension;
187 int32_t state_dimension;
190 class LstmStepManager
193 LstmStepManager() = delete;
194 // Does not take any ownership, and all pointers must refer to valid objects
195 // that outlive the one constructed.
196 explicit LstmStepManager(const LstmSizeInfo &size_info) : size_info_(size_info) {}
201 // default as one batch per inference
202 int input_step = size_info_.input_dimension;
203 int output_step = size_info_.state_dimension;
204 // time major: batch inference
205 if (size_info_.time_major)
207 input_step = input_step * size_info_.batch_size;
208 output_step = output_step * size_info_.batch_size;
211 input_offset_ += input_step;
212 output_offset_ += output_step;
218 // batch inference for time major: no action needed
219 if (size_info_.time_major)
223 // otherwise: singe batch inference, go to the next batch
224 hidden_state_offset_ += size_info_.state_dimension;
225 cell_state_offset_ += size_info_.state_dimension;
228 void resetTime() { current_time_ = 0; }
230 luci_interpreter::RuntimeShape inputShape() const
233 if (size_info_.time_major)
235 batch_size = size_info_.batch_size;
237 const int dims[2] = {batch_size, size_info_.input_dimension};
238 const int32_t *dims_data = reinterpret_cast<const int32_t *>(dims);
239 return luci_interpreter::RuntimeShape(2, dims_data);
242 luci_interpreter::RuntimeShape stateShape() const
245 if (size_info_.time_major)
247 batch_size = size_info_.batch_size;
249 const int dims[2] = {batch_size, size_info_.state_dimension};
250 const int32_t *dims_data = reinterpret_cast<const int32_t *>(dims);
251 return luci_interpreter::RuntimeShape(2, dims_data);
254 int inputOffset() const { return input_offset_; }
256 int outputOffset() const { return output_offset_; }
258 int hiddenStateOffset() const { return hidden_state_offset_; }
260 int cellStateOffset() const { return cell_state_offset_; }
263 int32_t current_time_ = 0;
264 int32_t current_batch_ = 0;
265 int32_t input_offset_ = 0;
266 int32_t output_offset_ = 0;
267 int32_t hidden_state_offset_ = 0;
268 int32_t cell_state_offset_ = 0;
270 const LstmSizeInfo &size_info_;
273 // Calculates a single LSTM gate.
274 // Implements the following formula:
275 // gate = activate(FC(input) + FC(recurrent))
276 // Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
277 template <typename ActivationType, typename WeightType, typename CellType, typename BiasType>
278 void calculateLstmGate(const LstmStepManager *step_info,
279 const luci_interpreter::lstm::GateParameters *gate_params,
281 ActivationType *input_data, const circle::Tensor *input_weight,
282 const circle::Tensor *input_bias,
284 ActivationType *recurrent_data, const circle::Tensor *recurrent_weight,
285 const circle::Tensor *recurrent_bias,
287 CellType *gate_output,
289 CellType *fc_output_buffer, const FusedActivation activation,
290 luci_interpreter::BaseRuntimeGraph *runtime_graph)
293 const auto gate_output_shape = step_info->stateShape();
295 FullyConnectedParams op_params{};
296 op_params.input_offset = gate_params->input_fc_params.input_offset;
297 op_params.weights_offset = gate_params->input_fc_params.weights_offset;
298 op_params.output_offset = gate_params->input_fc_params.output_offset;
299 op_params.output_multiplier = gate_params->input_fc_params.output_multiplier;
300 op_params.output_shift = gate_params->input_fc_params.output_shift;
301 op_params.quantized_activation_min = gate_params->input_fc_params.quantized_activation_min;
302 op_params.quantized_activation_max = gate_params->input_fc_params.quantized_activation_max;
303 op_params.float_activation_max = gate_params->input_fc_params.float_activation_max;
304 op_params.float_activation_min = gate_params->input_fc_params.float_activation_min;
306 int32_t input_weight_shape[luci_interpreter::kMaxSmallSize];
307 luci_interpreter::kernels::getTensorDims(input_weight, runtime_graph, input_weight_shape);
309 FullyConnected(op_params, step_info->inputShape().dimsData(),
310 input_data + step_info->inputOffset(), input_weight_shape,
311 luci_interpreter::kernels::getTensorData<WeightType>(
312 runtime_graph->getConstDataByTensor(input_weight)),
313 luci_interpreter::kernels::getTensorData<BiasType>(
314 runtime_graph->getConstDataByTensor(input_bias)),
315 gate_output_shape.dimsData(), gate_output);
320 FullyConnectedParams op_params{};
321 op_params.input_offset = gate_params->recurrent_fc_params.input_offset;
322 op_params.weights_offset = gate_params->recurrent_fc_params.weights_offset;
323 op_params.output_offset = gate_params->recurrent_fc_params.output_offset;
324 op_params.output_multiplier = gate_params->recurrent_fc_params.output_multiplier;
325 op_params.output_shift = gate_params->recurrent_fc_params.output_shift;
326 op_params.quantized_activation_min = gate_params->recurrent_fc_params.quantized_activation_min;
327 op_params.quantized_activation_max = gate_params->recurrent_fc_params.quantized_activation_max;
328 op_params.float_activation_max = gate_params->recurrent_fc_params.float_activation_max;
329 op_params.float_activation_min = gate_params->recurrent_fc_params.float_activation_min;
331 int32_t recurrent_weight_shape[luci_interpreter::kMaxSmallSize];
332 luci_interpreter::kernels::getTensorDims(recurrent_weight, runtime_graph,
333 recurrent_weight_shape);
335 FullyConnected(op_params, step_info->stateShape().dimsData(),
336 recurrent_data + step_info->hiddenStateOffset(), recurrent_weight_shape,
337 luci_interpreter::kernels::getTensorData<WeightType>(
338 runtime_graph->getConstDataByTensor(recurrent_weight)),
339 luci_interpreter::kernels::getTensorData<BiasType>(
340 runtime_graph->getConstDataByTensor(recurrent_bias)),
341 gate_output_shape.dimsData(), fc_output_buffer);
343 addElementWise(gate_output, fc_output_buffer, /*n_batch=*/gate_output_shape.dimsData()[0],
344 /*n_state=*/gate_output_shape.dimsData()[1], gate_output);
348 case FusedActivation::kTfLiteActSigmoid:
349 sigmoid(gate_output_shape, gate_output);
351 case FusedActivation::kTfLiteActTanh:
353 // Set the scale power to -12 to avoid shift
354 tanh(/*cell_state_scale_power=*/-12, gate_output_shape, gate_output, gate_output_shape,
359 // Only Sigmoid or Tanh is used.
360 assert(false && "Only Sigmoid or Tanh is used");
365 // Update the hidden state of the LSTM kernel using the following formula:
366 // updated_hidden_state = Tanh(updated_cell_state) * output_gate_output, * means
367 // element wise multiplication
368 template <typename CellType, typename ActivationType>
369 void updateLstmHidden(const LstmStepManager *step_info, CellType *cell_state_data_base,
370 ActivationType *hidden_state_data, const CellType *output_gate_output,
371 const ArithmeticParams *mul_params, int32_t cell_state_scale_power,
374 auto cell_state_shape = step_info->stateShape();
375 CellType *cell_state_data = cell_state_data_base + step_info->cellStateOffset();
377 tanh(cell_state_scale_power, cell_state_shape, cell_state_data, cell_state_shape, buffer);
378 // Update the hidden state
379 mul(cell_state_shape, mul_params, buffer, output_gate_output,
380 hidden_state_data + step_info->hiddenStateOffset());
383 // Update the cell state using the output from the forget gate, input gate, and
384 // cell gate Formula: updated_cell_state = forget_gate_output*cell_state +
385 // input_gate_output * cell_gate_output, where * denotes element wise
387 template <typename CellType>
388 void updateLstmCell(const LstmStepManager *step_info, CellType *cell_state_data,
390 CellType *forget_gate_output, const CellType *input_gate_output,
391 const CellType *cell_gate_output,
393 const ArithmeticParams &forget_cell_mul_params,
394 const ArithmeticParams &input_mul_params,
395 const luci_interpreter::lstm::CellStateInfo *cell_state_info, CellType *buffer)
397 auto cell_state_shape = step_info->stateShape();
398 // Forget Gate x Cell State
399 mul(cell_state_shape, &forget_cell_mul_params, forget_gate_output,
400 cell_state_data + step_info->cellStateOffset(),
401 cell_state_data + step_info->cellStateOffset());
402 // Input Gate x Cell Gate
403 mul(cell_state_shape, &input_mul_params, input_gate_output, cell_gate_output, buffer);
405 // Update the cell state
406 addElementWise(cell_state_data + step_info->cellStateOffset(), buffer,
407 /*n_batch=*/cell_state_shape.dimsData()[0],
408 /*n_state=*/cell_state_shape.dimsData()[1],
409 cell_state_data + step_info->cellStateOffset());
411 if (cell_state_info->cell_clip > 0)
413 clipping(cell_state_shape.flatSize(), cell_state_info,
414 cell_state_data + step_info->cellStateOffset());
418 template <typename ActivationType, typename WeightType, typename CellType, typename BiasType>
419 void lstmStep(luci_interpreter::lstm::LSTMStruct *lstm_struct,
420 luci_interpreter::lstm::LSTMParameters *lstm_params, LstmStepManager *step_info,
421 luci_interpreter::lstm::CellStateInfo *cell_state_info,
422 ActivationType *output_state_data, CellType *cell_state_data, CellType *scratch0,
423 CellType *scratch1, CellType *scratch2, CellType *scratch3,
424 luci_interpreter::BaseRuntimeGraph *runtime_graph)
426 /*Step1: Calculate gate outputs to prepare cell state update*/
427 CellType *gate_internal_buffer = scratch3;
428 CellType *forget_gate_output = scratch0;
430 auto input_data = luci_interpreter::kernels::getTensorData<ActivationType>(
431 runtime_graph->getDataByTensor(lstm_struct->input()));
433 calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
434 step_info, &lstm_params->forget_gate_parameters,
436 input_data, lstm_struct->input_to_forget_weights(), lstm_struct->forget_gate_bias(),
438 output_state_data, lstm_struct->recurrent_to_forget_weights(), nullptr,
440 forget_gate_output, gate_internal_buffer, FusedActivation::kTfLiteActSigmoid, runtime_graph);
442 // Input Gate calculation;
443 CellType *input_gate_output = scratch1;
444 calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
445 step_info, &lstm_params->input_gate_parameters,
447 input_data, lstm_struct->input_to_input_weights(), lstm_struct->input_gate_bias(),
449 output_state_data, lstm_struct->recurrent_to_input_weights(),
450 /*recurrent_bias*/ nullptr,
454 gate_internal_buffer, FusedActivation::kTfLiteActSigmoid, runtime_graph);
456 // Cell Gate calculation
457 CellType *cell_gate_output = scratch2;
458 calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
459 step_info, &lstm_params->cell_gate_parameters,
461 input_data, lstm_struct->input_to_cell_weights(), lstm_struct->cell_gate_bias(),
463 output_state_data, lstm_struct->recurrent_to_cell_weights(),
464 /*recurrent_bias*/ nullptr,
468 gate_internal_buffer, FusedActivation::kTfLiteActTanh, runtime_graph);
470 /*Step2: update the cell state */
472 // const InterGateParameters& inter_gate_params = op_data.inter_gate_parameters;
473 CellType *updated_input_buffer = scratch1; // reuse buffer
475 updateLstmCell<CellType>(
476 step_info, cell_state_data, forget_gate_output, input_gate_output, cell_gate_output,
477 lstm_params->inter_gate_parameters.forget_cell_mul_params,
478 lstm_params->inter_gate_parameters.input_mul_params, cell_state_info, updated_input_buffer);
482 /*Step3: update the hidden state */
483 CellType *output_gate_output = scratch1; // reuse buffer
484 calculateLstmGate<ActivationType, WeightType, CellType, BiasType>(
485 step_info, &lstm_params->output_gate_parameters,
487 input_data, lstm_struct->input_to_output_weights(), lstm_struct->output_gate_bias(),
489 output_state_data, lstm_struct->recurrent_to_output_weights(), nullptr,
493 gate_internal_buffer, FusedActivation::kTfLiteActSigmoid, runtime_graph);
494 CellType *tanh_activated_cell_buffer = scratch0; // reuse buffer
495 updateLstmHidden<CellType, ActivationType>(
496 step_info, cell_state_data, output_state_data, output_gate_output,
497 &lstm_params->inter_gate_parameters.output_mul_params,
498 cell_state_info->cell_state_scale_power, tanh_activated_cell_buffer);
500 ActivationType *output_ptr = luci_interpreter::kernels::getTensorData<ActivationType>(
501 runtime_graph->getDataByTensor(lstm_struct->output()));
502 std::memcpy(output_ptr + step_info->outputOffset(),
503 output_state_data + step_info->hiddenStateOffset(),
504 step_info->stateShape().flatSize() * sizeof(ActivationType));
508 } // namespace lstm_internal
510 // Evaluate the LSTM kernel with (potential) multi-steps and multi-batch input
511 template <typename ActivationType, typename WeightType, typename CellType, typename BiasType>
512 void evalLSTM(luci_interpreter::lstm::LSTMStruct *lstm_struct,
513 luci_interpreter::lstm::LSTMParameters *lstm_params,
514 luci_interpreter::lstm::CellStateInfo *cell_state_info,
515 ActivationType *output_state_data, CellType *cell_state_data, CellType *scratch0,
516 CellType *scratch1, CellType *scratch2, CellType *scratch3,
517 luci_interpreter::BaseRuntimeGraph *runtime_graph)
519 lstm_internal::LstmSizeInfo size_info;
521 size_info.time_major = lstm_struct->options->time_major();
522 size_info.batch_size = size_info.time_major
523 ? luci_interpreter::Tensor::dim(lstm_struct->input(), 1)
524 : luci_interpreter::Tensor::dim(lstm_struct->input(), 0);
525 size_info.time_steps = size_info.time_major
526 ? luci_interpreter::Tensor::dim(lstm_struct->input(), 0)
527 : luci_interpreter::Tensor::dim(lstm_struct->input(), 1);
528 size_info.input_dimension = luci_interpreter::Tensor::dim(lstm_struct->input(), 2);
529 size_info.state_dimension = luci_interpreter::Tensor::dim(lstm_struct->output_state(), 1);
531 lstm_internal::LstmStepManager step_info(size_info);
533 // time is the first dimention, enable batch computation
534 if (size_info.time_major)
536 for (int t = 0; t < size_info.time_steps; t++)
538 lstm_internal::lstmStep<ActivationType, WeightType, CellType, BiasType>(
539 lstm_struct, lstm_params, &step_info, cell_state_info, output_state_data, cell_state_data,
540 scratch0, scratch1, scratch2, scratch3, runtime_graph);
541 // prepare for the next time step
542 step_info.updateTime();
547 // batch first, unable to size the input data. single batch inference
548 for (int b = 0; b < size_info.batch_size; b++)
550 for (int t = 0; t < size_info.time_steps; t++)
552 lstm_internal::lstmStep<ActivationType, WeightType, CellType, BiasType>(
553 lstm_struct, lstm_params, &step_info, cell_state_info, output_state_data, cell_state_data,
554 scratch0, scratch1, scratch2, scratch3, runtime_graph);
555 // prepare for the next time step
556 step_info.updateTime();
558 // prepare for the next batch
559 step_info.updateBatch();
560 step_info.resetTime();
565 } // namespace luci_interpreter_pal
567 #endif // LUCI_INTERPRETER_PAL_UNIDIRECTIONAL_SEQUENCE_LSTM_COMMON_H