From b6a05c232ed8536689be5e67d7fc597a4f654fef Mon Sep 17 00:00:00 2001 From: Evgenya Stepyreva Date: Mon, 25 May 2020 10:52:58 +0300 Subject: [PATCH] [ MO TF ] IdentityN support (#529) --- model-optimizer/automation/package_BOM.txt | 1 + .../extensions/front/caffe/split_to_identity.py | 4 +- .../extensions/front/mxnet/block_grad_ext.py | 4 +- model-optimizer/extensions/front/mxnet/copy_ext.py | 4 +- .../extensions/front/mxnet/dropout_ext.py | 4 +- .../extensions/front/onnx/dropout_ext.py | 4 +- .../extensions/front/tf/identityN_to_identity.py | 52 ++++++++++++++++++ .../front/tf/identityN_to_identity_test.py | 63 ++++++++++++++++++++++ .../extensions/front/tf/identity_ext.py | 23 ++++++-- model-optimizer/extensions/ops/identity.py | 27 +++++++--- .../mo/front/kaldi/extractors/clip_ext.py | 4 +- .../mo/front/kaldi/extractors/noop_ext.py | 4 +- 12 files changed, 170 insertions(+), 24 deletions(-) create mode 100644 model-optimizer/extensions/front/tf/identityN_to_identity.py create mode 100644 model-optimizer/extensions/front/tf/identityN_to_identity_test.py diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index 3fba4a5..104d86f 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -379,6 +379,7 @@ extensions/front/tf/gather_ext.py extensions/front/tf/GatherTree_ext.py extensions/front/tf/GNMT_DynamicSequenceLengths.py extensions/front/tf/identity_ext.py +extensions/front/tf/identityN_to_identity.py extensions/front/tf/InterpolateTransposes.py extensions/front/tf/IteratorGetNext_ext.py extensions/front/tf/LoopCond_ext.py diff --git a/model-optimizer/extensions/front/caffe/split_to_identity.py b/model-optimizer/extensions/front/caffe/split_to_identity.py index dded37f..5d39d2d 100644 --- a/model-optimizer/extensions/front/caffe/split_to_identity.py +++ b/model-optimizer/extensions/front/caffe/split_to_identity.py @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -from extensions.ops.identity import IdentityOp +from extensions.ops.identity import Identity from mo.front.common.replacement import FrontReplacementOp from mo.graph.graph import Graph @@ -33,7 +33,7 @@ class SplitToIdentity(FrontReplacementOp): def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] - identity = IdentityOp(graph, {'name': node.soft_get('name', node.id)}).create_node() + identity = Identity(graph, {'name': node.soft_get('name', node.id)}).create_node() node.in_port(0).get_connection().set_destination(identity.in_port(0)) for idx, port in node.out_ports().items(): diff --git a/model-optimizer/extensions/front/mxnet/block_grad_ext.py b/model-optimizer/extensions/front/mxnet/block_grad_ext.py index 270d651..695e4cc 100644 --- a/model-optimizer/extensions/front/mxnet/block_grad_ext.py +++ b/model-optimizer/extensions/front/mxnet/block_grad_ext.py @@ -14,7 +14,7 @@ limitations under the License. """ -from extensions.ops.identity import IdentityOp +from extensions.ops.identity import Identity from mo.front.extractor import FrontExtractorOp from mo.graph.graph import Node @@ -25,5 +25,5 @@ class BlockGradExt(FrontExtractorOp): @classmethod def extract(cls, node: Node): - IdentityOp.update_node_stat(node, {}) + Identity.update_node_stat(node, {}) return cls.enabled diff --git a/model-optimizer/extensions/front/mxnet/copy_ext.py b/model-optimizer/extensions/front/mxnet/copy_ext.py index 1341ec2..6082865 100644 --- a/model-optimizer/extensions/front/mxnet/copy_ext.py +++ b/model-optimizer/extensions/front/mxnet/copy_ext.py @@ -14,7 +14,7 @@ limitations under the License. """ -from extensions.ops.identity import IdentityOp +from extensions.ops.identity import Identity from mo.front.extractor import FrontExtractorOp from mo.graph.graph import Node @@ -25,5 +25,5 @@ class CopyExt(FrontExtractorOp): @classmethod def extract(cls, node: Node): - IdentityOp.update_node_stat(node, {}) + Identity.update_node_stat(node, {}) return cls.enabled diff --git a/model-optimizer/extensions/front/mxnet/dropout_ext.py b/model-optimizer/extensions/front/mxnet/dropout_ext.py index 497308d..72c084f 100644 --- a/model-optimizer/extensions/front/mxnet/dropout_ext.py +++ b/model-optimizer/extensions/front/mxnet/dropout_ext.py @@ -14,7 +14,7 @@ limitations under the License. """ -from extensions.ops.identity import IdentityOp +from extensions.ops.identity import Identity from mo.front.extractor import FrontExtractorOp from mo.graph.graph import Node @@ -25,5 +25,5 @@ class DropoutExt(FrontExtractorOp): @classmethod def extract(cls, node: Node): - IdentityOp.update_node_stat(node, {}) + Identity.update_node_stat(node, {}) return cls.enabled diff --git a/model-optimizer/extensions/front/onnx/dropout_ext.py b/model-optimizer/extensions/front/onnx/dropout_ext.py index 6d23b3e..f18e8aa 100644 --- a/model-optimizer/extensions/front/onnx/dropout_ext.py +++ b/model-optimizer/extensions/front/onnx/dropout_ext.py @@ -14,7 +14,7 @@ limitations under the License. """ -from extensions.ops.identity import IdentityOp +from extensions.ops.identity import Identity from mo.front.extractor import FrontExtractorOp from mo.front.onnx.extractors.utils import onnx_attr from mo.utils.error import Error @@ -32,5 +32,5 @@ class DropoutFrontExtractor(FrontExtractorOp): raise Error('Dropout node {} has more than one consumer. Unsupported.', node.name) if not is_test: raise Error('Dropout node {} has is_test: 0. This means training mode which is not supported.', node.name) - IdentityOp.update_node_stat(node) + Identity.update_node_stat(node) return cls.enabled diff --git a/model-optimizer/extensions/front/tf/identityN_to_identity.py b/model-optimizer/extensions/front/tf/identityN_to_identity.py new file mode 100644 index 0000000..4e3d38f --- /dev/null +++ b/model-optimizer/extensions/front/tf/identityN_to_identity.py @@ -0,0 +1,52 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from extensions.ops.identity import Identity +from mo.front.common.replacement import FrontReplacementPattern +from mo.graph.graph import Graph, Node + + +class IdentityN_to_Identity(FrontReplacementPattern): + """ + Replaces IdentityN op with several Identity ops. + + Example: + input_0 input_1 input_0 input_1 + \ / | | + IdentityN Identity Identity + / \ | | + output_0 output_1 output_0 output_1 + """ + enabled = True + + @staticmethod + def replace_identityN(node: Node): + graph = node.graph + name = node.soft_get('name', node.id) + + assert node.has_valid('data_types'), 'IdentityN {} has no `data_types` attribute'.format(name) + dtypes = node.data_types + + for idx, port in node.in_ports().items(): + assert node.is_out_port_connected(idx), 'IdentityN {} has inconsistent input and output ports'.format(name) + assert idx < len(dtypes), 'IdentityN {} has inconsistent `data_types` attribute {}'.format(name, dtypes) + identity = Identity(graph, {'name': '{}/{}_port'.format(name, idx), 'data_type': dtypes[idx]}).create_node() + port.get_connection().set_destination(identity.in_port(0)) + node.out_port(idx).get_connection().set_source(identity.out_port(0)) + + def find_and_replace_pattern(self, graph: Graph): + for identityN in graph.get_op_nodes(op='IdentityN'): + self.replace_identityN(identityN) diff --git a/model-optimizer/extensions/front/tf/identityN_to_identity_test.py b/model-optimizer/extensions/front/tf/identityN_to_identity_test.py new file mode 100644 index 0000000..f6422ce --- /dev/null +++ b/model-optimizer/extensions/front/tf/identityN_to_identity_test.py @@ -0,0 +1,63 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest + +import numpy as np + +from extensions.front.tf.identityN_to_identity import IdentityN_to_Identity +from mo.utils.ir_engine.compare_graphs import compare_graphs +from mo.utils.unittest.graph import result, regular_op_with_shaped_data, \ + regular_op_with_empty_data, build_graph, connect, empty_data + +nodes = { + **regular_op_with_shaped_data('placeholder_0', [1, 227, 227, 3], {'type': 'Parameter'}), + **regular_op_with_shaped_data('placeholder_1', [1, 227, 227, 3], {'type': 'Parameter'}), + + **regular_op_with_empty_data('identityN', {'op': 'IdentityN', 'type': None, 'data_types': [np.int32, np.float], + 'name': 'my_identity'}), + **empty_data('identityN_1_d'), + **regular_op_with_empty_data('identity0', {'op': 'Identity', 'type': None, 'data_type': np.int32, + 'name': 'my_identity/0_port'}), + **regular_op_with_empty_data('identity1', {'op': 'Identity', 'type': None, 'data_type': np.float, + 'name': 'my_identity/1_port'}), + + **result('output0'), + **result('output1'), +} + + +class TestIdentityN(unittest.TestCase): + def test_identityN(self): + graph = build_graph(nodes, [ + *connect('placeholder_0', '0:identityN'), + *connect('placeholder_1', '1:identityN'), + *connect('identityN:0', 'output0'), + ('identityN', 'identityN_1_d', {'out': 1}), + ('identityN_1_d', 'output1', {'out': 1}), + ], nodes_with_edges_only=True) + + IdentityN_to_Identity().find_and_replace_pattern(graph) + + graph_ref = build_graph(nodes, [ + *connect('placeholder_0', 'identity0'), + *connect('placeholder_1', 'identity1'), + *connect('identity0', 'output0'), + *connect('identity1', 'output1'), + ], nodes_with_edges_only=True) + + (flag, resp) = compare_graphs(graph, graph_ref, 'output0', check_op_attrs=True) + self.assertTrue(flag, resp) diff --git a/model-optimizer/extensions/front/tf/identity_ext.py b/model-optimizer/extensions/front/tf/identity_ext.py index b876050..bdd9158 100644 --- a/model-optimizer/extensions/front/tf/identity_ext.py +++ b/model-optimizer/extensions/front/tf/identity_ext.py @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -from extensions.ops.identity import IdentityOp +from extensions.ops.identity import Identity, IdentityN from mo.front.extractor import FrontExtractorOp from mo.front.tf.extractors.utils import tf_dtype_extractor from mo.graph.graph import Node @@ -25,19 +25,34 @@ class IdentityFrontExtractor(FrontExtractorOp): @classmethod def extract(cls, node: Node): - IdentityOp.update_node_stat(node, { + Identity.update_node_stat(node, { 'data_type': tf_dtype_extractor(node.pb.attr["T"].type), }) return cls.enabled +class IdentityNFrontExtractor(FrontExtractorOp): + op = 'IdentityN' + enabled = True + + @classmethod + def extract(cls, node: Node): + dtypes = [tf_dtype_extractor(t) for t in node.pb.attr["T"].list.type] + IdentityN.update_node_stat(node, { + 'data_types': dtypes, + 'in_ports_count': len(dtypes), + 'out_ports_count': len(dtypes), + }) + return cls.enabled + + class ReadVariableOpFrontExtractor(FrontExtractorOp): op = 'ReadVariableOp' enabled = True @classmethod def extract(cls, node: Node): - IdentityOp.update_node_stat(node, { + Identity.update_node_stat(node, { 'data_type': tf_dtype_extractor(node.pb.attr["T"].type), }) return cls.enabled @@ -49,5 +64,5 @@ class StopGradientExtractor(FrontExtractorOp): @classmethod def extract(cls, node: Node): - IdentityOp.update_node_stat(node, {'op': 'StopGradient'}) + Identity.update_node_stat(node, {'op': 'StopGradient'}) return cls.enabled diff --git a/model-optimizer/extensions/ops/identity.py b/model-optimizer/extensions/ops/identity.py index ea8fd99..72dcfe7 100644 --- a/model-optimizer/extensions/ops/identity.py +++ b/model-optimizer/extensions/ops/identity.py @@ -13,24 +13,39 @@ See the License for the specific language governing permissions and limitations under the License. """ -from mo.front.common.partial_infer.elemental import copy_shape_infer, copy_value from mo.graph.graph import Graph from mo.ops.op import Op -class IdentityOp(Op): +class Identity(Op): op = 'Identity' enabled = True def __init__(self, graph: Graph, attrs: dict): super().__init__(graph, { - 'op': __class__.op, + 'op': self.op, + 'type': None, + 'identity': True, + 'infer': self.infer, + 'in_ports_count': 1, 'out_ports_count': 1, - 'infer': IdentityOp.shape_infer }, attrs) @staticmethod - def shape_infer(node): - copy_shape_infer(node, value_infer=copy_value) + def infer(node): + node.out_port(0).data.set_shape(node.in_port(0).data.get_shape()) + if node.in_port(0).data.get_value() is not None: + node.out_port(0).data.set_value(node.in_port(0).data.get_value()) + + +class IdentityN(Op): + op = 'IdentityN' + enabled = True + + def __init__(self, graph: Graph, attrs: dict): + super().__init__(graph, { + 'op': self.op, + 'type': None, + }, attrs) diff --git a/model-optimizer/mo/front/kaldi/extractors/clip_ext.py b/model-optimizer/mo/front/kaldi/extractors/clip_ext.py index c5f1e3a..5d07dc6 100644 --- a/model-optimizer/mo/front/kaldi/extractors/clip_ext.py +++ b/model-optimizer/mo/front/kaldi/extractors/clip_ext.py @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -from extensions.ops.identity import IdentityOp +from extensions.ops.identity import Identity from mo.front.extractor import FrontExtractorOp @@ -23,5 +23,5 @@ class ClipGradientComponentFrontExtractor(FrontExtractorOp): @classmethod def extract(cls, node): - IdentityOp.update_node_stat(node, {}) + Identity.update_node_stat(node, {}) return cls.enabled diff --git a/model-optimizer/mo/front/kaldi/extractors/noop_ext.py b/model-optimizer/mo/front/kaldi/extractors/noop_ext.py index a197068..26258f5 100644 --- a/model-optimizer/mo/front/kaldi/extractors/noop_ext.py +++ b/model-optimizer/mo/front/kaldi/extractors/noop_ext.py @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -from extensions.ops.identity import IdentityOp +from extensions.ops.identity import Identity from mo.front.extractor import FrontExtractorOp @@ -23,5 +23,5 @@ class NoOpFrontExtractor(FrontExtractorOp): @classmethod def extract(cls, node): - IdentityOp.update_node_stat(node) + Identity.update_node_stat(node) return cls.enabled -- 2.7.4