Fix bidirectional mode in reference implementations of GRU/LSTM/RNN Sequences (#2264)
authorIvan Tikhonov <ivan.tikhonov@intel.com>
Fri, 18 Sep 2020 07:14:01 +0000 (10:14 +0300)
committerGitHub <noreply@github.com>
Fri, 18 Sep 2020 07:14:01 +0000 (10:14 +0300)
* fix bidirectional case in references of sequences ops, enable decomposition of bidirectional cases in CommonOptimizations

* introduce new opset5, include GRU/RNN/LSTM Sequences to opset5

* Revert "introduce new opset5, include GRU/RNN/LSTM Sequences to opset5"

This reverts commit 73c22a11dbd724d2cfa9212ff211db74ef09cf2a.

inference-engine/src/transformations/src/transformations/bidirectional_sequences_decomposition.cpp
inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp
inference-engine/tests/functional/plugin/shared/src/single_layer_tests/gru_sequence.cpp
inference-engine/tests/functional/plugin/shared/src/single_layer_tests/lstm_sequence.cpp
inference-engine/tests/functional/plugin/shared/src/single_layer_tests/rnn_sequence.cpp
ngraph/core/reference/include/ngraph/runtime/reference/sequences.hpp

index 281d470..1df3c61 100644 (file)
@@ -19,6 +19,9 @@ ngraph::pass::BidirectionalLSTMSequenceDecomposition::BidirectionalLSTMSequenceD
             return false;
         }
 
+        if (lstm_sequence->get_direction() != ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
+            return false;
+
         auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0});
         auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1});
         auto H = std::make_shared<opset4::Split>(lstm_sequence->input_value(1), axis_1, 2);
@@ -84,6 +87,9 @@ ngraph::pass::BidirectionalGRUSequenceDecomposition::BidirectionalGRUSequenceDec
             return false;
         }
 
+        if (gru_sequence->get_direction() != ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
+            return false;
+
         auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0});
         auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1});
         auto H = std::make_shared<opset4::Split>(gru_sequence->input_value(1), axis_1, 2);
@@ -145,6 +151,9 @@ ngraph::pass::BidirectionalRNNSequenceDecomposition::BidirectionalRNNSequenceDec
             return false;
         }
 
+        if (rnn_sequence->get_direction() != ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
+            return false;
+
         auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0});
         auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1});
         auto H = std::make_shared<opset4::Split>(rnn_sequence->input_value(1), axis_1, 2);
index 4ec3c63..e45776a 100644 (file)
@@ -21,6 +21,7 @@
 #include "transformations/hswish_fusion.hpp"
 #include "transformations/normalize_l2_fusion.hpp"
 #include "transformations/convert_quantize_dequantize.hpp"
+#include "transformations/bidirectional_sequences_decomposition.hpp"
 
 #include <ngraph/pass/manager.hpp>
 #include <ngraph/pass/constant_folding.hpp>
@@ -50,6 +51,9 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
     manager.register_pass<ngraph::pass::HSwishFusion>();
     manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution>();
     manager.register_pass<ngraph::pass::NormalizeL2Fusion>();
+    manager.register_pass<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
+    manager.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
+    manager.register_pass<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
 
     manager.set_callback(m_transformation_callback);
     manager.run_passes(f);
index 1327831..f1d8afe 100644 (file)
@@ -84,11 +84,6 @@ namespace LayerTestsDefinitions {
         ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gru_sequence->output(0)),
                                      std::make_shared<ngraph::opset1::Result>(gru_sequence->output(1))};
         function = std::make_shared<ngraph::Function>(results, params, "gru_sequence");
-        if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
-            ngraph::pass::Manager m;
-            m.register_pass<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
-            m.run_passes(function);
-        }
     }
 
 
index b1edaa9..d910194 100644 (file)
@@ -82,11 +82,6 @@ namespace LayerTestsDefinitions {
                                      std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(1)),
                                      std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(2))};
         function = std::make_shared<ngraph::Function>(results, params, "lstm_sequence");
