Extend MO for operation GatherND (#2540)
authorRoman Kazantsev <roman.kazantsev@intel.com>
Thu, 15 Oct 2020 04:40:58 +0000 (07:40 +0300)
committerGitHub <noreply@github.com>
Thu, 15 Oct 2020 04:40:58 +0000 (07:40 +0300)
* Extend MO for operation GatherND

* Update documentation

* Rename GatherNd.py to gathernd.py

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md
model-optimizer/automation/package_BOM.txt
model-optimizer/extensions/front/onnx/gathernd_ext.py [new file with mode: 0644]
model-optimizer/extensions/front/tf/gathernd_ext.py [new file with mode: 0644]
model-optimizer/extensions/middle/GatherNdNormalizer.py
model-optimizer/extensions/ops/GatherNd.py [deleted file]
model-optimizer/extensions/ops/gathernd.py [new file with mode: 0644]
model-optimizer/extensions/ops/gathernd_test.py [new file with mode: 0644]

index 50920f6..78b47d2 100644 (file)
@@ -158,7 +158,7 @@ Standard TensorFlow\* operations:
 | FloorDiv | No |
 | FusedBatchNorm | No |
 | Gather | No |
-| GatherNd | Supported if it can be replaced with Gather |
+| GatherNd | No |
 | GatherV2 | No |
 | Greater | No |
 | GreaterEqual | No |
@@ -337,6 +337,7 @@ Standard ONNX\* operators:
 | Floor | No |
 | GRU | No |
 | Gather | No |
+| GatherND | No |
 | GatherTree | No |
 | Gemm | No |
 | GlobalAveragePool | No |
index d8e58a8..60dcace 100644 (file)
@@ -258,6 +258,7 @@ extensions/front/onnx/expand_ext.py
 extensions/front/onnx/flatten_ext.py
 extensions/front/onnx/flattenONNX_to_reshape.py
 extensions/front/onnx/gather_ext.py
+extensions/front/onnx/gathernd_ext.py
 extensions/front/onnx/gemm_ext.py
 extensions/front/onnx/group_norm_ext.py
 extensions/front/onnx/gru_ext.py
@@ -382,6 +383,7 @@ extensions/front/tf/FlattenToReshape.py
 extensions/front/tf/floor_div_decomposition.py
 extensions/front/tf/floor_ext.py
 extensions/front/tf/gather_ext.py
+extensions/front/tf/gathernd_ext.py
 extensions/front/tf/GatherTree_ext.py
 extensions/front/tf/GNMT_DynamicSequenceLengths.py
 extensions/front/tf/identity_ext.py
@@ -617,7 +619,7 @@ extensions/ops/ExtractImagePatches.py
 extensions/ops/fake_output.py
 extensions/ops/fakequantize.py
 extensions/ops/gather.py
-extensions/ops/GatherNd.py
+extensions/ops/gathernd.py
 extensions/ops/GatherTree.py
 extensions/ops/gelu.py
 extensions/ops/grn.py
diff --git a/model-optimizer/extensions/front/onnx/gathernd_ext.py b/model-optimizer/extensions/front/onnx/gathernd_ext.py
new file mode 100644 (file)
index 0000000..34be3aa
--- /dev/null
@@ -0,0 +1,32 @@
+"""
+ Copyright (C) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from extensions.ops.gathernd import GatherND
+from mo.front.extractor import FrontExtractorOp
+from mo.front.onnx.extractors.utils import onnx_attr
+
+
+class GatherNDFrontExtractor(FrontExtractorOp):
+    op = 'GatherND'
+    enabled = True
+
+    @classmethod
+    def extract(cls, node):
+        attrs = {
+            'batch_dims': onnx_attr(node, 'batch_dims', 'i', default=0)
+        }
+        GatherND.update_node_stat(node, attrs)
+        return cls.enabled
diff --git a/model-optimizer/extensions/front/tf/gathernd_ext.py b/model-optimizer/extensions/front/tf/gathernd_ext.py
new file mode 100644 (file)
index 0000000..24c1a44
--- /dev/null
@@ -0,0 +1,30 @@
+"""
+ Copyright (C) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+from extensions.ops.gathernd import GatherND
+from mo.front.extractor import FrontExtractorOp
+
+
+class GatherNDFrontExtractor(FrontExtractorOp):
+    op = 'GatherNd'
+    enabled = True
+
+    @classmethod
+    def extract(cls, node):
+        attrs = {
+            'batch_dims': 0,
+        }
+        GatherND.update_node_stat(node, attrs)
+        return cls.enabled
index 469433b..0a973ad 100644 (file)
@@ -25,11 +25,14 @@ from mo.middle.replacement import MiddleReplacementPattern
 from mo.ops.reshape import Reshape
 
 
-class GatherNdNormalize(MiddleReplacementPattern):
+class GatherNDNormalize(MiddleReplacementPattern):
     """
     Hot fix for new speech-to-text model enabling while GatherND is not implemented in IE.
-    We can replace GatherNd to Reshape + Gather in case when GatherNd indices have just one
+    We can replace GatherND to Reshape + Gather in case when GatherND indices have just one
     meaningful dimension.
+    TODO: Investigate whether we must replace GatherND with Reshape + Gather always (due to performance benefits)
+          for this particular case or only if the plugin does not support GatherND.
+          And the best place for the transformation is nGraph so we need to move it.
     """
     enabled = True
     force_clean_up = True
@@ -44,7 +47,7 @@ class GatherNdNormalize(MiddleReplacementPattern):
 
     def pattern(self):
         return dict(
-            nodes=[('GatherNd', dict(kind='op', op='GatherNd'))],
+            nodes=[('GatherND', dict(kind='op', op='GatherND', batch_dims=0))],
             edges=[]
         )
 
@@ -67,7 +70,7 @@ class GatherNdNormalize(MiddleReplacementPattern):
         return non_zero
 
     def replace_pattern(self, graph: Graph, match: dict):
-        gather = match['GatherNd']
+        gather = match['GatherND']
         gather_name = gather.soft_get('name', gather.id)
         input_shape = gather.in_node(0).shape
         indices = gather.in_node(1).value
@@ -75,16 +78,16 @@ class GatherNdNormalize(MiddleReplacementPattern):
             # We can't do such special pass without indices value
             return
 
-        # 0. All needed checks that we can replace GatherNd by Gather
+        # 0. All needed checks that we can replace GatherND by Gather
         gather_idx = self.indices_check(indices, input_shape)
         if gather_idx is None:
-            log.warning('Node {} with op=GatherN can\'t be normalized to op=Gather.'.format(gather_name))
+            log.warning('Node {} with op=GatherND can\'t be normalized to op=Gather.'.format(gather_name))
             return
 
         # 1. Add Reshape and connect
         new_shape = int64_array([-1] + list(input_shape[indices.shape[-1]:]))
         reshape = create_op_node_with_second_input(graph, Reshape, new_shape,
-                                                   {'name': gather_name + '/Reshape_for_GatherNd/'})
+                                                   {'name': gather_name + '/Reshape_for_GatherND/'})
         gather.in_port(0).get_connection().set_destination(reshape.in_port(0))
 
         # 2. Change indices from Nd to 1d:
