From c2394508c158e46b9d2bff61eda5f94d0e4178d9 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Tue, 20 Oct 2020 09:57:55 +0300 Subject: [PATCH] Implement LookupTableInsert shape inference (#2348) * Implement LookupTableInsertV2 shape inference It is needed if other nodes not beeing pruned in the graph have a conditional dependence on LookupTableInsertV2 node. Signed-off-by: Roman Kazantsev * Fix after core-review #1 Signed-off-by: Roman Kazantsev * Fix the code after review #2 Signed-off-by: Roman Kazantsev * Fix after code review #3 --- model-optimizer/automation/package_BOM.txt | 2 + .../extensions/front/tf/LookupTableInsert_ext.py | 38 ++++++++++++ .../extensions/ops/LookupTableInsert.py | 58 +++++++++++++++++ .../extensions/ops/LookupTableInsert_test.py | 72 ++++++++++++++++++++++ 4 files changed, 170 insertions(+) create mode 100644 model-optimizer/extensions/front/tf/LookupTableInsert_ext.py create mode 100644 model-optimizer/extensions/ops/LookupTableInsert.py create mode 100644 model-optimizer/extensions/ops/LookupTableInsert_test.py diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index 60dcace..b3a599a 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -390,6 +390,7 @@ extensions/front/tf/identity_ext.py extensions/front/tf/identityN_to_identity.py extensions/front/tf/InterpolateTransposes.py extensions/front/tf/IteratorGetNext_ext.py +extensions/front/tf/LookupTableInsert_ext.py extensions/front/tf/LoopCond_ext.py extensions/front/tf/lrn_ext.py extensions/front/tf/mask_rcnn_support.json @@ -630,6 +631,7 @@ extensions/ops/identity.py extensions/ops/instance_normalization.py extensions/ops/interp.py extensions/ops/interpolate.py +extensions/ops/LookupTableInsert.py extensions/ops/LSTM.py extensions/ops/lstm_cell.py extensions/ops/lstm_sequence.py diff --git a/model-optimizer/extensions/front/tf/LookupTableInsert_ext.py b/model-optimizer/extensions/front/tf/LookupTableInsert_ext.py new file mode 100644 index 0000000..609291f --- /dev/null +++ b/model-optimizer/extensions/front/tf/LookupTableInsert_ext.py @@ -0,0 +1,38 @@ +""" + Copyright (C) 2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from extensions.ops.LookupTableInsert import LookupTableInsert +from mo.front.extractor import FrontExtractorOp + + +class LookupTableInsertFrontExtractor(FrontExtractorOp): + op = 'LookupTableInsert' + enabled = True + + @classmethod + def extract(cls, node): + LookupTableInsert.update_node_stat(node, {}) + return cls.enabled + + +class LookupTableInsertV2FrontExtractor(FrontExtractorOp): + op = 'LookupTableInsertV2' + enabled = True + + @classmethod + def extract(cls, node): + LookupTableInsert.update_node_stat(node, {}) + return cls.enabled diff --git a/model-optimizer/extensions/ops/LookupTableInsert.py b/model-optimizer/extensions/ops/LookupTableInsert.py new file mode 100644 index 0000000..a225003 --- /dev/null +++ b/model-optimizer/extensions/ops/LookupTableInsert.py @@ -0,0 +1,58 @@ +""" + Copyright (C) 2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import numpy as np + +from mo.front.common.partial_infer.utils import int64_array +from mo.graph.graph import Node, Graph +from mo.ops.op import Op + + +class LookupTableInsert(Op): + ''' + This operation has only output control flow edges and no output data edges in some models. + And for these cases implementation of the shape inference is needed since the shape inference is executed + before control flow edges resolving. This operation has non-tensor output so the output shape is empty. + ''' + enabled = False + op = 'LookupTableInsert' + + def __init__(self, graph: Graph, attrs: dict): + mandatory_props = { + 'type': None, + 'op': self.op, + 'infer': self.infer, + 'in_ports_count': 3, + 'out_ports_count': 1, + } + super().__init__(graph, mandatory_props, attrs) + + @staticmethod + def infer(node: Node): + node_name = node.soft_get('name', node.id) + connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()] + assert len(connected_in_ports) == 3, \ + "Incorrect number of inputs for {} node".format(node_name) + + # check shapes of input tensors + keys_shape = node.in_port(1).data.get_shape() + values_shape = node.in_port(2).data.get_shape() + assert np.array_equal(keys_shape, values_shape), \ + 'Shapes of tensors with keys and values must be equal for {} node'.format(node_name) + + # set output shape that must be empty + # since output is not a tensor + node.out_port(0).data.set_shape(int64_array([])) diff --git a/model-optimizer/extensions/ops/LookupTableInsert_test.py b/model-optimizer/extensions/ops/LookupTableInsert_test.py new file mode 100644 index 0000000..bf822e3 --- /dev/null +++ b/model-optimizer/extensions/ops/LookupTableInsert_test.py @@ -0,0 +1,72 @@ +""" + Copyright (C) 2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest + +import numpy as np + +from extensions.ops.LookupTableInsert import LookupTableInsert +from mo.front.common.partial_infer.utils import int64_array +from mo.graph.graph import Node +from mo.utils.unittest.graph import build_graph + +nodes_attributes = {'table': {'kind': 'op'}, + 'table_data': {'shape': None, 'value': None, 'kind': 'data'}, + 'keys': {'kind': 'op'}, + 'keys_data': {'shape': None, 'value': None, 'kind': 'data'}, + 'values': {'kind': 'op'}, + 'values_data': {'shape': None, 'value': None, 'kind': 'data'}, + 'lookuptableinsert_node': {'op': 'LookupTableInsert', 'kind': 'op'}, + 'output': {'shape': None, 'value': None, 'kind': 'data'}} + +# graph 1 +edges1 = [('table', 'table_data'), + ('keys', 'keys_data'), + ('values', 'values_data'), + ('table_data', 'lookuptableinsert_node', {'in': 0}), + ('keys_data', 'lookuptableinsert_node', {'in': 1}), + ('values_data', 'lookuptableinsert_node', {'in': 2}), + ('lookuptableinsert_node', 'output')] + +# valid test case +inputs1 = {'table_data': {}, + 'keys_data': {'shape': int64_array([4])}, + 'values_data': {'shape': int64_array([4])}} + +# invalid test case +inputs2 = {'table_data': {}, + 'keys_data': {'shape': int64_array([5, 2])}, + 'values_data': {'shape': int64_array([4])}} + +class TestLookupTableInsert(unittest.TestCase): + def test_infer1(self): + graph = build_graph(nodes_attributes, edges1, inputs1) + lookuptableinsert_node = Node(graph, 'lookuptableinsert_node') + LookupTableInsert.infer(lookuptableinsert_node) + + # prepare reference results + ref_output_shape = int64_array([]) + + # get the result + res_output_shape = graph.node['output']['shape'] + + self.assertTrue(np.array_equal(ref_output_shape, res_output_shape), + 'shapes do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) + + def test_infer_invalid1(self): + graph = build_graph(nodes_attributes, edges1, inputs2) + lookuptableinsert_node = Node(graph, 'lookuptableinsert_node') + self.assertRaises(AssertionError, LookupTableInsert.infer, lookuptableinsert_node) -- 2.7.4