* SplitConcatPairToInterpolate transformation was moved to middle stage and is applied only for 4D and 5D inputs.
extensions/front/tf/sparse_to_dense_ext.py
extensions/front/tf/sparse_weighted_sum.py
extensions/front/tf/split_ext.py
-extensions/front/tf/SplitConcatPairToInterpolate.py
extensions/front/tf/ssd_support.json
extensions/front/tf/ssd_support_api_v1.14.json
extensions/front/tf/ssd_support_api_v1.15.json
extensions/middle/SliceLikeToStridedSlice.py
extensions/middle/space_to_depth.py
extensions/middle/sparse_reshape.py
+extensions/middle/SplitConcatPairToInterpolate.py
extensions/middle/ssd_anchors_to_const.py
extensions/middle/SwapAxesMiddleReplacer.py
extensions/middle/TensorIterator_utils.py
+++ /dev/null
-"""
- Copyright (c) 2020 Intel Corporation
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-"""
-
-import logging as log
-from typing import Optional
-
-from extensions.ops.elementwise import Mul
-from extensions.ops.interpolate import Interpolate
-from mo.front.common.partial_infer.utils import int64_array
-from mo.front.common.replacement import FrontReplacementSubgraph
-from mo.graph.graph import Graph, Node
-from mo.ops.const import Const
-from mo.ops.shape import Shape
-from mo.ops.strided_slice import StridedSlice
-
-
-def get_concat_after_split(split: Node) -> Optional[Node]:
- # If number of output nodes of 'split' is not equal to 1, then the transformation is not applicable.
- split_outputs = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()]
- names_of_split_outputs = set([n.name for n in split_outputs])
- if len(names_of_split_outputs) != 1:
- return
-
- groups_of_inputs = [[d.idx for d in p.get_connection().get_destinations()] for _, p in split.out_ports().items()]
- sizes_of_groups = set([len(g) for g in groups_of_inputs])
- # If numbers of consumer ports are various for various output ports of 'split', then the transformation
- # is not applicable.
- if len(sizes_of_groups) != 1:
- return
- # The transformation is applicable iff output port 0 of 'split' goes to ports [0, ..., m-1] of next node,
- # output port 1 of 'split' goes to ports [m, ..., m + (m-1)] of next node, ..., output port i of 'split'
- # goes to ports [i * m, ..., i * m + (m - 1)], and so on.
- flatten_groups = [i for g in groups_of_inputs for i in g]
- if flatten_groups != list(range(0, len(flatten_groups))):
- return
-
- dest = split.out_port(0).get_destinations()[0].node
- # The transformation is applicable, only if next node is Concat.
- return dest if dest.soft_get('type') == 'Concat' else None
-
-
-def get_interpolate_pattern(split: Node) -> dict:
- concat = get_concat_after_split(split)
- if concat is None:
- return {}
- return {'split': split, 'concat': concat}
-
-
-def get_split_scale(split: Node) -> int:
- split_dests = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()]
- num_of_split_dests = len(split_dests)
- num_of_split_out_ports = len(split.out_ports())
- fractional_part = num_of_split_dests / num_of_split_out_ports - num_of_split_dests // num_of_split_out_ports
- assert fractional_part == 0, "Number of output ports of Split must be multiple of number of inputs of Concat"
- return len(split_dests) // len(split.out_ports())
-
-
-def replace_interpolate_pattern(graph: Graph, match: dict):
- split = match['split']
- scale = int64_array([get_split_scale(split)])
- axis = int(split.in_port(1).get_connection().get_source().node.value)
- split_node_name = split.name
-
- shape_node = Shape(graph, dict(name=split_node_name + '/Shape_')).create_node()
- scales_node = Const(graph, dict(name=split_node_name + '/scales_', value=scale)).create_node()
- mul_node = Mul(graph, dict(name=split_node_name + '/Mul_')).create_node()
- scales_node.out_port(0).connect(mul_node.in_port(1))
-
- slice_begin = Const(graph, dict(name=split_node_name + '/slice_begin_',
- value=int64_array([axis]))).create_node()
- slice_end = Const(graph, dict(name=split_node_name + '/slice_end_',
- value=int64_array([axis + 1]))).create_node()
-
- strided_slice_node = StridedSlice(graph,
- {'name': split_node_name + '/StridedSlice_',
- 'begin_mask': int64_array([1]),
- 'end_mask': int64_array([1]),
- 'new_axis_mask': int64_array([0]),
- 'shrink_axis_mask': int64_array([0]),
- 'ellipsis_mask': int64_array([0]),
- }).create_node([shape_node, slice_begin, slice_end])
- strided_slice_node.out_port(0).connect(mul_node.in_port(0))
-
- interp_node = Interpolate(graph, dict(name=split_node_name + '/Interpolate_',
- axes=int64_array([axis]),
- mode='nearest')).create_node()
- mul_node.out_port(0).connect(interp_node.in_port(1))
-
- match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0))
-
- split_connection = split.in_port(0).get_connection()
- split_connection.set_destination(interp_node.in_port(0))
- split_connection.get_source().connect(shape_node.in_port(0))
-
-
-class SplitConcatPairToInterpolate(FrontReplacementSubgraph):
- """
- This transformation looks for Interpolation layer implemented using simple operations, i.e. Split and Concat,
- and replaces found pattern with a sequence of Shape, StridedSlice, Const, Mul, Interpolate.
-
- Found pattern:
- nodes=[
- ('split', dict(kind='op', op='Split')),
- ('concat', dict(kind='op', op='Concat')),
- ],
- edges=[
- ('split', 'concat'),
- ]
-
- Here we assume that
- 1) 'split' is in NDHWC layout and is a 5D-tensor;
- 2) split dimensions for 'split' belongs to {1, 2, 3};
- 3) all outputs of 'split' go to only inputs of 'concat';
- 4) 'concat' takes inputs only from 'split';
- 5) split_dim of 'split' is equal to axis of 'concat'.
-
- Found pattern will be replaced with
- nodes=[
- ('shape', dict(kind='op', op='Shape')),
- ('strided_slice', dict(kind='op', op='StridedSlice')),
- ('scales', dict(kind='op', op='Const')),
- ('scaled_shape', dict(kind='op', op='Mul')),
- ('interp', dict(kind='op', op='Interpolate'))
- ],
- edges=[
- ('shape', 'strided_slice', {'in': 0}),
- ('strided_slice', 'scaled_shape', {'in': 0}),
- ('scales', 'scaled_shape', {'in': 1}),
- ('scaled_shape', 'interp', {'in': 1}),
- ]
-
- Here scaling factor in Interpolate is equal to a quotient of dividing number of input ports of 'concat'
- by number of output ports of 'split'.
- """
- enabled = True
-
- def find_and_replace_pattern(self, graph: Graph):
- log.debug('Enabled replacement of a pair of Split and Concat with Interpolate.')
- splits = graph.get_op_nodes(op='Split')
- patterns = []
-
- for split_node in splits:
- interpolate_pattern = get_interpolate_pattern(split_node)
- if interpolate_pattern:
- patterns.append(interpolate_pattern)
-
- for pattern in patterns:
- replace_interpolate_pattern(graph, pattern)
+++ /dev/null
-"""
- Copyright (c) 2020 Intel Corporation
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-"""
-
-
-import unittest
-
-import numpy as np
-
-from extensions.front.tf.SplitConcatPairToInterpolate import SplitConcatPairToInterpolate
-from mo.front.common.partial_infer.utils import int64_array
-from mo.utils.ir_engine.compare_graphs import compare_graphs
-from mo.utils.unittest.graph import build_graph
-
-graph_node_attrs_for_2d_spatial_case = {
- 'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
- 'placeholder_data': {
- 'value': None,
- 'shape': int64_array([1, 100, 120, 150]),
- 'kind': 'data',
- 'data_type': None
- },
- 'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
- 'split_axis_const': {
- 'kind': 'op',
- 'value': np.array(3, dtype=np.int64),
- 'op': 'Const',
- 'type': 'Const'
- },
- 'split_axis_const_data': {'value': None, 'shape': np.array(3, dtype=np.int64).shape, 'kind': 'data'},
- 'concat': {'type': 'Concat', 'kind': 'op', 'axis': 3},
- 'split_data_0': {'value': None, 'shape': int64_array([1, 100, 120, 50]), 'kind': 'data'},
- 'split_data_1': {'value': None, 'shape': int64_array([1, 100, 120, 50]), 'kind': 'data'},
- 'split_data_2': {'value': None, 'shape': int64_array([1, 100, 120, 50]), 'kind': 'data'},
- 'concat_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
- 'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
- 'abs_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
- 'output': {'kind': 'op', 'op': 'Result'},
- }
-
-
-graph_node_attrs_for_3d_spatial_case = {
- 'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
- 'placeholder_data': {
- 'value': None,
- 'shape': int64_array([1, 3, 100, 120, 150]),
- 'kind': 'data',
- 'data_type': None
- },
- 'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
- 'split_axis_const': {
- 'kind': 'op',
- 'value': np.array(4, dtype=np.int64),
- 'op': 'Const',
- 'type': 'Const'
- },
- 'split_axis_const_data': {'value': None, 'shape': np.array(4, dtype=np.int64).shape, 'kind': 'data'},
- 'concat': {'type': 'Concat', 'kind': 'op', 'axis': 4},
- 'split_data_0': {'value': None, 'shape': int64_array([1, 3, 100, 120, 50]), 'kind': 'data'},
- 'split_data_1': {'value': None, 'shape': int64_array([1, 3, 100, 120, 50]), 'kind': 'data'},
- 'split_data_2': {'value': None, 'shape': int64_array([1, 3, 100, 120, 50]), 'kind': 'data'},
- 'concat_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
- 'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
- 'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
- 'output': {'kind': 'op', 'op': 'Result'},
- }
-
-
-graph_edges = [
- ('placeholder', 'placeholder_data'),
- ('placeholder_data', 'split', {'in': 0}),
- ('split_axis_const', 'split_axis_const_data'),
- ('split_axis_const_data', 'split', {'in': 1}),
- ('split', 'split_data_0', {'out': 0}),
- ('split', 'split_data_1', {'out': 1}),
- ('split', 'split_data_2', {'out': 2}),
- ('split_data_0', 'concat', {'in': 0}),
- ('split_data_0', 'concat', {'in': 1}),
- ('split_data_1', 'concat', {'in': 2}),
- ('split_data_1', 'concat', {'in': 3}),
- ('split_data_2', 'concat', {'in': 4}),
- ('split_data_2', 'concat', {'in': 5}),
- ('concat', 'concat_data'),
- ('concat_data', 'abs'),
- ('abs', 'abs_data'),
- ('abs_data', 'output')
- ]
-
-
-ref_graph_edges = [
- ('placeholder', 'placeholder_data'),
- ('placeholder_data', 'interpolate', {'in': 0}),
- ('placeholder_data', 'shape'),
- ('shape', 'sslice', {'in': 0}),
- ('slice_begin', 'sslice', {'in': 1}),
- ('slice_end', 'sslice', {'in': 2}),
- ('sslice', 'sslice_data'),
- ('scales', 'scales_data'),
- ('sslice_data', 'mul', {'in': 0}),
- ('scales_data', 'mul', {'in': 1}),
- ('mul', 'mul_data'),
- ('mul_data', 'interpolate', {'in': 1}),
- ('interpolate', 'interpolate_data'),
- ('interpolate_data', 'abs'),
- ('abs', 'abs_data'),
- ('abs_data', 'output'),
- ]
-
-
-ref_graph_node_attrs_for_2d_spatial_case_1 = {
- 'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
- 'placeholder_data': {
- 'value': None,
- 'shape': int64_array([1, 100, 120, 150]),
- 'kind': 'data',
- 'data_type': None
- },
- 'interpolate': {
- 'type': 'Interpolate',
- 'kind': 'op',
- 'op': 'Interpolate',
- 'axes': int64_array([3]),
- 'mode': 'nearest'
- },
- 'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
- 'slice_begin': {
- 'type': 'Const',
- 'op': 'Const',
- 'kind': 'op',
- 'value': int64_array([3]),
- 'shape': int64_array([1])
- },
- 'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([4])},
- 'sslice': {
- 'kind': 'op',
- 'type': 'StridedSlice',
- 'op': 'StridedSlice',
- 'begin_mask': int64_array([1]),
- 'end_mask': int64_array([1]),
- 'new_axis_mask': int64_array([0]),
- 'shrink_axis_mask': int64_array([0]),
- 'ellipsis_mask': int64_array([0]),
- },
- 'sslice_data': {'kind': 'data', 'shape': None},
- 'scales': {
- 'type': 'Const',
- 'op': 'Const',
- 'kind': 'op',
- 'value': int64_array([2]),
- 'shape': int64_array([1])
- },
- 'scales_data': {'kind': 'data', 'shape': None},
- 'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
- 'mul_data': {'kind': 'data', 'shape': None},
- 'interpolate_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
- 'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
- 'abs_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
- 'output': {'kind': 'op', 'op': 'Result'},
- }
-
-ref_graph_node_attrs_for_2d_spatial_case_2 = {
- 'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
- 'placeholder_data': {
- 'value': None,
- 'shape': int64_array([1, 100, 120, 150]),
- 'kind': 'data',
- 'data_type': None
- },
- 'interpolate': {
- 'type': 'Interpolate',
- 'kind': 'op',
- 'op': 'Interpolate',
- 'axes': int64_array([2]),
- 'mode': 'nearest'
- },
- 'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
- 'slice_begin': {
- 'type': 'Const',
- 'op': 'Const',
- 'kind': 'op',
- 'value': int64_array([2]),
- 'shape': int64_array([1])
- },
- 'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([3])},
- 'sslice': {
- 'kind': 'op',
- 'type': 'StridedSlice',
- 'op': 'StridedSlice',
- 'begin_mask': int64_array([1]),
- 'end_mask': int64_array([1]),
- 'new_axis_mask': int64_array([0]),
- 'shrink_axis_mask': int64_array([0]),
- 'ellipsis_mask': int64_array([0]),
- },
- 'sslice_data': {'kind': 'data', 'shape': None},
- 'scales': {
- 'type': 'Const',
- 'op': 'Const',
- 'kind': 'op',
- 'value': int64_array([2]),
- 'shape': int64_array([1])
- },
- 'scales_data': {'kind': 'data', 'shape': None},
- 'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
- 'mul_data': {'kind': 'data', 'shape': None},
- 'interpolate_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
- 'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
- 'abs_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
- 'output': {'kind': 'op', 'op': 'Result'},
- }
-
-
-ref_graph_node_attrs_for_3d_spatial_case_1 = {
- 'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
- 'placeholder_data': {
- 'value': None,
- 'shape': int64_array([1, 3, 100, 120, 150]),
- 'kind': 'data',
- 'data_type': None
- },
- 'interpolate': {
- 'type': 'Interpolate',
- 'kind': 'op',
- 'op': 'Interpolate',
- 'axes': int64_array([4]),
- 'mode': 'nearest'
- },
- 'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
- 'slice_begin': {
- 'type': 'Const',
- 'op': 'Const',
- 'kind': 'op',
- 'value': int64_array([4]),
- 'shape': int64_array([1])
- },
- 'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([5])},
- 'sslice': {
- 'kind': 'op',
- 'type': 'StridedSlice',
- 'op': 'StridedSlice',
- 'begin_mask': int64_array([1]),
- 'end_mask': int64_array([1]),
- 'new_axis_mask': int64_array([0]),
- 'shrink_axis_mask': int64_array([0]),
- 'ellipsis_mask': int64_array([0]),
- },
- 'sslice_data': {'kind': 'data', 'shape': None},
- 'scales': {
- 'type': 'Const',
- 'op': 'Const',
- 'kind': 'op',
- 'value': int64_array([2]),
- 'shape': int64_array([1])
- },
- 'scales_data': {'kind': 'data', 'shape': None},
- 'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
- 'mul_data': {'kind': 'data', 'shape': None},
- 'interpolate_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
- 'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
- 'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
- 'output': {'kind': 'op', 'op': 'Result'},
- }
-
-
-ref_graph_node_attrs_for_3d_spatial_case_2 = {
- 'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
- 'placeholder_data': {
- 'value': None,
- 'shape': int64_array([1, 3, 100, 120, 150]),
- 'kind': 'data',
- 'data_type': None
- },
- 'interpolate': {
- 'type': 'Interpolate',
- 'kind': 'op',
- 'op': 'Interpolate',
- 'axes': int64_array([3]),
- 'mode': 'nearest'
- },
- 'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
- 'slice_begin': {
- 'type': 'Const',
- 'op': 'Const',
- 'kind': 'op',
- 'value': int64_array([4]),
- 'shape': int64_array([1])
- },
- 'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([5])},
- 'sslice': {
- 'kind': 'op',
- 'type': 'StridedSlice',
- 'op': 'StridedSlice',
- 'begin_mask': int64_array([1]),
- 'end_mask': int64_array([1]),
- 'new_axis_mask': int64_array([0]),
- 'shrink_axis_mask': int64_array([0]),
- 'ellipsis_mask': int64_array([0]),
- },
- 'sslice_data': {'kind': 'data', 'shape': None},
- 'scales': {
- 'type': 'Const',
- 'op': 'Const',
- 'kind': 'op',
- 'value': int64_array([2]),
- 'shape': int64_array([1])
- },
- 'scales_data': {'kind': 'data', 'shape': None},
- 'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
- 'mul_data': {'kind': 'data', 'shape': None},
- 'interpolate_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
- 'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
- 'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
- 'output': {'kind': 'op', 'op': 'Result'},
- }
-
-
-class SplitConcatPairToInterpolateTest(unittest.TestCase):
- def test_spatial_2d_split_concat_1(self):
- graph = build_graph(
- nodes_attrs=graph_node_attrs_for_2d_spatial_case,
- edges=graph_edges
- )
- ref_graph = build_graph(
- nodes_attrs=ref_graph_node_attrs_for_2d_spatial_case_1,
- edges=ref_graph_edges
- )
- SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
- (flag, resp) = compare_graphs(graph, ref_graph, 'output')
- self.assertTrue(flag, resp)
-
- def test_spatial_2d_split_concat_2(self):
- graph = build_graph(
- nodes_attrs=graph_node_attrs_for_2d_spatial_case,
- edges=graph_edges,
- update_attributes={
- 'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
- 'split_axis_const': {
- 'kind': 'op',
- 'value': np.array(2, dtype=np.int64),
- 'op': 'Const',
- 'type': 'Const'
- },
- 'split_axis_const_data': {'value': None, 'shape': np.array(2, dtype=np.int64).shape, 'kind': 'data'},
- 'concat': {'type': 'Concat', 'kind': 'op', 'axis': 2},
- 'split_data_0': {'value': None, 'shape': int64_array([1, 100, 40, 150]), 'kind': 'data'},
- 'split_data_1': {'value': None, 'shape': int64_array([1, 100, 40, 150]), 'kind': 'data'},
- 'split_data_2': {'value': None, 'shape': int64_array([1, 100, 40, 150]), 'kind': 'data'},
- 'concat_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
- 'abs_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
- }
- )
- ref_graph = build_graph(
- nodes_attrs=ref_graph_node_attrs_for_2d_spatial_case_2,
- edges=ref_graph_edges
- )
- SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
- (flag, resp) = compare_graphs(graph, ref_graph, 'output')
- self.assertTrue(flag, resp)
-
- def test_spatial_3d_split_concat_1(self):
- graph = build_graph(
- nodes_attrs=graph_node_attrs_for_3d_spatial_case,
- edges=graph_edges
- )
- ref_graph = build_graph(
- nodes_attrs=ref_graph_node_attrs_for_3d_spatial_case_1,
- edges=ref_graph_edges
- )
- SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
- (flag, resp) = compare_graphs(graph, ref_graph, 'output')
- self.assertTrue(flag, resp)
-
- def test_spatial_3d_split_concat_2(self):
- graph = build_graph(
- nodes_attrs=graph_node_attrs_for_3d_spatial_case,
- edges=graph_edges,
- update_attributes={
- 'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
- 'split_axis_const': {
- 'kind': 'op',
- 'value': np.array(3, dtype=np.int64),
- 'op': 'Const',
- 'type': 'Const'
- },
- 'split_axis_const_data': {'value': None, 'shape': np.array(3, dtype=np.int64).shape, 'kind': 'data'},
- 'concat': {'type': 'Concat', 'kind': 'op', 'axis': 3},
- 'split_data_0': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
- 'split_data_1': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
- 'split_data_2': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
- 'concat_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
- 'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
- }
- )
- ref_graph = build_graph(
- nodes_attrs=ref_graph_node_attrs_for_3d_spatial_case_2,
- edges=ref_graph_edges
- )
- SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
- (flag, resp) = compare_graphs(graph, ref_graph, 'output')
- self.assertTrue(flag, resp)
--- /dev/null
+"""
+ Copyright (c) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import logging as log
+from typing import Optional
+
+from extensions.ops.elementwise import Mul
+from extensions.ops.interpolate import Interpolate
+from mo.front.common.partial_infer.utils import int64_array
+from mo.front.tf.graph_utils import create_op_with_const_inputs
+from mo.graph.graph import Graph, Node
+from mo.middle.replacement import MiddleReplacementPattern
+from mo.ops.const import Const
+from mo.ops.shape import Shape
+from mo.ops.strided_slice import StridedSlice
+
+
+def get_concat_after_split(split: Node) -> Optional[Node]:
+ # If number of output nodes of 'split' is not equal to 1, then the transformation is not applicable.
+ split_outputs = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()]
+ names_of_split_outputs = set([n.name for n in split_outputs])
+ if len(names_of_split_outputs) != 1:
+ return
+
+ groups_of_inputs = [[d.idx for d in p.get_connection().get_destinations()] for _, p in split.out_ports().items()]
+ sizes_of_groups = set([len(g) for g in groups_of_inputs])
+ # If numbers of consumer ports are various for various output ports of 'split', then the transformation
+ # is not applicable.
+ if len(sizes_of_groups) != 1:
+ return
+ # The transformation is applicable iff output port 0 of 'split' goes to ports [0, ..., m-1] of next node,
+ # output port 1 of 'split' goes to ports [m, ..., m + (m-1)] of next node, ..., output port i of 'split'
+ # goes to ports [i * m, ..., i * m + (m - 1)], and so on.
+ flatten_groups = [i for g in groups_of_inputs for i in g]
+ if flatten_groups != list(range(0, len(flatten_groups))):
+ return
+
+ dest = split.out_port(0).get_destinations()[0].node
+ # The transformation is applicable, only if next node is Concat.
+ return dest if dest.soft_get('type') == 'Concat' else None
+
+
+def get_interpolate_pattern(split: Node) -> dict:
+ split_shape = split.in_port(0).data.get_shape()
+ if len(split_shape) not in {4, 5}:
+ return {}
+ concat = get_concat_after_split(split)
+ if concat is None:
+ return {}
+ return {'split': split, 'concat': concat}
+
+
+def get_split_scale(split: Node) -> int:
+ split_dests = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()]
+ num_of_split_dests = len(split_dests)
+ num_of_split_out_ports = len(split.out_ports())
+ fractional_part = num_of_split_dests / num_of_split_out_ports - num_of_split_dests // num_of_split_out_ports
+ assert fractional_part == 0, "Number of output ports of Split must be multiple of number of inputs of Concat"
+ return len(split_dests) // len(split.out_ports())
+
+
+def replace_interpolate_pattern(graph: Graph, match: dict):
+ split = match['split']
+ scale = int64_array([get_split_scale(split)])
+ axis = int(split.in_port(1).get_connection().get_source().node.value)
+ split_node_name = split.name
+
+ shape_node = Shape(graph, dict(name=split_node_name + '/Shape_')).create_node()
+ scales_node = Const(graph, dict(name=split_node_name + '/scales_', value=scale)).create_node()
+ mul_node = Mul(graph, dict(name=split_node_name + '/Mul_')).create_node()
+ scales_node.out_port(0).connect(mul_node.in_port(1))
+
+ strided_slice_node = create_op_with_const_inputs(graph,
+ StridedSlice,
+ {1: int64_array([axis]), 2: int64_array([axis + 1])},
+ {
+ 'name': split_node_name + '/StridedSlice_',
+ 'begin_mask': int64_array([1]),
+ 'end_mask': int64_array([1]),
+ 'new_axis_mask': int64_array([0]),
+ 'shrink_axis_mask': int64_array([0]),
+ 'ellipsis_mask': int64_array([0])
+ })
+ shape_node.out_port(0).connect(strided_slice_node.in_port(0))
+
+ strided_slice_node.out_port(0).connect(mul_node.in_port(0))
+
+ interp_node = Interpolate(graph, dict(name=split_node_name + '/Interpolate_',
+ axes=int64_array([axis]),
+ mode='nearest')).create_node()
+ mul_node.out_port(0).connect(interp_node.in_port(1))
+
+ match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0))
+
+ split_connection = split.in_port(0).get_connection()
+ split_connection.set_destination(interp_node.in_port(0))
+ split_connection.get_source().connect(shape_node.in_port(0))
+
+
+class SplitConcatPairToInterpolate(MiddleReplacementPattern):
+ """
+ This transformation looks for Interpolation layer implemented using simple operations, i.e. Split and Concat,
+ and replaces found pattern with a sequence of Shape, StridedSlice, Const, Mul, Interpolate.
+
+ Found pattern:
+ nodes=[
+ ('split', dict(kind='op', op='Split')),
+ ('concat', dict(kind='op', op='Concat')),
+ ],
+ edges=[
+ ('split', 'concat'),
+ ]
+
+ Here we assume that
+ 1) 'split' is in NDHWC layout and is a 5D-tensor;
+ 2) split dimensions for 'split' belongs to {1, 2, 3};
+ 3) all outputs of 'split' go to only inputs of 'concat';
+ 4) 'concat' takes inputs only from 'split';
+ 5) split_dim of 'split' is equal to axis of 'concat'.
+
+ Found pattern will be replaced with
+ nodes=[
+ ('shape', dict(kind='op', op='Shape')),
+ ('strided_slice', dict(kind='op', op='StridedSlice')),
+ ('scales', dict(kind='op', op='Const')),
+ ('scaled_shape', dict(kind='op', op='Mul')),
+ ('interp', dict(kind='op', op='Interpolate'))
+ ],
+ edges=[
+ ('shape', 'strided_slice', {'in': 0}),
+ ('strided_slice', 'scaled_shape', {'in': 0}),
+ ('scales', 'scaled_shape', {'in': 1}),
+ ('scaled_shape', 'interp', {'in': 1}),
+ ]
+
+ Here scaling factor in Interpolate is equal to a quotient of dividing number of input ports of 'concat'
+ by number of output ports of 'split'.
+ """
+ enabled = True
+ force_clean_up = True
+
+ def run_before(self):
+ from extensions.middle.InterpolateSequenceToInterpolate import InterpolateSequenceToInterpolate
+ return [InterpolateSequenceToInterpolate]
+
+ def find_and_replace_pattern(self, graph: Graph):
+ log.debug('Enabled replacement of a pair of Split and Concat with Interpolate.')
+ splits = graph.get_op_nodes(op='Split')
+ patterns = []
+
+ for split_node in splits:
+ interpolate_pattern = get_interpolate_pattern(split_node)
+ if interpolate_pattern:
+ patterns.append(interpolate_pattern)
+
+ for pattern in patterns:
+ replace_interpolate_pattern(graph, pattern)
--- /dev/null
+"""
+ Copyright (c) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+
+import unittest
+
+import numpy as np
+
+from extensions.middle.SplitConcatPairToInterpolate import SplitConcatPairToInterpolate
+from mo.front.common.partial_infer.utils import int64_array
+from mo.utils.ir_engine.compare_graphs import compare_graphs
+from mo.utils.unittest.graph import build_graph
+
+graph_node_attrs_for_2d_spatial_case = {
+ 'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+ 'placeholder_data': {
+ 'value': None,
+ 'shape': int64_array([1, 100, 120, 150]),
+ 'kind': 'data',
+ 'data_type': None
+ },
+ 'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
+ 'split_axis_const': {
+ 'kind': 'op',
+ 'value': np.array(3, dtype=np.int64),
+ 'op': 'Const',
+ 'type': 'Const'
+ },
+ 'split_axis_const_data': {'value': None, 'shape': np.array(3, dtype=np.int64).shape, 'kind': 'data'},
+ 'concat': {'type': 'Concat', 'kind': 'op', 'axis': 3},
+ 'split_data_0': {'value': None, 'shape': int64_array([1, 100, 120, 50]), 'kind': 'data'},
+ 'split_data_1': {'value': None, 'shape': int64_array([1, 100, 120, 50]), 'kind': 'data'},
+ 'split_data_2': {'value': None, 'shape': int64_array([1, 100, 120, 50]), 'kind': 'data'},
+ 'concat_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
+ 'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
+ 'abs_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
+ 'output': {'kind': 'op', 'op': 'Result'},
+}
+
+
+graph_node_attrs_for_3d_spatial_case = {
+ 'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+ 'placeholder_data': {
+ 'value': None,
+ 'shape': int64_array([1, 3, 100, 120, 150]),
+ 'kind': 'data',
+ 'data_type': None
+ },
+ 'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
+ 'split_axis_const': {
+ 'kind': 'op',
+ 'value': np.array(4, dtype=np.int64),
+ 'op': 'Const',
+ 'type': 'Const'
+ },
+ 'split_axis_const_data': {'value': None, 'shape': np.array(4, dtype=np.int64).shape, 'kind': 'data'},
+ 'concat': {'type': 'Concat', 'kind': 'op', 'axis': 4},
+ 'split_data_0': {'value': None, 'shape': int64_array([1, 3, 100, 120, 50]), 'kind': 'data'},
+ 'split_data_1': {'value': None, 'shape': int64_array([1, 3, 100, 120, 50]), 'kind': 'data'},
+ 'split_data_2': {'value': None, 'shape': int64_array([1, 3, 100, 120, 50]), 'kind': 'data'},
+ 'concat_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
+ 'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
+ 'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
+ 'output': {'kind': 'op', 'op': 'Result'},
+ }
+
+
+graph_edges = [
+ ('placeholder', 'placeholder_data'),
+ ('placeholder_data', 'split', {'in': 0}),
+ ('split_axis_const', 'split_axis_const_data'),
+ ('split_axis_const_data', 'split', {'in': 1}),
+ ('split', 'split_data_0', {'out': 0}),
+ ('split', 'split_data_1', {'out': 1}),
+ ('split', 'split_data_2', {'out': 2}),
+ ('split_data_0', 'concat', {'in': 0}),
+ ('split_data_0', 'concat', {'in': 1}),
+ ('split_data_1', 'concat', {'in': 2}),
+ ('split_data_1', 'concat', {'in': 3}),
+ ('split_data_2', 'concat', {'in': 4}),
+ ('split_data_2', 'concat', {'in': 5}),
+ ('concat', 'concat_data'),
+ ('concat_data', 'abs'),
+ ('abs', 'abs_data'),
+ ('abs_data', 'output')
+]
+
+
+ref_graph_edges = [
+ ('placeholder', 'placeholder_data'),
+ ('placeholder_data', 'interpolate', {'in': 0}),
+ ('placeholder_data', 'shape'),
+ ('shape', 'shape_data'),
+ ('shape_data', 'sslice', {'in': 0}),
+ ('slice_begin', 'slice_begin_data'),
+ ('slice_begin_data', 'sslice', {'in': 1}),
+ ('slice_end', 'slice_end_data'),
+ ('slice_end_data', 'sslice', {'in': 2}),
+ ('sslice', 'sslice_data'),
+ ('scales', 'scales_data'),
+ ('sslice_data', 'mul', {'in': 0}),
+ ('scales_data', 'mul', {'in': 1}),
+ ('mul', 'mul_data'),
+ ('mul_data', 'interpolate', {'in': 1}),
+ ('interpolate', 'interpolate_data'),
+ ('interpolate_data', 'abs'),
+ ('abs', 'abs_data'),
+ ('abs_data', 'output'),
+ ]
+
+
+ref_graph_node_attrs_for_2d_spatial_case_1 = {
+ 'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+ 'placeholder_data': {
+ 'value': None,
+ 'shape': int64_array([1, 100, 120, 150]),
+ 'kind': 'data',
+ 'data_type': None
+ },
+ 'interpolate': {
+ 'type': 'Interpolate',
+ 'kind': 'op',
+ 'op': 'Interpolate',
+ 'axes': int64_array([3]),
+ 'mode': 'nearest'
+ },
+ 'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
+ 'shape_data': {'kind': 'data', 'shape': None, 'value': None},
+ 'slice_begin': {
+ 'type': 'Const',
+ 'op': 'Const',
+ 'kind': 'op',
+ 'value': int64_array([3]),
+ 'shape': int64_array([1])
+ },
+ 'slice_begin_data': {'kind': 'data', 'shape': None, 'value': None},
+ 'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([4])},
+ 'slice_end_data': {'kind': 'data', 'shape': None, 'value': None},
+ 'sslice': {
+ 'kind': 'op',
+ 'type': 'StridedSlice',
+ 'op': 'StridedSlice',
+ 'begin_mask': int64_array([1]),
+ 'end_mask': int64_array([1]),
+ 'new_axis_mask': int64_array([0]),
+ 'shrink_axis_mask': int64_array([0]),
+ 'ellipsis_mask': int64_array([0]),
+ },
+ 'sslice_data': {'kind': 'data', 'shape': None},
+ 'scales': {
+ 'type': 'Const',
+ 'op': 'Const',
+ 'kind': 'op',
+ 'value': int64_array([2]),
+ 'shape': int64_array([1])
+ },
+ 'scales_data': {'kind': 'data', 'shape': None},
+ 'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
+ 'mul_data': {'kind': 'data', 'shape': None},
+ 'interpolate_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
+ 'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
+ 'abs_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
+ 'output': {'kind': 'op', 'op': 'Result'},
+}
+
+ref_graph_node_attrs_for_2d_spatial_case_2 = {
+ 'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+ 'placeholder_data': {
+ 'value': None,
+ 'shape': int64_array([1, 100, 120, 150]),
+ 'kind': 'data',
+ 'data_type': None
+ },
+ 'interpolate': {
+ 'type': 'Interpolate',
+ 'kind': 'op',
+ 'op': 'Interpolate',
+ 'axes': int64_array([2]),
+ 'mode': 'nearest'
+ },
+ 'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
+ 'shape_data': {'kind': 'data', 'shape': None, 'value': None},
+ 'slice_begin': {
+ 'type': 'Const',
+ 'op': 'Const',
+ 'kind': 'op',
+ 'value': int64_array([2]),
+ 'shape': int64_array([1])
+ },
+ 'slice_begin_data': {'kind': 'data', 'shape': None, 'value': None},
+ 'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([3])},
+ 'slice_end_data': {'kind': 'data', 'shape': None, 'value': None},
+ 'sslice': {
+ 'kind': 'op',
+ 'type': 'StridedSlice',
+ 'op': 'StridedSlice',
+ 'begin_mask': int64_array([1]),
+ 'end_mask': int64_array([1]),
+ 'new_axis_mask': int64_array([0]),
+ 'shrink_axis_mask': int64_array([0]),
+ 'ellipsis_mask': int64_array([0]),
+ },
+ 'sslice_data': {'kind': 'data', 'shape': None},
+ 'scales': {
+ 'type': 'Const',
+ 'op': 'Const',
+ 'kind': 'op',
+ 'value': int64_array([2]),
+ 'shape': int64_array([1])
+ },
+ 'scales_data': {'kind': 'data', 'shape': None},
+ 'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
+ 'mul_data': {'kind': 'data', 'shape': None},
+ 'interpolate_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
+ 'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
+ 'abs_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
+ 'output': {'kind': 'op', 'op': 'Result'},
+}
+
+
+ref_graph_node_attrs_for_3d_spatial_case_1 = {
+ 'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+ 'placeholder_data': {
+ 'value': None,
+ 'shape': int64_array([1, 3, 100, 120, 150]),
+ 'kind': 'data',
+ 'data_type': None
+ },
+ 'interpolate': {
+ 'type': 'Interpolate',
+ 'kind': 'op',
+ 'op': 'Interpolate',
+ 'axes': int64_array([4]),
+ 'mode': 'nearest'
+ },
+ 'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
+ 'shape_data': {'kind': 'data', 'shape': None, 'value': None},
+ 'slice_begin': {
+ 'type': 'Const',
+ 'op': 'Const',
+ 'kind': 'op',
+ 'value': int64_array([4]),
+ 'shape': int64_array([1])
+ },
+ 'slice_begin_data': {'kind': 'data', 'shape': None, 'value': None},
+ 'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([5])},
+ 'slice_end_data': {'kind': 'data', 'shape': None, 'value': None},
+ 'sslice': {
+ 'kind': 'op',
+ 'type': 'StridedSlice',
+ 'op': 'StridedSlice',
+ 'begin_mask': int64_array([1]),
+ 'end_mask': int64_array([1]),
+ 'new_axis_mask': int64_array([0]),
+ 'shrink_axis_mask': int64_array([0]),
+ 'ellipsis_mask': int64_array([0]),
+ },
+ 'sslice_data': {'kind': 'data', 'shape': None},
+ 'scales': {
+ 'type': 'Const',
+ 'op': 'Const',
+ 'kind': 'op',
+ 'value': int64_array([2]),
+ 'shape': int64_array([1])
+ },
+ 'scales_data': {'kind': 'data', 'shape': None},
+ 'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
+ 'mul_data': {'kind': 'data', 'shape': None},
+ 'interpolate_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
+ 'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
+ 'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
+ 'output': {'kind': 'op', 'op': 'Result'},
+}
+
+
+ref_graph_node_attrs_for_3d_spatial_case_2 = {
+ 'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+ 'placeholder_data': {
+ 'value': None,
+ 'shape': int64_array([1, 3, 100, 120, 150]),
+ 'kind': 'data',
+ 'data_type': None
+ },
+ 'interpolate': {
+ 'type': 'Interpolate',
+ 'kind': 'op',
+ 'op': 'Interpolate',
+ 'axes': int64_array([3]),
+ 'mode': 'nearest'
+ },
+ 'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
+ 'shape_data': {'kind': 'data', 'shape': None, 'value': None},
+ 'slice_begin': {
+ 'type': 'Const',
+ 'op': 'Const',
+ 'kind': 'op',
+ 'value': int64_array([4]),
+ 'shape': int64_array([1])
+ },
+ 'slice_begin_data': {'kind': 'data', 'shape': None, 'value': None},
+ 'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([5])},
+ 'slice_end_data': {'kind': 'data', 'shape': None, 'value': None},
+ 'sslice': {
+ 'kind': 'op',
+ 'type': 'StridedSlice',
+ 'op': 'StridedSlice',
+ 'begin_mask': int64_array([1]),
+ 'end_mask': int64_array([1]),
+ 'new_axis_mask': int64_array([0]),
+ 'shrink_axis_mask': int64_array([0]),
+ 'ellipsis_mask': int64_array([0]),
+ },
+ 'sslice_data': {'kind': 'data', 'shape': None},
+ 'scales': {
+ 'type': 'Const',
+ 'op': 'Const',
+ 'kind': 'op',
+ 'value': int64_array([2]),
+ 'shape': int64_array([1])
+ },
+ 'scales_data': {'kind': 'data', 'shape': None},
+ 'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
+ 'mul_data': {'kind': 'data', 'shape': None},
+ 'interpolate_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
+ 'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
+ 'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
+ 'output': {'kind': 'op', 'op': 'Result'},
+}
+
+
+class SplitConcatPairToInterpolateTest(unittest.TestCase):
+ def test_spatial_2d_split_concat_1(self):
+ graph = build_graph(
+ nodes_attrs=graph_node_attrs_for_2d_spatial_case,
+ edges=graph_edges
+ )
+ ref_graph = build_graph(
+ nodes_attrs=ref_graph_node_attrs_for_2d_spatial_case_1,
+ edges=ref_graph_edges
+ )
+ SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
+ (flag, resp) = compare_graphs(graph, ref_graph, 'output')
+ self.assertTrue(flag, resp)
+
+ def test_spatial_2d_split_concat_2(self):
+ graph = build_graph(
+ nodes_attrs=graph_node_attrs_for_2d_spatial_case,
+ edges=graph_edges,
+ update_attributes={
+ 'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
+ 'split_axis_const': {
+ 'kind': 'op',
+ 'value': np.array(2, dtype=np.int64),
+ 'op': 'Const',
+ 'type': 'Const'
+ },
+ 'split_axis_const_data': {'value': None, 'shape': np.array(2, dtype=np.int64).shape, 'kind': 'data'},
+ 'concat': {'type': 'Concat', 'kind': 'op', 'axis': 2},
+ 'split_data_0': {'value': None, 'shape': int64_array([1, 100, 40, 150]), 'kind': 'data'},
+ 'split_data_1': {'value': None, 'shape': int64_array([1, 100, 40, 150]), 'kind': 'data'},
+ 'split_data_2': {'value': None, 'shape': int64_array([1, 100, 40, 150]), 'kind': 'data'},
+ 'concat_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
+ 'abs_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
+ }
+ )
+ ref_graph = build_graph(
+ nodes_attrs=ref_graph_node_attrs_for_2d_spatial_case_2,
+ edges=ref_graph_edges
+ )
+ SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
+ (flag, resp) = compare_graphs(graph, ref_graph, 'output')
+ self.assertTrue(flag, resp)
+
+ def test_spatial_3d_split_concat_1(self):
+ graph = build_graph(
+ nodes_attrs=graph_node_attrs_for_3d_spatial_case,
+ edges=graph_edges
+ )
+ ref_graph = build_graph(
+ nodes_attrs=ref_graph_node_attrs_for_3d_spatial_case_1,
+ edges=ref_graph_edges
+ )
+ SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
+ (flag, resp) = compare_graphs(graph, ref_graph, 'output')
+ self.assertTrue(flag, resp)
+
+ def test_spatial_3d_split_concat_2(self):
+ graph = build_graph(
+ nodes_attrs=graph_node_attrs_for_3d_spatial_case,
+ edges=graph_edges,
+ update_attributes={
+ 'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
+ 'split_axis_const': {
+ 'kind': 'op',
+ 'value': np.array(3, dtype=np.int64),
+ 'op': 'Const',
+ 'type': 'Const'
+ },
+ 'split_axis_const_data': {'value': None, 'shape': np.array(3, dtype=np.int64).shape, 'kind': 'data'},
+ 'concat': {'type': 'Concat', 'kind': 'op', 'axis': 3},
+ 'split_data_0': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
+ 'split_data_1': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
+ 'split_data_2': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
+ 'concat_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
+ 'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
+ }
+ )
+ ref_graph = build_graph(
+ nodes_attrs=ref_graph_node_attrs_for_3d_spatial_case_2,
+ edges=ref_graph_edges
+ )
+ SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
+ (flag, resp) = compare_graphs(graph, ref_graph, 'output')
+ self.assertTrue(flag, resp)