from extensions.ops.ReduceOps import reduce_map
from extensions.ops.range import Range
from extensions.ops.rank import Rank
+from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.subgraph_matcher import SubgraphMatch
+from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph
from mo.ops.const import Const
class ReduceAxisNormalizer(FrontReplacementSubgraph):
"""
- Reduce operation requires information about axis, that is represented in original frameworks differently:
- - by layer parameter
- - by 1-port input value
-
- ReduceAxisNormalizer reforms Reduce operations to store axis info in 1-port input.
+ Reduce operation requires information about axis, that is represented in original frameworks differently: as an
+ operation attribute or as a 1-st input port value. ReduceAxisNormalizer adds second input to Reduce operations with
+ axes to normalize if axes are specified as an attribute.
"""
enabled = True
- force_shape_inference = True
def pattern(self):
return dict(
# if the 'axis' is None then we still add a second input to the layer with a 1D array with 1 element equal
# to None. The infer function handles this case because the input shape is known at this stage only
- if node.has('axis'):
+ if node.has_valid('axis'):
const = Const(graph, {'name': node_name + '/axis', 'value': node.axis}).create_node()
node.add_input_port(1, skip_if_exist=True)
const.out_port(0).connect(node.in_port(1))
del graph.node[node.id]['axis']
else:
# The default (if there is no 'axis') is to reduce over all the dimensions of the input tensor.
-
- begin_of_range = Const(graph, dict(name=node_name + '/range_begin_', value=0)).create_node()
- step = Const(graph, dict(name=node_name + '/range_step_', value=1)).create_node()
- end_of_range = Rank(graph, dict(name=node_name + '/range_end_')).create_node()
- axes = Range(graph, dict(name=node_name + '/axes_')).create_node()
-
- begin_of_range.out_port(0).connect(axes.in_port(0))
+ axes = create_op_with_const_inputs(graph, Range, {0: int64_array(0), 2: int64_array(1)},
+ dict(name=node_name + '/axes'))
+ end_of_range = Rank(graph, dict(name=node_name + '/range_end')).create_node()
+ node.in_port(0).get_connection().get_source().connect(end_of_range.in_port(0))
end_of_range.out_port(0).connect(axes.in_port(1))
- step.out_port(0).connect(axes.in_port(2))
node.add_input_port(1, skip_if_exist=True)
axes.out_port(0).connect(node.in_port(1))
- node.in_port(0).get_connection().get_source().connect(end_of_range.in_port(0))
--- /dev/null
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import unittest
+
+import numpy as np
+
+from extensions.front.reduce_axis_normalizer import ReduceAxisNormalizer
+from mo.front.common.partial_infer.utils import int64_array
+from mo.utils.ir_engine.compare_graphs import compare_graphs
+from mo.utils.unittest.graph import build_graph, result, connect_front, regular_op
+
+nodes = {
+ **regular_op('parameter', {'type': 'Parameter'}),
+ **regular_op('reduce', {'op': 'ReduceSum', 'axis': None}),
+ **regular_op('axis', {'op': 'Const', 'type': 'Const', 'value': int64_array([1])}),
+ **result(),
+}
+
+edges = [
+ *connect_front('parameter:0', '0:reduce'),
+ *connect_front('reduce', 'output'),
+]
+
+
+class ReduceAxisNormalizerTest(unittest.TestCase):
+ def test_reduce_axis_is_None(self):
+ graph = build_graph(nodes, edges, nodes_with_edges_only=True)
+ graph.stage = 'front'
+
+ ReduceAxisNormalizer().find_and_replace_pattern(graph)
+
+ ref_nodes = nodes.copy()
+ ref_nodes.update({**regular_op('rank', {'op': 'Rank', 'type': None}),
+ **regular_op('range', {'op': 'Range', 'type': 'Range'}),
+ **regular_op('begin', {'type': 'Const', 'value': int64_array([0])}),
+ **regular_op('step', {'type': 'Const', 'value': int64_array([1])}),
+ })
+ graph_ref = build_graph(ref_nodes, [
+ *edges,
+ *connect_front('parameter:0', 'rank'),
+ *connect_front('begin:0', '0:range'),
+ *connect_front('rank:0', '1:range'),
+ *connect_front('step:0', '2:range'),
+ *connect_front('range:0', '1:reduce'),
+ ], nodes_with_edges_only=True)
+
+ (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
+ self.assertTrue(flag, resp)
+
+ def test_reduce_axis_is_const(self):
+ graph = build_graph(nodes, edges, {'reduce': {'axis': 1}}, nodes_with_edges_only=True)
+ graph.stage = 'front'
+
+ graph_ref = build_graph(nodes, [
+ *edges,
+ *connect_front('axis', '1:reduce'),
+ ], {'axis': {'value': np.int64(1)}}, nodes_with_edges_only=True)
+
+ ReduceAxisNormalizer().find_and_replace_pattern(graph)
+
+ (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
+ self.assertTrue(flag, resp)