57cf9f2e135620b94cfc75053b9a7f80d2b21e80
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / UnidirectionalSequenceLSTM.cpp
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "Builders.h"
19 #include "kernels/Utils.h"
20
21 #include "PALUnidirectionalSequenceLSTM.h"
22 #include "PALApplyActivationToVector.h"
23
24 namespace luci_interpreter
25 {
26 namespace
27 {
28
29 #ifndef DIS_QUANT
30
31 bool checkedLog2(const float x, int *log2_result)
32 {
33   // Using TfLiteRound instead of std::round and std::log instead of
34   // std::log2 to work around these functions being missing in a toolchain
35   // used in some TensorFlow tests as of May 2018.
36   const float x_log2 = std::log(x) * (1.0f / std::log(2.0f));
37   const float x_log2_rounded = std::round(x_log2);
38   const float x_log2_fracpart = x_log2 - x_log2_rounded;
39
40   *log2_result = static_cast<int>(x_log2_rounded);
41   return std::abs(x_log2_fracpart) < 1e-3f;
42 }
43
44 // Create parameters for element wise multiplication that happens in a) cell
45 // state update ; b) hidden state update
46 // Note that all the output of gates are symmetrically quantized so only scales
47 // are required for input. However, during the hidden state update phase, the
48 // output is the updated hidden state, which is asymmetrically quantized. Thus
49 // output may require zero point
50 lstm::ArithmeticParams createInterGateParams(const float input1_scale, const float input2_scale,
51                                              const float output_scale, const DataType output_type,
52                                              const int output_zp)
53 {
54   lstm::ArithmeticParams op_params;
55   if (output_type == DataType::S16)
56   {
57     op_params.quantized_activation_min = std::numeric_limits<int16_t>::min();
58     op_params.quantized_activation_max = std::numeric_limits<int16_t>::max();
59   }
60   else if (output_type == DataType::S8)
61   {
62     op_params.quantized_activation_min = std::numeric_limits<int8_t>::min();
63     op_params.quantized_activation_max = std::numeric_limits<int8_t>::max();
64   }
65
66   op_params.input1_offset = 0; // symmetric
67   op_params.input2_offset = 0; // symmetric
68   op_params.output_offset = output_zp;
69
70   const double input_product_scale =
71     static_cast<double>(input1_scale) * static_cast<double>(input2_scale);
72   double effective_scale = input_product_scale / static_cast<double>(output_scale);
73   auto output_shift = static_cast<int>(op_params.output_shift);
74   kernels::quantizeMultiplier(effective_scale, &op_params.output_multiplier, &output_shift);
75   op_params.output_shift = output_shift;
76   return op_params;
77 }
78
79 void createGateParams(const circle::Tensor *input, const circle::Tensor *input_weight,
80                       const circle::Tensor *input_bias, const circle::Tensor *hidden_state,
81                       const circle::Tensor *hidden_state_weight,
82                       const float nonlinear_activation_input_scale, const DataType cell_type,
83                       lstm::GateParameters *gate_params)
84 {
85   // Input CalculateOpDataFullyConnected
86   {
87     lstm::FullyConnectedParams input_gate_params;
88     double real_multiplier = 0.0;
89     int output_shift;
90     int32_t output_activation_min;
91     int32_t output_activation_max;
92     int32_t output_multiplier;
93     real_multiplier = kernels::getQuantizedConvolutionMultipler(
94       Tensor::scale(input), Tensor::scale(input_weight), nonlinear_activation_input_scale);
95     kernels::quantizeMultiplier(real_multiplier, &output_multiplier, &output_shift);
96     kernels::calculateActivationRangeQuantized(FusedActFunc::NONE, 0,
97                                                nonlinear_activation_input_scale, cell_type,
98                                                &output_activation_min, &output_activation_max);
99
100     input_gate_params.output_shift = output_shift;
101     input_gate_params.output_multiplier = output_multiplier;
102     input_gate_params.quantized_activation_max = output_activation_max;
103     input_gate_params.quantized_activation_min = output_activation_min;
104     input_gate_params.input_offset = -Tensor::zero_point(input);
105     input_gate_params.weights_offset = -Tensor::zero_point(input_weight);
106     input_gate_params.output_offset = 0;
107
108     gate_params->input_fc_params = input_gate_params;
109   }
110
111   // Recurrent CalculateOpDataFullyConnected
112   {
113     lstm::FullyConnectedParams recurrent_gate_params;
114     double real_multiplier = 0.0;
115     int output_shift;
116     int32_t output_activation_min;
117     int32_t output_activation_max;
118     int32_t output_multiplier;
119     real_multiplier = kernels::getQuantizedConvolutionMultipler(Tensor::scale(hidden_state),
120                                                                 Tensor::scale(hidden_state_weight),
121                                                                 nonlinear_activation_input_scale);
122     kernels::quantizeMultiplier(real_multiplier, &output_multiplier, &output_shift);
123     kernels::calculateActivationRangeQuantized(FusedActFunc::NONE, 0,
124                                                nonlinear_activation_input_scale, cell_type,
125                                                &output_activation_min, &output_activation_max);
126
127     recurrent_gate_params.output_shift = output_shift;
128     recurrent_gate_params.output_multiplier = output_multiplier;
129     recurrent_gate_params.quantized_activation_max = output_activation_max;
130     recurrent_gate_params.quantized_activation_min = output_activation_min;
131     recurrent_gate_params.input_offset = -Tensor::zero_point(hidden_state);
132     recurrent_gate_params.weights_offset = -Tensor::zero_point(hidden_state_weight);
133     recurrent_gate_params.output_offset = 0;
134
135     gate_params->recurrent_fc_params = recurrent_gate_params;
136   }
137 }
138
139 void prepareGateParamsInteger(lstm::LSTMStruct *lstm_struct,
140                               lstm::LSTMParameters *quant_lstm_params)
141 {
142   float nonlinear_input_scale = 0.00024414062; // 2^-12 Q3.12 -> Q0.15
143
144   createGateParams(lstm_struct->input(), lstm_struct->input_to_forget_weights(),
145                    lstm_struct->forget_gate_bias(), lstm_struct->output_state(),
146                    lstm_struct->recurrent_to_forget_weights(), nonlinear_input_scale, DataType::S16,
147                    &quant_lstm_params->forget_gate_parameters);
148
149   createGateParams(lstm_struct->input(), lstm_struct->input_to_input_weights(),
150                    lstm_struct->input_gate_bias(), lstm_struct->output_state(),
151                    lstm_struct->recurrent_to_input_weights(), nonlinear_input_scale, DataType::S16,
152                    &quant_lstm_params->input_gate_parameters);
153
154   // lstm::GateParameters cell_gate_parameters;
155   createGateParams(lstm_struct->input(), lstm_struct->input_to_cell_weights(),
156                    lstm_struct->cell_gate_bias(), lstm_struct->output_state(),
157                    lstm_struct->recurrent_to_cell_weights(), nonlinear_input_scale, DataType::S16,
158                    &quant_lstm_params->cell_gate_parameters);
159
160   // lstm::GateParameters output_gate_parameters;
161   createGateParams(lstm_struct->input(), lstm_struct->input_to_output_weights(),
162                    lstm_struct->output_gate_bias(), lstm_struct->output_state(),
163                    lstm_struct->recurrent_to_output_weights(), nonlinear_input_scale, DataType::S16,
164                    &quant_lstm_params->output_gate_parameters);
165
166   // Inter gate multiplication parameters
167   float nonlinear_output_scale = 0.00003051757; // 2^-15 Q3.12 -> Q0.15
168   float cell_state_scale =
169     Tensor::scale(lstm_struct->cell_state()); // lstm_tensors.CellStateTensor()->params.scale;
170   // forget gate output (nonlinear output) x cell state -> cell state
171   quant_lstm_params->inter_gate_parameters.forget_cell_mul_params = createInterGateParams(
172     nonlinear_output_scale, cell_state_scale, cell_state_scale, DataType::S16, 0);
173
174   // input gate output x cell gate output -> cell state
175   quant_lstm_params->inter_gate_parameters.input_mul_params = createInterGateParams(
176     nonlinear_output_scale, nonlinear_output_scale, cell_state_scale, DataType::S16, 0);
177
178   // tanh output x output gate output -> hidden state (potentially asymmetric)
179   quant_lstm_params->inter_gate_parameters.output_mul_params = createInterGateParams(
180     nonlinear_output_scale, nonlinear_output_scale, Tensor::scale(lstm_struct->output_state()),
181     Tensor::element_type(lstm_struct->output_state()),
182     Tensor::zero_point(lstm_struct->output_state()));
183 }
184
185 // Create the additional information about the cell state, which include:
186 // cell_state_scale_power: used in integer nonlinear function (e.g., tanh)
187 // quantized_cell_clip: quantized cell clip range
188 lstm::CellStateInfo createLstmCellStateInfo(const float cell_state_scale, const float cell_clip)
189 {
190   lstm::CellStateInfo cell_state_info;
191   // cell_state_scale_power: 2^-cell_state_scale_power = cell state scale
192   int buffer;
193   checkedLog2(cell_state_scale, &buffer);
194   cell_state_info.cell_state_scale_power = buffer;
195   // Cell state specifics
196   cell_state_info.cell_clip = cell_clip;
197   cell_state_info.quantized_cell_clip = static_cast<int16_t>(std::min(
198     std::max(static_cast<double>(cell_clip) / static_cast<double>(cell_state_scale), -32768.0),
199     32767.0));
200   return cell_state_info;
201 }
202
203 void evalInt8(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph, bool)
204 {
205   lstm::LSTMStruct lstm_struct(cur_op, runtime_graph);
206
207   lstm::LSTMParameters quant_lstm_params;
208   prepareGateParamsInteger(&lstm_struct, &quant_lstm_params);
209
210   lstm::CellStateInfo cell_state_info = createLstmCellStateInfo(
211     luci_interpreter::Tensor::scale(lstm_struct.cell_state()), lstm_struct.options->cell_clip());
212
213   const bool time_major = lstm_struct.options->time_major();
214   const auto batch_size =
215     time_major ? Tensor::dim(lstm_struct.input(), 1) : Tensor::dim(lstm_struct.input(), 0);
216   const auto state_dimension = Tensor::dim(lstm_struct.output_state(), 1);
217   const auto cell_state_type_size = getDataTypeSize(Tensor::element_type(lstm_struct.cell_state()));
218
219   auto scratch_0_data =
220     std::make_unique<uint8_t[]>(batch_size * state_dimension * cell_state_type_size);
221   auto scratch_1_data =
222     std::make_unique<uint8_t[]>(batch_size * state_dimension * cell_state_type_size);
223   auto scratch_2_data =
224     std::make_unique<uint8_t[]>(batch_size * state_dimension * cell_state_type_size);
225   auto scratch_3_data =
226     std::make_unique<uint8_t[]>(batch_size * state_dimension * cell_state_type_size);
227
228   // Create and fill with 0 output state tensor
229   auto output_state_data =
230     std::make_unique<int8_t[]>(Tensor::num_elements(lstm_struct.output_state()));
231   std::fill_n(output_state_data.get(), Tensor::num_elements(lstm_struct.output_state()), 0);
232
233   // Create and fill with 0 cell state tensor
234   auto cell_state_data =
235     std::make_unique<int16_t[]>(Tensor::num_elements(lstm_struct.cell_state()));
236   std::fill_n(cell_state_data.get(), Tensor::num_elements(lstm_struct.cell_state()), 0);
237
238   luci_interpreter_pal::evalLSTM<int8_t, int8_t, int16_t, int32_t>(
239     &lstm_struct, &quant_lstm_params, &cell_state_info, output_state_data.get(),
240     cell_state_data.get(), kernels::getTensorData<int16_t>(scratch_0_data.get()),
241     kernels::getTensorData<int16_t>(scratch_1_data.get()),
242     kernels::getTensorData<int16_t>(scratch_2_data.get()),
243     kernels::getTensorData<int16_t>(scratch_3_data.get()), runtime_graph);
244 }
245
246 #endif // DIS_QUANT
247
248 #ifndef DIS_FLOAT
249 lstm::FullyConnectedParams createFcParamsFloat()
250 {
251   lstm::FullyConnectedParams op_params;
252   kernels::calculateActivationRange(FusedActFunc::NONE, &op_params.float_activation_min,
253                                     &op_params.float_activation_max);
254   return op_params;
255 }
256
257 lstm::GateParameters createGateParamsFloat()
258 {
259   lstm::GateParameters gate_params;
260
261   gate_params.input_fc_params = createFcParamsFloat();
262   gate_params.recurrent_fc_params = createFcParamsFloat();
263
264   return gate_params;
265 }
266
267 lstm::CellStateInfo createLstmCellStateInfoFloat(const float cell_clip)
268 {
269   lstm::CellStateInfo cell_state_info;
270   cell_state_info.cell_clip = cell_clip;
271   cell_state_info.cell_state_scale_power = 0; // no quantization
272   cell_state_info.quantized_cell_clip = 0;    // no quantization
273   return cell_state_info;
274 }
275
276 void prepareGateParamsFloat(lstm::LSTMParameters *float_lstm_params)
277 {
278   // Gate Parameters
279   float_lstm_params->forget_gate_parameters = createGateParamsFloat();
280   float_lstm_params->input_gate_parameters = createGateParamsFloat();
281   float_lstm_params->cell_gate_parameters = createGateParamsFloat();
282   float_lstm_params->output_gate_parameters = createGateParamsFloat();
283
284   // Inter gate multiplication parameters
285   lstm::ArithmeticParams op_params;
286   kernels::calculateActivationRange(FusedActFunc::NONE, &op_params.float_activation_min,
287                                     &op_params.float_activation_max);
288   float_lstm_params->inter_gate_parameters.forget_cell_mul_params = op_params;
289   float_lstm_params->inter_gate_parameters.input_mul_params = op_params;
290   float_lstm_params->inter_gate_parameters.output_mul_params = op_params;
291 }
292
293 void evalFloat(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph, bool)
294 {
295   lstm::LSTMStruct lstm_struct(cur_op, runtime_graph);
296
297   lstm::CellStateInfo cell_state_info =
298     createLstmCellStateInfoFloat(lstm_struct.options->cell_clip());
299
300   lstm::LSTMParameters lstm_params;
301   prepareGateParamsFloat(&lstm_params);
302
303   const bool time_major = lstm_struct.options->time_major();
304   const auto batch_size =
305     time_major ? Tensor::dim(lstm_struct.input(), 1) : Tensor::dim(lstm_struct.input(), 0);
306   const auto state_dimension = Tensor::dim(lstm_struct.output_state(), 1);
307   const auto cell_state_type_size = getDataTypeSize(Tensor::element_type(lstm_struct.cell_state()));
308
309   auto scratch_0_data =
310     std::make_unique<uint8_t[]>(batch_size * state_dimension * cell_state_type_size);
311   auto scratch_1_data =
312     std::make_unique<uint8_t[]>(batch_size * state_dimension * cell_state_type_size);
313   auto scratch_2_data =
314     std::make_unique<uint8_t[]>(batch_size * state_dimension * cell_state_type_size);
315   auto scratch_3_data =
316     std::make_unique<uint8_t[]>(batch_size * state_dimension * cell_state_type_size);
317
318   // Create and fill with 0 output state tensor
319   auto output_state_data =
320     std::make_unique<float[]>(Tensor::num_elements(lstm_struct.output_state()));
321   std::fill_n(output_state_data.get(), Tensor::num_elements(lstm_struct.output_state()), 0);
322
323   // Create and fill with 0 cell state tensor
324   auto cell_state_data = std::make_unique<float[]>(Tensor::num_elements(lstm_struct.cell_state()));
325   std::fill_n(cell_state_data.get(), Tensor::num_elements(lstm_struct.cell_state()), 0);
326
327   luci_interpreter_pal::evalLSTM<float, float, float, float>(
328     &lstm_struct, &lstm_params, &cell_state_info, output_state_data.get(), cell_state_data.get(),
329     kernels::getTensorData<float>(scratch_0_data.get()),
330     kernels::getTensorData<float>(scratch_1_data.get()),
331     kernels::getTensorData<float>(scratch_2_data.get()),
332     kernels::getTensorData<float>(scratch_3_data.get()), runtime_graph);
333 }
334 #endif // DIS_FLOAT
335
336 void validateWeightTensorSize(const circle::Tensor *weight_tensor, int dim1_size, int dim2_size)
337 {
338   LUCI_INTERPRETER_CHECK(Tensor::num_dims(weight_tensor) == 2);
339   LUCI_INTERPRETER_CHECK(Tensor::dim(weight_tensor, 0) == dim1_size);
340   LUCI_INTERPRETER_CHECK(Tensor::dim(weight_tensor, 1) == dim2_size);
341 }
342
343 void validateTensorsSize(lstm::LSTMStruct *lstm_struct, const bool time_major)
344 {
345   const auto batch_size =
346     time_major ? Tensor::dim(lstm_struct->input(), 1) : Tensor::dim(lstm_struct->input(), 0);
347
348   const auto input_dimension = Tensor::dim(lstm_struct->input(), 2);
349   const auto state_dimension = Tensor::dim(lstm_struct->output_state(), 1);
350
351   // Input FC weights
352   for (int32_t i = 1; i < 5; i++)
353   {
354     validateWeightTensorSize(lstm_struct->get_internal_tensor(i), state_dimension, input_dimension);
355   }
356
357   // Recurrent FC weights
358   for (int32_t i = 5; i < 9; i++)
359   {
360     validateWeightTensorSize(lstm_struct->get_internal_tensor(i), state_dimension, state_dimension);
361   }
362
363   // Biases
364   for (int32_t i = 12; i < 16; i++)
365   {
366     LUCI_INTERPRETER_CHECK(Tensor::num_dims(lstm_struct->get_internal_tensor(i)) == 1);
367     LUCI_INTERPRETER_CHECK(Tensor::dim(lstm_struct->get_internal_tensor(i), 0) == state_dimension);
368   }
369
370   // Check the shape of input state tensors.
371   // These tensor may be 1D or 2D. It's fine as long as the total size is
372   // correct.
373   LUCI_INTERPRETER_CHECK(Tensor::num_elements(lstm_struct->output_state()) ==
374                          batch_size * state_dimension);
375   LUCI_INTERPRETER_CHECK(Tensor::num_elements(lstm_struct->cell_state()) ==
376                          batch_size * state_dimension);
377
378   // Check the shape of output tensor against that of input tensor
379   LUCI_INTERPRETER_CHECK(Tensor::num_dims(lstm_struct->output()) == 3);
380   LUCI_INTERPRETER_CHECK(Tensor::dim(lstm_struct->input(), 0) ==
381                          Tensor::dim(lstm_struct->output(), 0));
382   LUCI_INTERPRETER_CHECK(Tensor::dim(lstm_struct->input(), 1) ==
383                          Tensor::dim(lstm_struct->output(), 1));
384   LUCI_INTERPRETER_CHECK(Tensor::dim(lstm_struct->output(), 2) == state_dimension);
385 }
386
387 } // namespace
388
389 void configure_kernel_CircleUnidirectionalSequenceLSTM(const circle::Operator *cur_op,
390                                                        BaseRuntimeGraph *runtime_graph)
391 {
392   lstm::LSTMStruct lstm_struct(cur_op, runtime_graph);
393
394   LUCI_INTERPRETER_CHECK(Tensor::element_type(lstm_struct.input()) == DataType::FLOAT32 or
395                          Tensor::element_type(lstm_struct.input()) == DataType::S8);
396
397   lstm_struct.validateTensorTypes();
398
399   const bool time_major = lstm_struct.options->time_major();
400
401   validateTensorsSize(&lstm_struct, time_major);
402
403   // No peephole
404   for (int32_t i = 9; i < 12; ++i)
405     LUCI_INTERPRETER_CHECK(lstm_struct.get_internal_tensor(i) == nullptr);
406
407   // No projection
408   for (int32_t i = 16; i < 18; ++i)
409     LUCI_INTERPRETER_CHECK(lstm_struct.get_internal_tensor(i) == nullptr);
410
411   // No internal layer norm
412   for (int32_t i = 20; i < 24; ++i)
413     LUCI_INTERPRETER_CHECK(lstm_struct.get_internal_tensor(i) == nullptr);
414 }
415
416 void execute_kernel_CircleUnidirectionalSequenceLSTM(const circle::Operator *cur_op,
417                                                      BaseRuntimeGraph *runtime_graph, bool in_place)
418 {
419   const auto input_index = cur_op->inputs()->operator[](0);
420   assert(input_index != -1);
421
422   const auto input = runtime_graph->getCircleTensorByIndex(input_index);
423
424   switch (Tensor::element_type(input))
425   {
426 #ifndef DIS_FLOAT
427     case DataType::FLOAT32:
428       evalFloat(cur_op, runtime_graph, in_place);
429       break;
430 #endif // DIS_FLOAT
431 #ifndef DIS_QUANT
432     case DataType::S8:
433       evalInt8(cur_op, runtime_graph, in_place);
434       break;
435 #endif // DIS_QUANT
436     default:
437       assert(false && "Unsupported type.");
438   }
439 }
440
441 } // namespace luci_interpreter