Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api_extension / lstm_dynamic_timeloop.hpp
1 /*
2 // Copyright (c) 2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include "api/primitive.hpp"
20 #include <vector>
21
22 namespace cldnn {
23 /// @addtogroup cpp_api C++ API
24 /// @{
25 /// @addtogroup cpp_topology Network Topology
26 /// @{
27 /// @addtogroup cpp_primitives Primitives
28 /// @{
29
30 /// @brief Performs forward calcaulations of input gates for dynamic lstm layer.
31 /// @details The current implementation of LSTM_DYNAMIC is described the following equations.
32 ///   it = f(Xt*(Wi^T) + Ht-1*Ri + Wbi)
33 ///   ft = f(Xt*(Wf^T) + Ht-1*Rf + Wbf)
34 ///   ct = g(Xt*(Wc^T) + Ht-1*Rc + Wbc)
35 ///   Ct = ft (.) Ct-1 + it (.) ct
36 ///   ot = f(Xt*(Wo^T) + Ht-1*Ro + Wbo)
37 ///   Ht = ot (.) h(Ct)
38 /// Where f = Sigmoid, g = Tanh, and h = Tanh.
39 struct lstm_dynamic_timeloop
40     : public primitive_base<lstm_dynamic_timeloop> {
41     CLDNN_DECLARE_PRIMITIVE(lstm_dynamic_timeloop)
42
43     /// @brief Constructs lstm_dynamic layer.
44     /// @param id This primitive id.
45     /// @param input Primitive id of input layer.
46     /// @param dyn_length Primitive id of ilayer containg dynamic length values (shape: 1D).
47     /// @param recurrent Primitive id containing recurrent data.
48     /// @param initial_hidden Primitive id containing initial_hidden data. Provide empty string if using lstm_dynamic without initial_hidden values.
49     /// @param initial_cell Primitive id containing initial_cell data. Provide empty string if using lstm_dynamic without initial_cell values.
50     /// @param clip Clip threshold. Provide 0 if using lstm without activations clip threshold.
51     /// @param input_forget Provide 0 if using lstm without coupled input-forget gates.
52     lstm_dynamic_timeloop(const primitive_id& id,
53                           const primitive_id& input,
54                           const primitive_id& dyn_length,
55                           const primitive_id& recurrent,
56                           const primitive_id& last_hidden_state = "",
57                           const primitive_id& last_cell_state = "",
58                           const primitive_id& initial_hidden = "",
59                           const primitive_id& initial_cell = "",
60                           const float clip = 0.0f,
61                           const bool input_forget = 0,
62                           const padding& output_padding = padding())
63         : primitive_base(id, {input}, output_padding),
64           dyn_length(dyn_length),
65           recurrent(recurrent),
66           last_hidden_state(last_hidden_state),
67           last_cell_state(last_cell_state),
68           initial_hidden(initial_hidden),
69           initial_cell(initial_cell),
70           clip(clip),
71           input_forget(input_forget) {}
72
73     /// @brief Primitive id containing the dynamic sequence lengths.
74     primitive_id dyn_length;
75     /// @brief Primitive id containing recurrent data.
76     primitive_id recurrent;
77     /// @brief Primitive Id of mutable data primitive pointing to buffer, which will be filled with last hidden state.
78     primitive_id last_hidden_state;
79     /// @brief Primitive Id of mutable data primitive pointing to buffer, which will be filled with last cell state.
80     primitive_id last_cell_state;
81     /// @brief Primitive id containing the initial value of the hidden data.
82     primitive_id initial_hidden;
83     /// @brief Array of primitive ids containing the initial value of the hidden state data (Ht-1).
84     primitive_id initial_cell;
85     /// @brief Cell clip threshold T. It is applied to the input of activations [-T, T]. No clip is applied if it is not specified.
86     float clip;
87     /// @brief Couple the input and forget gates if input_forget is 1. Default is 0.
88     bool input_forget;
89
90 protected:
91     std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
92         std::vector<std::reference_wrapper<const primitive_id>> ret;
93         ret.push_back(dyn_length);
94         ret.push_back(recurrent);
95
96         if (!last_hidden_state.empty()) {
97             ret.push_back(last_hidden_state);
98         }
99         if (!last_cell_state.empty()) {
100             ret.push_back(last_cell_state);
101         }
102         if (!initial_hidden.empty()) {
103             ret.push_back(initial_hidden);
104         }
105         if (!initial_cell.empty()) {
106             ret.push_back(initial_cell);
107         }
108         return ret;
109     }
110 };
111 /// @}
112 /// @}
113 /// @}
114 }  // namespace cldnn