"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
-import networkx as nx
-
-from mo.graph.graph import copy_node
-from mo.utils.error import Error
+from extensions.middle.TF_lstm_cell_to_generic import TensorFlowLSTMtoGeneric
+from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
+from mo.graph.graph import Graph
from mo.middle.pattern_match import find_isomorphisms
from mo.middle.replacement import MiddleReplacementPattern
-from extensions.ops.lstm_sequence import LSTMSequence
-from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
-from extensions.middle.lstm_sequence_normalize import LSTMSequenceNormalize, permute_before_and_after
-from extensions.middle.lstm_sequence_tensor_iterator import LSTMSequenceTensorIterator
-from extensions.middle.TF_lstm_cell_to_generic import TensorFlowLSTMtoGeneric
+from mo.utils.error import Error
+from extensions.middle.ONNXRNNSequenceNormalize import ONNXRNNSequenceNormalize, permute_before_and_after
class TensorIteratorLSTM(MiddleReplacementPattern):
enabled = False
def run_after(self):
- return [TensorIteratorMerge, LSTMSequenceNormalize, LSTMSequenceTensorIterator, TensorFlowLSTMtoGeneric]
+ return [TensorIteratorMerge, ONNXRNNSequenceNormalize, TensorFlowLSTMtoGeneric]
def pattern(self):
return dict(
)
@staticmethod
- def replace_pattern(graph: nx.MultiDiGraph, match: dict):
- nodes=[
+ def replace_pattern(graph: Graph, match: dict):
+ nodes = [
('input_unsqueezed'),
('squeeze', dict(op='Reshape')),
('input_squeezed'),
('unsqueeze', dict(op='Reshape')),
('output_unsqueezed'),
]
- edges=[
+ edges = [
('input_unsqueezed', 'squeeze'),
('squeeze', 'input_squeezed'),
'Please modify the original network '
'to meet the requirements.'.format(ti.soft_get('name')))
# TODO Additional checks for port indices
- if body_match['lstm'].has_valid('mark_supported_by_IE'):
- body_match['lstm'].mark_supported_by_IE(body_match['lstm'])
-
-
-class CheckUnsupportedLSTMCell(MiddleReplacementPattern):
- """ Finds all unsupported LSTMCell.
-
- Initiates the second translation round if find any not supported LSTMCell instances.
- """
-
- enabled = False
-
- def run_after(self):
- return [TensorIteratorLSTM]
-
- def pattern(self):
- return dict(
- nodes=[
- ('lstm', dict(op='LSTMCell')),
- ],
- edges=[
- ]
- )
-
- @staticmethod
- def replace_pattern(graph: nx.MultiDiGraph, match: dict):
- lstmcell = match['lstm']
- if lstmcell.has_valid('finalize_first_round'):
- lstmcell.finalize_first_round()
- if not lstmcell.has_and_set('supported_by_IE'):
- # this is a signal for the main translation pipeline to repeat the entire conversion process
- graph.graph['repeat_conversion'] = True
- # in case when there is no lstmcell.finalize_first_round then this cell wasn't created with the pattern
- # (for example in ONNX) and we don't initiate the second round.