Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / lstm.hpp
index dd9e992..2276616 100644 (file)
@@ -29,14 +29,14 @@ namespace cldnn
 /// @{
 
 /// @brief Performs forward Long Short-Term Memory (LSTM) layer.
-/// @details The current implementation of LSTM supports Peepholes.
-///   it = f(Xt*(Wi^T) + Ht-1*Ri + Pi (.) Ct-1 + Wbi + Rbi)
-///   ft = f(Xt*(Wf^T) + Ht-1*Rf + Pf (.) Ct-1 + Wbf + Rbf)
-///   ct = g(Xt*(Wc^T) + Ht-1*Rc + Wbc + Rbc)
+/// @details The current implementation of LSTM is described the following equations.
+///   it = f(Xt*(Wi^T) + Ht-1*Ri + Wbi)
+///   ft = f(Xt*(Wf^T) + Ht-1*Rf + Wbf)
+///   ct = g(Xt*(Wc^T) + Ht-1*Rc + Wbc)
 ///   Ct = ft (.) Ct-1 + it (.) ct
-///   ot = f(Xt*(Wo^T) + Ht-1*Ro + Po (.) Ct + Wbo + Rbo)
+///   ot = f(Xt*(Wo^T) + Ht-1*Ro + Wbo)
 ///   Ht = ot (.) h(Ct)
-/// Where f=Sigmoid, g=Tanh, and h = Tanh.
+/// Where f = Sigmoid, g = Tanh, and h = Tanh.
 struct lstm : public primitive_base<lstm, CLDNN_PRIMITIVE_DESC(lstm)>
 {
     CLDNN_DECLARE_PRIMITIVE(lstm)
@@ -53,6 +53,7 @@ struct lstm : public primitive_base<lstm, CLDNN_PRIMITIVE_DESC(lstm)>
     /// @param input_forget Provide 0 if using lstm without coupled input-forget gates.
     /// @param activations Vector of activations. Specify [f, g, h]. Default are [sigmoid, tanh, tanh]
     /// @param activation_params Vector of ativation params. Specify params for each [f, g, h] activation.
+    /// @brief Output selection. Default the entire hidden sequence is returned.
     /// @param offset_order Order of the concatenated weights, recurrent, and bias. ONNX default is iofz [input, output, forget, block].
     lstm(
         const primitive_id& id,
@@ -67,6 +68,7 @@ struct lstm : public primitive_base<lstm, CLDNN_PRIMITIVE_DESC(lstm)>
         const bool input_forget = 0,
         const std::vector<cldnn_activation_func>& activations = {},
         const std::vector<cldnn_activation_additional_params> activation_params = {},
+        const cldnn_lstm_output output_selection = cldnn_lstm_output_sequence,
         const cldnn_lstm_offset_order offset_order = cldnn_lstm_offset_order_iofz,
         const padding& output_padding = padding()
         )
@@ -81,6 +83,7 @@ struct lstm : public primitive_base<lstm, CLDNN_PRIMITIVE_DESC(lstm)>
         , input_forget(input_forget)
         , activations(activations)
         , activation_params(activation_params)
+        , output_selection(output_selection)
         , offset_order(offset_order)
     {
     }
@@ -98,6 +101,7 @@ struct lstm : public primitive_base<lstm, CLDNN_PRIMITIVE_DESC(lstm)>
         , input_forget(dto->input_forget)
                , activations(dto->activations, std::end(dto->activations))
                , activation_params(dto->activation_params, std::end(dto->activation_params))
+        , output_selection(dto->output_selection)
         , offset_order(dto->offset_order)
     {
     }
@@ -122,6 +126,8 @@ struct lstm : public primitive_base<lstm, CLDNN_PRIMITIVE_DESC(lstm)>
     std::vector<cldnn_activation_func> activations;
     /// @brief Optional scaling values used by some activation functions. The values are consumed in the order of activation functions.
     std::vector<cldnn_activation_additional_params> activation_params;
+    /// @brief Output selection. Default the entire hidden sequence is returned.
+    cldnn_lstm_output output_selection;
     /// @brief Weights, recurrent weights, and biases order. [iofz] : ONNX, [ifoz] : Caffe
     cldnn_lstm_offset_order offset_order;
 
@@ -129,7 +135,7 @@ struct lstm : public primitive_base<lstm, CLDNN_PRIMITIVE_DESC(lstm)>
     // /// @brief Optional tensor specifying lengths of the sequences in a batch.
     // /// If not specified - assumed all sequences in the batch to have length `seq_length`. It has shape `[batch_size]`.
     // tensor sequence_lens;
-    // /// @brief The sequence output for the hidden??? This is not clearly specified in the ONNX definition.
+    // /// @brief The sequence output for the hidden.
     // uint32_t output_sequence;
 protected:
     std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override
@@ -160,6 +166,7 @@ protected:
         dto.peepholes = peepholes.c_str();
         dto.initial_hidden = initial_hidden.c_str();
         dto.initial_cell = initial_cell.c_str();
+        dto.output_selection = output_selection;
         dto.offset_order = offset_order;
         if (activations.size() == 3) {
             std::copy_n(activations.begin(), 3, dto.activations);
@@ -271,6 +278,7 @@ struct lstm_elt : public primitive_base<lstm_elt, CLDNN_PRIMITIVE_DESC(lstm_elt)
         const std::vector<cldnn_activation_func> activations = {},
         const std::vector<cldnn_activation_additional_params> activation_params = {},
         const cldnn_lstm_offset_order offset_order = cldnn_lstm_offset_order_iofz,
+        const uint32_t direction = 0,
         const padding& output_padding = padding()
         )
         : primitive_base(id, {input}, output_padding)
@@ -280,6 +288,7 @@ struct lstm_elt : public primitive_base<lstm_elt, CLDNN_PRIMITIVE_DESC(lstm_elt)
         , activations(activations)
         , activation_params(activation_params)
         , offset_order(offset_order)
+        , direction(direction)
     {
     }
 
@@ -292,6 +301,7 @@ struct lstm_elt : public primitive_base<lstm_elt, CLDNN_PRIMITIVE_DESC(lstm_elt)
                , activations(dto->activations, std::end(dto->activations))
                , activation_params(dto->activation_params, std::end(dto->activation_params))
         , offset_order(dto->offset_order)
+        , direction(dto->direction)
     {
     }
 
@@ -307,6 +317,9 @@ struct lstm_elt : public primitive_base<lstm_elt, CLDNN_PRIMITIVE_DESC(lstm_elt)
     std::vector<cldnn_activation_additional_params> activation_params;
     /// @brief Weights, recurrent weights, and biases order. [iofz] : ONNX, [ifoz] : Caffe
     cldnn_lstm_offset_order offset_order;
+    /// @brief direction default = 0, bidirectional = 1.
+    uint32_t direction;
+
 protected:
     std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override
     {
@@ -328,6 +341,7 @@ protected:
         if (activation_params.size() == 3) {
             std::copy_n(activation_params.begin(), 3, dto.activation_params);
         }
+        dto.direction = direction;
     }
 };