diff --git a/model-optimizer/extensions/ops/GatherNd.py b/model-optimizer/extensions/ops/GatherNd.py
deleted file mode 100644 (file)
index 219da66..0000000
+++ /dev/null
@@ -1,47 +0,0 @@
-"""
- Copyright (C) 2018-2020 Intel Corporation
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
-      http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-"""
-
-import numpy as np
-
-from mo.graph.graph import Node, Graph
-from mo.ops.op import Op
-
-
-class GatherNd(Op):
-    op = 'GatherNd'
-
-    def __init__(self, graph: Graph, attrs: dict):
-        mandatory_props = {
-            'op': __class__.op,
-            'infer': __class__.infer,
-            'in_ports_count': 2,
-            'out_ports_count': 1,
-        }
-        super().__init__(graph, mandatory_props, attrs)
-
-    def supported_attrs(self):
-        return []
-
-    @staticmethod
-    def infer(node: Node):
-        input_node = node.in_node(0)
-        indices = node.in_node(1).value
-
-        assert indices is not None
-
-        output_shape = list(indices.shape[:-1]) + list(input_node.shape[indices.shape[-1]:])
-        node.out_node().shape = np.array(output_shape, dtype=np.int64)
-        # TODO: implement constant path
diff --git a/model-optimizer/extensions/ops/gathernd.py b/model-optimizer/extensions/ops/gathernd.py
new file mode 100644 (file)
index 0000000..ff69ce7
--- /dev/null
@@ -0,0 +1,102 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import numpy as np
+
+from mo.front.common.partial_infer.utils import int64_array
+from mo.graph.graph import Node, Graph
+from mo.ops.op import Op
+
+
+class GatherND(Op):
+    op = 'GatherND'
+
+    def __init__(self, graph: Graph, attrs: dict):
+        mandatory_props = {
+            'type': self.op,
+            'op': self.op,
+            'version': 'opset5',
+            'infer': self.infer,
+            'in_ports_count': 2,
+            'out_ports_count': 1,
+            'batch_dims': 0
+        }
+        super().__init__(graph, mandatory_props, attrs)
+
+    def backend_attrs(self):
+        return ['batch_dims']
+
+    @staticmethod
+    def infer(node: Node):
+        node_name = node.soft_get('name', node.id)
+        connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
+        assert len(connected_in_ports) == 2, \
+            "Incorrect number of inputs for {} node".format(node_name)
+
+        data_shape = node.in_port(0).data.get_shape()
+        data_value = node.in_port(0).data.get_value()
+        indices_shape = node.in_port(1).data.get_shape()
+        indices_value = node.in_port(1).data.get_value()
+
+        assert node.has_valid('batch_dims'),  "Node {} must contain `batch_dims` attribute".format(node_name)
+        batch_dims = node.batch_dims
+
+        # check that a number of batch dimensions is less than both ranks of data and indices tensors
+        assert batch_dims < len(data_shape), "Number of batch dimensions must be less than a rank of data"
+        assert batch_dims < len(indices_shape), "Number of batch dimensions must be less than a rank of indices"
+
+        # check that batch dimensions of data and indices are the same
+        for batch_dim in range(batch_dims):
+            assert data_shape[batch_dim] == indices_shape[batch_dim], \
+                "The dimension {} for data and indices tensors must be the same".format(batch_dim)
+
+        # check ranks of input tensors
+        assert len(data_shape) > 0, "Data must not be a scalar"
+        assert len(indices_shape) > 0, "Indices must not be a scalar"
+        assert (batch_dims + indices_shape[-1]) <= len(data_shape), \
+            "Length of a tuple with indices must not exceed a rank of data tensor excluding batch dimensions"
+
+        # compute output shape
+        number_batches = [np.prod(data_shape[:batch_dims]).tolist()] if batch_dims > 0 else list()
+        slice_shape = list(data_shape[(batch_dims + indices_shape[-1]):])
+        output_shape = number_batches + list(indices_shape[batch_dims:-1]) + slice_shape
+        node.out_port(0).data.set_shape(int64_array(output_shape))
+
+        # compute output value if all input values are defined
+        if data_value is not None and indices_value is not None:
+            output_value = np.zeros(output_shape, dtype=data_value.dtype)
+            if batch_dims == 0:
+                output_indices_range = int64_array(indices_shape[:-1])
+                for output_index in np.ndindex(tuple(output_indices_range)):
+                    indices_tuple = indices_value[output_index]
+                    output_value[output_index] = data_value[tuple(indices_tuple.T)]
+            else:
+                batch_dims_range = int64_array(indices_shape[:batch_dims])
+                for batch_indices in np.ndindex(tuple(batch_dims_range)):
+                    # compute batch index in output tensor
+                    batch_ind = 0
+                    num_elements = 1
+                    for ind in reversed(range(len(batch_dims_range))):
+                        batch_ind += batch_indices[ind] * num_elements
+                        num_elements *= batch_dims_range[ind]
+                    output_indices_range = int64_array(indices_shape[batch_dims:-1])
+                    for output_index in np.ndindex(tuple(output_indices_range)):
+                        tmp_ind = batch_indices + output_index
+                        indices_tuple = tuple(indices_value[tmp_ind].T)
+                        full_input_ind = batch_indices + indices_tuple
+                        full_output_ind = tuple(np.array([batch_ind]).T) + output_index
+                        output_value[full_output_ind] = data_value[full_input_ind]
+            node.out_port(0).data.set_value(output_value)
diff --git a/model-optimizer/extensions/ops/gathernd_test.py b/model-optimizer/extensions/ops/gathernd_test.py
new file mode 100644 (file)
index 0000000..da27f49
--- /dev/null
@@ -0,0 +1,254 @@
+"""
+ Copyright (C) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import unittest
+
+import numpy as np
+
+from extensions.ops.gathernd import GatherND
+from mo.front.common.partial_infer.utils import int64_array
+from mo.graph.graph import Node
+from mo.utils.unittest.graph import build_graph
+
+nodes_attributes = {'data': {'kind': 'op'},
+                    'data_data': {'shape': None, 'value': None, 'kind': 'data'},
+                    'indices': {'kind': 'op'},
+                    'indices_data': {'shape': None, 'value': None, 'kind': 'data'},
+                    'gathernd_node': {'op': 'ScatterNDUpdate', 'kind': 'op', 'batch_dims': 0},
+                    'output': {'shape': None, 'value': None, 'kind': 'data'}}
+
+# graph 1
+edges = [('data', 'data_data', {'in': 0}),
+         ('indices', 'indices_data', {'in': 1}),
+         ('data_data', 'gathernd_node', {'in': 0}),
+         ('indices_data', 'gathernd_node', {'in': 1}),
+         ('gathernd_node', 'output', {'out': 0})]
+
+# test data for partial infer: gather elements
+inputs1 = {'data_data': {'shape': int64_array([10, 40]), 'value': None},
+           'indices_data': {'shape': int64_array([3, 2]), 'value': None}}
+
+# test data for partial infer: gather slices
+inputs2 = {'data_data': {'shape': int64_array([10, 40, 30]), 'value': None},
+           'indices_data': {'shape': int64_array([3, 2]), 'value': None}}
+
+# test data for partial infer: gather slices and batch_dims=2
+inputs3 = {'data_data': {'shape': int64_array([10, 40, 4, 9]), 'value': None},
+           'indices_data': {'shape': int64_array([10, 40, 3, 5, 1]), 'value': None}}
+
+# test data for constant folding: gather elements, batch_dims = 0
+inputs4 = {'data_data': {'shape': int64_array([2, 2]), 'value': int64_array([[1, 2],
+                                                                             [3, 4]])},
+           'indices_data': {'shape': int64_array([2, 2]), 'value': int64_array([[0, 0],
+                                                                                [1, 0]])}}
+output4 = int64_array([1, 3])
+
+# test data for constant folding: gather slices, batch_dims = 0
+inputs5 = {'data_data': {'shape': int64_array([2, 3, 4]), 'value': int64_array([[[1, 2, 3, 4],
+                                                                                 [5, 6, 7, 8],
+                                                                                 [9, 10, 11, 12]],
+                                                                                [[13, 14, 15, 16],
+                                                                                 [17, 18, 19, 20],
+                                                                                 [21, 22, 23, 24]]])},
+           'indices_data': {'shape': int64_array([3, 2]), 'value': int64_array([[0, 1],
+                                                                                [1, 0],
+                                                                                [1, 2]])}}
+output5 = int64_array([[5, 6, 7, 8],
+                       [13, 14, 15, 16],
+                       [21, 22, 23, 24]])
+
+# test data for constant folding: gather slices, batch_dims = 1
+inputs6 = {'data_data': {'shape': int64_array([2, 3, 4]), 'value': int64_array([[[1, 2, 3, 4],
+                                                                                 [5, 6, 7, 8],
+                                                                                 [9, 10, 11, 12]],
+                                                                                [[13, 14, 15, 16],
+                                                                                 [17, 18, 19, 20],
+                                                                                 [21, 22, 23, 24]]])},
+           'indices_data': {'shape': int64_array([2, 1]), 'value': int64_array([[1],
+                                                                                [0]])}}
+output6 = int64_array([[5, 6, 7, 8],
+                       [13, 14, 15, 16]])
+
+# test data for constant folding: gather slices with leading dimensions, batch_dims = 2
+inputs7 = {'data_data': {'shape': int64_array([2, 3, 4]), 'value': int64_array([[[1, 2, 3, 4],
+                                                                                 [5, 6, 7, 8],
+                                                                                 [9, 10, 11, 12]],
+                                                                                [[13, 14, 15, 16],
+                                                                                 [17, 18, 19, 20],
+                                                                                 [21, 22, 23, 24]]])},
+           'indices_data': {'shape': int64_array([2, 3, 1, 1]), 'value': int64_array([[[[1]],
+                                                                                       [[0]],
+                                                                                       [[2]]],
+                                                                                      [[[0]],
+                                                                                       [[2]],
+                                                                                       [[2]]]])}}
+output7 = int64_array([[2], [5], [11], [13], [19], [23]])
+
+# test data for constant folding: gather elements, batch_dims = 2
+inputs8 = {'data_data': {'shape': int64_array([2, 3, 4, 2]),
+                         'value': int64_array([[[[1, 2], [3, 4], [5, 6], [7, 8]],
+                                                [[9, 10], [11, 12], [13, 14], [15, 16]],
+                                                [[17, 18], [19, 20], [21, 22], [23, 24]]],
+                                               [[[25, 26], [27, 28], [29, 30], [31, 32]],
+                                                [[33, 34], [35, 36], [37, 38], [39, 40]],
+                                                [[41, 42], [43, 44], [45, 46], [47, 48]]]])},
+           'indices_data': {'shape': int64_array([2, 3, 3, 2]),
+                            'value': int64_array([[[[1, 0], [3, 1], [2, 1]],
+                                                   [[0, 1], [1, 1], [2, 0]],
+                                                   [[3, 0], [3, 1], [2, 1]]],
+                                                  [[[2, 0], [1, 1], [3, 1]],
+                                                   [[1, 1], [2, 0], [2, 0]],
+                                                   [[0, 0], [3, 1], [3, 1]]]])}}
+output8 = int64_array([[3, 8, 6],
+                       [10, 12, 13],
+                       [23, 24, 22],
+                       [29, 28, 32],
+                       [36, 37, 37],
+                       [41, 48, 48]])
+
+# invalid test case with incorrect rank for indices
+inputs_inv1 = {'data_data': {'shape': int64_array([10, 40]), 'value': None},
+               'indices_data': {'shape': int64_array([5, 3, 4]), 'value': None}}
+
+# invalid test case with unequal batch dimensions, batch_dims = 2
+inputs_inv2 = {'data_data': {'shape': int64_array([10, 40, 20]), 'value': None},
+               'indices_data': {'shape': int64_array([5, 3, 4]), 'value': None}}
+
+# invalid test case with indices rank greater than a rank of data excluding batch dimensions, batch_dims = 2
+inputs_inv3 = {'data_data': {'shape': int64_array([10, 40, 20, 10, 2]), 'value': None},
+               'indices_data': {'shape': int64_array([10, 40, 4]), 'value': None}}
+
+class TestScatterNDUpdate(unittest.TestCase):
+    def setUp(self):
+        nodes_attributes['gathernd_node']['batch_dims'] = 0
+
+    def test_partial_infer_gather_element(self):
+        graph = build_graph(nodes_attributes, edges, inputs1)
+        gathernd_node = Node(graph, 'gathernd_node')
+        GatherND.infer(gathernd_node)
+
+        # prepare reference results
+        ref_output_shape = int64_array([3])
+
+        # get the result
+        res_output_shape = graph.node['output']['shape']
+
+        self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
+                        'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
+
+    def test_partial_infer_gather_slice(self):
+        graph = build_graph(nodes_attributes, edges, inputs2)
+        gathernd_node = Node(graph, 'gathernd_node')
+        GatherND.infer(gathernd_node)
+
+        # prepare reference results
+        ref_output_shape = int64_array([3, 30])
+
+        # get the result
+        res_output_shape = graph.node['output']['shape']
+
+        self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
+                        'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
+
+    def test_partial_infer_gather_slice_batch_dims2(self):
+        nodes_attributes['gathernd_node']['batch_dims'] = 2
+        graph = build_graph(nodes_attributes, edges, inputs3)
+        gathernd_node = Node(graph, 'gathernd_node')
+        GatherND.infer(gathernd_node)
+
+        # prepare reference results
+        ref_output_shape = int64_array([400, 3, 5, 9])
+
+        # get the result
+        res_output_shape = graph.node['output']['shape']
+
+        self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
+                        'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
+
+    def test_infer4(self):
+        graph = build_graph(nodes_attributes, edges, inputs4)
+        gathernd_node = Node(graph, 'gathernd_node')
+        GatherND.infer(gathernd_node)
+
+        # get the result
+        res_output_value = graph.node['output']['value']
+
+        self.assertTrue(np.array_equal(output4, res_output_value),
+                        'values do not match expected: {} and given: {}'.format(output4, res_output_value))
+
+    def test_infer5(self):
+        graph = build_graph(nodes_attributes, edges, inputs5)
+        gathernd_node = Node(graph, 'gathernd_node')
+        GatherND.infer(gathernd_node)
+
+        # get the result
+        res_output_value = graph.node['output']['value']
+
+        self.assertTrue(np.array_equal(output5, res_output_value),
+                        'values do not match expected: {} and given: {}'.format(output4, res_output_value))
+
+    def test_infer6(self):
+        nodes_attributes['gathernd_node']['batch_dims'] = 1
+        graph = build_graph(nodes_attributes, edges, inputs6)
+        gathernd_node = Node(graph, 'gathernd_node')
+        GatherND.infer(gathernd_node)
+
+        # get the result
+        res_output_value = graph.node['output']['value']
+
+        self.assertTrue(np.array_equal(output6, res_output_value),
+                        'values do not match expected: {} and given: {}'.format(output4, res_output_value))
+
+    def test_infer7(self):
+        nodes_attributes['gathernd_node']['batch_dims'] = 2
+        graph = build_graph(nodes_attributes, edges, inputs7)
+        gathernd_node = Node(graph, 'gathernd_node')
+        GatherND.infer(gathernd_node)
+
+        # get the result
+        res_output_value = graph.node['output']['value']
+
+        self.assertTrue(np.array_equal(output7, res_output_value),
+                        'values do not match expected: {} and given: {}'.format(output4, res_output_value))
+
+    def test_infer8(self):
+        nodes_attributes['gathernd_node']['batch_dims'] = 2
+        graph = build_graph(nodes_attributes, edges, inputs8)
+        gathernd_node = Node(graph, 'gathernd_node')
+        GatherND.infer(gathernd_node)
+
+        # get the result
+        res_output_value = graph.node['output']['value']
+
+        self.assertTrue(np.array_equal(output8, res_output_value),
+                        'values do not match expected: {} and given: {}'.format(output4, res_output_value))
+
+    def test_infer_invalid1(self):
+        graph = build_graph(nodes_attributes, edges, inputs_inv1)
+        gathernd_node = Node(graph, 'gathernd_node')
+        self.assertRaises(AssertionError, GatherND.infer, gathernd_node)
+
+    def test_infer_invalid2(self):
+        nodes_attributes['gathernd_node']['batch_dims'] = 2
+        graph = build_graph(nodes_attributes, edges, inputs_inv2)
+        gathernd_node = Node(graph, 'gathernd_node')
+        self.assertRaises(AssertionError, GatherND.infer, gathernd_node)
+
+    def test_infer_invalid3(self):
+        nodes_attributes['gathernd_node']['batch_dims'] = 2
+        graph = build_graph(nodes_attributes, edges, inputs_inv3)
+        gathernd_node = Node(graph, 'gathernd_node')
+        self.assertRaises(AssertionError, GatherND.infer, gathernd_node)