From 0a9d883d7813cae4d07cc104118e4425d436d30c Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Mon, 16 Nov 2020 18:50:13 +0300 Subject: [PATCH] Fix for Reduce extractors and normalizer (#3136) * Fixed extractor for ONNX ReduceXXX operations and fixed ReduceAxisNormalizer transformation * Unit test for ReduceAxisNormalizer transformation --- .../extensions/front/onnx/reduce_ext.py | 4 +- .../extensions/front/reduce_axis_normalizer.py | 26 +++----- .../front/reduce_axis_normalizer_test.py | 76 ++++++++++++++++++++++ 3 files changed, 89 insertions(+), 17 deletions(-) create mode 100644 model-optimizer/extensions/front/reduce_axis_normalizer_test.py diff --git a/model-optimizer/extensions/front/onnx/reduce_ext.py b/model-optimizer/extensions/front/onnx/reduce_ext.py index 2bb3b83..e1c2458 100644 --- a/model-optimizer/extensions/front/onnx/reduce_ext.py +++ b/model-optimizer/extensions/front/onnx/reduce_ext.py @@ -22,7 +22,9 @@ from mo.graph.graph import Node def update_reduce_node_attrs_with(node: Node, c: callable): - axis = onnx_attr(node, 'axes', 'ints', default=None, dst_type=lambda x: int64_array(x)) + axis = onnx_attr(node, 'axes', 'ints', default=None) + if axis is not None: + axis = int64_array(axis) keep_dims = onnx_attr(node, 'keepdims', 'i', default=True) c.update_node_stat(node, {'axis': axis, 'keep_dims': keep_dims}) diff --git a/model-optimizer/extensions/front/reduce_axis_normalizer.py b/model-optimizer/extensions/front/reduce_axis_normalizer.py index 79208ef..dee2492 100644 --- a/model-optimizer/extensions/front/reduce_axis_normalizer.py +++ b/model-optimizer/extensions/front/reduce_axis_normalizer.py @@ -17,22 +17,21 @@ from extensions.ops.ReduceOps import reduce_map from extensions.ops.range import Range from extensions.ops.rank import Rank +from mo.front.common.partial_infer.utils import int64_array from mo.front.common.replacement import FrontReplacementSubgraph from mo.front.subgraph_matcher import SubgraphMatch +from mo.front.tf.graph_utils import create_op_with_const_inputs from mo.graph.graph import Graph from mo.ops.const import Const class ReduceAxisNormalizer(FrontReplacementSubgraph): """ - Reduce operation requires information about axis, that is represented in original frameworks differently: - - by layer parameter - - by 1-port input value - - ReduceAxisNormalizer reforms Reduce operations to store axis info in 1-port input. + Reduce operation requires information about axis, that is represented in original frameworks differently: as an + operation attribute or as a 1-st input port value. ReduceAxisNormalizer adds second input to Reduce operations with + axes to normalize if axes are specified as an attribute. """ enabled = True - force_shape_inference = True def pattern(self): return dict( @@ -50,23 +49,18 @@ class ReduceAxisNormalizer(FrontReplacementSubgraph): # if the 'axis' is None then we still add a second input to the layer with a 1D array with 1 element equal # to None. The infer function handles this case because the input shape is known at this stage only - if node.has('axis'): + if node.has_valid('axis'): const = Const(graph, {'name': node_name + '/axis', 'value': node.axis}).create_node() node.add_input_port(1, skip_if_exist=True) const.out_port(0).connect(node.in_port(1)) del graph.node[node.id]['axis'] else: # The default (if there is no 'axis') is to reduce over all the dimensions of the input tensor. - - begin_of_range = Const(graph, dict(name=node_name + '/range_begin_', value=0)).create_node() - step = Const(graph, dict(name=node_name + '/range_step_', value=1)).create_node() - end_of_range = Rank(graph, dict(name=node_name + '/range_end_')).create_node() - axes = Range(graph, dict(name=node_name + '/axes_')).create_node() - - begin_of_range.out_port(0).connect(axes.in_port(0)) + axes = create_op_with_const_inputs(graph, Range, {0: int64_array(0), 2: int64_array(1)}, + dict(name=node_name + '/axes')) + end_of_range = Rank(graph, dict(name=node_name + '/range_end')).create_node() + node.in_port(0).get_connection().get_source().connect(end_of_range.in_port(0)) end_of_range.out_port(0).connect(axes.in_port(1)) - step.out_port(0).connect(axes.in_port(2)) node.add_input_port(1, skip_if_exist=True) axes.out_port(0).connect(node.in_port(1)) - node.in_port(0).get_connection().get_source().connect(end_of_range.in_port(0)) diff --git a/model-optimizer/extensions/front/reduce_axis_normalizer_test.py b/model-optimizer/extensions/front/reduce_axis_normalizer_test.py new file mode 100644 index 0000000..896a1d7 --- /dev/null +++ b/model-optimizer/extensions/front/reduce_axis_normalizer_test.py @@ -0,0 +1,76 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest + +import numpy as np + +from extensions.front.reduce_axis_normalizer import ReduceAxisNormalizer +from mo.front.common.partial_infer.utils import int64_array +from mo.utils.ir_engine.compare_graphs import compare_graphs +from mo.utils.unittest.graph import build_graph, result, connect_front, regular_op + +nodes = { + **regular_op('parameter', {'type': 'Parameter'}), + **regular_op('reduce', {'op': 'ReduceSum', 'axis': None}), + **regular_op('axis', {'op': 'Const', 'type': 'Const', 'value': int64_array([1])}), + **result(), +} + +edges = [ + *connect_front('parameter:0', '0:reduce'), + *connect_front('reduce', 'output'), +] + + +class ReduceAxisNormalizerTest(unittest.TestCase): + def test_reduce_axis_is_None(self): + graph = build_graph(nodes, edges, nodes_with_edges_only=True) + graph.stage = 'front' + + ReduceAxisNormalizer().find_and_replace_pattern(graph) + + ref_nodes = nodes.copy() + ref_nodes.update({**regular_op('rank', {'op': 'Rank', 'type': None}), + **regular_op('range', {'op': 'Range', 'type': 'Range'}), + **regular_op('begin', {'type': 'Const', 'value': int64_array([0])}), + **regular_op('step', {'type': 'Const', 'value': int64_array([1])}), + }) + graph_ref = build_graph(ref_nodes, [ + *edges, + *connect_front('parameter:0', 'rank'), + *connect_front('begin:0', '0:range'), + *connect_front('rank:0', '1:range'), + *connect_front('step:0', '2:range'), + *connect_front('range:0', '1:reduce'), + ], nodes_with_edges_only=True) + + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + self.assertTrue(flag, resp) + + def test_reduce_axis_is_const(self): + graph = build_graph(nodes, edges, {'reduce': {'axis': 1}}, nodes_with_edges_only=True) + graph.stage = 'front' + + graph_ref = build_graph(nodes, [ + *edges, + *connect_front('axis', '1:reduce'), + ], {'axis': {'value': np.int64(1)}}, nodes_with_edges_only=True) + + ReduceAxisNormalizer().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + self.assertTrue(flag, resp) -- 2.7.4