2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "LSTMLayer.h"
20 #include "OperationUtils.h"
22 #include <cker/operation/LSTM.h>
36 T *getOptionalOutputBuffer(onert::backend::IPortableTensor *tensor, std::vector<uint8_t> *temp_vec,
39 if (tensor == nullptr)
41 temp_vec->reserve(total_size);
42 return reinterpret_cast<T *>(temp_vec->data());
46 assert(tensor->total_size() == total_size);
47 return getBuffer<T>(tensor);
51 inline void initializeStateBuffer(const onert::backend::IPortableTensor *tensor_in, void *buffer,
54 assert(tensor_in != nullptr);
55 assert(buffer != nullptr);
57 memcpy(buffer, tensor_in->buffer(), tensor_in->total_size());
59 memset(buffer, 0, tensor_in->total_size());
63 void LSTMLayer::LSTMFloat()
65 auto in_shape = _input->getShape();
66 assert(in_shape.rank() >= 2 && in_shape.rank() <= 3);
67 int max_time, n_batch;
68 if (in_shape.rank() == 3)
70 max_time = (_time_major) ? in_shape.dim(0) : in_shape.dim(1);
71 n_batch = (_time_major) ? in_shape.dim(1) : in_shape.dim(0);
76 n_batch = in_shape.dim(0);
78 const int n_input = in_shape.dim(_input->getShape().rank() - 1);
79 const int aux_input_size = 0;
81 // n_cell and n_output will be the same size when there is no projection.
82 const int n_cell = _input_to_output_weights->getShape().dim(0);
83 const int n_output = _recurrent_to_output_weights->getShape().dim(1);
85 // Since we have already checked that weights are all there or none, we can
86 // check the existence of only one to the get the condition.
87 const bool use_cifg = (_input_to_input_weights == nullptr);
90 float *output_state_buf = getOptionalOutputBuffer<float>(_output_state, &_output_state_vec,
91 _output_state_in->total_size());
92 float *cell_state_buf =
93 getOptionalOutputBuffer<float>(_cell_state, &_cell_state_vec, _cell_state_in->total_size());
95 initializeStateBuffer(_output_state_in, output_state_buf, _has_output_state_data);
96 initializeStateBuffer(_cell_state_in, cell_state_buf, _has_cell_state_data);
98 // Index the scratch buffers pointers to the global scratch buffer.
99 float *scratch_buffer_buf = getOptionalOutputBuffer<float>(
100 _scratch_buffer, &_scratch_vec, n_batch * n_cell * (use_cifg ? 3 : 4) * sizeof(float));
101 float *input_gate_scratch = nullptr;
102 float *cell_gate_scratch = nullptr;
103 float *forget_gate_scratch = nullptr;
104 float *output_gate_scratch = nullptr;
107 cell_gate_scratch = scratch_buffer_buf;
108 forget_gate_scratch = scratch_buffer_buf + n_cell * n_batch;
109 output_gate_scratch = scratch_buffer_buf + 2 * n_cell * n_batch;
113 input_gate_scratch = scratch_buffer_buf;
114 cell_gate_scratch = scratch_buffer_buf + n_cell * n_batch;
115 forget_gate_scratch = scratch_buffer_buf + 2 * n_cell * n_batch;
116 output_gate_scratch = scratch_buffer_buf + 3 * n_cell * n_batch;
119 auto optional_tensor_ptr = [](const IPortableTensor *tensor) {
120 // If tensor is not given or the tensor size is 0, consider it was not given
121 return (tensor && tensor->total_size() > 0) ? getBuffer<float>(tensor) : nullptr;
124 const float *input_to_input_weights_ptr = optional_tensor_ptr(_input_to_input_weights);
125 const float *recurrent_to_input_weights_ptr = optional_tensor_ptr(_recurrent_to_input_weights);
126 const float *cell_to_input_weights_ptr = optional_tensor_ptr(_cell_to_input_weights);
127 const float *cell_to_forget_weights_ptr = optional_tensor_ptr(_cell_to_forget_weights);
128 const float *cell_to_output_weights_ptr = optional_tensor_ptr(_cell_to_output_weights);
129 const float *input_gate_bias_ptr = optional_tensor_ptr(_input_gate_bias);
130 const float *projection_weights_ptr = optional_tensor_ptr(_projection_weights);
131 const float *projection_bias_ptr = optional_tensor_ptr(_projection_bias);
132 const float *input_layer_norm_coefficients_ptr =
133 optional_tensor_ptr(_input_layer_norm_coefficients);
134 const float *forget_layer_norm_coefficients_ptr =
135 optional_tensor_ptr(_forget_layer_norm_coefficients);
136 const float *cell_layer_norm_coefficients_ptr =
137 optional_tensor_ptr(_cell_layer_norm_coefficients);
138 const float *output_layer_norm_coefficients_ptr =
139 optional_tensor_ptr(_output_layer_norm_coefficients);
141 // Copy out the LSTM specific params so they can be passed in the function.
142 nnfw::cker::LSTMParams lstm_params;
143 lstm_params.activation = convertActivationType(_params.activation);
144 lstm_params.cell_clip = _params.cell_threshold;
145 lstm_params.proj_clip = _params.projection_threshold;
147 auto out_shape = _output->getShape();
148 const int output_batch_leading_dim = out_shape.dim(out_shape.rank() - 1);
151 // Loop through the sequence.
152 const int input_step = n_batch * n_input;
153 const int output_step = n_batch * output_batch_leading_dim;
154 for (int t = 0; t < max_time; t++)
156 // If this is the forward_sequence, step forward, otherwise step
158 const int t_rel = _forward_sequence ? t : max_time - t - 1;
159 const float *input_ptr = getBuffer<float>(_input) + t_rel * input_step;
160 const float *aux_input_ptr = nullptr;
163 aux_input_ptr = getBuffer<float>(_aux_input) + t_rel * input_step;
165 float *output_ptr = getBuffer<float>(_output) + t_rel * output_step + _output_offset;
168 input_ptr, input_to_input_weights_ptr, getBuffer<float>(_input_to_forget_weights),
169 getBuffer<float>(_input_to_cell_weights), getBuffer<float>(_input_to_output_weights),
171 /*aux_input_to_input_weights=*/nullptr,
172 /*aux_input_to_forget_weights=*/nullptr,
173 /*aux_input_to_cell_weights=*/nullptr,
174 /*aux_input_to_output_weights=*/nullptr, recurrent_to_input_weights_ptr,
175 getBuffer<float>(_recurrent_to_forget_weights),
176 getBuffer<float>(_recurrent_to_cell_weights),
177 getBuffer<float>(_recurrent_to_output_weights), cell_to_input_weights_ptr,
178 cell_to_forget_weights_ptr, cell_to_output_weights_ptr, input_layer_norm_coefficients_ptr,
179 forget_layer_norm_coefficients_ptr, cell_layer_norm_coefficients_ptr,
180 output_layer_norm_coefficients_ptr, input_gate_bias_ptr,
181 getBuffer<float>(_forget_gate_bias), getBuffer<float>(_cell_gate_bias),
182 getBuffer<float>(_output_gate_bias), projection_weights_ptr, projection_bias_ptr,
183 &lstm_params, n_batch, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
184 output_state_buf, cell_state_buf, input_gate_scratch, forget_gate_scratch,
185 cell_gate_scratch, output_gate_scratch, output_ptr);
190 for (int b = 0; b < n_batch; b++)
192 const int input_step = n_input;
193 const int output_step = output_batch_leading_dim;
194 for (int t = 0; t < max_time; t++)
196 // If this is the forward_sequence, step forward, otherwise step
198 const int t_rel = _forward_sequence ? t : max_time - t - 1;
199 const int time_offset = b * max_time + t_rel;
200 const float *input_ptr = getBuffer<float>(_input) + time_offset * input_step;
201 const float *aux_input_ptr = nullptr;
204 aux_input_ptr = getBuffer<float>(_aux_input) + time_offset * input_step;
206 float *output_ptr = getBuffer<float>(_output) + time_offset * output_step + _output_offset;
208 // Offset the {output,cell}_state pointers to the right batch.
209 float *output_state_ptr = output_state_buf + b * output_batch_leading_dim;
210 float *cell_state_ptr = cell_state_buf + b * n_cell;
211 // Offset the scratch pointers to the right batch.
212 float *input_gate_scratch_ptr =
213 input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
214 float *forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
215 float *cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
216 float *output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
219 input_ptr, input_to_input_weights_ptr, getBuffer<float>(_input_to_forget_weights),
220 getBuffer<float>(_input_to_cell_weights), getBuffer<float>(_input_to_output_weights),
222 /*aux_input_to_input_weights=*/nullptr,
223 /*aux_input_to_forget_weights=*/nullptr,
224 /*aux_input_to_cell_weights=*/nullptr,
225 /*aux_input_to_output_weights=*/nullptr, recurrent_to_input_weights_ptr,
226 getBuffer<float>(_recurrent_to_forget_weights),
227 getBuffer<float>(_recurrent_to_cell_weights),
228 getBuffer<float>(_recurrent_to_output_weights), cell_to_input_weights_ptr,
229 cell_to_forget_weights_ptr, cell_to_output_weights_ptr, input_layer_norm_coefficients_ptr,
230 forget_layer_norm_coefficients_ptr, cell_layer_norm_coefficients_ptr,
231 output_layer_norm_coefficients_ptr, input_gate_bias_ptr,
232 getBuffer<float>(_forget_gate_bias), getBuffer<float>(_cell_gate_bias),
233 getBuffer<float>(_output_gate_bias), projection_weights_ptr, projection_bias_ptr,
234 &lstm_params, /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
235 output_batch_leading_dim, output_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
236 forget_gate_scratch_ptr, cell_gate_scratch_ptr, output_gate_scratch_ptr, output_ptr);
242 void LSTMLayer::configure(
243 const IPortableTensor *input, const IPortableTensor *input_to_input_weights,
244 const IPortableTensor *input_to_forget_weights, const IPortableTensor *input_to_cell_weights,
245 const IPortableTensor *input_to_output_weights, const IPortableTensor *recurrent_to_input_weights,
246 const IPortableTensor *recurrent_to_forget_weights,
247 const IPortableTensor *recurrent_to_cell_weights,
248 const IPortableTensor *recurrent_to_output_weights, const IPortableTensor *cell_to_input_weights,
249 const IPortableTensor *cell_to_forget_weights, const IPortableTensor *cell_to_output_weights,
250 const IPortableTensor *input_layer_norm_weights, const IPortableTensor *forget_layer_norm_weights,
251 const IPortableTensor *cell_layer_norm_weights, const IPortableTensor *output_layer_norm_weights,
252 const IPortableTensor *aux_input, const IPortableTensor *aux_input_to_input_weights,
253 const IPortableTensor *aux_input_to_forget_weights,
254 const IPortableTensor *aux_input_to_cell_weights,
255 const IPortableTensor *aux_input_to_output_weights, const IPortableTensor *input_gate_bias,
256 const IPortableTensor *forget_gate_bias, const IPortableTensor *cell_gate_bias,
257 const IPortableTensor *output_gate_bias, const IPortableTensor *projection_weights,
258 const IPortableTensor *projection_bias, const IPortableTensor *output_state_in,
259 const IPortableTensor *cell_state_in, const ir::operation::LSTM::Param ¶ms,
260 bool forward_sequence, bool time_major, int output_offset, IPortableTensor *scratch_buffer,
261 IPortableTensor *output_state, IPortableTensor *cell_state, IPortableTensor *output,
262 bool has_output_state_data, bool has_cell_state_data)
265 _input_to_input_weights = input_to_input_weights;
266 _input_to_forget_weights = input_to_forget_weights;
267 _input_to_cell_weights = input_to_cell_weights;
268 _input_to_output_weights = input_to_output_weights;
269 _recurrent_to_input_weights = recurrent_to_input_weights;
270 _recurrent_to_forget_weights = recurrent_to_forget_weights;
271 _recurrent_to_cell_weights = recurrent_to_cell_weights;
272 _recurrent_to_output_weights = recurrent_to_output_weights;
273 _cell_to_input_weights = cell_to_input_weights;
274 _cell_to_forget_weights = cell_to_forget_weights;
275 _cell_to_output_weights = cell_to_output_weights;
276 _input_layer_norm_coefficients = input_layer_norm_weights;
277 _forget_layer_norm_coefficients = forget_layer_norm_weights;
278 _cell_layer_norm_coefficients = cell_layer_norm_weights;
279 _output_layer_norm_coefficients = output_layer_norm_weights;
280 _aux_input = aux_input, _aux_input_to_input_weights = aux_input_to_input_weights,
281 _aux_input_to_forget_weights = aux_input_to_forget_weights,
282 _aux_input_to_cell_weights = aux_input_to_cell_weights,
283 _aux_input_to_output_weights = aux_input_to_output_weights, _input_gate_bias = input_gate_bias;
284 _forget_gate_bias = forget_gate_bias;
285 _cell_gate_bias = cell_gate_bias;
286 _output_gate_bias = output_gate_bias;
287 _projection_weights = projection_weights;
288 _projection_bias = projection_bias;
289 _output_state_in = output_state_in;
290 _cell_state_in = cell_state_in;
292 _forward_sequence = forward_sequence;
293 _time_major = time_major;
294 _output_offset = output_offset;
295 _scratch_buffer = scratch_buffer;
296 _output_state = output_state;
297 _cell_state = cell_state;
299 _has_output_state_data = has_output_state_data;
300 _has_cell_state_data = has_cell_state_data;
303 void LSTMLayer::run()
306 if (_input->data_type() == OperandType::FLOAT32)
312 throw std::runtime_error{"LSTMLayer: unsupported data type"};
318 } // namespace backend