Fix for Reduce extractors and normalizer (#3136)
authorEvgeny Lazarev <evgeny.lazarev@intel.com>
Mon, 16 Nov 2020 15:50:13 +0000 (18:50 +0300)
committerGitHub <noreply@github.com>
Mon, 16 Nov 2020 15:50:13 +0000 (18:50 +0300)
* Fixed extractor for ONNX ReduceXXX operations and fixed ReduceAxisNormalizer transformation

* Unit test for ReduceAxisNormalizer transformation

model-optimizer/extensions/front/onnx/reduce_ext.py
model-optimizer/extensions/front/reduce_axis_normalizer.py
model-optimizer/extensions/front/reduce_axis_normalizer_test.py [new file with mode: 0644]

index 2bb3b83..e1c2458 100644 (file)
@@ -22,7 +22,9 @@ from mo.graph.graph import Node
 
 
 def update_reduce_node_attrs_with(node: Node, c: callable):
-    axis = onnx_attr(node, 'axes', 'ints', default=None, dst_type=lambda x: int64_array(x))
+    axis = onnx_attr(node, 'axes', 'ints', default=None)
+    if axis is not None:
+        axis = int64_array(axis)
     keep_dims = onnx_attr(node, 'keepdims', 'i', default=True)
     c.update_node_stat(node, {'axis': axis, 'keep_dims': keep_dims})
 
index 79208ef..dee2492 100644 (file)
 from extensions.ops.ReduceOps import reduce_map
 from extensions.ops.range import Range
 from extensions.ops.rank import Rank
+from mo.front.common.partial_infer.utils import int64_array
 from mo.front.common.replacement import FrontReplacementSubgraph
 from mo.front.subgraph_matcher import SubgraphMatch
+from mo.front.tf.graph_utils import create_op_with_const_inputs
 from mo.graph.graph import Graph
 from mo.ops.const import Const
 
 
 class ReduceAxisNormalizer(FrontReplacementSubgraph):
     """
-    Reduce operation requires information about axis, that is represented in original frameworks differently:
-        - by layer parameter
-        - by 1-port input value
-
-    ReduceAxisNormalizer reforms Reduce operations to store axis info in 1-port input.
+    Reduce operation requires information about axis, that is represented in original frameworks differently: as an
+    operation attribute or as a 1-st input port value. ReduceAxisNormalizer adds second input to Reduce operations with
+    axes to normalize if axes are specified as an attribute.
     """
     enabled = True
-    force_shape_inference = True
 
     def pattern(self):
         return dict(
@@ -50,23 +49,18 @@ class ReduceAxisNormalizer(FrontReplacementSubgraph):
 
             # if the 'axis' is None then we still add a second input to the layer with a 1D array with 1 element equal
             # to None. The infer function handles this case because the input shape is known at this stage only
-            if node.has('axis'):
+            if node.has_valid('axis'):
                 const = Const(graph, {'name': node_name + '/axis', 'value': node.axis}).create_node()
                 node.add_input_port(1, skip_if_exist=True)
                 const.out_port(0).connect(node.in_port(1))
                 del graph.node[node.id]['axis']
             else:
                 # The default (if there is no 'axis') is to reduce over all the dimensions of the input tensor.
-
-                begin_of_range = Const(graph, dict(name=node_name + '/range_begin_', value=0)).create_node()
-                step = Const(graph, dict(name=node_name + '/range_step_', value=1)).create_node()
-                end_of_range = Rank(graph, dict(name=node_name + '/range_end_')).create_node()
-                axes = Range(graph, dict(name=node_name + '/axes_')).create_node()
-
-                begin_of_range.out_port(0).connect(axes.in_port(0))
+                axes = create_op_with_const_inputs(graph, Range, {0: int64_array(0), 2: int64_array(1)},
+                                                   dict(name=node_name + '/axes'))
+                end_of_range = Rank(graph, dict(name=node_name + '/range_end')).create_node()
+                node.in_port(0).get_connection().get_source().connect(end_of_range.in_port(0))
                 end_of_range.out_port(0).connect(axes.in_port(1))
-                step.out_port(0).connect(axes.in_port(2))
 
                 node.add_input_port(1, skip_if_exist=True)
                 axes.out_port(0).connect(node.in_port(1))
-                node.in_port(0).get_connection().get_source().connect(end_of_range.in_port(0))
diff --git a/model-optimizer/extensions/front/reduce_axis_normalizer_test.py b/model-optimizer/extensions/front/reduce_axis_normalizer_test.py
new file mode 100644 (file)
index 0000000..896a1d7
--- /dev/null
@@ -0,0 +1,76 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import unittest
+
+import numpy as np
+
+from extensions.front.reduce_axis_normalizer import ReduceAxisNormalizer
+from mo.front.common.partial_infer.utils import int64_array
+from mo.utils.ir_engine.compare_graphs import compare_graphs
+from mo.utils.unittest.graph import build_graph, result, connect_front, regular_op
+
+nodes = {
+    **regular_op('parameter', {'type': 'Parameter'}),
+    **regular_op('reduce', {'op': 'ReduceSum', 'axis': None}),
+    **regular_op('axis', {'op': 'Const', 'type': 'Const', 'value': int64_array([1])}),
+    **result(),
+}
+
+edges = [
+    *connect_front('parameter:0', '0:reduce'),
+    *connect_front('reduce', 'output'),
+]
+
+
+class ReduceAxisNormalizerTest(unittest.TestCase):
+    def test_reduce_axis_is_None(self):
+        graph = build_graph(nodes, edges, nodes_with_edges_only=True)
+        graph.stage = 'front'
+
+        ReduceAxisNormalizer().find_and_replace_pattern(graph)
+
+        ref_nodes = nodes.copy()
+        ref_nodes.update({**regular_op('rank', {'op': 'Rank', 'type': None}),
+                          **regular_op('range', {'op': 'Range', 'type': 'Range'}),
+                          **regular_op('begin', {'type': 'Const', 'value': int64_array([0])}),
+                          **regular_op('step', {'type': 'Const', 'value': int64_array([1])}),
+                          })
+        graph_ref = build_graph(ref_nodes, [
+            *edges,
+            *connect_front('parameter:0', 'rank'),
+            *connect_front('begin:0', '0:range'),
+            *connect_front('rank:0', '1:range'),
+            *connect_front('step:0', '2:range'),
+            *connect_front('range:0', '1:reduce'),
+        ], nodes_with_edges_only=True)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
+        self.assertTrue(flag, resp)
+
+    def test_reduce_axis_is_const(self):
+        graph = build_graph(nodes, edges, {'reduce': {'axis': 1}}, nodes_with_edges_only=True)
+        graph.stage = 'front'
+
+        graph_ref = build_graph(nodes, [
+            *edges,
+            *connect_front('axis', '1:reduce'),
+        ], {'axis': {'value': np.int64(1)}}, nodes_with_edges_only=True)
+
+        ReduceAxisNormalizer().find_and_replace_pattern(graph)
+
+        (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
+        self.assertTrue(flag, resp)