Provide GatherND with original layout for inputs and output (#3002)
authorRoman Kazantsev <roman.kazantsev@intel.com>
Tue, 10 Nov 2020 14:24:04 +0000 (17:24 +0300)
committerGitHub <noreply@github.com>
Tue, 10 Nov 2020 14:24:04 +0000 (17:24 +0300)
* Provide GatherND with original layout for inputs and output

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
* Fix code review #1

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
model-optimizer/automation/package_BOM.txt
model-optimizer/extensions/back/LayoutChangeForGatherND.py [new file with mode: 0644]
model-optimizer/extensions/back/LayoutChangeForGatherND_test.py [new file with mode: 0644]

index 5aebddc..925d2a4 100644 (file)
@@ -28,6 +28,7 @@ extensions/back/GroupedConvWeightsNormalize.py
 extensions/back/insert_compatibility_l2normalization.py
 extensions/back/InterpolateReshape.py
 extensions/back/kaldi_remove_memory_output.py
+extensions/back/LayoutChangeForGatherND.py
 extensions/back/LeakyReLUMutation.py
 extensions/back/LRNToNorm.py
 extensions/back/MatMulNormalizer.py
diff --git a/model-optimizer/extensions/back/LayoutChangeForGatherND.py b/model-optimizer/extensions/back/LayoutChangeForGatherND.py
new file mode 100644 (file)
index 0000000..8f6d1c7
--- /dev/null
@@ -0,0 +1,61 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+import numpy as np
+
+from extensions.ops.transpose import Transpose
+from mo.front.common.partial_infer.utils import int64_array
+from mo.front.tf.graph_utils import create_op_with_const_inputs
+from mo.graph.graph import Graph, Port
+from mo.back.replacement import BackReplacementPattern
+
+
+class LayoutChangeForGatherND(BackReplacementPattern):
+    """
+    Return original layout for inputs and output of GatherND operation
+    since the operation is designed for NHWC layout.
+    """
+    enabled = True
+    force_shape_inference = True
+    graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
+
+    @staticmethod
+    def insert_transpose(graph: Graph, input_port: Port, before_input=True):
+        input_rank = len(input_port.data.get_shape())
+        if input_rank > 3:
+            if before_input:
+                axis_order = np.concatenate((int64_array([0]),
+                                             int64_array(list(range(2, input_rank))),
+                                             int64_array([1])))
+                source_node = input_port.get_source().node
+                transpose_name = source_node.soft_get('name', source_node.id) + '/TransposeToNHWC'
+            else:
+                axis_order = np.concatenate(
+                    (int64_array([0]),
+                     int64_array([input_rank - 1]),
+                     int64_array(list(range(1, input_rank - 1)))))
+                transpose_name = input_port.node.soft_get('name', input_port.node.id) + '/TransposeToNCHW'
+                input_port.node['need_shape_inference'] = True
+                input_port.node['override_output_shape'] = True
+            transpose = create_op_with_const_inputs(graph, Transpose, {1: axis_order}, {'name': transpose_name})
+            input_port.get_connection().insert_node(transpose)
+            transpose['need_shape_inference'] = True
+            transpose['override_output_shape'] = True
+
+    def find_and_replace_pattern(self, graph: Graph):
+        for gathernd in graph.get_op_nodes(type='GatherND'):
+            self.insert_transpose(graph, gathernd.in_port(0), before_input=True)
+            self.insert_transpose(graph, gathernd.in_port(1), before_input=True)
+            self.insert_transpose(graph, gathernd.out_port(0), before_input=False)
diff --git a/model-optimizer/extensions/back/LayoutChangeForGatherND_test.py b/model-optimizer/extensions/back/LayoutChangeForGatherND_test.py
new file mode 100644 (file)
index 0000000..7bf611e
--- /dev/null
@@ -0,0 +1,139 @@
+"""
+ Copyright (C) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import unittest
+
+import numpy as np
+
+from extensions.back.LayoutChangeForGatherND import LayoutChangeForGatherND
+from mo.front.common.partial_infer.utils import int64_array
+from mo.utils.ir_engine.compare_graphs import compare_graphs
+from mo.utils.unittest.graph import build_graph
+
+nodes_attributes = {
+    'placeholder_1': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+    'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
+    'placeholder_2': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+    'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
+    # GatherND
+    'gathernd': {'type': 'GatherND', 'kind': 'op', 'op': 'GatherND'},
+    'gathernd_data': {'value': None, 'shape': None, 'kind': 'data'},
+    # Result layer
+    'result': {'type': 'Result', 'kind': 'op', 'op': 'Result'},
+    # Transpose layers
+    'transpose_1': {'type': 'Transpose', 'kind': 'op', 'op': 'Transpose', 'need_shape_inference': True},
+    'transpose_1_data': {'value': None, 'shape': None, 'kind': 'data'},
+    'axis_1_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None},
+    'axis_1_const_data': {'kind': 'data', 'value': None, 'shape': None},
+    'transpose_2': {'type': 'Transpose', 'kind': 'op', 'op': 'Transpose', 'need_shape_inference': True},
+    'transpose_2_data': {'value': None, 'shape': None, 'kind': 'data'},
+    'axis_2_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None},
+    'axis_2_const_data': {'kind': 'data', 'value': None, 'shape': None},
+    'transpose_3': {'type': 'Transpose', 'kind': 'op', 'op': 'Transpose', 'need_shape_inference': True},
+    'transpose_3_data': {'value': None, 'shape': None, 'kind': 'data'},
+    'axis_3_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None},
+    'axis_3_const_data': {'kind': 'data', 'value': None, 'shape': None},
+}
+
+
+class LayoutChangeForGatherNDTests(unittest.TestCase):
+    def test_tf_all_ports(self):
+        graph = build_graph(nodes_attributes,
+                            [('placeholder_1', 'placeholder_1_data'),
+                             ('placeholder_2', 'placeholder_2_data'),
+                             ('placeholder_1_data', 'gathernd'),
+                             ('placeholder_2_data', 'gathernd'),
+                             ('gathernd', 'gathernd_data'),
+                             ('gathernd_data', 'result'),
+                             ],
+                            {'placeholder_1_data': {'shape': np.array([1, 3, 224, 224])},
+                             'placeholder_2_data': {'shape': np.array([1, 3, 224, 224])},
+                             'gathernd_data': {'shape': np.array([1, 3, 224, 224])},
+                             })
+        graph.graph['fw'] = 'tf'
+
+        graph_ref = build_graph(nodes_attributes,
+                                [('placeholder_1', 'placeholder_1_data'),
+                                 ('placeholder_2', 'placeholder_2_data'),
+                                 ('placeholder_1_data', 'transpose_1'),
+                                 ('axis_1_const', 'axis_1_const_data'),
+                                 ('axis_1_const_data', 'transpose_1'),
+                                 ('transpose_1', 'transpose_1_data'),
+                                 ('placeholder_2_data', 'transpose_2'),
+                                 ('axis_2_const', 'axis_2_const_data'),
+                                 ('axis_2_const_data', 'transpose_2'),
+                                 ('transpose_2', 'transpose_2_data'),
+                                 ('transpose_1_data', 'gathernd'),
+                                 ('transpose_2_data', 'gathernd'),
+                                 ('gathernd', 'gathernd_data'),
+                                 ('gathernd_data', 'transpose_3'),
+                                 ('axis_3_const', 'axis_3_const_data'),
+                                 ('axis_3_const_data', 'transpose_3'),
+                                 ('transpose_3', 'transpose_3_data'),
+                                 ('transpose_3_data', 'result'),
+                                 ],
+                                {'placeholder_1_data': {'shape': np.array([1, 3, 224, 224])},
+                                 'placeholder_2_data': {'shape': np.array([1, 3, 224, 224])},
+                                 'axis_1_const_data': {'value': int64_array([0, 2, 3, 1])},
+                                 'axis_2_const_data': {'value': int64_array([0, 2, 3, 1])},
+                                 'gathernd_data': {'shape': np.array([1, 3, 224, 224])},
+                                 'axis_3_const_data': {'value': int64_array([0, 3, 1, 2])},
+                                 })
+
+        pattern = LayoutChangeForGatherND()
+        pattern.find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
+        self.assertTrue(flag, resp)
+
+    def test_tf_one_ports(self):
+        graph = build_graph(nodes_attributes,
+                            [('placeholder_1', 'placeholder_1_data'),
+                             ('placeholder_2', 'placeholder_2_data'),
+                             ('placeholder_1_data', 'gathernd'),
+                             ('placeholder_2_data', 'gathernd'),
+                             ('gathernd', 'gathernd_data'),
+                             ('gathernd_data', 'result'),
+                             ],
+                            {'placeholder_1_data': {'shape': np.array([1, 3, 224, 224])},
+                             'placeholder_2_data': {'shape': np.array([1, 3])},
+                             'gathernd_data': {'shape': np.array([1, 3])},
+                             })
+        graph.graph['fw'] = 'tf'
+
+        graph_ref = build_graph(nodes_attributes,
+                                [('placeholder_1', 'placeholder_1_data'),
+                                 ('placeholder_2', 'placeholder_2_data'),
+                                 ('placeholder_1_data', 'transpose_1'),
+                                 ('axis_1_const', 'axis_1_const_data'),
+                                 ('axis_1_const_data', 'transpose_1'),
+                                 ('transpose_1', 'transpose_1_data'),
+                                 ('transpose_1_data', 'gathernd'),
+                                 ('placeholder_2_data', 'gathernd'),
+                                 ('gathernd', 'gathernd_data'),
+                                 ('gathernd_data', 'result'),
+                                 ],
+                                {'placeholder_1_data': {'shape': np.array([1, 3, 224, 224])},
+                                 'placeholder_2_data': {'shape': np.array([1, 3])},
+                                 'axis_1_const_data': {'value': int64_array([0, 2, 3, 1])},
+                                 'gathernd_data': {'shape': np.array([1, 3])}
+                                 })
+
+        pattern = LayoutChangeForGatherND()
+        pattern.find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
+        self.assertTrue(flag, resp)