2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from extensions.ops.upsample import UpsampleOp
21 from mo.front.common.partial_infer.utils import int64_array
22 from mo.front.common.replacement import FrontReplacementSubgraph
23 from mo.graph.graph import Graph, Node
26 class BatchToSpaceNDToUpsample(FrontReplacementSubgraph):
28 The transformation looks for pattern that performs NX upscale of the input image specified in the NHWC layout.
33 def pattern(**kwargs):
36 ('transpose', dict(op='Transpose')),
37 ('expand_dims', dict(op='Unsqueeze')),
38 ('tile', dict(op='Tile')),
39 ('batch_to_space_nd', dict(op='BatchToSpaceND')),
40 ('strided_slice', dict(op='StridedSlice')),
41 ('transpose_back', dict(op='Transpose')),
44 ('transpose', 'expand_dims', {'out': 0}),
45 ('expand_dims', 'tile', {'out': 0}),
46 ('tile', 'batch_to_space_nd', {'out': 0}),
47 ('batch_to_space_nd', 'strided_slice', {'out': 0}),
48 ('strided_slice', 'transpose_back', {'out': 0})
53 def replace_sub_graph(graph: Graph, match: dict, **kwargs):
54 def _input_node_value(node: Node, port_ind: int):
55 input_node = node.in_port(port_ind).get_source().node
56 return input_node.value if input_node.op == 'Const' else None
58 transpose = match['transpose']
59 transpose_order = _input_node_value(transpose, 1)
60 if transpose_order is None or not np.all(np.equal(transpose_order, int64_array([1, 2, 3, 0]))):
61 log.debug('The transpose order {} for node {} is not equal to [1, 2, 3, 0]. Cannot apply '
62 'BatchToSpaceNDToUpsample transformation.'.format(transpose_order, transpose.name))
65 expand_axis = match['expand_dims']
66 expand_axis_value = _input_node_value(expand_axis, 1)
67 if expand_axis_value != 0:
68 log.debug('The expand axis {} for node {} is not equal to 0. Cannot apply BatchToSpaceNDToUpsample '
69 'transformation.'.format(expand_axis_value, expand_axis.name))
73 tile_value = _input_node_value(tile, 1)
74 if tile_value is None:
75 log.debug('The tile value is not defined for node {}. Cannot apply BatchToSpaceNDToUpsample '
76 'transformation.'.format(tile.name))
79 if len(np.where(tile_value != 1)) != 1:
80 log.debug('The number of tiles not equal to 1 not equal to 1. Cannot apply BatchToSpaceNDToUpsample '
83 tile_batch = tile_value[0]
85 batch_to_space_nd = match['batch_to_space_nd']
86 block_shape = _input_node_value(batch_to_space_nd, 1)
87 if block_shape is None or tile_batch != np.prod(block_shape):
88 log.debug('The block shape {} for node {} is not defined or inconsistent with the tile size. Cannot apply '
89 'BatchToSpaceNDToUpsample transformation.'.format(block_shape, batch_to_space_nd.name))
91 if len(block_shape) != 2:
92 log.debug('The block shape len is not equal to 2 for node {}. Cannot apply BatchToSpaceNDToUpsample '
93 'transformation.'.format(batch_to_space_nd.name))
96 transpose_back = match['transpose_back']
97 transpose_back_order = _input_node_value(transpose_back, 1)
98 if transpose_back_order is None or not np.all(np.equal(transpose_back_order, int64_array([3, 0, 1, 2]))):
99 log.debug('The transpose order {} for node {} is not equal to [3, 0, 1, 2]. Cannot apply '
100 'BatchToSpaceNDToUpsample transformation.'.format(transpose_back_order, transpose_back.name))
103 upsample_node = UpsampleOp(graph, {'height_scale': block_shape[0], 'width_scale': block_shape[1],
105 'name': transpose.name + '/upsample'}).create_node()
107 match['transpose'].in_port(0).get_connection().set_destination(upsample_node.in_port(0))
108 match['transpose_back'].out_port(0).get_connection().set_source(upsample_node.out_port(0))