Imported Upstream version 1.22.1
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / loader / nodes / UnidirectionalSequenceLSTM.cpp
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "Builders.h"
19
20 #include "kernels/UnidirectionalSequenceLSTM.h"
21
22 namespace luci_interpreter
23 {
24
25 std::unique_ptr<Kernel>
26 build_kernel_CircleUnidirectionalSequenceLSTM(std::vector<const Tensor *> &&inputs,
27                                               std::vector<Tensor *> &&outputs,
28                                               const uint32_t op_index, KernelBuilder &builder)
29 {
30   assert(inputs.size() == 24);
31   const Tensor *input = inputs.at(0);
32   const Tensor *input_to_input_weights = inputs.at(1);
33   const Tensor *input_to_forget_weights = inputs.at(2);
34   const Tensor *input_to_cell_weights = inputs.at(3);
35   const Tensor *input_to_output_weights = inputs.at(4);
36
37   const Tensor *recurrent_to_input_weights = inputs.at(5);
38   const Tensor *recurrent_to_forget_weights = inputs.at(6);
39   const Tensor *recurrent_to_cell_weights = inputs.at(7);
40   const Tensor *recurrent_to_output_weights = inputs.at(8);
41
42   const Tensor *cell_to_input_weights = inputs.at(9);
43   const Tensor *cell_to_forget_weights = inputs.at(10);
44   const Tensor *cell_to_output_weights = inputs.at(11);
45
46   const Tensor *input_gate_bias = inputs.at(12);
47   const Tensor *forget_gate_bias = inputs.at(13);
48   const Tensor *cell_gate_bias = inputs.at(14);
49   const Tensor *output_gate_bias = inputs.at(15);
50
51   const Tensor *projection_weights = inputs.at(16);
52   const Tensor *projection_bias = inputs.at(17);
53
54   Tensor *output_state = const_cast<Tensor *>(inputs.at(18));
55   Tensor *cell_state = const_cast<Tensor *>(inputs.at(19));
56
57   const Tensor *input_layer_norm_coefficients = inputs.at(20);
58   const Tensor *forget_layer_norm_coefficients = inputs.at(21);
59   const Tensor *cell_layer_norm_coefficients = inputs.at(22);
60   const Tensor *output_layer_norm_coefficients = inputs.at(23);
61   Tensor *output = outputs.at(0);
62
63   circle::OperatorT oper_t;
64   builder.get_circle_reader()->operators()[op_index]->UnPackTo(&oper_t);
65   const auto *options = oper_t.builtin_options.AsUnidirectionalSequenceLSTMOptions();
66
67   UnidirectionalSequenceLSTMParams params{};
68   params.activation = luci_actfunc(options->fused_activation_function);
69   params.cell_clip = options->cell_clip;
70   params.proj_clip = options->proj_clip;
71   params.time_major = options->time_major;
72   params.asymmetric_quantize_inputs = options->asymmetric_quantize_inputs;
73
74   // scratch pad tensor
75   const bool is_integer = input->element_type() == DataType::S8;
76   bool use_layer_norm = (forget_layer_norm_coefficients != nullptr);
77
78   if (is_integer)
79   {
80     if (not use_layer_norm)
81     {
82       params.intermediate_affine_quant =
83         builder.get_runtime_graph()->getIntermediateAffineQuantizations();
84
85       // For integer LSTM need 4 16-bit buffer with size n_batch * n_cell
86       // and 1 8-bit buffer with size n_batch * n_cell
87       auto tmp_1 = std::make_unique<Tensor>(DataType::S16, Shape({}), nullptr);
88       tmp_1->set_data_buffer(nullptr);
89       outputs.push_back(builder.get_runtime_graph()->addTensor(std::move(tmp_1)));
90
91       auto tmp_2 = std::make_unique<Tensor>(DataType::S16, Shape({}), nullptr);
92       tmp_2->set_data_buffer(nullptr);
93       outputs.push_back(builder.get_runtime_graph()->addTensor(std::move(tmp_2)));
94
95       auto tmp_3 = std::make_unique<Tensor>(DataType::S16, Shape({}), nullptr);
96       tmp_3->set_data_buffer(nullptr);
97       outputs.push_back(builder.get_runtime_graph()->addTensor(std::move(tmp_3)));
98
99       auto tmp_4 = std::make_unique<Tensor>(DataType::S16, Shape({}), nullptr);
100       tmp_4->set_data_buffer(nullptr);
101       outputs.push_back(builder.get_runtime_graph()->addTensor(std::move(tmp_4)));
102
103       auto tmp_5 = std::make_unique<Tensor>(
104         DataType::S8, Shape({}),
105         builder.get_runtime_graph()->getIntermediateAffineQuantizations()[0]);
106       tmp_5->set_data_buffer(nullptr);
107       outputs.push_back(builder.get_runtime_graph()->addTensor(std::move(tmp_5)));
108     }
109     else
110     {
111       // TODO: support float
112       assert(false && "Not supported now");
113     }
114   }
115   else
116   {
117     // NOTE provide more scratch pads if support hybrid or integer
118     auto sp_output_state =
119       std::make_unique<Tensor>(output_state->element_type(), Shape({}), nullptr);
120     sp_output_state->set_data_buffer(nullptr);
121     outputs.push_back(builder.get_runtime_graph()->addTensor(std::move(sp_output_state)));
122
123     auto sp_cell_state = std::make_unique<Tensor>(cell_state->element_type(), Shape({}), nullptr);
124     sp_cell_state->set_data_buffer(nullptr);
125     outputs.push_back(builder.get_runtime_graph()->addTensor(std::move(sp_cell_state)));
126
127     auto sp_3 = std::make_unique<Tensor>(input->element_type(), Shape({}), nullptr);
128     sp_3->set_data_buffer(nullptr);
129     outputs.push_back(builder.get_runtime_graph()->addTensor(std::move(sp_3)));
130   }
131
132   outputs.push_back(output_state);
133   outputs.push_back(cell_state);
134
135   return std::make_unique<kernels::UnidirectionalSequenceLSTM>(
136     input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights,
137     input_to_output_weights, recurrent_to_input_weights, recurrent_to_forget_weights,
138     recurrent_to_cell_weights, recurrent_to_output_weights, cell_to_input_weights,
139     cell_to_forget_weights, cell_to_output_weights, input_gate_bias, forget_gate_bias,
140     cell_gate_bias, output_gate_bias, projection_weights, projection_bias,
141     input_layer_norm_coefficients, forget_layer_norm_coefficients, cell_layer_norm_coefficients,
142     output_layer_norm_coefficients, std::move(outputs), params);
143 }
144
145 } // namespace luci_interpreter