Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / LSTMLayer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "LSTMLayer.h"
19
20 #include "OperationUtils.h"
21
22 #include <cker/operation/LSTM.h>
23
24 namespace onert
25 {
26 namespace backend
27 {
28 namespace cpu
29 {
30 namespace ops
31 {
32
33 namespace
34 {
35 template <typename T>
36 T *getOptionalOutputBuffer(onert::backend::IPortableTensor *tensor, std::vector<uint8_t> *temp_vec,
37                            size_t total_size)
38 {
39   if (tensor == nullptr)
40   {
41     temp_vec->reserve(total_size);
42     return reinterpret_cast<T *>(temp_vec->data());
43   }
44   else
45   {
46     assert(tensor->total_size() == total_size);
47     return getBuffer<T>(tensor);
48   }
49 }
50
51 inline void initializeStateBuffer(const onert::backend::IPortableTensor *tensor_in, void *buffer,
52                                   bool needs_memcpy)
53 {
54   assert(tensor_in != nullptr);
55   assert(buffer != nullptr);
56   if (needs_memcpy)
57     memcpy(buffer, tensor_in->buffer(), tensor_in->total_size());
58   else
59     memset(buffer, 0, tensor_in->total_size());
60 }
61 } // namespace
62
63 void LSTMLayer::LSTMFloat()
64 {
65   auto in_shape = _input->getShape();
66   assert(in_shape.rank() >= 2 && in_shape.rank() <= 3);
67   int max_time, n_batch;
68   if (in_shape.rank() == 3)
69   {
70     max_time = (_time_major) ? in_shape.dim(0) : in_shape.dim(1);
71     n_batch = (_time_major) ? in_shape.dim(1) : in_shape.dim(0);
72   }
73   else
74   {
75     max_time = 1;
76     n_batch = in_shape.dim(0);
77   }
78   const int n_input = in_shape.dim(_input->getShape().rank() - 1);
79   const int aux_input_size = 0;
80
81   // n_cell and n_output will be the same size when there is no projection.
82   const int n_cell = _input_to_output_weights->getShape().dim(0);
83   const int n_output = _recurrent_to_output_weights->getShape().dim(1);
84
85   // Since we have already checked that weights are all there or none, we can
86   // check the existence of only one to the get the condition.
87   const bool use_cifg = (_input_to_input_weights == nullptr);
88
89   // Optional outputs
90   float *output_state_buf = getOptionalOutputBuffer<float>(_output_state, &_output_state_vec,
91                                                            _output_state_in->total_size());
92   float *cell_state_buf =
93     getOptionalOutputBuffer<float>(_cell_state, &_cell_state_vec, _cell_state_in->total_size());
94
95   initializeStateBuffer(_output_state_in, output_state_buf, _has_output_state_data);
96   initializeStateBuffer(_cell_state_in, cell_state_buf, _has_cell_state_data);
97
98   // Index the scratch buffers pointers to the global scratch buffer.
99   float *scratch_buffer_buf = getOptionalOutputBuffer<float>(
100     _scratch_buffer, &_scratch_vec, n_batch * n_cell * (use_cifg ? 3 : 4) * sizeof(float));
101   float *input_gate_scratch = nullptr;
102   float *cell_gate_scratch = nullptr;
103   float *forget_gate_scratch = nullptr;
104   float *output_gate_scratch = nullptr;
105   if (use_cifg)
106   {
107     cell_gate_scratch = scratch_buffer_buf;
108     forget_gate_scratch = scratch_buffer_buf + n_cell * n_batch;
109     output_gate_scratch = scratch_buffer_buf + 2 * n_cell * n_batch;
110   }
111   else
112   {
113     input_gate_scratch = scratch_buffer_buf;
114     cell_gate_scratch = scratch_buffer_buf + n_cell * n_batch;
115     forget_gate_scratch = scratch_buffer_buf + 2 * n_cell * n_batch;
116     output_gate_scratch = scratch_buffer_buf + 3 * n_cell * n_batch;
117   }
118
119   auto optional_tensor_ptr = [](const IPortableTensor *tensor) {
120     // If tensor is not given or the tensor size is 0, consider it was not given
121     return (tensor && tensor->total_size() > 0) ? getBuffer<float>(tensor) : nullptr;
122   };
123   // Optional inputs
124   const float *input_to_input_weights_ptr = optional_tensor_ptr(_input_to_input_weights);
125   const float *recurrent_to_input_weights_ptr = optional_tensor_ptr(_recurrent_to_input_weights);
126   const float *cell_to_input_weights_ptr = optional_tensor_ptr(_cell_to_input_weights);
127   const float *cell_to_forget_weights_ptr = optional_tensor_ptr(_cell_to_forget_weights);
128   const float *cell_to_output_weights_ptr = optional_tensor_ptr(_cell_to_output_weights);
129   const float *input_gate_bias_ptr = optional_tensor_ptr(_input_gate_bias);
130   const float *projection_weights_ptr = optional_tensor_ptr(_projection_weights);
131   const float *projection_bias_ptr = optional_tensor_ptr(_projection_bias);
132   const float *input_layer_norm_coefficients_ptr =
133     optional_tensor_ptr(_input_layer_norm_coefficients);
134   const float *forget_layer_norm_coefficients_ptr =
135     optional_tensor_ptr(_forget_layer_norm_coefficients);
136   const float *cell_layer_norm_coefficients_ptr =
137     optional_tensor_ptr(_cell_layer_norm_coefficients);
138   const float *output_layer_norm_coefficients_ptr =
139     optional_tensor_ptr(_output_layer_norm_coefficients);
140
141   // Copy out the LSTM specific params so they can be passed in the function.
142   nnfw::cker::LSTMParams lstm_params;
143   lstm_params.activation = convertActivationType(_params.activation);
144   lstm_params.cell_clip = _params.cell_threshold;
145   lstm_params.proj_clip = _params.projection_threshold;
146
147   auto out_shape = _output->getShape();
148   const int output_batch_leading_dim = out_shape.dim(out_shape.rank() - 1);
149   if (_time_major)
150   {
151     // Loop through the sequence.
152     const int input_step = n_batch * n_input;
153     const int output_step = n_batch * output_batch_leading_dim;
154     for (int t = 0; t < max_time; t++)
155     {
156       // If this is the forward_sequence, step forward, otherwise step
157       // backwards.
158       const int t_rel = _forward_sequence ? t : max_time - t - 1;
159       const float *input_ptr = getBuffer<float>(_input) + t_rel * input_step;
160       const float *aux_input_ptr = nullptr;
161       if (_aux_input)
162       {
163         aux_input_ptr = getBuffer<float>(_aux_input) + t_rel * input_step;
164       }
165       float *output_ptr = getBuffer<float>(_output) + t_rel * output_step + _output_offset;
166
167       LstmStepFloat(
168         input_ptr, input_to_input_weights_ptr, getBuffer<float>(_input_to_forget_weights),
169         getBuffer<float>(_input_to_cell_weights), getBuffer<float>(_input_to_output_weights),
170         aux_input_ptr,
171         /*aux_input_to_input_weights=*/nullptr,
172         /*aux_input_to_forget_weights=*/nullptr,
173         /*aux_input_to_cell_weights=*/nullptr,
174         /*aux_input_to_output_weights=*/nullptr, recurrent_to_input_weights_ptr,
175         getBuffer<float>(_recurrent_to_forget_weights),
176         getBuffer<float>(_recurrent_to_cell_weights),
177         getBuffer<float>(_recurrent_to_output_weights), cell_to_input_weights_ptr,
178         cell_to_forget_weights_ptr, cell_to_output_weights_ptr, input_layer_norm_coefficients_ptr,
179         forget_layer_norm_coefficients_ptr, cell_layer_norm_coefficients_ptr,
180         output_layer_norm_coefficients_ptr, input_gate_bias_ptr,
181         getBuffer<float>(_forget_gate_bias), getBuffer<float>(_cell_gate_bias),
182         getBuffer<float>(_output_gate_bias), projection_weights_ptr, projection_bias_ptr,
183         &lstm_params, n_batch, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
184         output_state_buf, cell_state_buf, input_gate_scratch, forget_gate_scratch,
185         cell_gate_scratch, output_gate_scratch, output_ptr);
186     }
187   }
188   else
189   {
190     for (int b = 0; b < n_batch; b++)
191     {
192       const int input_step = n_input;
193       const int output_step = output_batch_leading_dim;
194       for (int t = 0; t < max_time; t++)
195       {
196         // If this is the forward_sequence, step forward, otherwise step
197         // backwards.
198         const int t_rel = _forward_sequence ? t : max_time - t - 1;
199         const int time_offset = b * max_time + t_rel;
200         const float *input_ptr = getBuffer<float>(_input) + time_offset * input_step;
201         const float *aux_input_ptr = nullptr;
202         if (_aux_input)
203         {
204           aux_input_ptr = getBuffer<float>(_aux_input) + time_offset * input_step;
205         }
206         float *output_ptr = getBuffer<float>(_output) + time_offset * output_step + _output_offset;
207
208         // Offset the {output,cell}_state pointers to the right batch.
209         float *output_state_ptr = output_state_buf + b * output_batch_leading_dim;
210         float *cell_state_ptr = cell_state_buf + b * n_cell;
211         // Offset the scratch pointers to the right batch.
212         float *input_gate_scratch_ptr =
213           input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
214         float *forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
215         float *cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
216         float *output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
217
218         LstmStepFloat(
219           input_ptr, input_to_input_weights_ptr, getBuffer<float>(_input_to_forget_weights),
220           getBuffer<float>(_input_to_cell_weights), getBuffer<float>(_input_to_output_weights),
221           aux_input_ptr,
222           /*aux_input_to_input_weights=*/nullptr,
223           /*aux_input_to_forget_weights=*/nullptr,
224           /*aux_input_to_cell_weights=*/nullptr,
225           /*aux_input_to_output_weights=*/nullptr, recurrent_to_input_weights_ptr,
226           getBuffer<float>(_recurrent_to_forget_weights),
227           getBuffer<float>(_recurrent_to_cell_weights),
228           getBuffer<float>(_recurrent_to_output_weights), cell_to_input_weights_ptr,
229           cell_to_forget_weights_ptr, cell_to_output_weights_ptr, input_layer_norm_coefficients_ptr,
230           forget_layer_norm_coefficients_ptr, cell_layer_norm_coefficients_ptr,
231           output_layer_norm_coefficients_ptr, input_gate_bias_ptr,
232           getBuffer<float>(_forget_gate_bias), getBuffer<float>(_cell_gate_bias),
233           getBuffer<float>(_output_gate_bias), projection_weights_ptr, projection_bias_ptr,
234           &lstm_params, /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
235           output_batch_leading_dim, output_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
236           forget_gate_scratch_ptr, cell_gate_scratch_ptr, output_gate_scratch_ptr, output_ptr);
237       }
238     }
239   }
240 }
241
242 void LSTMLayer::configure(
243   const IPortableTensor *input, const IPortableTensor *input_to_input_weights,
244   const IPortableTensor *input_to_forget_weights, const IPortableTensor *input_to_cell_weights,
245   const IPortableTensor *input_to_output_weights, const IPortableTensor *recurrent_to_input_weights,
246   const IPortableTensor *recurrent_to_forget_weights,
247   const IPortableTensor *recurrent_to_cell_weights,
248   const IPortableTensor *recurrent_to_output_weights, const IPortableTensor *cell_to_input_weights,
249   const IPortableTensor *cell_to_forget_weights, const IPortableTensor *cell_to_output_weights,
250   const IPortableTensor *input_layer_norm_weights, const IPortableTensor *forget_layer_norm_weights,
251   const IPortableTensor *cell_layer_norm_weights, const IPortableTensor *output_layer_norm_weights,
252   const IPortableTensor *aux_input, const IPortableTensor *aux_input_to_input_weights,
253   const IPortableTensor *aux_input_to_forget_weights,
254   const IPortableTensor *aux_input_to_cell_weights,
255   const IPortableTensor *aux_input_to_output_weights, const IPortableTensor *input_gate_bias,
256   const IPortableTensor *forget_gate_bias, const IPortableTensor *cell_gate_bias,
257   const IPortableTensor *output_gate_bias, const IPortableTensor *projection_weights,
258   const IPortableTensor *projection_bias, const IPortableTensor *output_state_in,
259   const IPortableTensor *cell_state_in, const ir::operation::LSTM::Param &params,
260   bool forward_sequence, bool time_major, int output_offset, IPortableTensor *scratch_buffer,
261   IPortableTensor *output_state, IPortableTensor *cell_state, IPortableTensor *output,
262   bool has_output_state_data, bool has_cell_state_data)
263 {
264   _input = input;
265   _input_to_input_weights = input_to_input_weights;
266   _input_to_forget_weights = input_to_forget_weights;
267   _input_to_cell_weights = input_to_cell_weights;
268   _input_to_output_weights = input_to_output_weights;
269   _recurrent_to_input_weights = recurrent_to_input_weights;
270   _recurrent_to_forget_weights = recurrent_to_forget_weights;
271   _recurrent_to_cell_weights = recurrent_to_cell_weights;
272   _recurrent_to_output_weights = recurrent_to_output_weights;
273   _cell_to_input_weights = cell_to_input_weights;
274   _cell_to_forget_weights = cell_to_forget_weights;
275   _cell_to_output_weights = cell_to_output_weights;
276   _input_layer_norm_coefficients = input_layer_norm_weights;
277   _forget_layer_norm_coefficients = forget_layer_norm_weights;
278   _cell_layer_norm_coefficients = cell_layer_norm_weights;
279   _output_layer_norm_coefficients = output_layer_norm_weights;
280   _aux_input = aux_input, _aux_input_to_input_weights = aux_input_to_input_weights,
281   _aux_input_to_forget_weights = aux_input_to_forget_weights,
282   _aux_input_to_cell_weights = aux_input_to_cell_weights,
283   _aux_input_to_output_weights = aux_input_to_output_weights, _input_gate_bias = input_gate_bias;
284   _forget_gate_bias = forget_gate_bias;
285   _cell_gate_bias = cell_gate_bias;
286   _output_gate_bias = output_gate_bias;
287   _projection_weights = projection_weights;
288   _projection_bias = projection_bias;
289   _output_state_in = output_state_in;
290   _cell_state_in = cell_state_in;
291   _params = params;
292   _forward_sequence = forward_sequence;
293   _time_major = time_major;
294   _output_offset = output_offset;
295   _scratch_buffer = scratch_buffer;
296   _output_state = output_state;
297   _cell_state = cell_state;
298   _output = output;
299   _has_output_state_data = has_output_state_data;
300   _has_cell_state_data = has_cell_state_data;
301 }
302
303 void LSTMLayer::run()
304 {
305
306   if (_input->data_type() == OperandType::FLOAT32)
307   {
308     LSTMFloat();
309   }
310   else
311   {
312     throw std::runtime_error{"LSTMLayer: unsupported data type"};
313   }
314 }
315
316 } // namespace ops
317 } // namespace cpu
318 } // namespace backend
319 } // namespace onert