"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
-import networkx as nx
import numpy as np
-from copy import deepcopy
-from mo.graph.graph import copy_node, Node, dict_includes
-from mo.utils.error import Error
-from mo.middle.passes.eliminate import remove_op_node_with_data_node
-from mo.middle.pattern_match import find_isomorphisms, find_pattern_matches
-from mo.middle.replacement import MiddleReplacementPattern
-from mo.ops.op import Op
-from extensions.ops.lstm_sequence import LSTMSequence
from extensions.middle.FusePermutesSequence import FusePermutesSequence
+from extensions.middle.LSTMRNNSequenceToTensorIterator import LSTMToTensorIterator
+from extensions.middle.ONNXRNNSequenceNormalize import ONNXRNNSequenceNormalize
from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
-from extensions.middle.lstm_sequence_normalize import LSTMSequenceNormalize, permute_before_and_after
-from extensions.middle.lstm_sequence_tensor_iterator import LSTMSequenceTensorIterator
-from extensions.middle.decompose_bi_lstm import DecomposeBiLSTM
+from mo.graph.graph import dict_includes, Graph
+from mo.middle.passes.eliminate import remove_op_node_with_data_node
+from mo.middle.pattern_match import find_isomorphisms
+from mo.middle.replacement import MiddleReplacementPattern
class PermuteTensorIteratorLSTM(MiddleReplacementPattern):
- ''' Fuses Permute(1,0,2) --> TI --> Permute(1,0,2) pattern to a single TI with changed axis.
+ """ Fuses Permute(1,0,2) --> TI --> Permute(1,0,2) pattern to a single TI with changed axis.
WARNING This transformation is limited to support of very special case of TI but
code doesn't check all the cases.
- '''
+ """
enabled = True
def run_after(self):
- return [TensorIteratorMerge, LSTMSequenceNormalize, LSTMSequenceTensorIterator, FusePermutesSequence, DecomposeBiLSTM]
+ return [TensorIteratorMerge, ONNXRNNSequenceNormalize, LSTMToTensorIterator, FusePermutesSequence]
+
+
+ def run_before(self):
+ return []
def pattern(self):
return dict(
('input', 'direct_permute'),
('direct_permute', 'input_permuted'),
- ('input_permuted', 'ti', {'in': 0}), # affected by permute
+ ('input_permuted', 'ti', {'in': 0}), # affected by permute
('init_hidden', 'ti', {'in': 1}),
('init_cell', 'ti', {'in': 2}),
- ('ti', 'output_permuted', {'out': 0}), # affected by permute
+ ('ti', 'output_permuted', {'out': 0}), # affected by permute
('output_permuted', 'inverse_permute'),
('inverse_permute', 'output'),
]
)
- def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(self, graph: Graph, match: dict):
# This transformation works if and only if a body of TI
# matches the following topology (Reshape -> LSTMCell -> Reshape)
- nodes=[
+ nodes = [
('input_unsqueezed'),
('squeeze', dict(op='Reshape')),
('input_squeezed'),
('output_cell'),
('unsqueeze', dict(op='Reshape')),
('output_unsqueezed'),
+
+ ('const_w', dict(op='Const')),
+ ('const_b', dict(op='Const')),
+
+ ('op_output', dict(op='OpOutput')),
+ ('op_output_1', dict(op='OpOutput')),
+ ('op_output_2', dict(op='OpOutput'))
+
]
- edges=[
+ edges = [
('input_unsqueezed', 'squeeze'),
('squeeze', 'input_squeezed'),
('weights', 'lstm', {'in': 3}),
('biases', 'lstm', {'in': 4}),
+ ('const_w', 'weights'),
+ ('const_b', 'biases'),
+
('lstm', 'output_hidden', {'out': 0}),
('lstm', 'output_cell', {'out': 1}),
('output_hidden', 'unsqueeze'),
('unsqueeze', 'output_unsqueezed'),
+
+ ('output_unsqueezed', 'op_output'),
+ ('output_hidden', 'op_output_1'),
+ ('output_cell', 'op_output_2'),
+
]
ti = match['ti']
isomorphisms = find_isomorphisms(ti.body, nodes, edges)
if not inverse_permute.has_valid('order') or not np.array_equal(inverse_permute.order, permute_order):
return
-
def find_ports(port_map: list, attrs: dict):
""" Find all ports in a given port map with specified attributes """
result = []