[ MO ][ RESHAPE ] Fixes in RNNSequenceNormalize transformation (#1909)
authorYegor Kruglov <yegor.kruglov@intel.com>
Sat, 29 Aug 2020 22:19:22 +0000 (01:19 +0300)
committerGitHub <noreply@github.com>
Sat, 29 Aug 2020 22:19:22 +0000 (01:19 +0300)
* initial commit

* first reshap-able variant

* right version for reshape

* comment update

* fixes for failed e2e

* set data type to ngraph TensorIterator

* Fix dynamic shapes for cells ops

* clean up

Co-authored-by: yegor.kruglov <ykruglov@nnlvdp-mkaglins.inn.intel.com>
model-optimizer/extensions/middle/RNNSequenceNormalizeToIE.py
ngraph/core/src/op/gru_cell.cpp
ngraph/core/src/op/lstm_cell.cpp
ngraph/core/src/op/rnn_cell.cpp
ngraph/core/src/op/tensor_iterator.cpp

index b481d8b..e5de8c4 100644 (file)
  See the License for the specific language governing permissions and
  limitations under the License.
 """
-from copy import deepcopy
 
 import numpy as np
 
 from mo.front.common.partial_infer.utils import int64_array
+from mo.front.tf.graph_utils import create_op_node_with_second_input
 from mo.graph.graph import Graph
 from mo.middle.replacement import MiddleReplacementPattern
+from mo.ops.concat import Concat
 from mo.ops.const import Const
 from mo.ops.op import Op
 from mo.ops.reshape import Reshape
+from mo.ops.shape import Shape
+from mo.ops.unsqueeze import Unsqueeze
+from mo.utils.shape import node_to_get_shape_value_of_indices
 
 
 class RNNSequenceNormalize(MiddleReplacementPattern):
@@ -136,11 +140,12 @@ class RNNSequenceNormalize(MiddleReplacementPattern):
 
     @staticmethod
     def unsqueeze_num_directions(graph: Graph, match: dict):
-        """ Assuming considered LSTM/GRU/RNN node should has num_directions in output shape and add Reshape
+        """ Assuming considered LSTM/GRU/RNN node should has num_directions in output shape and add Unsqueeze
             to match it.
         """
 
         rnn_layer = match['rnn_layer']
+        rnn_layer_name = rnn_layer.soft_get('name', rnn_layer.id)
         # num_directions is at 1st position in output shape, and in 0st position in hidden and cell states
         # please refer to docs in this transform
 
@@ -154,14 +159,14 @@ class RNNSequenceNormalize(MiddleReplacementPattern):
             graph.remove_edge(rnn_layer.id, old_data_node.id)
             graph.add_edge(rnn_layer.id, data.id, key=0, out=i)
 
-            reshape = Reshape(graph, dict(dim=old_shape))
+            unsqueeze = Unsqueeze(graph, dict())
 
-            reshape_dim_data = Const(graph, {'name': rnn_layer.name + '/SqueezeNumDirections/{}/Dim'.format(i),
-                                             'value': old_shape}).create_node_with_data()
-            reshape.create_node_with_data([data, reshape_dim_data],
-                                          dict(name=rnn_layer.name + '/SqueezeNumDirections/{}'.format(i)),
-                                          data_nodes=[old_data_node])
+            unsqueeze_dim_data = Const(graph, {'name': rnn_layer.name + '/UnsqueezeNumDirections/{}/Dim'.format(i),
+                                               'value': int64_array([direction_dim[i]])}).create_node_with_data()
 
+            unsqueeze.create_node_with_data([data, unsqueeze_dim_data],
+                                            dict(name=rnn_layer_name + '/UnsqueezeNumDirections/{}'.format(i)),
+                                            data_nodes=[old_data_node])
     @staticmethod
     def squeeze_initial_states(graph: Graph, match: dict):
         """
@@ -171,36 +176,29 @@ class RNNSequenceNormalize(MiddleReplacementPattern):
         cell_init_port = 6
 
         rnn_layer = match['rnn_layer']
-
         # Add input ports to rnn_layer
         rnn_layer.add_sequence_of_ports(type='in', rng=range(7))
-
-        reshape = Reshape(graph, {})
+        rnn_layer_name = rnn_layer.soft_get('name', rnn_layer.id)
 
         assert hidden_init_port in rnn_layer.in_nodes()
-        init_h = rnn_layer.in_node(hidden_init_port)
-        edge_attrs = deepcopy(graph.get_edge_data(init_h.id, rnn_layer.id)[0])
-        edge_attrs['in'] = hidden_init_port
-        graph.remove_edge(init_h.id, rnn_layer.id)
-
-        new_dim = int64_array([rnn_layer.in_node(0).shape[rnn_layer.batch_dim], rnn_layer.hidden_size])
-        reshape_dim_data = Const(graph, {'name': rnn_layer.name + '/HiddenStateResizeDim',
-                                         'value': new_dim}).create_node_with_data()
-        new_init_h = reshape.create_node_with_data([init_h, reshape_dim_data], dict(name=rnn_layer.name + '/HiddenStateResize'))
-        graph.add_edge(new_init_h.id, rnn_layer.id, **edge_attrs)
+        hidden_size = rnn_layer.hidden_size
+        shape = Shape(graph, dict(name=rnn_layer_name + '/ShapeOf')).create_node()
+        rnn_layer.in_port(0).get_source().connect(shape.in_port(0))
+
+        batch = node_to_get_shape_value_of_indices(shape, int64_array([rnn_layer.batch_dim]))
+        new_dim = create_op_node_with_second_input(graph, Concat, second_input_value=int64_array([hidden_size]),
+                                                   op_attrs=dict(name=rnn_layer_name + '/HiddenStateResizeDim',
+                                                                 in_ports_count=2, axis=0), input_node=batch)
+        reshape_h = Reshape(graph, dict(name=rnn_layer_name + '/HiddenStateResize', override_output_shape=True)).create_node()
+        new_dim.out_port(0).connect(reshape_h.in_port(1))
+        rnn_layer.in_port(hidden_init_port).get_connection().insert_node(reshape_h)
 
         if rnn_layer.op == 'LSTM':
             assert cell_init_port in rnn_layer.in_nodes()
 
-            init_c = rnn_layer.in_node(cell_init_port)
-            edge_attrs = deepcopy(graph.get_edge_data(init_c.id, rnn_layer.id)[0])
-            edge_attrs['in'] = cell_init_port
-            graph.remove_edge(init_c.id, rnn_layer.id)
-            reshape_dim_data = Const(graph, {'name': rnn_layer.name + '/CellStateResizeDim',
-                                             'value': new_dim}).create_node_with_data()
-            new_init_c = reshape.create_node_with_data([init_c, reshape_dim_data],
-                                                       dict(name=rnn_layer.name + '/CellStateResize'))
-            graph.add_edge(new_init_c.id, rnn_layer.id, **edge_attrs)
+            reshape_c = Reshape(graph, dict(name=rnn_layer_name + '/CellStateResize', override_output_shape=True)).create_node()
+            new_dim.out_port(0).connect(reshape_c.in_port(1))
+            rnn_layer.in_port(cell_init_port).get_connection().insert_node(reshape_c)
 
     @staticmethod
     def reordering_inputs(graph: Graph, match: dict):
index fb1a327..6d09d06 100644 (file)
@@ -106,6 +106,8 @@ bool op::v3::GRUCell::visit_attributes(AttributeVisitor& visitor)
 
 void op::v3::GRUCell::pre_validate_and_infer_types()
 {
+    set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
+
     if (is_dynamic())
     {
         return;
index 2e1fb3b..71c3af1 100644 (file)
@@ -135,7 +135,8 @@ bool ngraph::op::v0::LSTMCell::visit_attributes(AttributeVisitor& visitor)
 
 void op::LSTMCell::pre_validate_and_infer_types()
 {
-    set_output_size(2);
+    set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
+    set_output_type(1, get_input_element_type(0), PartialShape::dynamic());
     if (is_dynamic())
     {
         return;
index 65ab2ba..e724d75 100644 (file)
@@ -80,6 +80,8 @@ bool op::RNNCell::visit_attributes(AttributeVisitor& visitor)
 
 void op::RNNCell::pre_validate_and_infer_types()
 {
+    set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
+
     if (is_dynamic())
     {
         return;
index 4f0168c..d1ee19e 100644 (file)
@@ -578,6 +578,7 @@ void op::v0::TensorIterator::validate_and_infer_types()
                 as_type_ptr<ConcatOutputDescription>(output_description))
         {
             auto body_value_partial_shape = body_value.get_partial_shape();
+            set_output_type(index, body_value.get_element_type(), PartialShape::dynamic());
             if (body_value_partial_shape.is_static())
             {
                 auto body_value_shape = body_value_partial_shape.to_shape();