Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / lstm.hpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include "../C/lstm.h"
20 #include "primitive.hpp"
21
22 namespace cldnn
23 {
24 /// @addtogroup cpp_api C++ API
25 /// @{
26 /// @addtogroup cpp_topology Network Topology
27 /// @{
28 /// @addtogroup cpp_primitives Primitives
29 /// @{
30
31 /// @brief Performs forward Long Short-Term Memory (LSTM) layer.
32 /// @details The current implementation of LSTM is described the following equations.
33 ///   it = f(Xt*(Wi^T) + Ht-1*Ri + Wbi)
34 ///   ft = f(Xt*(Wf^T) + Ht-1*Rf + Wbf)
35 ///   ct = g(Xt*(Wc^T) + Ht-1*Rc + Wbc)
36 ///   Ct = ft (.) Ct-1 + it (.) ct
37 ///   ot = f(Xt*(Wo^T) + Ht-1*Ro + Wbo)
38 ///   Ht = ot (.) h(Ct)
39 /// Where f = Sigmoid, g = Tanh, and h = Tanh.
40 struct lstm : public primitive_base<lstm, CLDNN_PRIMITIVE_DESC(lstm)>
41 {
42     CLDNN_DECLARE_PRIMITIVE(lstm)
43
44     /// @brief Constructs lstm layer.
45     /// @param id This primitive id.
46     /// @param input Vector of primitive id.
47     /// @param weights Primitive id containing weights data.
48     /// @param bias Primitive id containing bias data. Provide empty string if using lstm without bias.
49     /// @param initial_hidden Primitive id containing initial_hidden data. Provide empty string if using lstm without initial_hidden values.
50     /// @param initial_cell Primitive id containing initial_cell data. Provide empty string if using lstm without initial_cell values.
51     /// @param peepholes Primitive id containing peepholes data. Provide empty string if using lstm without peepholes.
52     /// @param clip Clip threshold. Provide 0 if using lstm without activations clip threshold.
53     /// @param input_forget Provide 0 if using lstm without coupled input-forget gates.
54     /// @param activations Vector of activations. Specify [f, g, h]. Default are [sigmoid, tanh, tanh]
55     /// @param activation_params Vector of ativation params. Specify params for each [f, g, h] activation.
56     /// @brief Output selection. Default the entire hidden sequence is returned.
57     /// @param offset_order Order of the concatenated weights, recurrent, and bias. ONNX default is iofz [input, output, forget, block].
58     lstm(
59         const primitive_id& id,
60         const std::vector<primitive_id>& input,
61         const primitive_id& weights,
62         const primitive_id& recurrent,
63         const primitive_id& bias = "",
64         const primitive_id& initial_hidden = "",
65         const primitive_id& initial_cell = "",
66         const primitive_id& peepholes = "",
67         const float clip = 0,
68         const bool input_forget = 0,
69         const std::vector<cldnn_activation_func>& activations = {},
70         const std::vector<cldnn_activation_additional_params> activation_params = {},
71         const cldnn_lstm_output output_selection = cldnn_lstm_output_sequence,
72         const cldnn_lstm_offset_order offset_order = cldnn_lstm_offset_order_iofz,
73         const padding& output_padding = padding()
74         )
75         : primitive_base(id, input, output_padding)
76         , weights(weights)
77         , recurrent(recurrent)
78         , bias(bias)
79         , initial_hidden(initial_hidden)
80         , initial_cell(initial_cell)
81         , peepholes(peepholes)
82         , clip(clip)
83         , input_forget(input_forget)
84         , activations(activations)
85         , activation_params(activation_params)
86         , output_selection(output_selection)
87         , offset_order(offset_order)
88     {
89     }
90
91     /// @brief Constructs a copy from basic C API @CLDNN_PRIMITIVE_DESC{lstm}
92     lstm(const dto* dto)
93         : primitive_base(dto)
94         , weights(dto->weights)
95         , recurrent(dto->recurrent)
96         , bias(dto->bias)
97         , initial_hidden(dto->initial_hidden)
98         , initial_cell(dto->initial_cell)
99         , peepholes(dto->peepholes)
100         , clip(dto->clip)
101         , input_forget(dto->input_forget)
102                 , activations(dto->activations, std::end(dto->activations))
103                 , activation_params(dto->activation_params, std::end(dto->activation_params))
104         , output_selection(dto->output_selection)
105         , offset_order(dto->offset_order)
106     {
107     }
108
109     /// @brief Primitive id containing weights data.
110     primitive_id weights;
111     /// @brief Primitive id containing recurrent data.
112     primitive_id recurrent;
113     /// @brief Primitive id containing bias data.
114     primitive_id bias;
115     /// @brief Primitive id containing the initial value of the hidden data.
116     primitive_id initial_hidden;
117     /// @brief Primitive id containing the initial value of the cell state data.
118     primitive_id initial_cell;
119     /// @brief Primitive id containing peepholes data.
120     primitive_id peepholes;
121     /// @brief Cell clip threshold T. It is applied to the input of activations [-T, T]. No clip is applied if it is not specified.
122     float clip;
123     /// @brief Couple the input and forget gates if input_forget is 1. Default is 0.
124     bool input_forget;
125     /// @brief A list of 3 activation functions for the input, output, forget, cell, and hidden.
126     std::vector<cldnn_activation_func> activations;
127     /// @brief Optional scaling values used by some activation functions. The values are consumed in the order of activation functions.
128     std::vector<cldnn_activation_additional_params> activation_params;
129     /// @brief Output selection. Default the entire hidden sequence is returned.
130     cldnn_lstm_output output_selection;
131     /// @brief Weights, recurrent weights, and biases order. [iofz] : ONNX, [ifoz] : Caffe
132     cldnn_lstm_offset_order offset_order;
133
134     // NOT SUPPORTED YET
135     // /// @brief Optional tensor specifying lengths of the sequences in a batch.
136     // /// If not specified - assumed all sequences in the batch to have length `seq_length`. It has shape `[batch_size]`.
137     // tensor sequence_lens;
138     // /// @brief The sequence output for the hidden.
139     // uint32_t output_sequence;
140 protected:
141     std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override
142     {
143         std::vector<std::reference_wrapper<const primitive_id>> ret;
144         ret.push_back(weights);
145         ret.push_back(recurrent);
146         if (!bias.empty())
147         {
148             ret.push_back(bias);
149         }
150         if (!initial_hidden.empty())
151         {
152             ret.push_back(initial_hidden);
153         }
154         if (!initial_cell.empty())
155         {
156             ret.push_back(initial_cell);
157         }
158         return ret;
159     }
160
161     void update_dto(dto& dto) const override
162     {
163         dto.weights = weights.c_str();
164         dto.recurrent = recurrent.c_str();
165         dto.bias = bias.c_str();
166         dto.peepholes = peepholes.c_str();
167         dto.initial_hidden = initial_hidden.c_str();
168         dto.initial_cell = initial_cell.c_str();
169         dto.output_selection = output_selection;
170         dto.offset_order = offset_order;
171         if (activations.size() == 3) {
172             std::copy_n(activations.begin(), 3, dto.activations);
173         }
174         if (activation_params.size() == 3) {
175             std::copy_n(activation_params.begin(), 3, dto.activation_params);
176         }
177         dto.clip = clip;
178         dto.input_forget = input_forget;
179     }
180 };
181
182 struct lstm_gemm : public primitive_base<lstm_gemm, CLDNN_PRIMITIVE_DESC(lstm_gemm)>
183 {
184     CLDNN_DECLARE_PRIMITIVE(lstm_gemm)
185
186     /// @brief Constructs lstm layer.
187     /// @param id This primitive id.
188     /// @param input input primitive id.
189     /// @param input weights Primitive id containing weights data.
190     /// @param input recurrent Primitive id containing recurrent data. It is required even for no hidden values.
191     /// @param input bias Primitive id containing bias data. Provide empty string if using lstm without bias.
192     /// @param input hidden Primitive id containing hidden data. Provide empty string if using lstm without hidden values.
193     /// @param direction default = 0, bidirectional = 1.
194     lstm_gemm(
195         const primitive_id& id,
196         const primitive_id& input,
197         const primitive_id& weights,
198         const primitive_id& recurrent,
199         const primitive_id& bias = "",
200         const primitive_id& hidden = "",
201         const uint32_t direction = 0,
202         const padding& output_padding = padding()
203         )
204         : primitive_base(id, {input}, output_padding)
205         , weights(weights)
206         , recurrent(recurrent)
207         , bias(bias)
208         , hidden(hidden)
209         , direction(direction)
210     {
211     }
212
213     /// @brief Constructs a copy from basic C API @CLDNN_PRIMITIVE_DESC{lstm}
214     lstm_gemm(const dto* dto)
215         : primitive_base(dto)
216         , weights(dto->weights)
217         , recurrent(dto->recurrent)
218         , bias(dto->bias)
219         , hidden(dto->hidden)
220         , direction(dto->direction)
221     {
222     }
223
224     /// @brief Primitive id containing weights data.
225     primitive_id weights;
226     /// @brief Primitive id containing recurrent data.
227     primitive_id recurrent;
228     /// @brief Primitive id containing bias data.
229     primitive_id bias;
230     /// @brief Primitive id containing the initial value of the hidden data.
231     primitive_id hidden;
232     /// @brief direction default = 0, bidirectional = 1.
233     uint32_t direction;
234
235 protected:
236     std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override
237     {
238         std::vector<std::reference_wrapper<const primitive_id>> ret;
239         ret.push_back(weights);
240         ret.push_back(recurrent);
241         if (!bias.empty())
242             ret.push_back(bias);
243         if (!hidden.empty())
244             ret.push_back(hidden);
245         return ret;
246     }
247
248     void update_dto(dto& dto) const override
249     {
250         dto.weights = weights.c_str();
251         dto.recurrent = recurrent.c_str();
252         dto.bias = bias.c_str();
253         dto.hidden = hidden.c_str();
254         dto.direction = direction;
255     }
256 };
257
258 struct lstm_elt : public primitive_base<lstm_elt, CLDNN_PRIMITIVE_DESC(lstm_elt)>
259 {
260     CLDNN_DECLARE_PRIMITIVE(lstm_elt)
261     using vec_activation = std::vector<cldnn_activation_func>;
262     using vec_activation_param = std::vector<cldnn_activation_additional_params>;
263
264     /// @brief Constructs lstm layer.
265     /// @param id This primitive id.
266     /// @param input input primitive id.
267     /// @param input cell Primitive id containing cell data. Provide empty string if using lstm without cell values.
268     /// @param clip Clip threshold. Provide 0 if using lstm without activations clip threshold.
269     /// @param input_forget Provide 0 if using lstm without coupled input-forget gates.
270     /// @param offset_order. Order of the concatenated weights, recurrent, and bias. ONNX default is iofz [input, output, forget, block].
271     /// @param direction default = 0, bidirectional = 1.
272     lstm_elt(
273         const primitive_id& id,
274         const primitive_id& input,
275         const primitive_id& cell = "",
276         const float clip = 0,
277         const bool input_forget = 0,
278         const std::vector<cldnn_activation_func> activations = {},
279         const std::vector<cldnn_activation_additional_params> activation_params = {},
280         const cldnn_lstm_offset_order offset_order = cldnn_lstm_offset_order_iofz,
281         const uint32_t direction = 0,
282         const padding& output_padding = padding()
283         )
284         : primitive_base(id, {input}, output_padding)
285         , cell(cell)
286         , clip(clip)
287         , input_forget(input_forget)
288         , activations(activations)
289         , activation_params(activation_params)
290         , offset_order(offset_order)
291         , direction(direction)
292     {
293     }
294
295     /// @brief Constructs a copy from basic C API @CLDNN_PRIMITIVE_DESC{lstm}
296     lstm_elt(const dto* dto)
297         : primitive_base(dto)
298         , cell(dto->cell)
299         , clip(dto->clip)
300         , input_forget(dto->input_forget)
301                 , activations(dto->activations, std::end(dto->activations))
302                 , activation_params(dto->activation_params, std::end(dto->activation_params))
303         , offset_order(dto->offset_order)
304         , direction(dto->direction)
305     {
306     }
307
308     /// @brief Primitive id containing the initial value of the cell state data.
309     primitive_id cell;
310     /// @brief Cell clip threshold T. It is applied to the input of activations [-T, T]. No clip is applied if it is not specified.
311     float clip;
312     /// @brief Couple the input and forget gates if input_forget is 1. Default is 0.
313     bool input_forget;
314     /// @brief A list of 3 activation functions for the input, output, forget, cell, and hidden.
315     std::vector<cldnn_activation_func> activations;
316     /// @brief Optional scaling values used by some activation functions. The values are consumed in the order of activation functions.
317     std::vector<cldnn_activation_additional_params> activation_params;
318     /// @brief Weights, recurrent weights, and biases order. [iofz] : ONNX, [ifoz] : Caffe
319     cldnn_lstm_offset_order offset_order;
320     /// @brief direction default = 0, bidirectional = 1.
321     uint32_t direction;
322
323 protected:
324     std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override
325     {
326         std::vector<std::reference_wrapper<const primitive_id>> ret;
327         if (!cell.empty())
328             ret.push_back(cell);
329         return ret;
330     }
331
332     void update_dto(dto& dto) const override
333     {
334         dto.cell = cell.c_str();
335         dto.offset_order = offset_order;
336         dto.clip = clip;
337         dto.input_forget = input_forget;
338         if (activations.size() == 3) {
339             std::copy_n(activations.begin(), 3, dto.activations);
340         }
341         if (activation_params.size() == 3) {
342             std::copy_n(activation_params.begin(), 3, dto.activation_params);
343         }
344         dto.direction = direction;
345     }
346 };
347
348 /// @}
349 /// @}
350 /// @}
351 }