-        if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
-            ngraph::pass::Manager m;
-            m.register_pass<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
-            m.run_passes(function);
-        }
     }
 
 
index 63f9e85..90ac191 100644 (file)
@@ -82,11 +82,6 @@ namespace LayerTestsDefinitions {
         ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(rnn_sequence->output(0)),
                                      std::make_shared<ngraph::opset1::Result>(rnn_sequence->output(1))};
         function = std::make_shared<ngraph::Function>(results, params, "rnn_sequence");
-        if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
-            ngraph::pass::Manager m;
-            m.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
-            m.run_passes(function);
-        }
     }
 
 
index e236bbd..894f1c3 100644 (file)
@@ -218,15 +218,15 @@ namespace ngraph
                     // Split bidirectional case to forward + reverse passes.
                     // split inputs
                     std::vector<std::vector<char>> H_split(
-                        2, std::vector<char>(ngraph::shape_size(H_shape) / 2));
+                        2, std::vector<char>(sizeof(T) * ngraph::shape_size(H_shape) / 2));
                     std::vector<std::vector<char>> C_split(
-                        2, std::vector<char>(ngraph::shape_size(C_shape) / 2));
+                        2, std::vector<char>(sizeof(T) * ngraph::shape_size(C_shape) / 2));
                     std::vector<std::vector<char>> W_split(
-                        2, std::vector<char>(ngraph::shape_size(W_shape) / 2));
+                        2, std::vector<char>(sizeof(T) * ngraph::shape_size(W_shape) / 2));
                     std::vector<std::vector<char>> R_split(
-                        2, std::vector<char>(ngraph::shape_size(R_shape) / 2));
+                        2, std::vector<char>(sizeof(T) * ngraph::shape_size(R_shape) / 2));
                     std::vector<std::vector<char>> B_split(
-                        2, std::vector<char>(ngraph::shape_size(B_shape) / 2));
+                        2, std::vector<char>(sizeof(T) * ngraph::shape_size(B_shape) / 2));
                     char* h_pointers[2] = {H_split[0].data(), H_split[1].data()};
                     char* c_pointers[2] = {C_split[0].data(), C_split[1].data()};
                     char* w_pointers[2] = {W_split[0].data(), W_split[1].data()};
@@ -234,13 +234,17 @@ namespace ngraph
                     char* b_pointers[2] = {B_split[0].data(), B_split[1].data()};
                     reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers);
                     reference::split(C, C_shape, sizeof(T), 1, 2, c_pointers);
-                    reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers);
-                    reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers);
-                    reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers);
+                    reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers);
+                    reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers);
+                    reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers);
+                    std::vector<char> forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
+                                                    X_shape[1]);
+                    std::vector<char> reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
+                                                    X_shape[1]);
                     std::vector<std::vector<char>> forward_res(
-                        3, std::vector<char>(H_shape[0] * H_shape[2]));
+                        2, std::vector<char>(sizeof(T) * H_shape[0] * H_shape[2]));
                     std::vector<std::vector<char>> reverse_res(
-                        3, std::vector<char>(H_shape[0] * H_shape[2]));
+                        2, std::vector<char>(sizeof(T) * H_shape[0] * H_shape[2]));
 
                     CellArgs args;
                     args.activation_f = activation_f;
@@ -249,6 +253,13 @@ namespace ngraph
                     args.clip = clip;
                     std::vector<Shape> shapes = {
                         X_shape, seq_lengths_shape, H_shape, C_shape, W_shape, R_shape, B_shape};
+                    // update H,C,W,R,B shapes after split
+                    shapes[2][1] = 1;
+                    shapes[3][1] = 1;
+                    for (int i = 4; i < shapes.size(); ++i)
+                    {
+                        shapes[i][0] = 1;
+                    }
                     // forward pass
                     cell_pass<T>(
                         CellType::LSTM,
@@ -260,7 +271,7 @@ namespace ngraph
                          r_pointers[0],
                          b_pointers[0]},
                         shapes,
