Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / LSTM.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_UNIDIRECTIONALSEQUENCELSTM_H__
19 #define __NNFW_CKER_UNIDIRECTIONALSEQUENCELSTM_H__
20
21 #include "cker/TensorUtils.h"
22 #include "cker/Types.h"
23
24 namespace nnfw
25 {
26 namespace cker
27 {
28
29 // LINT.IfChange
30 // Calculates a single LSTM gate.
31 //
32 // Implements the following formula: (* is matrix multiply)
33 //   gate = activate(W_input    * input + W_aux       * aux_input   +
34 //                   W_peephole * cell  + W_recurrent * prev_output + bias)
35 // with layer norm:
36 //   gate = activate(W_norm * normalize(...) + bias) // not adding bias inside
37 //
38 // Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
39 //
40 // Parameters:
41 // Input vectors (to LSTM):    | Size:                | Optional?
42 //   input                     | n_input              |
43 //   aux_input                 | n_aux_input          | y (bidir LSTM)
44 // Input vectors (persistent states):
45 //   output_state              | n_output             |
46 //   cell_state                | n_cell               |
47 // 'Constant' inputs:
48 //   input_to_gate_weights     | n_cell * n_input     |
49 //   aux_input_to_gate_weights | n_cell * n_aux_input | y (bidir LSTM)
50 //   recurrent_to_gate_weights | n_cell * n_output    |
51 //   cell_to_gate_weights      | n_cell               | y (peephole)
52 //   gate_bias                 | n_cell               |
53 //   layer_norm_coefficients   | n_cell               | y (layer norm)
54 // Output vector:
55 //   gate                      | n_cell               |
56 // Scalar parameters:
57 //   n_batch                                    - batch size / number of vectors
58 //   n_input, n_aux_input, n_output, n_cell     - size of vectors.
59 //   activation                                 - activation to use.
60 //   is_input_all_zeros, is_aux_input_all_zeros - if input vectors are all zero.
61 //   use_layer_norm                             - if doing layer norm LSTM.
62 inline void CalculateLstmGateFloat(const float *input, const float *input_to_gate_weights,
63                                    const float *aux_input, const float *aux_input_to_gate_weights,
64                                    const float *output_state,
65                                    const float *recurrent_to_gate_weights, const float *cell_state,
66                                    const float *cell_to_gate_weights,
67                                    const float *layer_norm_coefficients, const float *gate_bias,
68                                    const int n_batch, const int n_input, const int n_aux_input,
69                                    const int n_output, const int n_cell,
70                                    const FusedActivationFunctionType activation, float *gate,
71                                    const bool is_input_all_zeros, const bool is_aux_input_all_zeros)
72 {
73   const bool use_peephole = (cell_to_gate_weights != nullptr);
74   const bool use_layer_norm = (layer_norm_coefficients != nullptr);
75
76   // Initialize scratch buffers with bias for regular lstm or initialize with
77   // zero for layer norm lstm.
78   if (use_layer_norm)
79   {
80     std::fill_n(gate, n_cell * n_batch, 0.0f);
81   }
82   else
83   {
84     VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
85   }
86   // For each batch and cell: compute input_weight * input.
87   // Skip if input is all zeros.
88   if (!is_input_all_zeros)
89   {
90     MatrixBatchVectorMultiplyAccumulate(input_to_gate_weights, n_cell, n_input, input, n_batch,
91                                         gate, /*result_stride=*/1);
92   }
93   // For each batch and cell: compute aux_input_weight * aux_input.
94   // Skip if auxiliary input is not available or all zeros.
95   if (!is_aux_input_all_zeros)
96   {
97     MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights, n_cell, n_aux_input, aux_input,
98                                         n_batch, gate, /*result_stride=*/1);
99   }
100   // For each batch and cell: compute recurrent_weight * output_state.
101   MatrixBatchVectorMultiplyAccumulate(recurrent_to_gate_weights, n_cell, n_output, output_state,
102                                       n_batch, gate, /*result_stride=*/1);
103   // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
104   if (use_peephole)
105   {
106     VectorBatchVectorCwiseProductAccumulate(cell_to_gate_weights, n_cell, cell_state, n_batch,
107                                             gate);
108   }
109   // Do layer normalization (if layer norm LSTM)
110   if (use_layer_norm)
111   {
112     MeanStddevNormalization(gate, gate, n_cell, n_batch);
113     VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell, gate, n_batch, gate);
114     VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
115   }
116   // Apply activation
117   ApplyActivationToVector(gate, n_batch * n_cell, activation, gate);
118 }
119
120 // Updates the LSTM cell state, used by both float and hybrid LSTM versions.
121 //
122 // Implements the following formula:
123 //   cell_state_new = clip(forget_gate * cell_state + input_gate * cell_gate)
124 //
125 // With CIFG LSTM, input gate is replaced by (1-forget_gate).
126 //
127 // Parameters:
128 //  - n_batch, n_cell: sizes of vectors
129 //  - cell_state: input/output vector, size n_batch*n_cell
130 //  - input_gate: input vector, size n_batch*n_cell.
131 //  - forget_gate: input/scratch vector, size n_batch*n_cell, modified with CIFG
132 //  - cell_gate: input vector, size n_batch*n_cell.
133 //  - use_cifg: use 1-forget_gate instead of input_gate.
134 //  - clip: if > 0, clip the resulting cell state to [-clip, +clip].
135 void UpdateLstmCellFloat(int n_batch, int n_cell, float *cell_state, const float *input_gate,
136                          float *forget_gate, const float *cell_gate, bool use_cifg, float clip)
137 {
138   // Define variable for 4th argument to avoid warning
139   // Compiler warning: passing argument 4 to restrict-qualified parameter aliases with argument 2
140   const float *cwise_product_rhs = cell_state;
141   VectorVectorCwiseProduct(forget_gate, cwise_product_rhs, n_batch * n_cell, cell_state);
142
143   if (use_cifg)
144   {
145     // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
146     // scratch, as input_gate array is not allocated in this case. (Be careful
147     // not to write to the scratch before reading the forget gate data.)
148     float *scratch = forget_gate;
149     Sub1Vector(forget_gate, n_batch * n_cell, scratch);
150     VectorVectorCwiseProductAccumulate(cell_gate, scratch, n_batch * n_cell, cell_state);
151   }
152   else
153   {
154     VectorVectorCwiseProductAccumulate(cell_gate, input_gate, n_batch * n_cell, cell_state);
155   }
156   if (clip > 0.0f)
157   {
158     CwiseClipping(cell_state, n_batch * n_cell, clip);
159   }
160 }
161
162 // Calculates the output state tensor of an LSTM step.
163 //
164 // Implements the following formula:
165 //   output_no_projection = output_gate .* activate(cell_state)
166 //     (elementwise vector product)
167 // If no projection is used:
168 //   output = output_state = output_no_projection
169 // With projection:
170 //   output = output_state = clip(W*output_no_projection + bias)
171 //
172 // Output might not have a different 'stride' than n_batch, so we need to copy.
173 //
174 // Parameters:
175 //  - n_batch: batches: the number of distinct vectors in each array.
176 //  - n_cell, n_output: sizes of vectors.
177 //  - cell_state, output_gate: input vectors, size n_batch*n_cell.
178 //  - projection_weights, projection_weights_scale, projection_bias:
179 //      constant inputs, describing projection matrix and bias.
180 //  - proj_clip: if > 0, clip the output of the projection.
181 //  - output_state: output vector, size n_batch*n_output. Must be contigous.
182 //  - scratch: scratch area, size n_batch*n_cell.
183 void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output, const float *cell_state,
184                               const float *output_gate, FusedActivationFunctionType activation,
185                               const float *projection_weights, const float *projection_bias,
186                               const float proj_clip, float *output_state, float *scratch)
187 {
188   ApplyActivationToVector(cell_state, n_batch * n_cell, activation, scratch);
189
190   // Define variable for 4th argument to avoid warning
191   // Compiler warning: passing argument 4 to restrict-qualified parameter aliases with argument 2
192   const float *cwise_product_rhs = scratch;
193   VectorVectorCwiseProduct(output_gate, cwise_product_rhs, n_batch * n_cell, scratch);
194
195   const bool use_projection = (projection_weights != nullptr);
196   const bool use_projection_bias = (projection_bias != nullptr);
197
198   if (use_projection)
199   {
200     if (use_projection_bias)
201     {
202       VectorBatchVectorAssign(projection_bias, n_output, n_batch, output_state);
203     }
204     else
205     {
206       std::fill_n(output_state, n_batch * n_output, 0.0f);
207     }
208     MatrixBatchVectorMultiplyAccumulate(projection_weights, n_output, n_cell, scratch, n_batch,
209                                         output_state, /*result_stride=*/1);
210     if (proj_clip > 0.0f)
211     {
212       CwiseClipping(output_state, n_batch * n_output, proj_clip);
213     }
214   }
215   else
216   {
217     std::copy_n(scratch, n_batch * n_output, output_state);
218   }
219 }
220
221 // Performs an LSTM batch inference step for input specified by input_ptr.
222 // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
223 // biases (*_bias_ptr), and buffers (*_scratch), along with additional
224 // parameters:
225 //  - params: various LSTM params including activation, clipping, etc.,
226 //  - n_batch: size of batch,
227 //  - n_cell: number of cells (or units),
228 //  - n_input: the input size,
229 //  - n_aux_input: the auxiliary input size.
230 //  - n_output: the output size.
231 //  - output_batch_leading_dim: the leading dimension of the output buffer.
232 //
233 // Input of size 'n_batch * n_input':
234 //   input_ptr
235 // Input of size 'n_batch * n_aux_input':
236 //   aux_input_ptr                     - optional (can be nullptr)
237 //
238 // LSTM weights:
239 // Input weights of size 'n_cell * n_input':
240 //   input_to_input_weights            - optional
241 //   input_to_forget_weights
242 //   input_to_cell_weights
243 //   input_to_output_weights
244 // Auxiliary input weights of size 'n_cell * n_aux_input':
245 //   aux_input_to_input_weights        - optional
246 //   aux_input_to_forget_weights       - optional
247 //   aux_input_to_cell_weights         - optional
248 //   aux_input_to_output_weights       - optional
249 // Recurrent weights of size 'n_cell * n_output':
250 //   recurrent_to_input_weights        - optional
251 //   recurrent_to_forget_weights
252 //   recurrent_to_cell_weights
253 //   recurrent_to_input_weights
254 // Peephole weights of size 'n_cell', representing diagonal matrices.
255 //   cell_to_input_weights             - optional
256 //   cell_to_cell_weights              - optional
257 //   cell_to_output_weights            - optional
258 // Projection weights of size 'n_output * n_cell'
259 //   projection_weights_ptr            - optional
260 // Gate biases of size 'n_cell':
261 //   input_gate_bias_ptr               - optional
262 //   forget_gate_bias_ptr
263 //   cell_gate_bias_ptr
264 //   output_gate_bias_ptr
265 //
266 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
267 //   input_layer_norm_coefficients_ptr  - optional
268 //   forget_layer_norm_coefficients_ptr - optional
269 //   cell_layer_norm_coefficients_ptr   - optional
270 //   output_layer_norm_coefficients_ptr - optional
271 //
272 // The pointers to the cell and output state and the output are updated.
273 //
274 // The pointers input_ptr, aux_input_ptr, and output_ptr point to data aligned
275 // in batch_major order, and each step processes batch_size many inputs from
276 // input_ptr, and updates batch_size many cell and output states.
277 //
278 // The output_batch_dim is output.shape[-1], i.e. the outermost dimension of the
279 // output tensor, and in most cases will be equal to n_output. It is usually not
280 // when we want to store the LSTM output into a slice of the output tensor, e.g.
281 // for bidirectional LSTMs with merge_outputs. In this case, the batched
282 // operations cannot be used since they assume that the batched outputs are
283 // contiguous, and we manually loop over the batched outputs.
284 // LINT.IfChange
285 inline void LstmStepFloat(
286   const float *input_ptr, const float *input_to_input_weights_ptr,
287   const float *input_to_forget_weights_ptr, const float *input_to_cell_weights_ptr,
288   const float *input_to_output_weights_ptr, const float *aux_input_ptr,
289   const float *aux_input_to_input_weights_ptr, const float *aux_input_to_forget_weights_ptr,
290   const float *aux_input_to_cell_weights_ptr, const float *aux_input_to_output_weights_ptr,
291   const float *recurrent_to_input_weights_ptr, const float *recurrent_to_forget_weights_ptr,
292   const float *recurrent_to_cell_weights_ptr, const float *recurrent_to_output_weights_ptr,
293   const float *cell_to_input_weights_ptr, const float *cell_to_forget_weights_ptr,
294   const float *cell_to_output_weights_ptr, const float *input_layer_norm_coefficients_ptr,
295   const float *forget_layer_norm_coefficients_ptr, const float *cell_layer_norm_coefficients_ptr,
296   const float *output_layer_norm_coefficients_ptr, const float *input_gate_bias_ptr,
297   const float *forget_gate_bias_ptr, const float *cell_gate_bias_ptr,
298   const float *output_gate_bias_ptr, const float *projection_weights_ptr,
299   const float *projection_bias_ptr, const LSTMParams *params, int n_batch, int n_cell, int n_input,
300   int n_aux_input, int n_output, int output_batch_leading_dim, float *output_state_ptr,
301   float *cell_state_ptr, float *scratch0, float *scratch1, float *scratch2, float *scratch3,
302   float *output_ptr)
303 {
304   // Since we have already checked that weights are all there or none, we can
305   // check the existence of only one to the get the condition.
306   const bool use_cifg = (input_to_input_weights_ptr == nullptr);
307
308   // Make named scratch buffers.
309   float *input_gate_scratch = scratch0;
310   float *forget_gate_scratch = scratch1;
311   float *cell_gate_scratch = scratch2;
312   float *output_gate_scratch = scratch3;
313
314   // Check if inputs are all zeros so we can skip some computations.
315   const bool is_input_all_zeros = IsZeroVector(input_ptr, n_batch * n_input);
316   const bool is_aux_input_all_zeros =
317     (aux_input_ptr == nullptr || IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
318   if (!use_cifg)
319   {
320     // Calculate the input gate. (If not CIFG.)
321     CalculateLstmGateFloat(input_ptr, input_to_input_weights_ptr, aux_input_ptr,
322                            aux_input_to_input_weights_ptr, output_state_ptr,
323                            recurrent_to_input_weights_ptr, cell_state_ptr,
324                            cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
325                            input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
326                            /*activation=kTfLiteActSigmoid*/ FusedActivationFunctionType::kSigmoid,
327                            input_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros);
328   }
329   // Calculate the forget gate.
330   CalculateLstmGateFloat(input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
331                          aux_input_to_forget_weights_ptr, output_state_ptr,
332                          recurrent_to_forget_weights_ptr, cell_state_ptr,
333                          cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
334                          forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
335                          /*activation=kTfLiteActSigmoid*/ FusedActivationFunctionType::kSigmoid,
336                          forget_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros);
337   // Calculate the cell update gate.
338   CalculateLstmGateFloat(
339     input_ptr, input_to_cell_weights_ptr, aux_input_ptr, aux_input_to_cell_weights_ptr,
340     output_state_ptr, recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
341     /*cell_to_gate_weights=*/nullptr, cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr, n_batch,
342     n_input, n_aux_input, n_output, n_cell, params->activation, cell_gate_scratch,
343     is_input_all_zeros, is_aux_input_all_zeros);
344   // Update the cell state.
345   UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
346                       cell_gate_scratch, use_cifg, params->cell_clip);
347   // Calculate output gate.
348   CalculateLstmGateFloat(input_ptr, input_to_output_weights_ptr, aux_input_ptr,
349                          aux_input_to_output_weights_ptr, output_state_ptr,
350                          recurrent_to_output_weights_ptr, cell_state_ptr,
351                          cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
352                          output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
353                          /*activation=kTfLiteActSigmoid*/ FusedActivationFunctionType::kSigmoid,
354                          output_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros);
355   // Update the output state.
356   CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
357                            params->activation, projection_weights_ptr, projection_bias_ptr,
358                            params->proj_clip, output_state_ptr, scratch2);
359   // Copy output state to the output. Note that the output's rows may not be
360   // contiguous (output_batch_leading_dim != n_output).
361   for (int b = 0; b < n_batch; b++)
362   {
363     std::copy_n(output_state_ptr + b * n_output, n_output,
364                 output_ptr + b * output_batch_leading_dim);
365   }
366 }
367
368 } // namespace cker
369 } // namespace nnfw
370
371 #endif // __NNFW_CKER_UNIDIRECTIONALSEQUENCELSTM_H__