From: iliya mironov Date: Tue, 17 Nov 2020 13:28:27 +0000 (+0300) Subject: Fix graph clenup (#3159) X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9cb3c2a6beed4bf0e6caae753e9dab9aee6573ee;p=platform%2Fupstream%2Fdldt.git Fix graph clenup (#3159) * Fix graph clenup * Refactoring graph clean up function * Change wa comment Co-authored-by: Your Name --- diff --git a/model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py b/model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py index 0010a90..fc7b834 100644 --- a/model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py +++ b/model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py @@ -139,14 +139,10 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern): else: mark_as_correct_data_layout(node) node['nchw_layout'] = True - if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up - node.out_node()['nchw_layout'] = True for node in self.get_ports_and_nodes_on_shape_subgraphs(graph)[1]: mark_as_correct_data_layout(node) node['nchw_layout'] = True - if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up - node.out_node()['nchw_layout'] = True @staticmethod def get_weighted_layer_type_to_in_weights_port(): diff --git a/model-optimizer/mo/middle/passes/eliminate.py b/model-optimizer/mo/middle/passes/eliminate.py index 8137a47..2aa689c 100644 --- a/model-optimizer/mo/middle/passes/eliminate.py +++ b/model-optimizer/mo/middle/passes/eliminate.py @@ -125,8 +125,16 @@ def mark_const_producer_nodes(graph): def eliminate_dead_nodes(graph): + from mo.graph.graph import Node nodes_to_remove = set() for node_name, node_attrs in graph.nodes(data=True): + # The Const operation node may have set an attribute 'nchw_layout' attribute to prevent shape permutation. + # During graph clean-up the operation node is removed and the attribute is lost. + # This results in permutation of the Const shape in the IR and wrong inference results. + # Here we explicitly save the 'nchw_layout' attribute in the data node to prevent permutation." + if node_attrs.get('type', None) == 'Const' and node_attrs.get('nchw_layout', False): + Node(graph, node_name).out_node()['nchw_layout'] = True + if not node_attrs['is_output_reachable'] or \ (node_attrs['is_const_producer'] and (not node_attrs['is_undead'] or node_attrs.get('force_dead_node', False))): @@ -153,13 +161,6 @@ def add_constant_operations(graph): graph.add_edges_from([(const_node.id, node.id, {'out': 0})]) -def remove_const_ops(graph): - - for node in graph.get_op_nodes(type='Const'): - graph.remove_edge(node.id, node.out_node().id) - graph.remove_node(node.id) - - def shape_inference(graph): for node in graph.pseudo_topological_sort(): if node.has_and_set('need_shape_inference'): @@ -252,5 +253,3 @@ def remove_edges_for_nodes(graph, node_attrs: dict, edge_attrs: dict): src_node, edge = nodes_edges[port] if all([attr in edge and edge[attr] == edge_attrs[attr] for attr in edge_attrs]): graph.remove_edge(src_node.id, node.id) - -