From 5cc8114322d7fcd8057a80a3229a9bb16276fa70 Mon Sep 17 00:00:00 2001 From: Evgenya Stepyreva Date: Fri, 29 May 2020 09:11:22 +0300 Subject: [PATCH] [ MO: CVS-32286 ] IdentityN fix (#668) --- .../extensions/front/tf/identityN_to_identity.py | 15 ++++++++++++++- .../extensions/front/tf/identityN_to_identity_test.py | 17 +++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/model-optimizer/extensions/front/tf/identityN_to_identity.py b/model-optimizer/extensions/front/tf/identityN_to_identity.py index 4e3d38f..7578ef9 100644 --- a/model-optimizer/extensions/front/tf/identityN_to_identity.py +++ b/model-optimizer/extensions/front/tf/identityN_to_identity.py @@ -29,6 +29,11 @@ class IdentityN_to_Identity(FrontReplacementPattern): IdentityN Identity Identity / \ | | output_0 output_1 output_0 output_1 + + ATTENTION: not all in/outputs of the IdentityN may survive during ModelOptimizer pipeline. + And it breaks the original operation semantics. + For example, output_1 may be not be used during network output computations. + To preserve this unused in/output ports we disconnect the corresponding out/input port. """ enabled = True @@ -41,12 +46,20 @@ class IdentityN_to_Identity(FrontReplacementPattern): dtypes = node.data_types for idx, port in node.in_ports().items(): - assert node.is_out_port_connected(idx), 'IdentityN {} has inconsistent input and output ports'.format(name) + if not node.is_in_port_connected(idx) or not node.is_out_port_connected(idx): + # ATTENTION section in the description above + continue assert idx < len(dtypes), 'IdentityN {} has inconsistent `data_types` attribute {}'.format(name, dtypes) identity = Identity(graph, {'name': '{}/{}_port'.format(name, idx), 'data_type': dtypes[idx]}).create_node() port.get_connection().set_destination(identity.in_port(0)) node.out_port(idx).get_connection().set_source(identity.out_port(0)) + # ATTENTION section in the description above + for in_port in node.in_ports().values(): + in_port.disconnect() + for out_port in node.out_ports().values(): + out_port.disconnect() + def find_and_replace_pattern(self, graph: Graph): for identityN in graph.get_op_nodes(op='IdentityN'): self.replace_identityN(identityN) diff --git a/model-optimizer/extensions/front/tf/identityN_to_identity_test.py b/model-optimizer/extensions/front/tf/identityN_to_identity_test.py index f6422ce..71571d7 100644 --- a/model-optimizer/extensions/front/tf/identityN_to_identity_test.py +++ b/model-optimizer/extensions/front/tf/identityN_to_identity_test.py @@ -61,3 +61,20 @@ class TestIdentityN(unittest.TestCase): (flag, resp) = compare_graphs(graph, graph_ref, 'output0', check_op_attrs=True) self.assertTrue(flag, resp) + + def test_identityN_unused_ports(self): + graph = build_graph(nodes, [ + *connect('placeholder_0', '0:identityN'), + *connect('placeholder_1', '1:identityN'), + *connect('identityN:0', 'output0'), + ], nodes_with_edges_only=True) + + IdentityN_to_Identity().find_and_replace_pattern(graph) + + graph_ref = build_graph(nodes, [ + *connect('placeholder_0', 'identity0'), + *connect('identity0', 'output0'), + ], nodes_with_edges_only=True) + + (flag, resp) = compare_graphs(graph, graph_ref, 'output0', check_op_attrs=True) + self.assertTrue(flag, resp) -- 2.7.4