"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
See the License for the specific language governing permissions and
limitations under the License.
"""
-
-import networkx as nx
import numpy as np
from extensions.middle.FusePermutesSequence import FusePermutesSequence
-from extensions.middle.lstm_sequence_normalize import LSTMSequenceNormalize
-from extensions.middle.mxnet_lstm_sequence_normalize import MXNetLSTMSequenceNormalize
+from extensions.middle.RNNSequenceNormalizeToIE import RNNSequenceNormalize
from extensions.ops.lstm_cell import LSTMCell
from extensions.ops.tensor_iterator import TensorIterator
+from mo.graph.graph import Graph, add_opoutput
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.op import Op
from mo.ops.reshape import Reshape
-class LSTMSequenceTensorIterator(MiddleReplacementPattern):
- """ Converts normalized LSTMSequence op to TensorIterator.
+class LSTMToTensorIterator(MiddleReplacementPattern):
+ """ Converts normalized RNNSequence with op=LSTM to TensorIterator.
- Normalized LSTMSequence means that it should be processed by
- LSTMSequenceNormalize transform that ensures its stict form.
+ Normalized RNNSequence means that it should be processed by
+ RNNSequenceNormalize transform that ensures its strict form.
- This transformation builds an altenative sub-graph for LSTMSequence
+ This transformation builds an alternative sub-graph for LSTMSequence
with TensorIterator connected in the same way as an original LSTMSequence
node and with internal body represented as LSTMCell op node with necessary
squeezes and unsqueezes around.
"""
enabled = True
-
+ force_clean_up = True
+ id = 'lstm_to_tensor_iterator'
+
def run_after(self):
- return [LSTMSequenceNormalize, MXNetLSTMSequenceNormalize]
+ return [RNNSequenceNormalize]
def run_before(self):
return [FusePermutesSequence]
def pattern(self):
return dict(
nodes=[
- ('lstm', dict(kind='op', op='LSTMSequence')),
+ ('lstm', dict(kind='op', op='LSTM', type='RNNSequence')),
('input', dict(kind='data')),
('weights', dict(kind='data')),
('biases', dict(kind='data')),
]
)
- def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(self, graph: Graph, match: dict):
lstm = match['lstm']
# Build TensorIterator body first
- body = nx.MultiDiGraph(name=lstm.name + '/sub_graph', layout=graph.graph['layout'])
+ body = Graph(name=lstm.name + '/sub_graph')
+ body.graph = graph.graph
+
+ # 1. Input squeeze Reshape
inputs = [Op._create_data_node(body, lstm.name + '/inport/' + str(inp),
{'shape': lstm.in_node(inp).shape.copy(),
'value': lstm.in_node(inp).value.copy()
if lstm.in_node(inp).value is not None and inp in [1, 2] else None})
- for inp in [0, 3, 4, 1, 2]]
+ for inp in [0, 4, 5, 1, 2]] # X, WR, B, h_init, c_init
+
inputs[0].shape[lstm.sequence_dim] = 1
reshape_dim = inputs[0].shape.copy()
reshape_dim[lstm.batch_dim] = -1
dict(name=lstm.name + '/input_squeeze', internal_layer_id=0, dim=reshape_dim)
)
inputs[0] = input_squeeze.create_node_with_data([inputs[0]], edge_attrs=[{'internal_port_id': 0}])
- lstm_cell_op = LSTMCell(body, dict(hidden_size=match['lstm'].hidden_size, name=lstm.name + '/LSTMCell',
- internal_layer_id=1))
+
+ # 2. Output unsqueeze Reshape
outputs = [Op._create_data_node(body, lstm.name + '/outport/' + str(out),
{'shape': lstm.out_node(out).shape.copy() if out in lstm.out_nodes()
- else lstm.in_node(3).shape.copy(), 'is_output': True}) for out in [0, 1]]
+ else lstm.in_node(4).shape.copy()}) for out in [0, 1]]
+ for out in outputs:
+ add_opoutput(body, out.id, 0, False)
+
unsqueezed_output_shape = outputs[0].shape.copy()
unsqueezed_output_shape[lstm.sequence_dim] = 1
squeezed_output_shape = np.delete(unsqueezed_output_shape, lstm.sequence_dim)
unsqueezed_output_shape[lstm.batch_dim] = -1
output_unsqueeze = Reshape(body, dict(name=lstm.name + 'output_unsqueeze', dim=unsqueezed_output_shape,
internal_layer_id=2))
- # TODO edge attributes should be assigned by the op itself
+
+ # 3. LSTMCell
+ lstm_cell_op = LSTMCell(body, dict(hidden_size=lstm.hidden_size,
+ activations=lstm.activations,
+ activation_alpha=lstm.activation_alpha,
+ activation_beta=lstm.activation_beta,
+ clip=lstm.clip,
+ input_forget=lstm.input_forget,
+ name=lstm.name + '/LSTMCell',
+ internal_layer_id=1))
lstm_cell_node = lstm_cell_op.create_node_with_data(inputs, data_nodes=outputs,
edge_attrs=[{}, {'internal_port_id': 1},
{'internal_port_id': 2}, {'bin': 'weights'},
lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
lstm_cell_node[0] = output_unsqueeze.create_node_with_data([lstm_cell_node[0]])
lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
- lstm_cell_node[0]['is_output'] = True
+ add_opoutput(body, lstm_cell_node[0].id, 0, False)
+ # 4. TensorIterator layer creating
assert lstm.direction in ['forward', 'reverse']
if lstm.direction == 'forward':
stride = 1
'external_port_id': 3,
'internal_layer_id': 2,
'internal_port_id': 3,
+
'axis': lstm.sequence_dim,
'stride': stride,
'start': start,
'part_size': 1,
}]
+ # Adding h_state, c_state to outputs
if len(lstm.out_nodes()) == 3:
output_port_map.extend([{
'external_port_id': 4,
ti_op = TensorIterator(graph, {
'name': lstm.name + '/TensorIterator',
'body': body,
+ 'in_ports_count': 3,
+ 'out_ports_count': len(lstm.out_nodes()),
'input_port_map': [
{
'external_port_id': 0,
'internal_layer_id': 0,
'internal_port_id': 0,
+
'axis': lstm.sequence_dim,
'stride': stride,
'start': start,
assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \
"There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id)
- outs = ti_op.create_node_with_data([lstm.in_node(i) for i in [0, 3, 4]],
+
+ outs = ti_op.create_node_with_data([lstm.in_node(i) for i in [0, 4, 5]], # X, h_init, c_init
data_nodes=[lstm.out_node(i) for i in range(len(lstm.out_nodes()))],
edge_attrs=[{'external_port_id': 0}, {'external_port_id': 1},
{'external_port_id': 2}])