SplitConcatPairToInterpolate inserts Interpolate when input is 2D (#596)
authorVladimir Gavrilov <vladimir.gavrilov@intel.com>
Thu, 28 May 2020 15:08:24 +0000 (18:08 +0300)
committerGitHub <noreply@github.com>
Thu, 28 May 2020 15:08:24 +0000 (18:08 +0300)
* SplitConcatPairToInterpolate transformation was moved to middle stage and is applied only for 4D and 5D inputs.

model-optimizer/automation/package_BOM.txt
model-optimizer/extensions/front/tf/SplitConcatPairToInterpolate.py [deleted file]
model-optimizer/extensions/front/tf/SplitConcatPairToInterpolate_test.py [deleted file]
model-optimizer/extensions/middle/SplitConcatPairToInterpolate.py [new file with mode: 0644]
model-optimizer/extensions/middle/SplitConcatPairToInterpolate_test.py [new file with mode: 0644]

index 5b5143859094e3a62813a5807327040aed8d4190..e4da215ad3cda9f8b18bd7c6bfc7185d86f1696a 100644 (file)
@@ -436,7 +436,6 @@ extensions/front/tf/sparse_segment_sum_ext.py
 extensions/front/tf/sparse_to_dense_ext.py
 extensions/front/tf/sparse_weighted_sum.py
 extensions/front/tf/split_ext.py
-extensions/front/tf/SplitConcatPairToInterpolate.py
 extensions/front/tf/ssd_support.json
 extensions/front/tf/ssd_support_api_v1.14.json
 extensions/front/tf/ssd_support_api_v1.15.json
@@ -568,6 +567,7 @@ extensions/middle/SliceConverter.py
 extensions/middle/SliceLikeToStridedSlice.py
 extensions/middle/space_to_depth.py
 extensions/middle/sparse_reshape.py
+extensions/middle/SplitConcatPairToInterpolate.py
 extensions/middle/ssd_anchors_to_const.py
 extensions/middle/SwapAxesMiddleReplacer.py
 extensions/middle/TensorIterator_utils.py
diff --git a/model-optimizer/extensions/front/tf/SplitConcatPairToInterpolate.py b/model-optimizer/extensions/front/tf/SplitConcatPairToInterpolate.py
deleted file mode 100644 (file)
index f46bb92..0000000
+++ /dev/null
@@ -1,161 +0,0 @@
-"""
- Copyright (c) 2020 Intel Corporation
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
-      http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-"""
-
-import logging as log
-from typing import Optional
-
-from extensions.ops.elementwise import Mul
-from extensions.ops.interpolate import Interpolate
-from mo.front.common.partial_infer.utils import int64_array
-from mo.front.common.replacement import FrontReplacementSubgraph
-from mo.graph.graph import Graph, Node
-from mo.ops.const import Const
-from mo.ops.shape import Shape
-from mo.ops.strided_slice import StridedSlice
-
-
-def get_concat_after_split(split: Node) -> Optional[Node]:
-    # If number of output nodes of 'split' is not equal to 1, then the transformation is not applicable.
-    split_outputs = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()]
-    names_of_split_outputs = set([n.name for n in split_outputs])
-    if len(names_of_split_outputs) != 1:
-        return
-
-    groups_of_inputs = [[d.idx for d in p.get_connection().get_destinations()] for _, p in split.out_ports().items()]
-    sizes_of_groups = set([len(g) for g in groups_of_inputs])
-    # If numbers of consumer ports are various for various output ports of 'split', then the transformation
-    # is not applicable.
-    if len(sizes_of_groups) != 1:
-        return
-    # The transformation is applicable iff output port 0 of 'split' goes to ports [0, ..., m-1] of next node,
-    # output port 1 of 'split' goes to ports [m, ..., m + (m-1)] of next node, ..., output port i of 'split'
-    # goes to ports [i * m, ..., i * m + (m - 1)], and so on.
-    flatten_groups = [i for g in groups_of_inputs for i in g]
-    if flatten_groups != list(range(0, len(flatten_groups))):
-        return
-
-    dest = split.out_port(0).get_destinations()[0].node
-    # The transformation is applicable, only if next node is Concat.
-    return dest if dest.soft_get('type') == 'Concat' else None
-
-
-def get_interpolate_pattern(split: Node) -> dict:
-    concat = get_concat_after_split(split)
-    if concat is None:
-        return {}
-    return {'split': split, 'concat': concat}
-
-
-def get_split_scale(split: Node) -> int:
-    split_dests = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()]
-    num_of_split_dests = len(split_dests)
-    num_of_split_out_ports = len(split.out_ports())
-    fractional_part = num_of_split_dests / num_of_split_out_ports - num_of_split_dests // num_of_split_out_ports
-    assert fractional_part == 0, "Number of output ports of Split must be multiple of number of inputs of Concat"
-    return len(split_dests) // len(split.out_ports())
-
-
-def replace_interpolate_pattern(graph: Graph, match: dict):
-    split = match['split']
-    scale = int64_array([get_split_scale(split)])
-    axis = int(split.in_port(1).get_connection().get_source().node.value)
-    split_node_name = split.name
-
-    shape_node = Shape(graph, dict(name=split_node_name + '/Shape_')).create_node()
-    scales_node = Const(graph, dict(name=split_node_name + '/scales_', value=scale)).create_node()
-    mul_node = Mul(graph, dict(name=split_node_name + '/Mul_')).create_node()
-    scales_node.out_port(0).connect(mul_node.in_port(1))
-
-    slice_begin = Const(graph, dict(name=split_node_name + '/slice_begin_',
-                                    value=int64_array([axis]))).create_node()
-    slice_end = Const(graph, dict(name=split_node_name + '/slice_end_',
-                                  value=int64_array([axis + 1]))).create_node()
-
-    strided_slice_node = StridedSlice(graph,
-                                      {'name': split_node_name + '/StridedSlice_',
-                                       'begin_mask': int64_array([1]),
-                                       'end_mask': int64_array([1]),
-                                       'new_axis_mask': int64_array([0]),
-                                       'shrink_axis_mask': int64_array([0]),
-                                       'ellipsis_mask': int64_array([0]),
-                                       }).create_node([shape_node, slice_begin, slice_end])
-    strided_slice_node.out_port(0).connect(mul_node.in_port(0))
-
-    interp_node = Interpolate(graph, dict(name=split_node_name + '/Interpolate_',
-                                          axes=int64_array([axis]),
-                                          mode='nearest')).create_node()
-    mul_node.out_port(0).connect(interp_node.in_port(1))
-
-    match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0))
-
-    split_connection = split.in_port(0).get_connection()
-    split_connection.set_destination(interp_node.in_port(0))
-    split_connection.get_source().connect(shape_node.in_port(0))
-
-
-class SplitConcatPairToInterpolate(FrontReplacementSubgraph):
-    """
-    This transformation looks for Interpolation layer implemented using simple operations, i.e. Split and Concat,
-    and replaces found pattern with a sequence of Shape, StridedSlice, Const, Mul, Interpolate.
-
-    Found pattern:
-        nodes=[
-            ('split', dict(kind='op', op='Split')),
-            ('concat', dict(kind='op', op='Concat')),
-        ],
-        edges=[
-            ('split', 'concat'),
-        ]
-
-    Here we assume that
-        1) 'split' is in NDHWC layout and is a 5D-tensor;
-        2) split dimensions for 'split' belongs to {1, 2, 3};
-        3) all outputs of 'split' go to only inputs of 'concat';
-        4) 'concat' takes inputs only from 'split';
-        5) split_dim of 'split' is equal to axis of 'concat'.
-
-    Found pattern will be replaced with
-        nodes=[
-            ('shape', dict(kind='op', op='Shape')),
-            ('strided_slice', dict(kind='op', op='StridedSlice')),
-            ('scales', dict(kind='op', op='Const')),
-            ('scaled_shape', dict(kind='op', op='Mul')),
-            ('interp', dict(kind='op', op='Interpolate'))
-        ],
-        edges=[
-            ('shape', 'strided_slice', {'in': 0}),
-            ('strided_slice', 'scaled_shape', {'in': 0}),
-            ('scales', 'scaled_shape', {'in': 1}),
-            ('scaled_shape', 'interp', {'in': 1}),
-        ]
-
-    Here scaling factor in Interpolate is equal to a quotient of dividing number of input ports of 'concat'
-    by number of output ports of 'split'.
-    """
-    enabled = True
-
-    def find_and_replace_pattern(self, graph: Graph):
-        log.debug('Enabled replacement of a pair of Split and Concat with Interpolate.')
-        splits = graph.get_op_nodes(op='Split')
-        patterns = []
-
-        for split_node in splits:
-            interpolate_pattern = get_interpolate_pattern(split_node)
-            if interpolate_pattern:
-                patterns.append(interpolate_pattern)
-
-        for pattern in patterns:
-            replace_interpolate_pattern(graph, pattern)
diff --git a/model-optimizer/extensions/front/tf/SplitConcatPairToInterpolate_test.py b/model-optimizer/extensions/front/tf/SplitConcatPairToInterpolate_test.py
deleted file mode 100644 (file)
index 6eb9e5f..0000000
+++ /dev/null
@@ -1,412 +0,0 @@
-"""
- Copyright (c) 2020 Intel Corporation
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
-      http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-"""
-
-
-import unittest
-
-import numpy as np
-
-from extensions.front.tf.SplitConcatPairToInterpolate import SplitConcatPairToInterpolate
-from mo.front.common.partial_infer.utils import int64_array
-from mo.utils.ir_engine.compare_graphs import compare_graphs
-from mo.utils.unittest.graph import build_graph
-
-graph_node_attrs_for_2d_spatial_case = {
-        'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
-        'placeholder_data': {
-            'value': None,
-            'shape': int64_array([1, 100, 120, 150]),
-            'kind': 'data',
-            'data_type': None
-        },
-        'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
-        'split_axis_const': {
-            'kind': 'op',
-            'value': np.array(3, dtype=np.int64),
-            'op': 'Const',
-            'type': 'Const'
-        },
-        'split_axis_const_data': {'value': None, 'shape': np.array(3, dtype=np.int64).shape, 'kind': 'data'},
-        'concat': {'type': 'Concat', 'kind': 'op', 'axis': 3},
-        'split_data_0': {'value': None, 'shape': int64_array([1, 100, 120, 50]), 'kind': 'data'},
-        'split_data_1': {'value': None, 'shape': int64_array([1, 100, 120, 50]), 'kind': 'data'},
-        'split_data_2': {'value': None, 'shape': int64_array([1, 100, 120, 50]), 'kind': 'data'},
-        'concat_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
-        'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
-        'abs_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
-        'output': {'kind': 'op', 'op': 'Result'},
-    }
-
-
-graph_node_attrs_for_3d_spatial_case = {
-        'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
-        'placeholder_data': {
-            'value': None,
-            'shape': int64_array([1, 3, 100, 120, 150]),
-            'kind': 'data',
-            'data_type': None
-        },
-        'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
-        'split_axis_const': {
-            'kind': 'op',
-            'value': np.array(4, dtype=np.int64),
-            'op': 'Const',
-            'type': 'Const'
-        },
-        'split_axis_const_data': {'value': None, 'shape': np.array(4, dtype=np.int64).shape, 'kind': 'data'},
-        'concat': {'type': 'Concat', 'kind': 'op', 'axis': 4},
-        'split_data_0': {'value': None, 'shape': int64_array([1, 3, 100, 120, 50]), 'kind': 'data'},
-        'split_data_1': {'value': None, 'shape': int64_array([1, 3, 100, 120, 50]), 'kind': 'data'},
-        'split_data_2': {'value': None, 'shape': int64_array([1, 3, 100, 120, 50]), 'kind': 'data'},
-        'concat_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
-        'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
-        'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
-        'output': {'kind': 'op', 'op': 'Result'},
-    }
-
-
-graph_edges = [
-        ('placeholder', 'placeholder_data'),
-        ('placeholder_data', 'split', {'in': 0}),
-        ('split_axis_const', 'split_axis_const_data'),
-        ('split_axis_const_data', 'split', {'in': 1}),
-        ('split', 'split_data_0', {'out': 0}),
-        ('split', 'split_data_1', {'out': 1}),
-        ('split', 'split_data_2', {'out': 2}),
-        ('split_data_0', 'concat', {'in': 0}),
-        ('split_data_0', 'concat', {'in': 1}),
-        ('split_data_1', 'concat', {'in': 2}),
-        ('split_data_1', 'concat', {'in': 3}),
-        ('split_data_2', 'concat', {'in': 4}),
-        ('split_data_2', 'concat', {'in': 5}),
-        ('concat', 'concat_data'),
-        ('concat_data', 'abs'),
-        ('abs', 'abs_data'),
-        ('abs_data', 'output')
-    ]
-
-
-ref_graph_edges = [
-        ('placeholder', 'placeholder_data'),
-        ('placeholder_data', 'interpolate', {'in': 0}),
-        ('placeholder_data', 'shape'),
-        ('shape', 'sslice', {'in': 0}),
-        ('slice_begin', 'sslice', {'in': 1}),
-        ('slice_end', 'sslice', {'in': 2}),
-        ('sslice', 'sslice_data'),
-        ('scales', 'scales_data'),
-        ('sslice_data', 'mul', {'in': 0}),
-        ('scales_data', 'mul', {'in': 1}),
-        ('mul', 'mul_data'),
-        ('mul_data', 'interpolate', {'in': 1}),
-        ('interpolate', 'interpolate_data'),
-        ('interpolate_data', 'abs'),
-        ('abs', 'abs_data'),
-        ('abs_data', 'output'),
-    ]
-
-
-ref_graph_node_attrs_for_2d_spatial_case_1 = {
-        'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
-        'placeholder_data': {
-            'value': None,
-            'shape': int64_array([1, 100, 120, 150]),
-            'kind': 'data',
-            'data_type': None
-        },
-        'interpolate': {
-            'type': 'Interpolate',
-            'kind': 'op',
-            'op': 'Interpolate',
-            'axes': int64_array([3]),
-            'mode': 'nearest'
-        },
-        'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
-        'slice_begin': {
-            'type': 'Const',
-            'op': 'Const',
-            'kind': 'op',
-            'value': int64_array([3]),
-            'shape': int64_array([1])
-        },
-        'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([4])},
-        'sslice': {
-            'kind': 'op',
-            'type': 'StridedSlice',
-            'op': 'StridedSlice',
-            'begin_mask': int64_array([1]),
-            'end_mask': int64_array([1]),
-            'new_axis_mask': int64_array([0]),
-            'shrink_axis_mask': int64_array([0]),
-            'ellipsis_mask': int64_array([0]),
-        },
-        'sslice_data': {'kind': 'data', 'shape': None},
-        'scales': {
-            'type': 'Const',
-            'op': 'Const',
-            'kind': 'op',
-            'value': int64_array([2]),
-            'shape': int64_array([1])
-        },
-        'scales_data': {'kind': 'data', 'shape': None},
-        'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
-        'mul_data': {'kind': 'data', 'shape': None},
-        'interpolate_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
-        'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
-        'abs_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
-        'output': {'kind': 'op', 'op': 'Result'},
-    }
-
-ref_graph_node_attrs_for_2d_spatial_case_2 = {
-        'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
-        'placeholder_data': {
-            'value': None,
-            'shape': int64_array([1, 100, 120, 150]),
-            'kind': 'data',
-            'data_type': None
-        },
-        'interpolate': {
-            'type': 'Interpolate',
-            'kind': 'op',
-            'op': 'Interpolate',
-            'axes': int64_array([2]),
-            'mode': 'nearest'
-        },
-        'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
-        'slice_begin': {
-            'type': 'Const',
-            'op': 'Const',
-            'kind': 'op',
-            'value': int64_array([2]),
-            'shape': int64_array([1])
-        },
-        'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([3])},
-        'sslice': {
-            'kind': 'op',
-            'type': 'StridedSlice',
-            'op': 'StridedSlice',
-            'begin_mask': int64_array([1]),
-            'end_mask': int64_array([1]),
-            'new_axis_mask': int64_array([0]),
-            'shrink_axis_mask': int64_array([0]),
-            'ellipsis_mask': int64_array([0]),
-        },
-        'sslice_data': {'kind': 'data', 'shape': None},
-        'scales': {
-            'type': 'Const',
-            'op': 'Const',
-            'kind': 'op',
-            'value': int64_array([2]),
-            'shape': int64_array([1])
-        },
-        'scales_data': {'kind': 'data', 'shape': None},
-        'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
-        'mul_data': {'kind': 'data', 'shape': None},
-        'interpolate_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
-        'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
-        'abs_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
-        'output': {'kind': 'op', 'op': 'Result'},
-    }
-
-
-ref_graph_node_attrs_for_3d_spatial_case_1 = {
-        'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
-        'placeholder_data': {
-            'value': None,
-            'shape': int64_array([1, 3, 100, 120, 150]),
-            'kind': 'data',
-            'data_type': None
-        },
-        'interpolate': {
-            'type': 'Interpolate',
-            'kind': 'op',
-            'op': 'Interpolate',
-            'axes': int64_array([4]),
-            'mode': 'nearest'
-        },
-        'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
-        'slice_begin': {
-            'type': 'Const',
-            'op': 'Const',
-            'kind': 'op',
-            'value': int64_array([4]),
-            'shape': int64_array([1])
-        },
-        'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([5])},
-        'sslice': {
-            'kind': 'op',
-            'type': 'StridedSlice',
-            'op': 'StridedSlice',
-            'begin_mask': int64_array([1]),
-            'end_mask': int64_array([1]),
-            'new_axis_mask': int64_array([0]),
-            'shrink_axis_mask': int64_array([0]),
-            'ellipsis_mask': int64_array([0]),
-        },
-        'sslice_data': {'kind': 'data', 'shape': None},
-        'scales': {
-            'type': 'Const',
-            'op': 'Const',
-            'kind': 'op',
-            'value': int64_array([2]),
-            'shape': int64_array([1])
-        },
-        'scales_data': {'kind': 'data', 'shape': None},
-        'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
-        'mul_data': {'kind': 'data', 'shape': None},
-        'interpolate_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
-        'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
-        'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
-        'output': {'kind': 'op', 'op': 'Result'},
-    }
-
-
-ref_graph_node_attrs_for_3d_spatial_case_2 = {
-        'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
-        'placeholder_data': {
-            'value': None,
-            'shape': int64_array([1, 3, 100, 120, 150]),
-            'kind': 'data',
-            'data_type': None
-        },
-        'interpolate': {
-            'type': 'Interpolate',
-            'kind': 'op',
-            'op': 'Interpolate',
-            'axes': int64_array([3]),
-            'mode': 'nearest'
-        },
-        'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
-        'slice_begin': {
-            'type': 'Const',
-            'op': 'Const',
-            'kind': 'op',
-            'value': int64_array([4]),
-            'shape': int64_array([1])
-        },
-        'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([5])},
-        'sslice': {
-            'kind': 'op',
-            'type': 'StridedSlice',
-            'op': 'StridedSlice',
-            'begin_mask': int64_array([1]),
-            'end_mask': int64_array([1]),
-            'new_axis_mask': int64_array([0]),
-            'shrink_axis_mask': int64_array([0]),
-            'ellipsis_mask': int64_array([0]),
-        },
-        'sslice_data': {'kind': 'data', 'shape': None},
-        'scales': {
-            'type': 'Const',
-            'op': 'Const',
-            'kind': 'op',
-            'value': int64_array([2]),
-            'shape': int64_array([1])
-        },
-        'scales_data': {'kind': 'data', 'shape': None},
-        'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
-        'mul_data': {'kind': 'data', 'shape': None},
-        'interpolate_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
-        'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
-        'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
-        'output': {'kind': 'op', 'op': 'Result'},
-    }
-
-
-class SplitConcatPairToInterpolateTest(unittest.TestCase):
-    def test_spatial_2d_split_concat_1(self):
-        graph = build_graph(
-            nodes_attrs=graph_node_attrs_for_2d_spatial_case,
-            edges=graph_edges
-        )
-        ref_graph = build_graph(
-            nodes_attrs=ref_graph_node_attrs_for_2d_spatial_case_1,
-            edges=ref_graph_edges
-        )
-        SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
-        (flag, resp) = compare_graphs(graph, ref_graph, 'output')
-        self.assertTrue(flag, resp)
-
-    def test_spatial_2d_split_concat_2(self):
-        graph = build_graph(
-            nodes_attrs=graph_node_attrs_for_2d_spatial_case,
-            edges=graph_edges,
-            update_attributes={
-                'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
-                'split_axis_const': {
-                    'kind': 'op',
-                    'value': np.array(2, dtype=np.int64),
-                    'op': 'Const',
-                    'type': 'Const'
-                },
-                'split_axis_const_data': {'value': None, 'shape': np.array(2, dtype=np.int64).shape, 'kind': 'data'},
-                'concat': {'type': 'Concat', 'kind': 'op', 'axis': 2},
-                'split_data_0': {'value': None, 'shape': int64_array([1, 100, 40, 150]), 'kind': 'data'},
-                'split_data_1': {'value': None, 'shape': int64_array([1, 100, 40, 150]), 'kind': 'data'},
-                'split_data_2': {'value': None, 'shape': int64_array([1, 100, 40, 150]), 'kind': 'data'},
-                'concat_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
-                'abs_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
-            }
-        )
-        ref_graph = build_graph(
-            nodes_attrs=ref_graph_node_attrs_for_2d_spatial_case_2,
-            edges=ref_graph_edges
-        )
-        SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
-        (flag, resp) = compare_graphs(graph, ref_graph, 'output')
-        self.assertTrue(flag, resp)
-
-    def test_spatial_3d_split_concat_1(self):
-        graph = build_graph(
-            nodes_attrs=graph_node_attrs_for_3d_spatial_case,
-            edges=graph_edges
-        )
-        ref_graph = build_graph(
-            nodes_attrs=ref_graph_node_attrs_for_3d_spatial_case_1,
-            edges=ref_graph_edges
-        )
-        SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
-        (flag, resp) = compare_graphs(graph, ref_graph, 'output')
-        self.assertTrue(flag, resp)
-
-    def test_spatial_3d_split_concat_2(self):
-        graph = build_graph(
-            nodes_attrs=graph_node_attrs_for_3d_spatial_case,
-            edges=graph_edges,
-            update_attributes={
-                'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
-                'split_axis_const': {
-                    'kind': 'op',
-                    'value': np.array(3, dtype=np.int64),
-                    'op': 'Const',
-                    'type': 'Const'
-                },
-                'split_axis_const_data': {'value': None, 'shape': np.array(3, dtype=np.int64).shape, 'kind': 'data'},
-                'concat': {'type': 'Concat', 'kind': 'op', 'axis': 3},
-                'split_data_0': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
-                'split_data_1': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
-                'split_data_2': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
-                'concat_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
-                'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
-            }
-        )
-        ref_graph = build_graph(
-            nodes_attrs=ref_graph_node_attrs_for_3d_spatial_case_2,
-            edges=ref_graph_edges
-        )
-        SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
-        (flag, resp) = compare_graphs(graph, ref_graph, 'output')
-        self.assertTrue(flag, resp)
diff --git a/model-optimizer/extensions/middle/SplitConcatPairToInterpolate.py b/model-optimizer/extensions/middle/SplitConcatPairToInterpolate.py
new file mode 100644 (file)
index 0000000..b55e2f1
--- /dev/null
@@ -0,0 +1,170 @@
+"""
+ Copyright (c) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import logging as log
+from typing import Optional
+
+from extensions.ops.elementwise import Mul
+from extensions.ops.interpolate import Interpolate
+from mo.front.common.partial_infer.utils import int64_array
+from mo.front.tf.graph_utils import create_op_with_const_inputs
+from mo.graph.graph import Graph, Node
+from mo.middle.replacement import MiddleReplacementPattern
+from mo.ops.const import Const
+from mo.ops.shape import Shape
+from mo.ops.strided_slice import StridedSlice
+
+
+def get_concat_after_split(split: Node) -> Optional[Node]:
+    # If number of output nodes of 'split' is not equal to 1, then the transformation is not applicable.
+    split_outputs = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()]
+    names_of_split_outputs = set([n.name for n in split_outputs])
+    if len(names_of_split_outputs) != 1:
+        return
+
+    groups_of_inputs = [[d.idx for d in p.get_connection().get_destinations()] for _, p in split.out_ports().items()]
+    sizes_of_groups = set([len(g) for g in groups_of_inputs])
+    # If numbers of consumer ports are various for various output ports of 'split', then the transformation
+    # is not applicable.
+    if len(sizes_of_groups) != 1:
+        return
+    # The transformation is applicable iff output port 0 of 'split' goes to ports [0, ..., m-1] of next node,
+    # output port 1 of 'split' goes to ports [m, ..., m + (m-1)] of next node, ..., output port i of 'split'
+    # goes to ports [i * m, ..., i * m + (m - 1)], and so on.
+    flatten_groups = [i for g in groups_of_inputs for i in g]
+    if flatten_groups != list(range(0, len(flatten_groups))):
+        return
+
+    dest = split.out_port(0).get_destinations()[0].node
+    # The transformation is applicable, only if next node is Concat.
+    return dest if dest.soft_get('type') == 'Concat' else None
+
+
+def get_interpolate_pattern(split: Node) -> dict:
+    split_shape = split.in_port(0).data.get_shape()
+    if len(split_shape) not in {4, 5}:
+        return {}
+    concat = get_concat_after_split(split)
+    if concat is None:
+        return {}
+    return {'split': split, 'concat': concat}
+
+
+def get_split_scale(split: Node) -> int:
+    split_dests = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()]
+    num_of_split_dests = len(split_dests)
+    num_of_split_out_ports = len(split.out_ports())
+    fractional_part = num_of_split_dests / num_of_split_out_ports - num_of_split_dests // num_of_split_out_ports
+    assert fractional_part == 0, "Number of output ports of Split must be multiple of number of inputs of Concat"
+    return len(split_dests) // len(split.out_ports())
+
+
+def replace_interpolate_pattern(graph: Graph, match: dict):
+    split = match['split']
+    scale = int64_array([get_split_scale(split)])
+    axis = int(split.in_port(1).get_connection().get_source().node.value)
+    split_node_name = split.name
+
+    shape_node = Shape(graph, dict(name=split_node_name + '/Shape_')).create_node()
+    scales_node = Const(graph, dict(name=split_node_name + '/scales_', value=scale)).create_node()
+    mul_node = Mul(graph, dict(name=split_node_name + '/Mul_')).create_node()
+    scales_node.out_port(0).connect(mul_node.in_port(1))
+
+    strided_slice_node = create_op_with_const_inputs(graph,
+                                                     StridedSlice,
+                                                     {1: int64_array([axis]), 2: int64_array([axis + 1])},
+                                                     {
+                                                        'name': split_node_name + '/StridedSlice_',
+                                                        'begin_mask': int64_array([1]),
+                                                        'end_mask': int64_array([1]),
+                                                        'new_axis_mask': int64_array([0]),
+                                                        'shrink_axis_mask': int64_array([0]),
+                                                        'ellipsis_mask': int64_array([0])
+                                                     })
+    shape_node.out_port(0).connect(strided_slice_node.in_port(0))
+
+    strided_slice_node.out_port(0).connect(mul_node.in_port(0))
+
+    interp_node = Interpolate(graph, dict(name=split_node_name + '/Interpolate_',
+                                          axes=int64_array([axis]),
+                                          mode='nearest')).create_node()
+    mul_node.out_port(0).connect(interp_node.in_port(1))
+
+    match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0))
+
+    split_connection = split.in_port(0).get_connection()
+    split_connection.set_destination(interp_node.in_port(0))
+    split_connection.get_source().connect(shape_node.in_port(0))
+
+
+class SplitConcatPairToInterpolate(MiddleReplacementPattern):
+    """
+    This transformation looks for Interpolation layer implemented using simple operations, i.e. Split and Concat,
+    and replaces found pattern with a sequence of Shape, StridedSlice, Const, Mul, Interpolate.
+
+    Found pattern:
+        nodes=[
+            ('split', dict(kind='op', op='Split')),
+            ('concat', dict(kind='op', op='Concat')),
+        ],
+        edges=[
+            ('split', 'concat'),
+        ]
+
+    Here we assume that
+        1) 'split' is in NDHWC layout and is a 5D-tensor;
+        2) split dimensions for 'split' belongs to {1, 2, 3};
+        3) all outputs of 'split' go to only inputs of 'concat';
+        4) 'concat' takes inputs only from 'split';
+        5) split_dim of 'split' is equal to axis of 'concat'.
+
+    Found pattern will be replaced with
+        nodes=[
+            ('shape', dict(kind='op', op='Shape')),
+            ('strided_slice', dict(kind='op', op='StridedSlice')),
+            ('scales', dict(kind='op', op='Const')),
+            ('scaled_shape', dict(kind='op', op='Mul')),
+            ('interp', dict(kind='op', op='Interpolate'))
+        ],
+        edges=[
+            ('shape', 'strided_slice', {'in': 0}),
+            ('strided_slice', 'scaled_shape', {'in': 0}),
+            ('scales', 'scaled_shape', {'in': 1}),
+            ('scaled_shape', 'interp', {'in': 1}),
+        ]
+
+    Here scaling factor in Interpolate is equal to a quotient of dividing number of input ports of 'concat'
+    by number of output ports of 'split'.
+    """
+    enabled = True
+    force_clean_up = True
+
+    def run_before(self):
+        from extensions.middle.InterpolateSequenceToInterpolate import InterpolateSequenceToInterpolate
+        return [InterpolateSequenceToInterpolate]
+
+    def find_and_replace_pattern(self, graph: Graph):
+        log.debug('Enabled replacement of a pair of Split and Concat with Interpolate.')
+        splits = graph.get_op_nodes(op='Split')
+        patterns = []
+
+        for split_node in splits:
+            interpolate_pattern = get_interpolate_pattern(split_node)
+            if interpolate_pattern:
+                patterns.append(interpolate_pattern)
+
+        for pattern in patterns:
+            replace_interpolate_pattern(graph, pattern)
diff --git a/model-optimizer/extensions/middle/SplitConcatPairToInterpolate_test.py b/model-optimizer/extensions/middle/SplitConcatPairToInterpolate_test.py
new file mode 100644 (file)
index 0000000..b7f4fac
--- /dev/null
@@ -0,0 +1,427 @@
+"""
+ Copyright (c) 2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+
+import unittest
+
+import numpy as np
+
+from extensions.middle.SplitConcatPairToInterpolate import SplitConcatPairToInterpolate
+from mo.front.common.partial_infer.utils import int64_array
+from mo.utils.ir_engine.compare_graphs import compare_graphs
+from mo.utils.unittest.graph import build_graph
+
+graph_node_attrs_for_2d_spatial_case = {
+    'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+    'placeholder_data': {
+        'value': None,
+        'shape': int64_array([1, 100, 120, 150]),
+        'kind': 'data',
+        'data_type': None
+    },
+    'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
+    'split_axis_const': {
+        'kind': 'op',
+        'value': np.array(3, dtype=np.int64),
+        'op': 'Const',
+        'type': 'Const'
+    },
+    'split_axis_const_data': {'value': None, 'shape': np.array(3, dtype=np.int64).shape, 'kind': 'data'},
+    'concat': {'type': 'Concat', 'kind': 'op', 'axis': 3},
+    'split_data_0': {'value': None, 'shape': int64_array([1, 100, 120, 50]), 'kind': 'data'},
+    'split_data_1': {'value': None, 'shape': int64_array([1, 100, 120, 50]), 'kind': 'data'},
+    'split_data_2': {'value': None, 'shape': int64_array([1, 100, 120, 50]), 'kind': 'data'},
+    'concat_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
+    'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
+    'abs_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
+    'output': {'kind': 'op', 'op': 'Result'},
+}
+
+
+graph_node_attrs_for_3d_spatial_case = {
+        'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+        'placeholder_data': {
+            'value': None,
+            'shape': int64_array([1, 3, 100, 120, 150]),
+            'kind': 'data',
+            'data_type': None
+        },
+        'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
+        'split_axis_const': {
+            'kind': 'op',
+            'value': np.array(4, dtype=np.int64),
+            'op': 'Const',
+            'type': 'Const'
+        },
+        'split_axis_const_data': {'value': None, 'shape': np.array(4, dtype=np.int64).shape, 'kind': 'data'},
+        'concat': {'type': 'Concat', 'kind': 'op', 'axis': 4},
+        'split_data_0': {'value': None, 'shape': int64_array([1, 3, 100, 120, 50]), 'kind': 'data'},
+        'split_data_1': {'value': None, 'shape': int64_array([1, 3, 100, 120, 50]), 'kind': 'data'},
+        'split_data_2': {'value': None, 'shape': int64_array([1, 3, 100, 120, 50]), 'kind': 'data'},
+        'concat_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
+        'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
+        'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
+        'output': {'kind': 'op', 'op': 'Result'},
+    }
+
+
+graph_edges = [
+    ('placeholder', 'placeholder_data'),
+    ('placeholder_data', 'split', {'in': 0}),
+    ('split_axis_const', 'split_axis_const_data'),
+    ('split_axis_const_data', 'split', {'in': 1}),
+    ('split', 'split_data_0', {'out': 0}),
+    ('split', 'split_data_1', {'out': 1}),
+    ('split', 'split_data_2', {'out': 2}),
+    ('split_data_0', 'concat', {'in': 0}),
+    ('split_data_0', 'concat', {'in': 1}),
+    ('split_data_1', 'concat', {'in': 2}),
+    ('split_data_1', 'concat', {'in': 3}),
+    ('split_data_2', 'concat', {'in': 4}),
+    ('split_data_2', 'concat', {'in': 5}),
+    ('concat', 'concat_data'),
+    ('concat_data', 'abs'),
+    ('abs', 'abs_data'),
+    ('abs_data', 'output')
+]
+
+
+ref_graph_edges = [
+        ('placeholder', 'placeholder_data'),
+        ('placeholder_data', 'interpolate', {'in': 0}),
+        ('placeholder_data', 'shape'),
+        ('shape', 'shape_data'),
+        ('shape_data', 'sslice', {'in': 0}),
+        ('slice_begin', 'slice_begin_data'),
+        ('slice_begin_data', 'sslice', {'in': 1}),
+        ('slice_end', 'slice_end_data'),
+        ('slice_end_data', 'sslice', {'in': 2}),
+        ('sslice', 'sslice_data'),
+        ('scales', 'scales_data'),
+        ('sslice_data', 'mul', {'in': 0}),
+        ('scales_data', 'mul', {'in': 1}),
+        ('mul', 'mul_data'),
+        ('mul_data', 'interpolate', {'in': 1}),
+        ('interpolate', 'interpolate_data'),
+        ('interpolate_data', 'abs'),
+        ('abs', 'abs_data'),
+        ('abs_data', 'output'),
+    ]
+
+
+ref_graph_node_attrs_for_2d_spatial_case_1 = {
+    'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+    'placeholder_data': {
+        'value': None,
+        'shape': int64_array([1, 100, 120, 150]),
+        'kind': 'data',
+        'data_type': None
+    },
+    'interpolate': {
+        'type': 'Interpolate',
+        'kind': 'op',
+        'op': 'Interpolate',
+        'axes': int64_array([3]),
+        'mode': 'nearest'
+    },
+    'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
+    'shape_data': {'kind': 'data', 'shape': None, 'value': None},
+    'slice_begin': {
+        'type': 'Const',
+        'op': 'Const',
+        'kind': 'op',
+        'value': int64_array([3]),
+        'shape': int64_array([1])
+    },
+    'slice_begin_data': {'kind': 'data', 'shape': None, 'value': None},
+    'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([4])},
+    'slice_end_data': {'kind': 'data', 'shape': None, 'value': None},
+    'sslice': {
+        'kind': 'op',
+        'type': 'StridedSlice',
+        'op': 'StridedSlice',
+        'begin_mask': int64_array([1]),
+        'end_mask': int64_array([1]),
+        'new_axis_mask': int64_array([0]),
+        'shrink_axis_mask': int64_array([0]),
+        'ellipsis_mask': int64_array([0]),
+    },
+    'sslice_data': {'kind': 'data', 'shape': None},
+    'scales': {
+        'type': 'Const',
+        'op': 'Const',
+        'kind': 'op',
+        'value': int64_array([2]),
+        'shape': int64_array([1])
+    },
+    'scales_data': {'kind': 'data', 'shape': None},
+    'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
+    'mul_data': {'kind': 'data', 'shape': None},
+    'interpolate_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
+    'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
+    'abs_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
+    'output': {'kind': 'op', 'op': 'Result'},
+}
+
+ref_graph_node_attrs_for_2d_spatial_case_2 = {
+    'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+    'placeholder_data': {
+        'value': None,
+        'shape': int64_array([1, 100, 120, 150]),
+        'kind': 'data',
+        'data_type': None
+    },
+    'interpolate': {
+        'type': 'Interpolate',
+        'kind': 'op',
+        'op': 'Interpolate',
+        'axes': int64_array([2]),
+        'mode': 'nearest'
+    },
+    'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
+    'shape_data': {'kind': 'data', 'shape': None, 'value': None},
+    'slice_begin': {
+        'type': 'Const',
+        'op': 'Const',
+        'kind': 'op',
+        'value': int64_array([2]),
+        'shape': int64_array([1])
+    },
+    'slice_begin_data': {'kind': 'data', 'shape': None, 'value': None},
+    'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([3])},
+    'slice_end_data': {'kind': 'data', 'shape': None, 'value': None},
+    'sslice': {
+        'kind': 'op',
+        'type': 'StridedSlice',
+        'op': 'StridedSlice',
+        'begin_mask': int64_array([1]),
+        'end_mask': int64_array([1]),
+        'new_axis_mask': int64_array([0]),
+        'shrink_axis_mask': int64_array([0]),
+        'ellipsis_mask': int64_array([0]),
+    },
+    'sslice_data': {'kind': 'data', 'shape': None},
+    'scales': {
+        'type': 'Const',
+        'op': 'Const',
+        'kind': 'op',
+        'value': int64_array([2]),
+        'shape': int64_array([1])
+    },
+    'scales_data': {'kind': 'data', 'shape': None},
+    'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
+    'mul_data': {'kind': 'data', 'shape': None},
+    'interpolate_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
+    'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
+    'abs_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
+    'output': {'kind': 'op', 'op': 'Result'},
+}
+
+
+ref_graph_node_attrs_for_3d_spatial_case_1 = {
+    'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+    'placeholder_data': {
+        'value': None,
+        'shape': int64_array([1, 3, 100, 120, 150]),
+        'kind': 'data',
+        'data_type': None
+    },
+    'interpolate': {
+        'type': 'Interpolate',
+        'kind': 'op',
+        'op': 'Interpolate',
+        'axes': int64_array([4]),
+        'mode': 'nearest'
+    },
+    'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
+    'shape_data': {'kind': 'data', 'shape': None, 'value': None},
+    'slice_begin': {
+        'type': 'Const',
+        'op': 'Const',
+        'kind': 'op',
+        'value': int64_array([4]),
+        'shape': int64_array([1])
+    },
+    'slice_begin_data': {'kind': 'data', 'shape': None, 'value': None},
+    'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([5])},
+    'slice_end_data': {'kind': 'data', 'shape': None, 'value': None},
+    'sslice': {
+        'kind': 'op',
+        'type': 'StridedSlice',
+        'op': 'StridedSlice',
+        'begin_mask': int64_array([1]),
+        'end_mask': int64_array([1]),
+        'new_axis_mask': int64_array([0]),
+        'shrink_axis_mask': int64_array([0]),
+        'ellipsis_mask': int64_array([0]),
+    },
+    'sslice_data': {'kind': 'data', 'shape': None},
+    'scales': {
+        'type': 'Const',
+        'op': 'Const',
+        'kind': 'op',
+        'value': int64_array([2]),
+        'shape': int64_array([1])
+    },
+    'scales_data': {'kind': 'data', 'shape': None},
+    'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
+    'mul_data': {'kind': 'data', 'shape': None},
+    'interpolate_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
+    'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
+    'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
+    'output': {'kind': 'op', 'op': 'Result'},
+}
+
+
+ref_graph_node_attrs_for_3d_spatial_case_2 = {
+    'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
+    'placeholder_data': {
+        'value': None,
+        'shape': int64_array([1, 3, 100, 120, 150]),
+        'kind': 'data',
+        'data_type': None
+    },
+    'interpolate': {
+        'type': 'Interpolate',
+        'kind': 'op',
+        'op': 'Interpolate',
+        'axes': int64_array([3]),
+        'mode': 'nearest'
+    },
+    'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
+    'shape_data': {'kind': 'data', 'shape': None, 'value': None},
+    'slice_begin': {
+        'type': 'Const',
+        'op': 'Const',
+        'kind': 'op',
+        'value': int64_array([4]),
+        'shape': int64_array([1])
+    },
+    'slice_begin_data': {'kind': 'data', 'shape': None, 'value': None},
+    'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([5])},
+    'slice_end_data': {'kind': 'data', 'shape': None, 'value': None},
+    'sslice': {
+        'kind': 'op',
+        'type': 'StridedSlice',
+        'op': 'StridedSlice',
+        'begin_mask': int64_array([1]),
+        'end_mask': int64_array([1]),
+        'new_axis_mask': int64_array([0]),
+        'shrink_axis_mask': int64_array([0]),
+        'ellipsis_mask': int64_array([0]),
+    },
+    'sslice_data': {'kind': 'data', 'shape': None},
+    'scales': {
+        'type': 'Const',
+        'op': 'Const',
+        'kind': 'op',
+        'value': int64_array([2]),
+        'shape': int64_array([1])
+    },
+    'scales_data': {'kind': 'data', 'shape': None},
+    'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
+    'mul_data': {'kind': 'data', 'shape': None},
+    'interpolate_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
+    'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
+    'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
+    'output': {'kind': 'op', 'op': 'Result'},
+}
+
+
+class SplitConcatPairToInterpolateTest(unittest.TestCase):
+    def test_spatial_2d_split_concat_1(self):
+        graph = build_graph(
+            nodes_attrs=graph_node_attrs_for_2d_spatial_case,
+            edges=graph_edges
+        )
+        ref_graph = build_graph(
+            nodes_attrs=ref_graph_node_attrs_for_2d_spatial_case_1,
+            edges=ref_graph_edges
+        )
+        SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
+        (flag, resp) = compare_graphs(graph, ref_graph, 'output')
+        self.assertTrue(flag, resp)
+
+    def test_spatial_2d_split_concat_2(self):
+        graph = build_graph(
+            nodes_attrs=graph_node_attrs_for_2d_spatial_case,
+            edges=graph_edges,
+            update_attributes={
+                'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
+                'split_axis_const': {
+                    'kind': 'op',
+                    'value': np.array(2, dtype=np.int64),
+                    'op': 'Const',
+                    'type': 'Const'
+                },
+                'split_axis_const_data': {'value': None, 'shape': np.array(2, dtype=np.int64).shape, 'kind': 'data'},
+                'concat': {'type': 'Concat', 'kind': 'op', 'axis': 2},
+                'split_data_0': {'value': None, 'shape': int64_array([1, 100, 40, 150]), 'kind': 'data'},
+                'split_data_1': {'value': None, 'shape': int64_array([1, 100, 40, 150]), 'kind': 'data'},
+                'split_data_2': {'value': None, 'shape': int64_array([1, 100, 40, 150]), 'kind': 'data'},
+                'concat_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
+                'abs_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
+            }
+        )
+        ref_graph = build_graph(
+            nodes_attrs=ref_graph_node_attrs_for_2d_spatial_case_2,
+            edges=ref_graph_edges
+        )
+        SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
+        (flag, resp) = compare_graphs(graph, ref_graph, 'output')
+        self.assertTrue(flag, resp)
+
+    def test_spatial_3d_split_concat_1(self):
+        graph = build_graph(
+            nodes_attrs=graph_node_attrs_for_3d_spatial_case,
+            edges=graph_edges
+        )
+        ref_graph = build_graph(
+            nodes_attrs=ref_graph_node_attrs_for_3d_spatial_case_1,
+            edges=ref_graph_edges
+        )
+        SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
+        (flag, resp) = compare_graphs(graph, ref_graph, 'output')
+        self.assertTrue(flag, resp)
+
+    def test_spatial_3d_split_concat_2(self):
+        graph = build_graph(
+            nodes_attrs=graph_node_attrs_for_3d_spatial_case,
+            edges=graph_edges,
+            update_attributes={
+                'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
+                'split_axis_const': {
+                    'kind': 'op',
+                    'value': np.array(3, dtype=np.int64),
+                    'op': 'Const',
+                    'type': 'Const'
+                },
+                'split_axis_const_data': {'value': None, 'shape': np.array(3, dtype=np.int64).shape, 'kind': 'data'},
+                'concat': {'type': 'Concat', 'kind': 'op', 'axis': 3},
+                'split_data_0': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
+                'split_data_1': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
+                'split_data_2': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
+                'concat_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
+                'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
+            }
+        )
+        ref_graph = build_graph(
+            nodes_attrs=ref_graph_node_attrs_for_3d_spatial_case_2,
+            edges=ref_graph_edges
+        )
+        SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
+        (flag, resp) = compare_graphs(graph, ref_graph, 'output')
+        self.assertTrue(flag, resp)