"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
"""
import logging as log
-import networkx as nx
+
import numpy as np
from extensions.ops.TensorIterator_ops import TensorIteratorInput
+from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
|__________________________________________________|
"""
- enabled = False # called from mo.pipeline.tf directly
+ enabled = True
+ graph_condition = [lambda graph: graph.graph['is_cyclic']]
+
+ def run_after(self):
+ from extensions.middle.TensorIterator_utils import DeleteSelect
+ return [DeleteSelect]
+
+ def run_before(self):
+ from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
+ return [TensorIteratorMerge]
@staticmethod
def pattern():
)
@staticmethod
- def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(graph: Graph, match: dict):
log.debug('================== SmartInputFind ===============')
assert match['Enter_data'].value is not None
# axis == 0 because in TensorArray we ALWAYS iterate over 0 axis, other params will be fill later (with
# condition)
input_node = TensorIteratorInput(graph, dict(axis=0, start=start, stride=None, part_size=None,
- external_port_id=str(match['Enter_data'].value),
- internal_layer_id=match['TensorArrayRead_data'].id,
- name=match['TensorArrayRead'].name + '/TensorIteratorInput_'
- ))
+ external_port_id=str(match['Enter_data'].value),
+ internal_layer_id=match['TensorArrayRead_data'].id,
+ name=match['TensorArrayRead'].name + '/TensorIteratorInput_'
+ ))
input_node.create_node_with_data(inputs=[ta_size_data, value, match['Condition_data']],
- data_nodes=[match['TensorArrayRead_data']])
+ data_nodes=[match['TensorArrayRead_data']])
# Delete useless nodes
safe_nodes = ['TensorArrayRead_data', 'Condition', 'Condition_data']
class SimpleInputMatcher(MiddleReplacementPattern):
+ enabled = True
+ graph_condition = [lambda graph: graph.graph['is_cyclic']]
- enabled = False # called from mo.pipeline.tf directly
+ def run_after(self):
+ from extensions.middle.DeleteNotExecutable import DeleteNotExecutable
+ return [DeleteNotExecutable]
+
+ def run_before(self):
+ from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
+ return [TensorIteratorMerge]
"""
This pattern match simple inputs (without partitions) in while loops in TF (this inputs are set by Enter nodes).
"""
+
@staticmethod
def pattern():
return dict(
)
@staticmethod
- def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(graph: Graph, match: dict):
log.debug('================== SimpletInputFind ===============')
input_node = TensorIteratorInput(graph, dict(external_port_id=None,
- internal_layer_id=None,
- name=match['Enter'].name + '/TensorIteratorInput_'
- ))
+ internal_layer_id=None,
+ name=match['Enter'].name + '/TensorIteratorInput_'
+ ))
input_node.create_node_with_data(inputs=[match['Enter'].in_node()], data_nodes=[match['Enter'].out_node()])
# Delete useless nodes
class BackEdgeSimpleInputMatcher(MiddleReplacementPattern):
+ enabled = True
+ graph_condition = [lambda graph: graph.graph['is_cyclic']]
- enabled = False # called from mo.pipeline.tf directly
+ def run_after(self):
+ return [SimpleInputMatcher]
+
+ def run_before(self):
+ from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
+ return [TensorIteratorMerge]
@staticmethod
def pattern():
)
@staticmethod
- def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(graph: Graph, match: dict):
log.debug('================== SimpleBackEdgeInputFind ===============')
assert len(match['BackEdge'].in_nodes()) == 3
cycle_input = match['BackEdge'].in_node(1)
# We need to create new TensorItertorInput node only if this node doesn't exist already.
- if len(init_input.in_nodes()) == 0:
+ if len(init_input.in_nodes()) == 0 or\
+ (len(init_input.in_nodes()) == 1 and init_input.has_valid('value')):
+
input_node = TensorIteratorInput(graph, dict(external_port_id=None,
- internal_layer_id=None,
- name=match['BackEdge'].name + '/TensorIteratorInput_'
- ))
+ internal_layer_id=None,
+ name=match['BackEdge'].name + '/TensorIteratorInput_'
+ ))
+
+ # In case if data node has Constant producer
+ if len(init_input.in_nodes()) == 1:
+ graph.remove_edge(init_input.in_node(0).id, init_input.id)
+
input_data_node = input_node.create_node_with_data(inputs=[init_input])
input_data_node.shape = np.array(init_input.shape, dtype=np.int64)
graph.remove_edges_from([(init_input.id, match['BackEdge'].id)])