"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import logging as log
-import networkx as nx
-
from extensions.ops.TensorIterator_ops import TensorIteratorCondition
+from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
+import numpy as np
class LoopConditionMatcher(MiddleReplacementPattern):
Const----
"""
enabled = True
+ graph_condition = [lambda graph: graph.graph['is_cyclic']]
+
+ def run_after(self):
+ return []
+
+ def run_before(self):
+ from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
+ return [TensorIteratorMerge]
@staticmethod
def pattern():
('Enter_2_less', dict(kind='op', op='Enter')),
('Enter_2_less_data', dict(kind='data')),
- ('minimum', dict(kind='op', op='Minimum')),
('minimum_data', dict(kind='data')),
('and', dict(kind='op', op='LogicalAnd')),
('loop_cond_data', dict(kind='data')),
('init_1', dict(kind='op', op='Const')),
- ('init_1_data', dict(kind='data')),
+ ('init_1_data', dict(kind='data')),
('Enter_1', dict(kind='op', op='Enter')),
- ('Enter_1_data', dict(kind='data')),
+ ('Enter_1_data', dict(kind='data')),
('init_2', dict(kind='op', op='Const')),
('init_2_data', dict(kind='data')),
('Identity_1', dict(kind='op', op='Identity')),
('Identity_1_data', dict(kind='data')),
('add_1', dict(kind='op', op='Add')),
- ('add_1_y', dict(kind='op', op='Const')),
+ ('add_1_y', dict(kind='op', op='Const')),
('add_1_y_data', dict(kind='data')),
('add_1_data', dict(kind='data')),
('NextIteration_1', dict(kind='op', op='NextIteration')),
edges=[
('Strided_slice', 'Strided_slice_data'),
('Strided_slice_data', 'Enter_1_less'),
- ('Strided_slice_data', 'minimum'),
('Enter_1_less', 'Enter_1_less_data'),
('Enter_1_less_data', 'Less_1'),
('Less_1', 'Less_1_data'),
('add_2', 'add_2_data'),
('add_2_data', 'NextIteration_2'),
- ('minimum', 'minimum_data'),
('minimum_data', 'Enter_2_less'),
('Enter_2_less', 'Enter_2_less_data'),
('Enter_2_less_data', 'Less_2'),
)
@staticmethod
- def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+ def looking_for_iteration_counter(graph: Graph, match: dict):
+ types = ['TensorIteratorInput', 'TensorIteratorOutput']
+ candidates = np.array([match['Identity_1_data'], match['Identity_2_data']])
+ results = np.array([False for i in range(len(candidates))])
+ for i, candidat in enumerate(candidates):
+ for node in candidat.out_nodes():
+ if node['op'] in types:
+ results[i] = True
+ assert not np.all(results)
+ assert sum(results) == 1
+ return candidates[results == True][0]
+
+ def replace_pattern(self, graph: Graph, match: dict):
log.debug('================== ConditionFind ===============')
- max_node = match['minimum'].in_node(1).in_node()
- assert max_node['kind'] == 'op' and max_node['op'] == 'Maximum'
-
- #init_1
+ # init_1
init_1 = match['init_1_data'].value
assert init_1 is not None
init_1 = int(init_1)
- #init_2
+ # init_2
init_2 = match['init_2_data'].value
assert init_2 is not None
init_2 = int(init_2)
- #step_1
+ # step_1
assert match['add_1_y_data'].value is not None
step_1 = int(match['add_1_y_data'].value)
- #step_2
+ # step_2
assert match['add_2_y_data'].value is not None
step_2 = int(match['add_2_y_data'].value)
match['Identity_2_data'].value = None
# Create condition node and delete all useless nodes from condition pattern
- condition_attrs = dict(time=dict(init=init_2, step=step_2), iter=dict(init=init_1, step=step_1), \
+ loop_condiiton = match['loop_cond_data']
+ iterator_data = self.looking_for_iteration_counter(graph, match)
+
+ condition_attrs = dict(time=dict(init=init_2, step=step_2), iter=dict(init=init_1, step=step_1),
name=match['loop_cond'].name + '/TensorIteratorCondition_')
condition = TensorIteratorCondition(graph, attrs=condition_attrs)
condition.create_node_with_data(inputs=[match['Strided_slice_data'], match['minimum_data']],
- data_nodes=[match['loop_cond_data'], match['Identity_2_data']])
+ data_nodes=[loop_condiiton, iterator_data])
# Delete useless nodes
- safe_nodes = ['loop_cond_data', 'Identity_2_data', 'Strided_slice', 'Strided_slice_data',
+ safe_nodes = ['loop_cond_data', 'Identity_1_data', 'Identity_2_data', 'Strided_slice', 'Strided_slice_data',
'minimum', 'minimum_data']
nodes_for_remove = []
for node in match.keys():
graph.remove_nodes_from(nodes_for_remove)
-class SimpleConditionMather(MiddleReplacementPattern):
+class SimpleConditionMatcher(MiddleReplacementPattern):
+ enabled = True
+ graph_condition = [lambda graph: graph.graph['is_cyclic']]
+
+ def run_after(self):
+ return [LoopConditionMatcher]
+
+ def run_before(self):
+ from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
+ return [TensorIteratorMerge]
+
@staticmethod
def pattern():
log.debug('+++++++++++++++ SimpleConditionMatching ++++++++++++++++')
('loop_cond_data', dict(kind='data')),
('init_1', dict(kind='op', op='Const')),
- ('init_1_data', dict(kind='data')),
+ ('init_1_data', dict(kind='data')),
('Enter_1', dict(kind='op', op='Enter')),
- ('Enter_1_data', dict(kind='data')),
-
+ ('Enter_1_data', dict(kind='data')),
('Switch_1', dict(kind='op', op='Switch')),
('Switch_1_data', dict(kind='data')),
('Identity_1', dict(kind='op', op='Identity')),
('Identity_1_data', dict(kind='data')),
('add_1', dict(kind='op', op='Add')),
- ('add_1_y', dict(kind='op', op='Const')),
+ ('add_1_y', dict(kind='op', op='Const')),
('add_1_y_data', dict(kind='data')),
('add_1_data', dict(kind='data')),
('NextIteration_1', dict(kind='op', op='NextIteration')),
)
@staticmethod
- def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(graph: Graph, match: dict):
log.debug('================== SimpleConditionFind ===============')
# init_1
init_1 = match['init_1_data'].value
match['loop_cond_data'].value = None
# Create condition node and delete all useless nodes from condition pattern
- condition_attrs = dict(iter=dict(init=init_1, step=step_1), \
+ condition_attrs = dict(iter=dict(init=init_1, step=step_1),
name=match['loop_cond'].name + '/TensorIteratorCondition_')
condition = TensorIteratorCondition(graph, attrs=condition_attrs)
condition.create_node_with_data(inputs=[match['Strided_slice_data']],
for node in match.keys():
if node not in safe_nodes:
nodes_for_remove.append(match[node].id)
- graph.remove_nodes_from(nodes_for_remove)
\ No newline at end of file
+ graph.remove_nodes_from(nodes_for_remove)