See the License for the specific language governing permissions and
limitations under the License.
"""
-from copy import deepcopy
import numpy as np
from mo.front.common.partial_infer.utils import int64_array
+from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
+from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.op import Op
from mo.ops.reshape import Reshape
+from mo.ops.shape import Shape
+from mo.ops.unsqueeze import Unsqueeze
+from mo.utils.shape import node_to_get_shape_value_of_indices
class RNNSequenceNormalize(MiddleReplacementPattern):
@staticmethod
def unsqueeze_num_directions(graph: Graph, match: dict):
- """ Assuming considered LSTM/GRU/RNN node should has num_directions in output shape and add Reshape
+ """ Assuming considered LSTM/GRU/RNN node should has num_directions in output shape and add Unsqueeze
to match it.
"""
rnn_layer = match['rnn_layer']
+ rnn_layer_name = rnn_layer.soft_get('name', rnn_layer.id)
# num_directions is at 1st position in output shape, and in 0st position in hidden and cell states
# please refer to docs in this transform
graph.remove_edge(rnn_layer.id, old_data_node.id)
graph.add_edge(rnn_layer.id, data.id, key=0, out=i)
- reshape = Reshape(graph, dict(dim=old_shape))
+ unsqueeze = Unsqueeze(graph, dict())
- reshape_dim_data = Const(graph, {'name': rnn_layer.name + '/SqueezeNumDirections/{}/Dim'.format(i),
- 'value': old_shape}).create_node_with_data()
- reshape.create_node_with_data([data, reshape_dim_data],
- dict(name=rnn_layer.name + '/SqueezeNumDirections/{}'.format(i)),
- data_nodes=[old_data_node])
+ unsqueeze_dim_data = Const(graph, {'name': rnn_layer.name + '/UnsqueezeNumDirections/{}/Dim'.format(i),
+ 'value': int64_array([direction_dim[i]])}).create_node_with_data()
+ unsqueeze.create_node_with_data([data, unsqueeze_dim_data],
+ dict(name=rnn_layer_name + '/UnsqueezeNumDirections/{}'.format(i)),
+ data_nodes=[old_data_node])
@staticmethod
def squeeze_initial_states(graph: Graph, match: dict):
"""
cell_init_port = 6
rnn_layer = match['rnn_layer']
-
# Add input ports to rnn_layer
rnn_layer.add_sequence_of_ports(type='in', rng=range(7))
-
- reshape = Reshape(graph, {})
+ rnn_layer_name = rnn_layer.soft_get('name', rnn_layer.id)
assert hidden_init_port in rnn_layer.in_nodes()
- init_h = rnn_layer.in_node(hidden_init_port)
- edge_attrs = deepcopy(graph.get_edge_data(init_h.id, rnn_layer.id)[0])
- edge_attrs['in'] = hidden_init_port
- graph.remove_edge(init_h.id, rnn_layer.id)
-
- new_dim = int64_array([rnn_layer.in_node(0).shape[rnn_layer.batch_dim], rnn_layer.hidden_size])
- reshape_dim_data = Const(graph, {'name': rnn_layer.name + '/HiddenStateResizeDim',
- 'value': new_dim}).create_node_with_data()
- new_init_h = reshape.create_node_with_data([init_h, reshape_dim_data], dict(name=rnn_layer.name + '/HiddenStateResize'))
- graph.add_edge(new_init_h.id, rnn_layer.id, **edge_attrs)
+ hidden_size = rnn_layer.hidden_size
+ shape = Shape(graph, dict(name=rnn_layer_name + '/ShapeOf')).create_node()
+ rnn_layer.in_port(0).get_source().connect(shape.in_port(0))
+
+ batch = node_to_get_shape_value_of_indices(shape, int64_array([rnn_layer.batch_dim]))
+ new_dim = create_op_node_with_second_input(graph, Concat, second_input_value=int64_array([hidden_size]),
+ op_attrs=dict(name=rnn_layer_name + '/HiddenStateResizeDim',
+ in_ports_count=2, axis=0), input_node=batch)
+ reshape_h = Reshape(graph, dict(name=rnn_layer_name + '/HiddenStateResize', override_output_shape=True)).create_node()
+ new_dim.out_port(0).connect(reshape_h.in_port(1))
+ rnn_layer.in_port(hidden_init_port).get_connection().insert_node(reshape_h)
if rnn_layer.op == 'LSTM':
assert cell_init_port in rnn_layer.in_nodes()
- init_c = rnn_layer.in_node(cell_init_port)
- edge_attrs = deepcopy(graph.get_edge_data(init_c.id, rnn_layer.id)[0])
- edge_attrs['in'] = cell_init_port
- graph.remove_edge(init_c.id, rnn_layer.id)
- reshape_dim_data = Const(graph, {'name': rnn_layer.name + '/CellStateResizeDim',
- 'value': new_dim}).create_node_with_data()
- new_init_c = reshape.create_node_with_data([init_c, reshape_dim_data],
- dict(name=rnn_layer.name + '/CellStateResize'))
- graph.add_edge(new_init_c.id, rnn_layer.id, **edge_attrs)
+ reshape_c = Reshape(graph, dict(name=rnn_layer_name + '/CellStateResize', override_output_shape=True)).create_node()
+ new_dim.out_port(0).connect(reshape_c.in_port(1))
+ rnn_layer.in_port(cell_init_port).get_connection().insert_node(reshape_c)
@staticmethod
def reordering_inputs(graph: Graph, match: dict):