Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_rnn.cpp
index ba32285..af11763 100644 (file)
@@ -1,11 +1,10 @@
-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
 // SPDX-License-Identifier: Apache-2.0
 //
 
 #include "mkldnn_rnn.h"
 #include "mkldnn_extension_utils.h"
 #include "desc_iterator.hpp"
-#include <ie_layers_prv.h>
 
 #include <string>
 #include <utility>
@@ -22,19 +21,40 @@ inline bool one_of(T val, P item, Args... item_others) {
     return val == item || one_of(val, item_others...);
 }
 
-rnn_direction ie2mkl(RNNLayer::Direction &direction) {
-    return direction == RNNLayer::RNN_FWD ? unidirectional_left2right
-         : direction == RNNLayer::RNN_BWD ? unidirectional_right2left
-         : direction == RNNLayer::RNN_BDR ? bidirectional_concat
-                                          : unidirectional;
+using _RNN = RNNSequenceLayer;  // alias
+
+static rnn_direction ie2mkl(_RNN::Direction &direction) {
+    return direction == _RNN::FWD ? unidirectional_left2right
+         : direction == _RNN::BWD ? unidirectional_right2left
+         : direction == _RNN::BDR ? bidirectional_concat
+         : unidirectional;
+}
+
+static algorithm ie2mkl(std::string act_type) {
+    return act_type == "sigmoid" ? eltwise_logistic
+         : act_type == "tanh"    ? eltwise_tanh
+         : act_type == "relu"    ? eltwise_relu
+         : algorithm_undef;
+}
+
+static algorithm ie2mkl(RNNCellBase::CellType cell_type) {
+    switch (cell_type) {
+        case RNNCellBase::LSTM: return vanilla_lstm;
+        case RNNCellBase::GRU:  return vanilla_gru;
+        case RNNCellBase::GRU_LBR:  return gru_linear_before_reset;
+        case RNNCellBase::RNN:  return vanilla_rnn;
+        default:
+            THROW_IE_EXCEPTION << "Unsoupported cell type";
+            return algorithm_undef;
+    }
 }
 
 MKLDNNRNN::MKLDNNRNN(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {
-    is_cell = layer->type == "LSTMCell";
+    is_cell = one_of(layer->type, "LSTMCell", "GRUCell", "RNNCell");
 }
 
 bool MKLDNNRNN::created() const {
-    return getType() == (is_cell ? LSTMCell : RNN);
+    return getType() == (is_cell ? RNNCell : RNNSeq);
 }
 
 void MKLDNNRNN::getSupportedDescriptors() {
@@ -46,48 +66,59 @@ void MKLDNNRNN::getSupportedDescriptors() {
 
 void MKLDNNRNN::fillCellDesc() {
     if (!descs.empty()) return;
-    auto cellLayer = std::dynamic_pointer_cast<InferenceEngine::LSTMCell>(getCnnLayer());
+    auto cellLayer = std::dynamic_pointer_cast<RNNCellBase>(getCnnLayer());
 
     if (!cellLayer)
-        THROW_IE_EXCEPTION << "Wrong RNN layer representation. Cannot cast to RNNLayer.";
+        THROW_IE_EXCEPTION << "No original layer for RNNCell.";
+
+    algorithm cell_type = ie2mkl(cellLayer->cellType);
+    algorithm cell_act = ie2mkl(cellLayer->activations[0]);  // Works only for RNN with one gate
+
+    cell_desc = {cell_type, cell_act};
+    if (cellLayer->clip != 0.0f)
+        cell_desc.set_clipping(cellLayer->clip);
 
     auto &ins = cellLayer->insData;
     auto &outs = cellLayer->outData;
 
-    if (ins.size() != 3)
+    if (!one_of(ins.size(), 3, 2))
         THROW_IE_EXCEPTION << "Incorrect number of input ports for layer " << getName();
-    if (outs.size() != 2)
+    if (!one_of(outs.size(), 2, 1))
         THROW_IE_EXCEPTION << "Incorrect number of output ports for layer " << getName();
 
     auto in_data_dims = getParentEdgeAt(0)->getDims();
     auto in_h_state_dims = getParentEdgeAt(1)->getDims();
-    auto in_c_state_dims = getParentEdgeAt(2)->getDims();
-
     auto out_h_state_dims = getChildEdgeAt(0)->getDims();
-    auto out_c_state_dims = getChildEdgeAt(1)->getDims();
 
-    if (in_data_dims.ndims() != 2
-        || in_h_state_dims.ndims() != 2
-        || in_c_state_dims.ndims() != 2
-        || out_h_state_dims.ndims() != 2
-        || out_c_state_dims.ndims() != 2)
+    if (in_data_dims.ndims() != 2 || in_h_state_dims.ndims() != 2)
         THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
 
+    G = cell_desc.get_gates_count();
+    S = cell_desc.get_state_count();
     T = 1;
     N  = in_data_dims[0];
     DC = in_data_dims[1];
     SC = in_h_state_dims[1];
 
+    Gb = (cell_type != gru_linear_before_reset) ? G : G + 1;
+
     // Expected shapes
     MKLDNNDims D_shape {N, DC}, S_shape {N, SC};
 
     if (in_data_dims != D_shape
         || in_h_state_dims != S_shape
-        || in_c_state_dims != S_shape
-        || out_h_state_dims != S_shape
-        || out_c_state_dims != S_shape)
+        || out_h_state_dims != S_shape)
         THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
 
+    if (S == 2) {
+        auto in_c_state_dims = getParentEdgeAt(2)->getDims();
+        auto out_c_state_dims = getChildEdgeAt(1)->getDims();
+
+        if (in_c_state_dims != S_shape
+            || out_c_state_dims != S_shape)
+            THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
+    }
+
     auto blobs = cellLayer->blobs;
     Blob::Ptr weights, bias;
     if (blobs.find("weights") != blobs.end()) weights = blobs["weights"];
@@ -99,7 +130,7 @@ void MKLDNNRNN::fillCellDesc() {
     if (weights->size() != G*SC*(SC+DC))
         THROW_IE_EXCEPTION << "RNN Layer. Weights size is not correct. Expected size:" << G*SC*(SC+DC);
 
-    if (bias && bias->size() != G*SC)
+    if (bias && bias->size() != Gb*SC)
         THROW_IE_EXCEPTION << "RNN Layer. Biases size is not correct. Expected size:" << G*SC;
 
     // Shapes and Attributes are correct. Can start internal stuff initialization.
@@ -114,44 +145,55 @@ void MKLDNNRNN::fillCellDesc() {
     w_state_d  = {{L, D, SC, G, SC}, memory::f32, memory::ldigo};
 
     if (bias)
-        w_bias_d = {{L, D, G, SC}, memory::f32, memory::ldgo};
+        w_bias_d = {{L, D, Gb, SC}, memory::f32, memory::ldgo};
 
-    std::vector<TensorDesc> in_candidate;
+    std::vector<TensorDesc> in_candidate, out_candidate;
     in_candidate.emplace_back(MKLDNNMemoryDesc {D_shape, memory::f32, memory::nc});
     in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
-    in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
-
-    std::vector<TensorDesc> out_candidate;
-    out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
     out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
 
+    if (S == 2) {
+        in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
+        out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
+    }
+
     createDescriptor(in_candidate, out_candidate);
 }
 
 void MKLDNNRNN::fillSeqDesc() {
     if (!descs.empty()) return;
-    auto rnnLayer = std::dynamic_pointer_cast<RNNLayer>(getCnnLayer());
+    auto rnnLayer = std::dynamic_pointer_cast<RNNSequenceLayer>(getCnnLayer());
 
     if (!rnnLayer)
-        THROW_IE_EXCEPTION << "Wrong RNN layer representation. Cannot cast to RNNLayer.";
+        THROW_IE_EXCEPTION << "Wrong RNN layer representation. Cannot cast to RNNSequenceLayer.";
+
+    if (!one_of(rnnLayer->cellType, _RNN::LSTM, _RNN::GRU, _RNN::GRU_LBR, _RNN::RNN))
+        THROW_IE_EXCEPTION << "RNN layer supports only LSTM/GRU/RNN cell";
 
-    if (!one_of(rnnLayer->cellType, "LSTM"))
-        THROW_IE_EXCEPTION << "RNN layer supports only LSTM like cell";
+    algorithm cell_type = ie2mkl(rnnLayer->cellType);
+    algorithm cell_act = algorithm_undef;
+    if (!rnnLayer->activations.empty())
+        cell_act = ie2mkl(rnnLayer->activations[0]);  // Works only for RNN with one gate
+
+    cell_desc = {cell_type, cell_act};
+
+    if (rnnLayer->clip != 0.0f)
+        cell_desc.set_clipping(rnnLayer->clip);
 
     if (!one_of(rnnLayer->axis, 0, 1))
         THROW_IE_EXCEPTION << "RNN layer supports only sequence axis 0 or 1";
     nativeOrder = rnnLayer->axis == 0;
 
-    if (!one_of(rnnLayer->direction, RNNLayer::RNN_FWD, RNNLayer::RNN_BWD))
+    if (!one_of(rnnLayer->direction, _RNN::FWD, _RNN::BWD))
         THROW_IE_EXCEPTION << "RNN layer supports only unidirectional RNN layer";
     direction = ie2mkl(rnnLayer->direction);
 
     auto &ins = rnnLayer->insData;
     auto &outs = rnnLayer->outData;
 
-    if (!one_of(ins.size(), 3, 1))
+    if (!one_of(ins.size(), 3, 2, 1))
         THROW_IE_EXCEPTION << "Incorrect number of input ports for layer " << getName();
-    if (!one_of(outs.size(), 3, 1))
+    if (!one_of(outs.size(), 3, 2, 1))
         THROW_IE_EXCEPTION << "Incorrect number of output ports for layer " << getName();
 
     auto in_data_dims = getParentEdgeAt(0)->getDims();
@@ -165,32 +207,32 @@ void MKLDNNRNN::fillSeqDesc() {
         std::swap(out_data_dims[0], out_data_dims[1]);
     }
 
+    G = cell_desc.get_gates_count();
+    S = cell_desc.get_state_count();
     T = in_data_dims[0];
     N = in_data_dims[1];
     DC = in_data_dims[2];
     SC = out_data_dims[2];
 
+    Gb = (cell_type != gru_linear_before_reset) ? G : G + 1;
+
     MKLDNNDims ID_shape {T, N, DC}, OD_shape {T, N, SC}, S_shape {N, SC};
 
     if (out_data_dims != OD_shape)
         THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
 
-    if (ins.size() == 3) {
-        auto state_dims1 = getParentEdgeAt(1)->getDims();
-        auto stats_dims2 = getParentEdgeAt(2)->getDims();
-
-        if (state_dims1 != S_shape || stats_dims2 != S_shape)
-            THROW_IE_EXCEPTION << "Incorrect shape of state ports for layer " << getName();
+    if (ins.size() > 1) {
+        for (int i = 1; i < ins.size(); i++)
+            if (getParentEdgeAt(i)->getDims() != S_shape)
+                THROW_IE_EXCEPTION << "Incorrect shape of state ports for layer " << getName();
 
         in_state_d = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
     }
 
-    if (outs.size() == 3) {
-        auto state_dims1 = getChildEdgeAt(1)->getDims();
-        auto stats_dims2 = getChildEdgeAt(2)->getDims();
-
-        if (state_dims1 != S_shape || stats_dims2 != S_shape)
-            THROW_IE_EXCEPTION << "Incorrect shape of state ports for layer " << getName();
+    if (outs.size() > 1) {
+        for (int i = 1; i < outs.size(); i++)
+            if (getChildEdgeAt(i)->getDims() != S_shape)
+                THROW_IE_EXCEPTION << "Incorrect shape of state ports for layer " << getName();
 
         out_state_d = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
     }
@@ -209,11 +251,11 @@ void MKLDNNRNN::fillSeqDesc() {
     w_data_d  = {{L, D, DC, G, SC}, memory::f32, memory::ldigo};
     w_state_d = {{L, D, SC, G, SC}, memory::f32, memory::ldigo};
 
-    if (bias && bias->size() != G*SC)
+    if (bias && bias->size() != Gb*SC)
         THROW_IE_EXCEPTION << "RNN Layer. Biases size is not correct. Expected size:" << G*SC;
 
     if (bias)
-        w_bias_d = {{L, D, G, SC}, memory::f32, memory::ldgo};
+        w_bias_d = {{L, D, Gb, SC}, memory::f32, memory::ldgo};
 
     // Try to create descriptor and corresponding configuration
     in_data_d = {in_data_dims, memory::f32, memory::tnc};
@@ -225,10 +267,8 @@ void MKLDNNRNN::fillSeqDesc() {
     else
         in_candidate.push_back(MKLDNNMemoryDesc{{N, T, DC}, memory::f32, memory::ntc});
 
-    if (ins.size() == 3) {
-        in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
+    for (int i = 1; i < ins.size(); i++)
         in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
-    }
 
     std::vector<TensorDesc> out_candidate;
     if (nativeOrder)
@@ -236,10 +276,8 @@ void MKLDNNRNN::fillSeqDesc() {
     else
         out_candidate.push_back(MKLDNNMemoryDesc{{N, T, SC}, memory::f32, memory::ntc});
 
-    if (outs.size() == 3) {
-        out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
+    for (int i = 1; i < outs.size(); i++)
         out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
-    }
 
     createDescriptor(in_candidate, out_candidate);
 }
@@ -247,8 +285,7 @@ void MKLDNNRNN::fillSeqDesc() {
 void MKLDNNRNN::createDescriptor(const std::vector<TensorDesc> &inputDesc,
                                  const std::vector<TensorDesc> &outputDesc) {
     MKLDNNDescriptor desc(std::shared_ptr<rnn_forward::desc>(
-            new rnn_forward::desc(forward_scoring,
-                    {algorithm::vanilla_lstm, algorithm::eltwise_tanh },
+            new rnn_forward::desc(forward_scoring, cell_desc,
                     direction,
                     /* In Data       */ in_data_d,
                     /* In State      */ in_state_d,
@@ -305,7 +342,6 @@ void MKLDNNRNN::createPrimitive() {
 
     {
         /* Copy Weight data
-         *
          * IE format:
          *   W - [gates, out_state_size, in_data_size + in_state_size]
          *   B - [gates, out_state_size]
@@ -316,11 +352,46 @@ void MKLDNNRNN::createPrimitive() {
          *   B - [gates, out_state_size]
          *
          *   Gate order
+         *   ====== LSTM ======
          *   Caffe - IFOC, ONNX   - IOFC
          *   IE    - FICO, mkldnn - IFCO
+         *
+         *   ====== GRU ======
+         *   IE - URO, mkldnn - URO
          */
-        // FICO -> IFCO
-        const int gate_map[] = {1, 0, 2, 3};
+        const int gate_map_lstm[] = {1, 0, 2, 3};  // FICO -> IFCO
+        const int gate_map_gru[]  = {0, 1, 2, 3};
+        const int gate_map_rnn[]  = {0};
+        const int *gate_map;
+        const int gate_map_lstm_size = sizeof(gate_map_lstm) / sizeof(int);
+        const int gate_map_gru_size = sizeof(gate_map_gru) / sizeof(int);
+        const int gate_map_rnn_size = sizeof(gate_map_rnn) / sizeof(int);
+        if (cell_desc.get_cell_kind() == vanilla_lstm) {
+            gate_map = gate_map_lstm;
+            if (G > gate_map_lstm_size) {
+                THROW_IE_EXCEPTION << "G isn't equal to the size of gate_map";
+            }
+        } else if (cell_desc.get_cell_kind() == vanilla_gru) {
+            gate_map = gate_map_gru;
+            if (G > gate_map_gru_size) {
+                THROW_IE_EXCEPTION << "G isn't equal to the size of gate_map";
+            }
+        } else if (cell_desc.get_cell_kind() == gru_linear_before_reset) {
+            gate_map = gate_map_gru;
+            if (G > gate_map_gru_size) {
+                THROW_IE_EXCEPTION << "G isn't equal to the size of gate_map";
+            }
+        } else if (cell_desc.get_cell_kind() == vanilla_rnn) {
+            gate_map = gate_map_rnn;
+            if (G > gate_map_rnn_size) {
+                THROW_IE_EXCEPTION << "G isn't equal to the size of gate_map";
+            }
+        } else {
+            gate_map = gate_map_gru;
+            if (G > gate_map_gru_size) {
+                THROW_IE_EXCEPTION << "G isn't equal to the size of gate_map";
+            }
+        }
 
         auto ie_w_ptr = getCnnLayer()->blobs["weights"]->buffer().as<const float*>();
         auto w_ptr = static_cast<float*>(w_data_mem->GetData());
@@ -348,7 +419,7 @@ void MKLDNNRNN::createPrimitive() {
         if (w_bias_d) {
             auto ie_b_ptr = getCnnLayer()->blobs["biases"]->buffer().as<const float*>();
             auto b_ptr = static_cast<float*>(w_bias_mem->GetData());
-            for (int g = 0; g < G; g++) {
+            for (int g = 0; g < Gb; g++) {
                 float *l_b_ptr = b_ptr + gate_map[g]*SC;
                 for (int out_i = 0; out_i < SC; out_i++) {
                     *l_b_ptr = *ie_b_ptr;
@@ -363,53 +434,44 @@ void MKLDNNRNN::createPrimitive() {
     src_state_mem->Create(in_state_d);
     internalBlobMemory.push_back(src_state_mem);
     if (in_state_d) {
-        /* create copy/concat primitive */
-        auto src_stat_1 = getParentEdgeAt(1)->getMemory().GetPrimitive();
-        auto src_stat_2 = getParentEdgeAt(2)->getMemory().GetPrimitive();
-
-        auto low_half_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
-        low_half_state_mem->Create(
-                src_stat_1.get_primitive_desc().desc(),
-                src_state_mem->GetPrimitive().get_data_handle());
-        internalBlobMemory.push_back(low_half_state_mem);
-
-        auto high_half_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
-        high_half_state_mem->Create(
-                src_stat_2.get_primitive_desc().desc(),
-                static_cast<uint8_t*>(src_state_mem->GetPrimitive().get_data_handle()) +
-                src_stat_1.get_primitive_desc().get_size());
-        internalBlobMemory.push_back(high_half_state_mem);
-
-        exec_before.emplace_back(src_stat_1, low_half_state_mem->GetPrimitive());
-        exec_before.emplace_back(src_stat_2, high_half_state_mem->GetPrimitive());
+        int offset = 0;
+        for (int i = 0; i < S; i++) {
+            /* create copy/concat primitive */
+            auto src_stat = getParentEdgeAt(i+1)->getMemory().GetPrimitive();
+
+            auto state_mem = std::make_shared<MKLDNNMemory>(getEngine());
+            state_mem->Create(
+                    src_stat.get_primitive_desc().desc(),
+                    static_cast<uint8_t *>(src_state_mem->GetPrimitive().get_data_handle()) + offset);
+            offset += src_stat.get_primitive_desc().get_size();
+
+            internalBlobMemory.push_back(state_mem);
+
+            exec_before.emplace_back(src_stat, state_mem->GetPrimitive());
+        }
     }
 
     auto dst_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
     dst_state_mem->Create(out_state_d);
     internalBlobMemory.push_back(dst_state_mem);
     if (out_state_d) {
-        int idx_H = is_cell ? 0 : 1;
-        int idx_C = is_cell ? 1 : 2;
-        /* create copy/split primitive */
-        auto dst_stat_1 = getChildEdgeAt(idx_H)->getMemory().GetPrimitive();
-        auto dst_stat_2 = getChildEdgeAt(idx_C)->getMemory().GetPrimitive();
-
-        auto low_half_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
-        low_half_state_mem->Create(
-                dst_stat_1.get_primitive_desc().desc(),
-                dst_state_mem->GetPrimitive().get_data_handle());
-        internalBlobMemory.push_back(low_half_state_mem);
-
-        auto high_half_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
-        high_half_state_mem->Create(
-                dst_stat_2.get_primitive_desc().desc(),
-                static_cast<uint8_t*>(dst_state_mem->GetPrimitive().get_data_handle()) +
-                        dst_stat_1.get_primitive_desc().get_size());
-        internalBlobMemory.push_back(high_half_state_mem);
-
-
-        if (!is_cell) exec_after.emplace_back(low_half_state_mem->GetPrimitive(),  dst_stat_1);
-        exec_after.emplace_back(high_half_state_mem->GetPrimitive(), dst_stat_2);
+        int offset = 0;
+        int idx_start = is_cell ? 0 : 1;
+        for (int i = 0; i < S; i++) {
+            /* create copy/split primitive */
+            auto dst_stat = getChildEdgeAt(idx_start + i)->getMemory().GetPrimitive();
+
+            auto state_mem = std::make_shared<MKLDNNMemory>(getEngine());
+            state_mem->Create(
+                    dst_stat.get_primitive_desc().desc(),
+                    static_cast<uint8_t *>(dst_state_mem->GetPrimitive().get_data_handle()) + offset);
+            offset += dst_stat.get_primitive_desc().get_size();
+
+            internalBlobMemory.push_back(state_mem);
+
+            if (is_cell && i == 0) continue;
+            exec_after.emplace_back(state_mem->GetPrimitive(), dst_stat);
+        }
     }
 
     auto workspace_mem = std::make_shared<MKLDNNMemory>(getEngine());