-                        {forward_res[0].data(), forward_res[1].data(), forward_res[2].data()},
+                        {forward_res_y.data(), forward_res[0].data(), forward_res[1].data()},
                         args,
                         false);
                     // reverse pass
@@ -274,32 +285,34 @@ namespace ngraph
                          r_pointers[1],
                          b_pointers[1]},
                         shapes,
-                        {reverse_res[0].data(), reverse_res[1].data(), reverse_res[2].data()},
+                        {reverse_res_y.data(), reverse_res[0].data(), reverse_res[1].data()},
                         args,
                         true);
 
                     // Stack together respective outputs from both forward and reverse passes.
-                    std::vector<Shape> in_shapes = {{H_shape[0], 1, H_shape[2]},
-                                                    {H_shape[0], 1, H_shape[2]},
-                                                    {H_shape[0], 1, H_shape[2]}};
-                    Shape output_shape = {{H_shape[0], 2, H_shape[2]}};
+                    std::vector<Shape> in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]},
+                                                      {H_shape[0], 1, X_shape[1], H_shape[2]}};
+                    std::vector<Shape> in_shapes_h_c = {{H_shape[0], 1, H_shape[2]},
+                                                        {H_shape[0], 1, H_shape[2]}};
+                    Shape output_shape_y{H_shape[0], 2, X_shape[1], H_shape[2]};
+                    Shape output_shape_h_c{H_shape[0], 2, H_shape[2]};
 
-                    runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
+                    runtime::reference::concat({forward_res_y.data(), reverse_res_y.data()},
                                                Y,
-                                               in_shapes,
-                                               output_shape,
+                                               in_shapes_y,
+                                               output_shape_y,
                                                1,
                                                sizeof(T));
-                    runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
+                    runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
                                                Ho,
-                                               in_shapes,
-                                               output_shape,
+                                               in_shapes_h_c,
+                                               output_shape_h_c,
                                                1,
                                                sizeof(T));
-                    runtime::reference::concat({forward_res[2].data(), reverse_res[2].data()},
+                    runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
                                                Co,
-                                               in_shapes,
-                                               output_shape,
+                                               in_shapes_h_c,
+                                               output_shape_h_c,
                                                1,
                                                sizeof(T));
                 }
@@ -351,25 +364,27 @@ namespace ngraph
                     // Split bidirectional case to forward + reverse passes.
                     // split inputs
                     std::vector<std::vector<char>> H_split(
-                        2, std::vector<char>(ngraph::shape_size(H_shape) / 2));
+                        2, std::vector<char>(sizeof(T) * ngraph::shape_size(H_shape) / 2));
                     std::vector<std::vector<char>> W_split(
-                        2, std::vector<char>(ngraph::shape_size(W_shape) / 2));
+                        2, std::vector<char>(sizeof(T) * ngraph::shape_size(W_shape) / 2));
                     std::vector<std::vector<char>> R_split(
-                        2, std::vector<char>(ngraph::shape_size(R_shape) / 2));
+                        2, std::vector<char>(sizeof(T) * ngraph::shape_size(R_shape) / 2));
                     std::vector<std::vector<char>> B_split(
-                        2, std::vector<char>(ngraph::shape_size(B_shape) / 2));
+                        2, std::vector<char>(sizeof(T) * ngraph::shape_size(B_shape) / 2));
                     char* h_pointers[2] = {H_split[0].data(), H_split[1].data()};
                     char* w_pointers[2] = {W_split[0].data(), W_split[1].data()};
                     char* r_pointers[2] = {R_split[0].data(), R_split[1].data()};
                     char* b_pointers[2] = {B_split[0].data(), B_split[1].data()};
                     reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers);
