Fix graph clenup (#3159)
authoriliya mironov <iliya.mironov@intel.com>
Tue, 17 Nov 2020 13:28:27 +0000 (16:28 +0300)
committerGitHub <noreply@github.com>
Tue, 17 Nov 2020 13:28:27 +0000 (16:28 +0300)
* Fix graph clenup

* Refactoring graph clean up function

* Change wa comment

Co-authored-by: Your Name <you@example.com>
model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py
model-optimizer/mo/middle/passes/eliminate.py

index 0010a90..fc7b834 100644 (file)
@@ -139,14 +139,10 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
             else:
                 mark_as_correct_data_layout(node)
             node['nchw_layout'] = True
-            if node.soft_get('type') == 'Const':  # WA for Const op deletion during clean_up
-                node.out_node()['nchw_layout'] = True
 
         for node in self.get_ports_and_nodes_on_shape_subgraphs(graph)[1]:
             mark_as_correct_data_layout(node)
             node['nchw_layout'] = True
-            if node.soft_get('type') == 'Const':  # WA for Const op deletion during clean_up
-                node.out_node()['nchw_layout'] = True
 
     @staticmethod
     def get_weighted_layer_type_to_in_weights_port():
index 8137a47..2aa689c 100644 (file)
@@ -125,8 +125,16 @@ def mark_const_producer_nodes(graph):
 
 
 def eliminate_dead_nodes(graph):
+    from mo.graph.graph import Node
     nodes_to_remove = set()
     for node_name, node_attrs in graph.nodes(data=True):
+        # The Const operation node may have set an attribute 'nchw_layout' attribute to prevent shape permutation.
+        # During graph clean-up the operation node is removed and the attribute is lost.
+        # This results in permutation of the Const shape in the IR and wrong inference results.
+        # Here we explicitly save the 'nchw_layout' attribute in the data node to prevent permutation."
+        if node_attrs.get('type', None) == 'Const' and node_attrs.get('nchw_layout', False):
+            Node(graph, node_name).out_node()['nchw_layout'] = True
+
         if not node_attrs['is_output_reachable'] or \
                 (node_attrs['is_const_producer'] and (not node_attrs['is_undead'] or
                                                       node_attrs.get('force_dead_node', False))):
@@ -153,13 +161,6 @@ def add_constant_operations(graph):
             graph.add_edges_from([(const_node.id, node.id, {'out': 0})])
 
 
-def remove_const_ops(graph):
-
-    for node in graph.get_op_nodes(type='Const'):
-        graph.remove_edge(node.id, node.out_node().id)
-        graph.remove_node(node.id)
-
-
 def shape_inference(graph):
     for node in graph.pseudo_topological_sort():
         if node.has_and_set('need_shape_inference'):
@@ -252,5 +253,3 @@ def remove_edges_for_nodes(graph, node_attrs: dict, edge_attrs: dict):
                 src_node, edge = nodes_edges[port]
                 if all([attr in edge and edge[attr] == edge_attrs[attr] for attr in edge_attrs]):
                     graph.remove_edge(src_node.id, node.id)
-
-