2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef LUCI_INTERPRETER_KERNELS_UNIDIRECTIONAL_SEQUENCE_LSTM_H
19 #define LUCI_INTERPRETER_KERNELS_UNIDIRECTIONAL_SEQUENCE_LSTM_H
23 namespace luci_interpreter
30 LSTMStruct() = delete;
31 LSTMStruct(const LSTMStruct &) = delete;
33 explicit LSTMStruct(const circle::Operator *cur_op,
34 luci_interpreter::BaseRuntimeGraph *runtime_graph)
36 const auto input_index = cur_op->inputs()->operator[](0);
37 const auto input_to_input_weights_index = cur_op->inputs()->operator[](1);
38 const auto input_to_forget_weights_index = cur_op->inputs()->operator[](2);
39 const auto input_to_cell_weights_index = cur_op->inputs()->operator[](3);
40 const auto input_to_output_weights_index = cur_op->inputs()->operator[](4);
41 assert(input_index != -1);
42 // input_to_input_weights_index - optional
43 assert(input_to_forget_weights_index != -1);
44 assert(input_to_cell_weights_index != -1);
45 assert(input_to_output_weights_index != -1);
46 internal_tensors[0] = runtime_graph->getCircleTensorByIndex(input_index);
47 internal_tensors[1] = runtime_graph->getCircleTensorByIndex(input_to_input_weights_index);
48 internal_tensors[2] = runtime_graph->getCircleTensorByIndex(input_to_forget_weights_index);
49 internal_tensors[3] = runtime_graph->getCircleTensorByIndex(input_to_cell_weights_index);
50 internal_tensors[4] = runtime_graph->getCircleTensorByIndex(input_to_output_weights_index);
52 const auto recurrent_to_input_weights_index = cur_op->inputs()->operator[](5);
53 const auto recurrent_to_forget_weights_index = cur_op->inputs()->operator[](6);
54 const auto recurrent_to_cell_weights_index = cur_op->inputs()->operator[](7);
55 const auto recurrent_to_output_weights_index = cur_op->inputs()->operator[](8);
56 // recurrent_to_input_weights_index - optional
57 assert(recurrent_to_forget_weights_index != -1);
58 assert(recurrent_to_cell_weights_index != -1);
59 assert(recurrent_to_output_weights_index != -1);
60 internal_tensors[5] = runtime_graph->getCircleTensorByIndex(recurrent_to_input_weights_index);
61 internal_tensors[6] = runtime_graph->getCircleTensorByIndex(recurrent_to_forget_weights_index);
62 internal_tensors[7] = runtime_graph->getCircleTensorByIndex(recurrent_to_cell_weights_index);
63 internal_tensors[8] = runtime_graph->getCircleTensorByIndex(recurrent_to_output_weights_index);
65 const auto cell_to_input_weights_index = cur_op->inputs()->operator[](9);
66 const auto cell_to_forget_weights_index = cur_op->inputs()->operator[](10);
67 const auto cell_to_output_weights_index = cur_op->inputs()->operator[](11);
68 // optional cell_to_input_weights_index
69 // optional cell_to_forget_weights_index
70 // optional cell_to_output_weights_index
71 internal_tensors[9] = runtime_graph->getCircleTensorByIndex(cell_to_input_weights_index);
72 internal_tensors[10] = runtime_graph->getCircleTensorByIndex(cell_to_forget_weights_index);
73 internal_tensors[11] = runtime_graph->getCircleTensorByIndex(cell_to_output_weights_index);
75 const auto input_gate_bias_index = cur_op->inputs()->operator[](12);
76 const auto forget_gate_bias_index = cur_op->inputs()->operator[](13);
77 const auto cell_gate_bias_index = cur_op->inputs()->operator[](14);
78 const auto output_gate_bias_index = cur_op->inputs()->operator[](15);
79 // optional input_gate_bias_index
80 assert(forget_gate_bias_index != -1);
81 assert(cell_gate_bias_index != -1);
82 assert(output_gate_bias_index != -1);
83 internal_tensors[12] = runtime_graph->getCircleTensorByIndex(input_gate_bias_index);
84 internal_tensors[13] = runtime_graph->getCircleTensorByIndex(forget_gate_bias_index);
85 internal_tensors[14] = runtime_graph->getCircleTensorByIndex(cell_gate_bias_index);
86 internal_tensors[15] = runtime_graph->getCircleTensorByIndex(output_gate_bias_index);
88 const auto projection_weights_index = cur_op->inputs()->operator[](16);
89 const auto projection_bias_index = cur_op->inputs()->operator[](17);
90 // optional projection_weights_index
91 // optional projection_bias_index
92 internal_tensors[16] = runtime_graph->getCircleTensorByIndex(projection_weights_index);
93 internal_tensors[17] = runtime_graph->getCircleTensorByIndex(projection_bias_index);
95 const auto output_state_index = cur_op->inputs()->operator[](18);
96 const auto cell_state_index = cur_op->inputs()->operator[](19);
97 assert(output_state_index != -1);
98 assert(cell_state_index != -1);
99 internal_tensors[18] = runtime_graph->getCircleTensorByIndex(output_state_index);
100 internal_tensors[19] = runtime_graph->getCircleTensorByIndex(cell_state_index);
102 const auto input_layer_norm_coefficients_index = cur_op->inputs()->operator[](20);
103 const auto forget_layer_norm_coefficients_index = cur_op->inputs()->operator[](21);
104 const auto cell_layer_norm_coefficients_index = cur_op->inputs()->operator[](22);
105 const auto output_layer_norm_coefficients_index = cur_op->inputs()->operator[](23);
106 // optional input_layer_norm_coefficients_index
107 // optional forget_layer_norm_coefficients_index
108 // optional cell_layer_norm_coefficients_index
109 // optional output_layer_norm_coefficients_index
110 internal_tensors[20] =
111 runtime_graph->getCircleTensorByIndex(input_layer_norm_coefficients_index);
112 internal_tensors[21] =
113 runtime_graph->getCircleTensorByIndex(forget_layer_norm_coefficients_index);
114 internal_tensors[22] =
115 runtime_graph->getCircleTensorByIndex(cell_layer_norm_coefficients_index);
116 internal_tensors[23] =
117 runtime_graph->getCircleTensorByIndex(output_layer_norm_coefficients_index);
119 const auto output_index = cur_op->outputs()->operator[](0);
120 assert(output_index != -1);
121 output_internal = runtime_graph->getCircleTensorByIndex(output_index);
123 options = cur_op->builtin_options_as_UnidirectionalSequenceLSTMOptions();
126 void validateTensorTypes()
128 LUCI_INTERPRETER_CHECK(Tensor::element_type(input()) == Tensor::element_type(output_state()));
129 LUCI_INTERPRETER_CHECK(Tensor::element_type(output()) == Tensor::element_type(input()));
131 for (int32_t i = 1; i < 9; ++i)
133 LUCI_INTERPRETER_CHECK(internal_tensors[i] == nullptr or
134 Tensor::element_type(input_to_forget_weights()) ==
135 Tensor::element_type(internal_tensors[i]));
138 for (int32_t i = 12; i < 16; ++i)
140 LUCI_INTERPRETER_CHECK(internal_tensors[i] == nullptr or
141 Tensor::element_type(forget_gate_bias()) ==
142 Tensor::element_type(internal_tensors[i]));
146 const circle::Tensor *input() { return internal_tensors[0]; };
148 const circle::Tensor *input_to_input_weights() { return internal_tensors[1]; };
149 const circle::Tensor *input_to_forget_weights() { return internal_tensors[2]; };
150 const circle::Tensor *input_to_cell_weights() { return internal_tensors[3]; };
151 const circle::Tensor *input_to_output_weights() { return internal_tensors[4]; };
153 const circle::Tensor *recurrent_to_input_weights() { return internal_tensors[5]; };
154 const circle::Tensor *recurrent_to_forget_weights() { return internal_tensors[6]; };
155 const circle::Tensor *recurrent_to_cell_weights() { return internal_tensors[7]; };
156 const circle::Tensor *recurrent_to_output_weights() { return internal_tensors[8]; };
158 const circle::Tensor *cell_to_input_weights() { return internal_tensors[9]; };
159 const circle::Tensor *cell_to_forget_weights() { return internal_tensors[10]; };
160 const circle::Tensor *cell_to_output_weights() { return internal_tensors[11]; };
162 const circle::Tensor *input_gate_bias() { return internal_tensors[12]; };
163 const circle::Tensor *forget_gate_bias() { return internal_tensors[13]; };
164 const circle::Tensor *cell_gate_bias() { return internal_tensors[14]; };
165 const circle::Tensor *output_gate_bias() { return internal_tensors[15]; };
167 const circle::Tensor *projection_weights() { return internal_tensors[16]; };
168 const circle::Tensor *projection_bias() { return internal_tensors[17]; };
170 const circle::Tensor *output_state() { return internal_tensors[18]; };
171 const circle::Tensor *cell_state() { return internal_tensors[19]; };
173 const circle::Tensor *input_layer_norm_coefficients() { return internal_tensors[20]; };
174 const circle::Tensor *forget_layer_norm_coefficients() { return internal_tensors[21]; };
175 const circle::Tensor *cell_layer_norm_coefficients() { return internal_tensors[22]; };
176 const circle::Tensor *output_layer_norm_coefficients() { return internal_tensors[23]; };
177 const circle::Tensor *output() { return output_internal; };
179 const circle::UnidirectionalSequenceLSTMOptions *options;
181 const circle::Tensor *get_internal_tensor(int i) { return internal_tensors[i]; }
184 const circle::Tensor *output_internal;
185 const circle::Tensor *internal_tensors[24];
188 struct FullyConnectedParams
190 int32_t input_offset;
191 int32_t weights_offset;
192 int32_t output_offset;
193 int32_t output_multiplier;
194 int32_t output_shift;
195 int32_t quantized_activation_min;
196 int32_t quantized_activation_max;
197 int32_t float_activation_min;
198 int32_t float_activation_max;
201 struct GateParameters
203 FullyConnectedParams input_fc_params;
204 FullyConnectedParams recurrent_fc_params;
207 struct ArithmeticParams
209 int32_t input1_offset;
210 int32_t input2_offset;
211 int32_t quantized_activation_min;
212 int32_t quantized_activation_max;
213 int32_t output_offset;
214 int32_t output_multiplier;
215 int32_t output_shift;
216 int32_t float_activation_min;
217 int32_t float_activation_max;
220 struct InterGateParameters
222 ArithmeticParams forget_cell_mul_params;
223 ArithmeticParams input_mul_params;
224 ArithmeticParams output_mul_params;
230 // clipping range for cell state only 16 bits cell is supported (could be
231 // generalized through templatation)
232 int16_t quantized_cell_clip;
233 // 2^-cell_state_scale_power = cell state scale, required by integer tanh
235 int32_t cell_state_scale_power;
238 struct LSTMParameters
240 GateParameters forget_gate_parameters;
241 GateParameters input_gate_parameters;
242 GateParameters cell_gate_parameters;
243 GateParameters output_gate_parameters;
244 InterGateParameters inter_gate_parameters;
248 } // namespace luci_interpreter
250 #endif // LUCI_INTERPRETER_KERNELS_UNIDIRECTIONAL_SEQUENCE_LSTM_H