-                    reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers);
-                    reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers);
-                    reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers);
-                    std::vector<std::vector<char>> forward_res(
-                        2, std::vector<char>(H_shape[0] * H_shape[2]));
-                    std::vector<std::vector<char>> reverse_res(
-                        2, std::vector<char>(H_shape[0] * H_shape[2]));
+                    reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers);
+                    reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers);
+                    reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers);
+                    std::vector<char> forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
+                                                    X_shape[1]);
+                    std::vector<char> forward_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
+                    std::vector<char> reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
+                                                    X_shape[1]);
+                    std::vector<char> reverse_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
 
                     CellArgs args;
                     args.activation_f = activation_f;
@@ -378,6 +393,12 @@ namespace ngraph
                     args.clip = clip;
                     std::vector<Shape> shapes = {
                         X_shape, seq_lengths_shape, H_shape, W_shape, R_shape, B_shape};
+                    // update H,W,R,B shapes after split
+                    shapes[2][1] = 1;
+                    for (int i = 3; i < shapes.size(); ++i)
+                    {
+                        shapes[i][0] = 1;
+                    }
                     // forward pass
                     cell_pass<T>(CellType::GRU,
                                  {X,
@@ -387,7 +408,7 @@ namespace ngraph
                                   r_pointers[0],
                                   b_pointers[0]},
                                  shapes,
-                                 {forward_res[0].data(), forward_res[1].data()},
+                                 {forward_res_y.data(), forward_res_h.data()},
                                  args,
                                  false);
                     // reverse pass
@@ -399,25 +420,28 @@ namespace ngraph
                                   r_pointers[1],
                                   b_pointers[1]},
                                  shapes,
-                                 {reverse_res[0].data(), reverse_res[1].data()},
+                                 {reverse_res_y.data(), reverse_res_h.data()},
                                  args,
                                  true);
 
                     // Stack together respective outputs from both forward and reverse passes.
-                    std::vector<Shape> in_shapes = {{H_shape[0], 1, H_shape[2]},
-                                                    {H_shape[0], 1, H_shape[2]}};
-                    Shape output_shape = {{H_shape[0], 2, H_shape[2]}};
+                    std::vector<Shape> in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]},
+                                                      {H_shape[0], 1, X_shape[1], H_shape[2]}};
+                    std::vector<Shape> in_shapes_h = {{H_shape[0], 1, H_shape[2]},
+                                                      {H_shape[0], 1, H_shape[2]}};
+                    Shape output_shape_y{H_shape[0], 2, X_shape[1], H_shape[2]};
+                    Shape output_shape_h{H_shape[0], 2, H_shape[2]};
 
-                    runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
+                    runtime::reference::concat({forward_res_y.data(), reverse_res_y.data()},
                                                Y,
-                                               in_shapes,
-                                               output_shape,
+                                               in_shapes_y,
+                                               output_shape_y,
                                                1,
                                                sizeof(T));
-                    runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
+                    runtime::reference::concat({forward_res_h.data(), reverse_res_h.data()},
                                                Ho,
-                                               in_shapes,
-                                               output_shape,
+                                               in_shapes_h,
+                                               output_shape_h,
                                                1,
                                                sizeof(T));
                 }
