2 // Copyright (c) 2019 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "api/primitive.hpp"
23 /// @addtogroup cpp_api C++ API
25 /// @addtogroup cpp_topology Network Topology
27 /// @addtogroup cpp_primitives Primitives
30 /// @brief Performs forward calcaulations of input gates for dynamic lstm layer.
31 /// @details The current implementation of LSTM_DYNAMIC is described the following equations.
32 /// it = f(Xt*(Wi^T) + Ht-1*Ri + Wbi)
33 /// ft = f(Xt*(Wf^T) + Ht-1*Rf + Wbf)
34 /// ct = g(Xt*(Wc^T) + Ht-1*Rc + Wbc)
35 /// Ct = ft (.) Ct-1 + it (.) ct
36 /// ot = f(Xt*(Wo^T) + Ht-1*Ro + Wbo)
38 /// Where f = Sigmoid, g = Tanh, and h = Tanh.
39 struct lstm_dynamic_timeloop
40 : public primitive_base<lstm_dynamic_timeloop> {
41 CLDNN_DECLARE_PRIMITIVE(lstm_dynamic_timeloop)
43 /// @brief Constructs lstm_dynamic layer.
44 /// @param id This primitive id.
45 /// @param input Primitive id of input layer.
46 /// @param dyn_length Primitive id of ilayer containg dynamic length values (shape: 1D).
47 /// @param recurrent Primitive id containing recurrent data.
48 /// @param initial_hidden Primitive id containing initial_hidden data. Provide empty string if using lstm_dynamic without initial_hidden values.
49 /// @param initial_cell Primitive id containing initial_cell data. Provide empty string if using lstm_dynamic without initial_cell values.
50 /// @param clip Clip threshold. Provide 0 if using lstm without activations clip threshold.
51 /// @param input_forget Provide 0 if using lstm without coupled input-forget gates.
52 lstm_dynamic_timeloop(const primitive_id& id,
53 const primitive_id& input,
54 const primitive_id& dyn_length,
55 const primitive_id& recurrent,
56 const primitive_id& last_hidden_state = "",
57 const primitive_id& last_cell_state = "",
58 const primitive_id& initial_hidden = "",
59 const primitive_id& initial_cell = "",
60 const float clip = 0.0f,
61 const bool input_forget = 0,
62 const padding& output_padding = padding())
63 : primitive_base(id, {input}, output_padding),
64 dyn_length(dyn_length),
66 last_hidden_state(last_hidden_state),
67 last_cell_state(last_cell_state),
68 initial_hidden(initial_hidden),
69 initial_cell(initial_cell),
71 input_forget(input_forget) {}
73 /// @brief Primitive id containing the dynamic sequence lengths.
74 primitive_id dyn_length;
75 /// @brief Primitive id containing recurrent data.
76 primitive_id recurrent;
77 /// @brief Primitive Id of mutable data primitive pointing to buffer, which will be filled with last hidden state.
78 primitive_id last_hidden_state;
79 /// @brief Primitive Id of mutable data primitive pointing to buffer, which will be filled with last cell state.
80 primitive_id last_cell_state;
81 /// @brief Primitive id containing the initial value of the hidden data.
82 primitive_id initial_hidden;
83 /// @brief Array of primitive ids containing the initial value of the hidden state data (Ht-1).
84 primitive_id initial_cell;
85 /// @brief Cell clip threshold T. It is applied to the input of activations [-T, T]. No clip is applied if it is not specified.
87 /// @brief Couple the input and forget gates if input_forget is 1. Default is 0.
91 std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
92 std::vector<std::reference_wrapper<const primitive_id>> ret;
93 ret.push_back(dyn_length);
94 ret.push_back(recurrent);
96 if (!last_hidden_state.empty()) {
97 ret.push_back(last_hidden_state);
99 if (!last_cell_state.empty()) {
100 ret.push_back(last_cell_state);
102 if (!initial_hidden.empty()) {
103 ret.push_back(initial_hidden);
105 if (!initial_cell.empty()) {
106 ret.push_back(initial_cell);