"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
-import networkx as nx
-
-from mo.middle.replacement import MiddleReplacementPattern
-from extensions.ops.lstm_sequence import LSTMSequence
from extensions.middle.FusePermutesSequence import FusePermutesSequence
-from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
-from extensions.middle.lstm_sequence_normalize import LSTMSequenceNormalize
-from extensions.middle.lstm_sequence_tensor_iterator import LSTMSequenceTensorIterator
+from extensions.middle.ONNXRNNSequenceNormalize import ONNXRNNSequenceNormalize
from extensions.middle.permute_tensor_iterator import PermuteTensorIteratorLSTM
+from mo.graph.graph import Graph
from mo.middle.passes.eliminate import remove_op_node_with_data_node
from mo.middle.replacement import MiddleReplacementPattern
def run_after(self):
return [
- TensorIteratorMerge,
- LSTMSequenceNormalize,
- LSTMSequenceTensorIterator,
+ ONNXRNNSequenceNormalize,
+
FusePermutesSequence,
PermuteTensorIteratorLSTM,
]
+ def run_before(self):
+ from extensions.middle.pass_separator import MiddleFinish
+ return [MiddleFinish]
+
def pattern(self):
return dict(
nodes=[
('direct_reverse', dict(op='ReverseSequence')),
('input_reversed'),
('init_hidden'),
- ('init_cell'),
('ti', dict(kind='op', op='TensorIterator')),
('input_reversed', 'ti', {'in': 0}),
('init_hidden', 'ti', {'in': 1}),
- ('init_cell', 'ti', {'in': 2}),
('ti', 'output_reversed', {'out': 0}),
('output_reversed', 'inverse_reverse', {'in': 0}),
]
)
- def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(self, graph: Graph, match: dict):
ti = match['ti']
direct_reverse = match['direct_reverse']
inverse_reverse = match['inverse_reverse']
- assert direct_reverse.seq_dim == inverse_reverse.seq_dim
- assert direct_reverse.batch_dim is None and inverse_reverse.batch_dim is None or \
- direct_reverse.batch_dim == inverse_reverse.batch_dim
+ assert direct_reverse.seq_axis == inverse_reverse.seq_axis
+ assert direct_reverse.batch_axis is None and inverse_reverse.batch_axis is None or \
+ direct_reverse.batch_axis == inverse_reverse.batch_axis
# Modify stride in TI
for port_map in [ti.input_port_map, ti.output_port_map]:
for port in port_map:
if 'axis' in port and port['axis'] is not None and 'external_port_id' in port:
- assert port['axis'] == direct_reverse.seq_dim, \
- 'axis == {} != {} == direct_reverse.seq_dim'.format(port['axis'], direct_reverse.seq_dim)
+ assert port['axis'] == direct_reverse.seq_axis, \
+ 'axis == {} != {} == direct_reverse.seq_dim'.format(port['axis'], direct_reverse.seq_axis)
if 'stride' not in port or port['stride'] is None:
port['stride'] = 1
assert port['stride'] in [-1, 1]