8cdccc10a75c62bb81f8cdd97cdfd884f75695ae
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / UnidirectionalSequenceLSTM.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_KERNELS_UNIDIRECTIONAL_SEQUENCE_LSTM_H
19 #define LUCI_INTERPRETER_KERNELS_UNIDIRECTIONAL_SEQUENCE_LSTM_H
20
21 #include "Utils.h"
22
23 namespace luci_interpreter
24 {
25 namespace lstm
26 {
27
28 struct LSTMStruct
29 {
30   LSTMStruct() = delete;
31   LSTMStruct(const LSTMStruct &) = delete;
32
33   explicit LSTMStruct(const circle::Operator *cur_op,
34                       luci_interpreter::BaseRuntimeGraph *runtime_graph)
35   {
36     const auto input_index = cur_op->inputs()->operator[](0);
37     const auto input_to_input_weights_index = cur_op->inputs()->operator[](1);
38     const auto input_to_forget_weights_index = cur_op->inputs()->operator[](2);
39     const auto input_to_cell_weights_index = cur_op->inputs()->operator[](3);
40     const auto input_to_output_weights_index = cur_op->inputs()->operator[](4);
41     assert(input_index != -1);
42     // input_to_input_weights_index - optional
43     assert(input_to_forget_weights_index != -1);
44     assert(input_to_cell_weights_index != -1);
45     assert(input_to_output_weights_index != -1);
46     internal_tensors[0] = runtime_graph->getCircleTensorByIndex(input_index);
47     internal_tensors[1] = runtime_graph->getCircleTensorByIndex(input_to_input_weights_index);
48     internal_tensors[2] = runtime_graph->getCircleTensorByIndex(input_to_forget_weights_index);
49     internal_tensors[3] = runtime_graph->getCircleTensorByIndex(input_to_cell_weights_index);
50     internal_tensors[4] = runtime_graph->getCircleTensorByIndex(input_to_output_weights_index);
51
52     const auto recurrent_to_input_weights_index = cur_op->inputs()->operator[](5);
53     const auto recurrent_to_forget_weights_index = cur_op->inputs()->operator[](6);
54     const auto recurrent_to_cell_weights_index = cur_op->inputs()->operator[](7);
55     const auto recurrent_to_output_weights_index = cur_op->inputs()->operator[](8);
56     // recurrent_to_input_weights_index - optional
57     assert(recurrent_to_forget_weights_index != -1);
58     assert(recurrent_to_cell_weights_index != -1);
59     assert(recurrent_to_output_weights_index != -1);
60     internal_tensors[5] = runtime_graph->getCircleTensorByIndex(recurrent_to_input_weights_index);
61     internal_tensors[6] = runtime_graph->getCircleTensorByIndex(recurrent_to_forget_weights_index);
62     internal_tensors[7] = runtime_graph->getCircleTensorByIndex(recurrent_to_cell_weights_index);
63     internal_tensors[8] = runtime_graph->getCircleTensorByIndex(recurrent_to_output_weights_index);
64
65     const auto cell_to_input_weights_index = cur_op->inputs()->operator[](9);
66     const auto cell_to_forget_weights_index = cur_op->inputs()->operator[](10);
67     const auto cell_to_output_weights_index = cur_op->inputs()->operator[](11);
68     // optional cell_to_input_weights_index
69     // optional cell_to_forget_weights_index
70     // optional cell_to_output_weights_index
71     internal_tensors[9] = runtime_graph->getCircleTensorByIndex(cell_to_input_weights_index);
72     internal_tensors[10] = runtime_graph->getCircleTensorByIndex(cell_to_forget_weights_index);
73     internal_tensors[11] = runtime_graph->getCircleTensorByIndex(cell_to_output_weights_index);
74
75     const auto input_gate_bias_index = cur_op->inputs()->operator[](12);
76     const auto forget_gate_bias_index = cur_op->inputs()->operator[](13);
77     const auto cell_gate_bias_index = cur_op->inputs()->operator[](14);
78     const auto output_gate_bias_index = cur_op->inputs()->operator[](15);
79     // optional input_gate_bias_index
80     assert(forget_gate_bias_index != -1);
81     assert(cell_gate_bias_index != -1);
82     assert(output_gate_bias_index != -1);
83     internal_tensors[12] = runtime_graph->getCircleTensorByIndex(input_gate_bias_index);
84     internal_tensors[13] = runtime_graph->getCircleTensorByIndex(forget_gate_bias_index);
85     internal_tensors[14] = runtime_graph->getCircleTensorByIndex(cell_gate_bias_index);
86     internal_tensors[15] = runtime_graph->getCircleTensorByIndex(output_gate_bias_index);
87
88     const auto projection_weights_index = cur_op->inputs()->operator[](16);
89     const auto projection_bias_index = cur_op->inputs()->operator[](17);
90     // optional projection_weights_index
91     // optional projection_bias_index
92     internal_tensors[16] = runtime_graph->getCircleTensorByIndex(projection_weights_index);
93     internal_tensors[17] = runtime_graph->getCircleTensorByIndex(projection_bias_index);
94
95     const auto output_state_index = cur_op->inputs()->operator[](18);
96     const auto cell_state_index = cur_op->inputs()->operator[](19);
97     assert(output_state_index != -1);
98     assert(cell_state_index != -1);
99     internal_tensors[18] = runtime_graph->getCircleTensorByIndex(output_state_index);
100     internal_tensors[19] = runtime_graph->getCircleTensorByIndex(cell_state_index);
101
102     const auto input_layer_norm_coefficients_index = cur_op->inputs()->operator[](20);
103     const auto forget_layer_norm_coefficients_index = cur_op->inputs()->operator[](21);
104     const auto cell_layer_norm_coefficients_index = cur_op->inputs()->operator[](22);
105     const auto output_layer_norm_coefficients_index = cur_op->inputs()->operator[](23);
106     // optional input_layer_norm_coefficients_index
107     // optional forget_layer_norm_coefficients_index
108     // optional cell_layer_norm_coefficients_index
109     // optional output_layer_norm_coefficients_index
110     internal_tensors[20] =
111       runtime_graph->getCircleTensorByIndex(input_layer_norm_coefficients_index);
112     internal_tensors[21] =
113       runtime_graph->getCircleTensorByIndex(forget_layer_norm_coefficients_index);
114     internal_tensors[22] =
115       runtime_graph->getCircleTensorByIndex(cell_layer_norm_coefficients_index);
116     internal_tensors[23] =
117       runtime_graph->getCircleTensorByIndex(output_layer_norm_coefficients_index);
118
119     const auto output_index = cur_op->outputs()->operator[](0);
120     assert(output_index != -1);
121     output_internal = runtime_graph->getCircleTensorByIndex(output_index);
122
123     options = cur_op->builtin_options_as_UnidirectionalSequenceLSTMOptions();
124   }
125
126   void validateTensorTypes()
127   {
128     LUCI_INTERPRETER_CHECK(Tensor::element_type(input()) == Tensor::element_type(output_state()));
129     LUCI_INTERPRETER_CHECK(Tensor::element_type(output()) == Tensor::element_type(input()));
130
131     for (int32_t i = 1; i < 9; ++i)
132     {
133       LUCI_INTERPRETER_CHECK(internal_tensors[i] == nullptr or
134                              Tensor::element_type(input_to_forget_weights()) ==
135                                Tensor::element_type(internal_tensors[i]));
136     }
137
138     for (int32_t i = 12; i < 16; ++i)
139     {
140       LUCI_INTERPRETER_CHECK(internal_tensors[i] == nullptr or
141                              Tensor::element_type(forget_gate_bias()) ==
142                                Tensor::element_type(internal_tensors[i]));
143     }
144   }
145
146   const circle::Tensor *input() { return internal_tensors[0]; };
147
148   const circle::Tensor *input_to_input_weights() { return internal_tensors[1]; };
149   const circle::Tensor *input_to_forget_weights() { return internal_tensors[2]; };
150   const circle::Tensor *input_to_cell_weights() { return internal_tensors[3]; };
151   const circle::Tensor *input_to_output_weights() { return internal_tensors[4]; };
152
153   const circle::Tensor *recurrent_to_input_weights() { return internal_tensors[5]; };
154   const circle::Tensor *recurrent_to_forget_weights() { return internal_tensors[6]; };
155   const circle::Tensor *recurrent_to_cell_weights() { return internal_tensors[7]; };
156   const circle::Tensor *recurrent_to_output_weights() { return internal_tensors[8]; };
157
158   const circle::Tensor *cell_to_input_weights() { return internal_tensors[9]; };
159   const circle::Tensor *cell_to_forget_weights() { return internal_tensors[10]; };
160   const circle::Tensor *cell_to_output_weights() { return internal_tensors[11]; };
161
162   const circle::Tensor *input_gate_bias() { return internal_tensors[12]; };
163   const circle::Tensor *forget_gate_bias() { return internal_tensors[13]; };
164   const circle::Tensor *cell_gate_bias() { return internal_tensors[14]; };
165   const circle::Tensor *output_gate_bias() { return internal_tensors[15]; };
166
167   const circle::Tensor *projection_weights() { return internal_tensors[16]; };
168   const circle::Tensor *projection_bias() { return internal_tensors[17]; };
169
170   const circle::Tensor *output_state() { return internal_tensors[18]; };
171   const circle::Tensor *cell_state() { return internal_tensors[19]; };
172
173   const circle::Tensor *input_layer_norm_coefficients() { return internal_tensors[20]; };
174   const circle::Tensor *forget_layer_norm_coefficients() { return internal_tensors[21]; };
175   const circle::Tensor *cell_layer_norm_coefficients() { return internal_tensors[22]; };
176   const circle::Tensor *output_layer_norm_coefficients() { return internal_tensors[23]; };
177   const circle::Tensor *output() { return output_internal; };
178
179   const circle::UnidirectionalSequenceLSTMOptions *options;
180
181   const circle::Tensor *get_internal_tensor(int i) { return internal_tensors[i]; }
182
183 private:
184   const circle::Tensor *output_internal;
185   const circle::Tensor *internal_tensors[24];
186 };
187
188 struct FullyConnectedParams
189 {
190   int32_t input_offset;
191   int32_t weights_offset;
192   int32_t output_offset;
193   int32_t output_multiplier;
194   int32_t output_shift;
195   int32_t quantized_activation_min;
196   int32_t quantized_activation_max;
197   int32_t float_activation_min;
198   int32_t float_activation_max;
199 };
200
201 struct GateParameters
202 {
203   FullyConnectedParams input_fc_params;
204   FullyConnectedParams recurrent_fc_params;
205 };
206
207 struct ArithmeticParams
208 {
209   int32_t input1_offset;
210   int32_t input2_offset;
211   int32_t quantized_activation_min;
212   int32_t quantized_activation_max;
213   int32_t output_offset;
214   int32_t output_multiplier;
215   int32_t output_shift;
216   int32_t float_activation_min;
217   int32_t float_activation_max;
218 };
219
220 struct InterGateParameters
221 {
222   ArithmeticParams forget_cell_mul_params;
223   ArithmeticParams input_mul_params;
224   ArithmeticParams output_mul_params;
225 };
226
227 struct CellStateInfo
228 {
229   float cell_clip;
230   // clipping range for cell state only 16 bits cell is supported (could be
231   // generalized through templatation)
232   int16_t quantized_cell_clip;
233   // 2^-cell_state_scale_power = cell state scale, required by integer tanh
234   // computation
235   int32_t cell_state_scale_power;
236 };
237
238 struct LSTMParameters
239 {
240   GateParameters forget_gate_parameters;
241   GateParameters input_gate_parameters;
242   GateParameters cell_gate_parameters;
243   GateParameters output_gate_parameters;
244   InterGateParameters inter_gate_parameters;
245 };
246
247 } // namespace lstm
248 } // namespace luci_interpreter
249
250 #endif // LUCI_INTERPRETER_KERNELS_UNIDIRECTIONAL_SEQUENCE_LSTM_H