else:
mark_as_correct_data_layout(node)
node['nchw_layout'] = True
- if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up
- node.out_node()['nchw_layout'] = True
for node in self.get_ports_and_nodes_on_shape_subgraphs(graph)[1]:
mark_as_correct_data_layout(node)
node['nchw_layout'] = True
- if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up
- node.out_node()['nchw_layout'] = True
@staticmethod
def get_weighted_layer_type_to_in_weights_port():
def eliminate_dead_nodes(graph):
+ from mo.graph.graph import Node
nodes_to_remove = set()
for node_name, node_attrs in graph.nodes(data=True):
+ # The Const operation node may have set an attribute 'nchw_layout' attribute to prevent shape permutation.
+ # During graph clean-up the operation node is removed and the attribute is lost.
+ # This results in permutation of the Const shape in the IR and wrong inference results.
+ # Here we explicitly save the 'nchw_layout' attribute in the data node to prevent permutation."
+ if node_attrs.get('type', None) == 'Const' and node_attrs.get('nchw_layout', False):
+ Node(graph, node_name).out_node()['nchw_layout'] = True
+
if not node_attrs['is_output_reachable'] or \
(node_attrs['is_const_producer'] and (not node_attrs['is_undead'] or
node_attrs.get('force_dead_node', False))):
graph.add_edges_from([(const_node.id, node.id, {'out': 0})])
-def remove_const_ops(graph):
-
- for node in graph.get_op_nodes(type='Const'):
- graph.remove_edge(node.id, node.out_node().id)
- graph.remove_node(node.id)
-
-
def shape_inference(graph):
for node in graph.pseudo_topological_sort():
if node.has_and_set('need_shape_inference'):
src_node, edge = nodes_edges[port]
if all([attr in edge and edge[attr] == edge_attrs[attr] for attr in edge_attrs]):
graph.remove_edge(src_node.id, node.id)
-
-