2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "../C/lstm.h"
20 #include "primitive.hpp"
24 /// @addtogroup cpp_api C++ API
26 /// @addtogroup cpp_topology Network Topology
28 /// @addtogroup cpp_primitives Primitives
31 /// @brief Performs forward Long Short-Term Memory (LSTM) layer.
32 /// @details The current implementation of LSTM is described the following equations.
33 /// it = f(Xt*(Wi^T) + Ht-1*Ri + Wbi)
34 /// ft = f(Xt*(Wf^T) + Ht-1*Rf + Wbf)
35 /// ct = g(Xt*(Wc^T) + Ht-1*Rc + Wbc)
36 /// Ct = ft (.) Ct-1 + it (.) ct
37 /// ot = f(Xt*(Wo^T) + Ht-1*Ro + Wbo)
39 /// Where f = Sigmoid, g = Tanh, and h = Tanh.
40 struct lstm : public primitive_base<lstm, CLDNN_PRIMITIVE_DESC(lstm)>
42 CLDNN_DECLARE_PRIMITIVE(lstm)
44 /// @brief Constructs lstm layer.
45 /// @param id This primitive id.
46 /// @param input Vector of primitive id.
47 /// @param weights Primitive id containing weights data.
48 /// @param bias Primitive id containing bias data. Provide empty string if using lstm without bias.
49 /// @param initial_hidden Primitive id containing initial_hidden data. Provide empty string if using lstm without initial_hidden values.
50 /// @param initial_cell Primitive id containing initial_cell data. Provide empty string if using lstm without initial_cell values.
51 /// @param peepholes Primitive id containing peepholes data. Provide empty string if using lstm without peepholes.
52 /// @param clip Clip threshold. Provide 0 if using lstm without activations clip threshold.
53 /// @param input_forget Provide 0 if using lstm without coupled input-forget gates.
54 /// @param activations Vector of activations. Specify [f, g, h]. Default are [sigmoid, tanh, tanh]
55 /// @param activation_params Vector of ativation params. Specify params for each [f, g, h] activation.
56 /// @brief Output selection. Default the entire hidden sequence is returned.
57 /// @param offset_order Order of the concatenated weights, recurrent, and bias. ONNX default is iofz [input, output, forget, block].
59 const primitive_id& id,
60 const std::vector<primitive_id>& input,
61 const primitive_id& weights,
62 const primitive_id& recurrent,
63 const primitive_id& bias = "",
64 const primitive_id& initial_hidden = "",
65 const primitive_id& initial_cell = "",
66 const primitive_id& peepholes = "",
68 const bool input_forget = 0,
69 const std::vector<cldnn_activation_func>& activations = {},
70 const std::vector<cldnn_activation_additional_params> activation_params = {},
71 const cldnn_lstm_output output_selection = cldnn_lstm_output_sequence,
72 const cldnn_lstm_offset_order offset_order = cldnn_lstm_offset_order_iofz,
73 const padding& output_padding = padding()
75 : primitive_base(id, input, output_padding)
77 , recurrent(recurrent)
79 , initial_hidden(initial_hidden)
80 , initial_cell(initial_cell)
81 , peepholes(peepholes)
83 , input_forget(input_forget)
84 , activations(activations)
85 , activation_params(activation_params)
86 , output_selection(output_selection)
87 , offset_order(offset_order)
91 /// @brief Constructs a copy from basic C API @CLDNN_PRIMITIVE_DESC{lstm}
94 , weights(dto->weights)
95 , recurrent(dto->recurrent)
97 , initial_hidden(dto->initial_hidden)
98 , initial_cell(dto->initial_cell)
99 , peepholes(dto->peepholes)
101 , input_forget(dto->input_forget)
102 , activations(dto->activations, std::end(dto->activations))
103 , activation_params(dto->activation_params, std::end(dto->activation_params))
104 , output_selection(dto->output_selection)
105 , offset_order(dto->offset_order)
109 /// @brief Primitive id containing weights data.
110 primitive_id weights;
111 /// @brief Primitive id containing recurrent data.
112 primitive_id recurrent;
113 /// @brief Primitive id containing bias data.
115 /// @brief Primitive id containing the initial value of the hidden data.
116 primitive_id initial_hidden;
117 /// @brief Primitive id containing the initial value of the cell state data.
118 primitive_id initial_cell;
119 /// @brief Primitive id containing peepholes data.
120 primitive_id peepholes;
121 /// @brief Cell clip threshold T. It is applied to the input of activations [-T, T]. No clip is applied if it is not specified.
123 /// @brief Couple the input and forget gates if input_forget is 1. Default is 0.
125 /// @brief A list of 3 activation functions for the input, output, forget, cell, and hidden.
126 std::vector<cldnn_activation_func> activations;
127 /// @brief Optional scaling values used by some activation functions. The values are consumed in the order of activation functions.
128 std::vector<cldnn_activation_additional_params> activation_params;
129 /// @brief Output selection. Default the entire hidden sequence is returned.
130 cldnn_lstm_output output_selection;
131 /// @brief Weights, recurrent weights, and biases order. [iofz] : ONNX, [ifoz] : Caffe
132 cldnn_lstm_offset_order offset_order;
135 // /// @brief Optional tensor specifying lengths of the sequences in a batch.
136 // /// If not specified - assumed all sequences in the batch to have length `seq_length`. It has shape `[batch_size]`.
137 // tensor sequence_lens;
138 // /// @brief The sequence output for the hidden.
139 // uint32_t output_sequence;
141 std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override
143 std::vector<std::reference_wrapper<const primitive_id>> ret;
144 ret.push_back(weights);
145 ret.push_back(recurrent);
150 if (!initial_hidden.empty())
152 ret.push_back(initial_hidden);
154 if (!initial_cell.empty())
156 ret.push_back(initial_cell);
161 void update_dto(dto& dto) const override
163 dto.weights = weights.c_str();
164 dto.recurrent = recurrent.c_str();
165 dto.bias = bias.c_str();
166 dto.peepholes = peepholes.c_str();
167 dto.initial_hidden = initial_hidden.c_str();
168 dto.initial_cell = initial_cell.c_str();
169 dto.output_selection = output_selection;
170 dto.offset_order = offset_order;
171 if (activations.size() == 3) {
172 std::copy_n(activations.begin(), 3, dto.activations);
174 if (activation_params.size() == 3) {
175 std::copy_n(activation_params.begin(), 3, dto.activation_params);
178 dto.input_forget = input_forget;
182 struct lstm_gemm : public primitive_base<lstm_gemm, CLDNN_PRIMITIVE_DESC(lstm_gemm)>
184 CLDNN_DECLARE_PRIMITIVE(lstm_gemm)
186 /// @brief Constructs lstm layer.
187 /// @param id This primitive id.
188 /// @param input input primitive id.
189 /// @param input weights Primitive id containing weights data.
190 /// @param input recurrent Primitive id containing recurrent data. It is required even for no hidden values.
191 /// @param input bias Primitive id containing bias data. Provide empty string if using lstm without bias.
192 /// @param input hidden Primitive id containing hidden data. Provide empty string if using lstm without hidden values.
193 /// @param direction default = 0, bidirectional = 1.
195 const primitive_id& id,
196 const primitive_id& input,
197 const primitive_id& weights,
198 const primitive_id& recurrent,
199 const primitive_id& bias = "",
200 const primitive_id& hidden = "",
201 const uint32_t direction = 0,
202 const padding& output_padding = padding()
204 : primitive_base(id, {input}, output_padding)
206 , recurrent(recurrent)
209 , direction(direction)
213 /// @brief Constructs a copy from basic C API @CLDNN_PRIMITIVE_DESC{lstm}
214 lstm_gemm(const dto* dto)
215 : primitive_base(dto)
216 , weights(dto->weights)
217 , recurrent(dto->recurrent)
219 , hidden(dto->hidden)
220 , direction(dto->direction)
224 /// @brief Primitive id containing weights data.
225 primitive_id weights;
226 /// @brief Primitive id containing recurrent data.
227 primitive_id recurrent;
228 /// @brief Primitive id containing bias data.
230 /// @brief Primitive id containing the initial value of the hidden data.
232 /// @brief direction default = 0, bidirectional = 1.
236 std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override
238 std::vector<std::reference_wrapper<const primitive_id>> ret;
239 ret.push_back(weights);
240 ret.push_back(recurrent);
244 ret.push_back(hidden);
248 void update_dto(dto& dto) const override
250 dto.weights = weights.c_str();
251 dto.recurrent = recurrent.c_str();
252 dto.bias = bias.c_str();
253 dto.hidden = hidden.c_str();
254 dto.direction = direction;
258 struct lstm_elt : public primitive_base<lstm_elt, CLDNN_PRIMITIVE_DESC(lstm_elt)>
260 CLDNN_DECLARE_PRIMITIVE(lstm_elt)
261 using vec_activation = std::vector<cldnn_activation_func>;
262 using vec_activation_param = std::vector<cldnn_activation_additional_params>;
264 /// @brief Constructs lstm layer.
265 /// @param id This primitive id.
266 /// @param input input primitive id.
267 /// @param input cell Primitive id containing cell data. Provide empty string if using lstm without cell values.
268 /// @param clip Clip threshold. Provide 0 if using lstm without activations clip threshold.
269 /// @param input_forget Provide 0 if using lstm without coupled input-forget gates.
270 /// @param offset_order. Order of the concatenated weights, recurrent, and bias. ONNX default is iofz [input, output, forget, block].
271 /// @param direction default = 0, bidirectional = 1.
273 const primitive_id& id,
274 const primitive_id& input,
275 const primitive_id& cell = "",
276 const float clip = 0,
277 const bool input_forget = 0,
278 const std::vector<cldnn_activation_func> activations = {},
279 const std::vector<cldnn_activation_additional_params> activation_params = {},
280 const cldnn_lstm_offset_order offset_order = cldnn_lstm_offset_order_iofz,
281 const uint32_t direction = 0,
282 const padding& output_padding = padding()
284 : primitive_base(id, {input}, output_padding)
287 , input_forget(input_forget)
288 , activations(activations)
289 , activation_params(activation_params)
290 , offset_order(offset_order)
291 , direction(direction)
295 /// @brief Constructs a copy from basic C API @CLDNN_PRIMITIVE_DESC{lstm}
296 lstm_elt(const dto* dto)
297 : primitive_base(dto)
300 , input_forget(dto->input_forget)
301 , activations(dto->activations, std::end(dto->activations))
302 , activation_params(dto->activation_params, std::end(dto->activation_params))
303 , offset_order(dto->offset_order)
304 , direction(dto->direction)
308 /// @brief Primitive id containing the initial value of the cell state data.
310 /// @brief Cell clip threshold T. It is applied to the input of activations [-T, T]. No clip is applied if it is not specified.
312 /// @brief Couple the input and forget gates if input_forget is 1. Default is 0.
314 /// @brief A list of 3 activation functions for the input, output, forget, cell, and hidden.
315 std::vector<cldnn_activation_func> activations;
316 /// @brief Optional scaling values used by some activation functions. The values are consumed in the order of activation functions.
317 std::vector<cldnn_activation_additional_params> activation_params;
318 /// @brief Weights, recurrent weights, and biases order. [iofz] : ONNX, [ifoz] : Caffe
319 cldnn_lstm_offset_order offset_order;
320 /// @brief direction default = 0, bidirectional = 1.
324 std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override
326 std::vector<std::reference_wrapper<const primitive_id>> ret;
332 void update_dto(dto& dto) const override
334 dto.cell = cell.c_str();
335 dto.offset_order = offset_order;
337 dto.input_forget = input_forget;
338 if (activations.size() == 3) {
339 std::copy_n(activations.begin(), 3, dto.activations);
341 if (activation_params.size() == 3) {
342 std::copy_n(activation_params.begin(), 3, dto.activation_params);
344 dto.direction = direction;