2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_UNIDIRECTIONALSEQUENCELSTM_H__
19 #define __NNFW_CKER_UNIDIRECTIONALSEQUENCELSTM_H__
21 #include "cker/TensorUtils.h"
22 #include "cker/Types.h"
30 // Calculates a single LSTM gate.
32 // Implements the following formula: (* is matrix multiply)
33 // gate = activate(W_input * input + W_aux * aux_input +
34 // W_peephole * cell + W_recurrent * prev_output + bias)
36 // gate = activate(W_norm * normalize(...) + bias) // not adding bias inside
38 // Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
41 // Input vectors (to LSTM): | Size: | Optional?
43 // aux_input | n_aux_input | y (bidir LSTM)
44 // Input vectors (persistent states):
45 // output_state | n_output |
46 // cell_state | n_cell |
48 // input_to_gate_weights | n_cell * n_input |
49 // aux_input_to_gate_weights | n_cell * n_aux_input | y (bidir LSTM)
50 // recurrent_to_gate_weights | n_cell * n_output |
51 // cell_to_gate_weights | n_cell | y (peephole)
52 // gate_bias | n_cell |
53 // layer_norm_coefficients | n_cell | y (layer norm)
57 // n_batch - batch size / number of vectors
58 // n_input, n_aux_input, n_output, n_cell - size of vectors.
59 // activation - activation to use.
60 // is_input_all_zeros, is_aux_input_all_zeros - if input vectors are all zero.
61 // use_layer_norm - if doing layer norm LSTM.
62 inline void CalculateLstmGateFloat(const float *input, const float *input_to_gate_weights,
63 const float *aux_input, const float *aux_input_to_gate_weights,
64 const float *output_state,
65 const float *recurrent_to_gate_weights, const float *cell_state,
66 const float *cell_to_gate_weights,
67 const float *layer_norm_coefficients, const float *gate_bias,
68 const int n_batch, const int n_input, const int n_aux_input,
69 const int n_output, const int n_cell,
70 const FusedActivationFunctionType activation, float *gate,
71 const bool is_input_all_zeros, const bool is_aux_input_all_zeros)
73 const bool use_peephole = (cell_to_gate_weights != nullptr);
74 const bool use_layer_norm = (layer_norm_coefficients != nullptr);
76 // Initialize scratch buffers with bias for regular lstm or initialize with
77 // zero for layer norm lstm.
80 std::fill_n(gate, n_cell * n_batch, 0.0f);
84 VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
86 // For each batch and cell: compute input_weight * input.
87 // Skip if input is all zeros.
88 if (!is_input_all_zeros)
90 MatrixBatchVectorMultiplyAccumulate(input_to_gate_weights, n_cell, n_input, input, n_batch,
91 gate, /*result_stride=*/1);
93 // For each batch and cell: compute aux_input_weight * aux_input.
94 // Skip if auxiliary input is not available or all zeros.
95 if (!is_aux_input_all_zeros)
97 MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights, n_cell, n_aux_input, aux_input,
98 n_batch, gate, /*result_stride=*/1);
100 // For each batch and cell: compute recurrent_weight * output_state.
101 MatrixBatchVectorMultiplyAccumulate(recurrent_to_gate_weights, n_cell, n_output, output_state,
102 n_batch, gate, /*result_stride=*/1);
103 // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
106 VectorBatchVectorCwiseProductAccumulate(cell_to_gate_weights, n_cell, cell_state, n_batch,
109 // Do layer normalization (if layer norm LSTM)
112 MeanStddevNormalization(gate, gate, n_cell, n_batch);
113 VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell, gate, n_batch, gate);
114 VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
117 ApplyActivationToVector(gate, n_batch * n_cell, activation, gate);
120 // Updates the LSTM cell state, used by both float and hybrid LSTM versions.
122 // Implements the following formula:
123 // cell_state_new = clip(forget_gate * cell_state + input_gate * cell_gate)
125 // With CIFG LSTM, input gate is replaced by (1-forget_gate).
128 // - n_batch, n_cell: sizes of vectors
129 // - cell_state: input/output vector, size n_batch*n_cell
130 // - input_gate: input vector, size n_batch*n_cell.
131 // - forget_gate: input/scratch vector, size n_batch*n_cell, modified with CIFG
132 // - cell_gate: input vector, size n_batch*n_cell.
133 // - use_cifg: use 1-forget_gate instead of input_gate.
134 // - clip: if > 0, clip the resulting cell state to [-clip, +clip].
135 void UpdateLstmCellFloat(int n_batch, int n_cell, float *cell_state, const float *input_gate,
136 float *forget_gate, const float *cell_gate, bool use_cifg, float clip)
138 // Define variable for 4th argument to avoid warning
139 // Compiler warning: passing argument 4 to restrict-qualified parameter aliases with argument 2
140 const float *cwise_product_rhs = cell_state;
141 VectorVectorCwiseProduct(forget_gate, cwise_product_rhs, n_batch * n_cell, cell_state);
145 // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
146 // scratch, as input_gate array is not allocated in this case. (Be careful
147 // not to write to the scratch before reading the forget gate data.)
148 float *scratch = forget_gate;
149 Sub1Vector(forget_gate, n_batch * n_cell, scratch);
150 VectorVectorCwiseProductAccumulate(cell_gate, scratch, n_batch * n_cell, cell_state);
154 VectorVectorCwiseProductAccumulate(cell_gate, input_gate, n_batch * n_cell, cell_state);
158 CwiseClipping(cell_state, n_batch * n_cell, clip);
162 // Calculates the output state tensor of an LSTM step.
164 // Implements the following formula:
165 // output_no_projection = output_gate .* activate(cell_state)
166 // (elementwise vector product)
167 // If no projection is used:
168 // output = output_state = output_no_projection
170 // output = output_state = clip(W*output_no_projection + bias)
172 // Output might not have a different 'stride' than n_batch, so we need to copy.
175 // - n_batch: batches: the number of distinct vectors in each array.
176 // - n_cell, n_output: sizes of vectors.
177 // - cell_state, output_gate: input vectors, size n_batch*n_cell.
178 // - projection_weights, projection_weights_scale, projection_bias:
179 // constant inputs, describing projection matrix and bias.
180 // - proj_clip: if > 0, clip the output of the projection.
181 // - output_state: output vector, size n_batch*n_output. Must be contigous.
182 // - scratch: scratch area, size n_batch*n_cell.
183 void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output, const float *cell_state,
184 const float *output_gate, FusedActivationFunctionType activation,
185 const float *projection_weights, const float *projection_bias,
186 const float proj_clip, float *output_state, float *scratch)
188 ApplyActivationToVector(cell_state, n_batch * n_cell, activation, scratch);
190 // Define variable for 4th argument to avoid warning
191 // Compiler warning: passing argument 4 to restrict-qualified parameter aliases with argument 2
192 const float *cwise_product_rhs = scratch;
193 VectorVectorCwiseProduct(output_gate, cwise_product_rhs, n_batch * n_cell, scratch);
195 const bool use_projection = (projection_weights != nullptr);
196 const bool use_projection_bias = (projection_bias != nullptr);
200 if (use_projection_bias)
202 VectorBatchVectorAssign(projection_bias, n_output, n_batch, output_state);
206 std::fill_n(output_state, n_batch * n_output, 0.0f);
208 MatrixBatchVectorMultiplyAccumulate(projection_weights, n_output, n_cell, scratch, n_batch,
209 output_state, /*result_stride=*/1);
210 if (proj_clip > 0.0f)
212 CwiseClipping(output_state, n_batch * n_output, proj_clip);
217 std::copy_n(scratch, n_batch * n_output, output_state);
221 // Performs an LSTM batch inference step for input specified by input_ptr.
222 // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
223 // biases (*_bias_ptr), and buffers (*_scratch), along with additional
225 // - params: various LSTM params including activation, clipping, etc.,
226 // - n_batch: size of batch,
227 // - n_cell: number of cells (or units),
228 // - n_input: the input size,
229 // - n_aux_input: the auxiliary input size.
230 // - n_output: the output size.
231 // - output_batch_leading_dim: the leading dimension of the output buffer.
233 // Input of size 'n_batch * n_input':
235 // Input of size 'n_batch * n_aux_input':
236 // aux_input_ptr - optional (can be nullptr)
239 // Input weights of size 'n_cell * n_input':
240 // input_to_input_weights - optional
241 // input_to_forget_weights
242 // input_to_cell_weights
243 // input_to_output_weights
244 // Auxiliary input weights of size 'n_cell * n_aux_input':
245 // aux_input_to_input_weights - optional
246 // aux_input_to_forget_weights - optional
247 // aux_input_to_cell_weights - optional
248 // aux_input_to_output_weights - optional
249 // Recurrent weights of size 'n_cell * n_output':
250 // recurrent_to_input_weights - optional
251 // recurrent_to_forget_weights
252 // recurrent_to_cell_weights
253 // recurrent_to_input_weights
254 // Peephole weights of size 'n_cell', representing diagonal matrices.
255 // cell_to_input_weights - optional
256 // cell_to_cell_weights - optional
257 // cell_to_output_weights - optional
258 // Projection weights of size 'n_output * n_cell'
259 // projection_weights_ptr - optional
260 // Gate biases of size 'n_cell':
261 // input_gate_bias_ptr - optional
262 // forget_gate_bias_ptr
263 // cell_gate_bias_ptr
264 // output_gate_bias_ptr
266 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
267 // input_layer_norm_coefficients_ptr - optional
268 // forget_layer_norm_coefficients_ptr - optional
269 // cell_layer_norm_coefficients_ptr - optional
270 // output_layer_norm_coefficients_ptr - optional
272 // The pointers to the cell and output state and the output are updated.
274 // The pointers input_ptr, aux_input_ptr, and output_ptr point to data aligned
275 // in batch_major order, and each step processes batch_size many inputs from
276 // input_ptr, and updates batch_size many cell and output states.
278 // The output_batch_dim is output.shape[-1], i.e. the outermost dimension of the
279 // output tensor, and in most cases will be equal to n_output. It is usually not
280 // when we want to store the LSTM output into a slice of the output tensor, e.g.
281 // for bidirectional LSTMs with merge_outputs. In this case, the batched
282 // operations cannot be used since they assume that the batched outputs are
283 // contiguous, and we manually loop over the batched outputs.
285 inline void LstmStepFloat(
286 const float *input_ptr, const float *input_to_input_weights_ptr,
287 const float *input_to_forget_weights_ptr, const float *input_to_cell_weights_ptr,
288 const float *input_to_output_weights_ptr, const float *aux_input_ptr,
289 const float *aux_input_to_input_weights_ptr, const float *aux_input_to_forget_weights_ptr,
290 const float *aux_input_to_cell_weights_ptr, const float *aux_input_to_output_weights_ptr,
291 const float *recurrent_to_input_weights_ptr, const float *recurrent_to_forget_weights_ptr,
292 const float *recurrent_to_cell_weights_ptr, const float *recurrent_to_output_weights_ptr,
293 const float *cell_to_input_weights_ptr, const float *cell_to_forget_weights_ptr,
294 const float *cell_to_output_weights_ptr, const float *input_layer_norm_coefficients_ptr,
295 const float *forget_layer_norm_coefficients_ptr, const float *cell_layer_norm_coefficients_ptr,
296 const float *output_layer_norm_coefficients_ptr, const float *input_gate_bias_ptr,
297 const float *forget_gate_bias_ptr, const float *cell_gate_bias_ptr,
298 const float *output_gate_bias_ptr, const float *projection_weights_ptr,
299 const float *projection_bias_ptr, const LSTMParams *params, int n_batch, int n_cell, int n_input,
300 int n_aux_input, int n_output, int output_batch_leading_dim, float *output_state_ptr,
301 float *cell_state_ptr, float *scratch0, float *scratch1, float *scratch2, float *scratch3,
304 // Since we have already checked that weights are all there or none, we can
305 // check the existence of only one to the get the condition.
306 const bool use_cifg = (input_to_input_weights_ptr == nullptr);
308 // Make named scratch buffers.
309 float *input_gate_scratch = scratch0;
310 float *forget_gate_scratch = scratch1;
311 float *cell_gate_scratch = scratch2;
312 float *output_gate_scratch = scratch3;
314 // Check if inputs are all zeros so we can skip some computations.
315 const bool is_input_all_zeros = IsZeroVector(input_ptr, n_batch * n_input);
316 const bool is_aux_input_all_zeros =
317 (aux_input_ptr == nullptr || IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
320 // Calculate the input gate. (If not CIFG.)
321 CalculateLstmGateFloat(input_ptr, input_to_input_weights_ptr, aux_input_ptr,
322 aux_input_to_input_weights_ptr, output_state_ptr,
323 recurrent_to_input_weights_ptr, cell_state_ptr,
324 cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
325 input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
326 /*activation=kTfLiteActSigmoid*/ FusedActivationFunctionType::kSigmoid,
327 input_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros);
329 // Calculate the forget gate.
330 CalculateLstmGateFloat(input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
331 aux_input_to_forget_weights_ptr, output_state_ptr,
332 recurrent_to_forget_weights_ptr, cell_state_ptr,
333 cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
334 forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
335 /*activation=kTfLiteActSigmoid*/ FusedActivationFunctionType::kSigmoid,
336 forget_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros);
337 // Calculate the cell update gate.
338 CalculateLstmGateFloat(
339 input_ptr, input_to_cell_weights_ptr, aux_input_ptr, aux_input_to_cell_weights_ptr,
340 output_state_ptr, recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
341 /*cell_to_gate_weights=*/nullptr, cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr, n_batch,
342 n_input, n_aux_input, n_output, n_cell, params->activation, cell_gate_scratch,
343 is_input_all_zeros, is_aux_input_all_zeros);
344 // Update the cell state.
345 UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
346 cell_gate_scratch, use_cifg, params->cell_clip);
347 // Calculate output gate.
348 CalculateLstmGateFloat(input_ptr, input_to_output_weights_ptr, aux_input_ptr,
349 aux_input_to_output_weights_ptr, output_state_ptr,
350 recurrent_to_output_weights_ptr, cell_state_ptr,
351 cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
352 output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
353 /*activation=kTfLiteActSigmoid*/ FusedActivationFunctionType::kSigmoid,
354 output_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros);
355 // Update the output state.
356 CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
357 params->activation, projection_weights_ptr, projection_bias_ptr,
358 params->proj_clip, output_state_ptr, scratch2);
359 // Copy output state to the output. Note that the output's rows may not be
360 // contiguous (output_batch_leading_dim != n_output).
361 for (int b = 0; b < n_batch; b++)
363 std::copy_n(output_state_ptr + b * n_output, n_output,
364 output_ptr + b * output_batch_leading_dim);
371 #endif // __NNFW_CKER_UNIDIRECTIONALSEQUENCELSTM_H__