@@ -465,31 +489,39 @@ namespace ngraph
                     // Split bidirectional case to forward + reverse passes.
                     // split inputs
                     std::vector<std::vector<char>> H_split(
-                        2, std::vector<char>(ngraph::shape_size(H_shape) / 2));
+                        2, std::vector<char>(sizeof(T) * ngraph::shape_size(H_shape) / 2));
                     std::vector<std::vector<char>> W_split(
-                        2, std::vector<char>(ngraph::shape_size(W_shape) / 2));
+                        2, std::vector<char>(sizeof(T) * ngraph::shape_size(W_shape) / 2));
                     std::vector<std::vector<char>> R_split(
-                        2, std::vector<char>(ngraph::shape_size(R_shape) / 2));
+                        2, std::vector<char>(sizeof(T) * ngraph::shape_size(R_shape) / 2));
                     std::vector<std::vector<char>> B_split(
-                        2, std::vector<char>(ngraph::shape_size(B_shape) / 2));
+                        2, std::vector<char>(sizeof(T) * ngraph::shape_size(B_shape) / 2));
                     char* h_pointers[2] = {H_split[0].data(), H_split[1].data()};
                     char* w_pointers[2] = {W_split[0].data(), W_split[1].data()};
                     char* r_pointers[2] = {R_split[0].data(), R_split[1].data()};
                     char* b_pointers[2] = {B_split[0].data(), B_split[1].data()};
                     reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers);
-                    reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers);
-                    reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers);
-                    reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers);
-                    std::vector<std::vector<char>> forward_res(
-                        2, std::vector<char>(H_shape[0] * H_shape[2]));
-                    std::vector<std::vector<char>> reverse_res(
-                        2, std::vector<char>(H_shape[0] * H_shape[2]));
+                    reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers);
+                    reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers);
+                    reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers);
+                    std::vector<char> forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
+                                                    X_shape[1]);
+                    std::vector<char> forward_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
+                    std::vector<char> reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
+                                                    X_shape[1]);
+                    std::vector<char> reverse_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
 
                     CellArgs args;
                     args.activation_f = activation_f;
                     args.clip = clip;
                     std::vector<Shape> shapes = {
                         X_shape, seq_lengths_shape, H_shape, W_shape, R_shape, B_shape};
+                    // update H,W,R,B shapes after split
+                    shapes[2][1] = 1;
+                    for (int i = 3; i < shapes.size(); ++i)
+                    {
+                        shapes[i][0] = 1;
+                    }
                     // forward pass
                     cell_pass<T>(CellType::RNN,
                                  {X,
@@ -499,7 +531,7 @@ namespace ngraph
                                   r_pointers[0],
                                   b_pointers[0]},
                                  shapes,
-                                 {forward_res[0].data(), forward_res[1].data()},
+                                 {forward_res_y.data(), forward_res_h.data()},
                                  args,
                                  false);
                     // reverse pass
@@ -511,25 +543,28 @@ namespace ngraph
                                   r_pointers[1],
                                   b_pointers[1]},
                                  shapes,
-                                 {reverse_res[0].data(), reverse_res[1].data()},
+                                 {reverse_res_y.data(), reverse_res_h.data()},
                                  args,
                                  true);
 
                     // Stack together respective outputs from both forward and reverse passes.
-                    std::vector<Shape> in_shapes = {{H_shape[0], 1, H_shape[2]},
-                                                    {H_shape[0], 1, H_shape[2]}};
-                    Shape output_shape = {{H_shape[0], 2, H_shape[2]}};
+                    std::vector<Shape> in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]},
+                                                      {H_shape[0], 1, X_shape[1], H_shape[2]}};
+                    std::vector<Shape> in_shapes_h = {{H_shape[0], 1, H_shape[2]},
+                                                      {H_shape[0], 1, H_shape[2]}};
+                    Shape output_shape_y{H_shape[0], 2, X_shape[1], H_shape[2]};
+                    Shape output_shape_h{H_shape[0], 2, H_shape[2]};
 
-                    runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
+                    runtime::reference::concat({forward_res_y.data(), reverse_res_y.data()},
                                                Y,
-                                               in_shapes,
-                                               output_shape,
+                                               in_shapes_y,
+                                               output_shape_y,
                                                1,
                                                sizeof(T));
-                    runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
+                    runtime::reference::concat({forward_res_h.data(), reverse_res_h.data()},
                                                Ho,
-                                               in_shapes,
-                                               output_shape,
+                                               in_shapes_h,
+                                               output_shape_h,
                                                1,
                                                sizeof(T));
                 }