From 265e3c7cbad122028b6e11c4bdedf445b603840e Mon Sep 17 00:00:00 2001 From: Anton Chetverikov Date: Tue, 2 Jun 2020 12:43:41 +0300 Subject: [PATCH] Remove TopKnormalizer from MO IR Reader transformation_list (#590) * Remove TopKnormalizer from transformation_list and added call of normalize_outputs to fix read/save of some models --- model-optimizer/extensions/back/TopKNormalizer.py | 13 ++++++++++--- model-optimizer/mo/utils/ir_reader/layer_to_class.py | 2 ++ model-optimizer/mo/utils/ir_reader/restore_graph.py | 2 -- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/model-optimizer/extensions/back/TopKNormalizer.py b/model-optimizer/extensions/back/TopKNormalizer.py index 2febe2e..257aaa6 100644 --- a/model-optimizer/extensions/back/TopKNormalizer.py +++ b/model-optimizer/extensions/back/TopKNormalizer.py @@ -19,7 +19,7 @@ from extensions.back.ScalarConstNormalize import ScalarNormalize from mo.back.replacement import BackReplacementPattern from mo.front.common.partial_infer.utils import int64_array from mo.front.tf.graph_utils import create_op_node_with_second_input -from mo.graph.graph import Graph +from mo.graph.graph import Graph, Node from mo.ops.reshape import Reshape from mo.ops.result import Result @@ -58,10 +58,17 @@ class TopKNormalizer(BackReplacementPattern): {'override_output_shape': True}) node.in_port(1).get_connection().insert_node(reshape) + TopKNormalizer.normalize_outputs(node) + + @staticmethod + def normalize_outputs(node: Node): + """ + This function adds missed outputs for TopK node. + """ if node.out_port(0).disconnected(): - output = Result(graph, {'name': node.name + '/Result_port_0/', + output = Result(node.graph, {'name': node.name + '/Result_port_0/', 'remove_from_xml': node.has_and_set('remove_values_output')}).create_node() node.out_port(0).get_connection().set_destination(output.in_port(0)) if node.out_port(1).disconnected(): - output = Result(graph, {'name': node.name + '/Result_port_1/'}).create_node() + output = Result(node.graph, {'name': node.name + '/Result_port_1/'}).create_node() node.out_port(1).get_connection().set_destination(output.in_port(0)) diff --git a/model-optimizer/mo/utils/ir_reader/layer_to_class.py b/model-optimizer/mo/utils/ir_reader/layer_to_class.py index c9ad763..3d23550 100644 --- a/model-optimizer/mo/utils/ir_reader/layer_to_class.py +++ b/model-optimizer/mo/utils/ir_reader/layer_to_class.py @@ -19,6 +19,7 @@ import os import numpy as np +from extensions.back.TopKNormalizer import TopKNormalizer from extensions.ops.Cast import Cast from extensions.ops.ReduceOps import ReduceOp from extensions.ops.activation_ops import Activation @@ -260,6 +261,7 @@ preprocessing_op_nodes = { postprocessing_op_nodes = { 'Assign': assign_add_output_result, 'TensorIterator': ti_add_edge_attrs, + 'TopK': TopKNormalizer.normalize_outputs, } diff --git a/model-optimizer/mo/utils/ir_reader/restore_graph.py b/model-optimizer/mo/utils/ir_reader/restore_graph.py index f1e202e..80ddf09 100644 --- a/model-optimizer/mo/utils/ir_reader/restore_graph.py +++ b/model-optimizer/mo/utils/ir_reader/restore_graph.py @@ -20,7 +20,6 @@ from extensions.back.ConvolutionNormalizer import ConvolutionNormalizer, Convolu from extensions.back.PackBinaryWeights import PackBinaryWeights from extensions.back.SpecialNodesFinalization import RemoveConstOps, CreateConstNodesReplacement from extensions.back.StridedSliceMasksNormalizer import StridedSliceMasksNormalizer -from extensions.back.TopKNormalizer import TopKNormalizer from extensions.back.blob_normalizer import BlobNormalizer from extensions.back.kaldi_remove_memory_output import KaldiRemoveMemoryOutputBackReplacementPattern from mo.graph.graph import Graph @@ -73,7 +72,6 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None): # List items order matters, do not change it. transformation_list = [ ConvolutionWithGroupsResolver, - TopKNormalizer, StridedSliceMasksNormalizer, PackBinaryWeights, BlobNormalizer, -- 2.7.4