Publishing 2019 R3 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / BatchToSpaceNDToUpsample.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import logging as log
17
18 import numpy as np
19
20 from extensions.ops.upsample import UpsampleOp
21 from mo.front.common.partial_infer.utils import int64_array
22 from mo.front.common.replacement import FrontReplacementSubgraph
23 from mo.graph.graph import Graph, Node
24
25
26 class BatchToSpaceNDToUpsample(FrontReplacementSubgraph):
27     """
28     The transformation looks for pattern that performs NX upscale of the input image specified in the NHWC layout.
29     """
30     enabled = True
31
32     @staticmethod
33     def pattern(**kwargs):
34         return dict(
35             nodes=[
36                 ('transpose', dict(op='Transpose')),
37                 ('expand_dims', dict(op='Unsqueeze')),
38                 ('tile', dict(op='Tile')),
39                 ('batch_to_space_nd', dict(op='BatchToSpaceND')),
40                 ('strided_slice', dict(op='StridedSlice')),
41                 ('transpose_back', dict(op='Transpose')),
42             ],
43             edges=[
44                 ('transpose', 'expand_dims', {'out': 0}),
45                 ('expand_dims', 'tile', {'out': 0}),
46                 ('tile', 'batch_to_space_nd', {'out': 0}),
47                 ('batch_to_space_nd', 'strided_slice', {'out': 0}),
48                 ('strided_slice', 'transpose_back', {'out': 0})
49             ]
50         )
51
52     @staticmethod
53     def replace_sub_graph(graph: Graph, match: dict, **kwargs):
54         def _input_node_value(node: Node, port_ind: int):
55             input_node = node.in_port(port_ind).get_source().node
56             return input_node.value if input_node.op == 'Const' else None
57
58         transpose = match['transpose']
59         transpose_order = _input_node_value(transpose, 1)
60         if transpose_order is None or not np.all(np.equal(transpose_order, int64_array([1, 2, 3, 0]))):
61             log.debug('The transpose order {} for node {} is not equal to [1, 2, 3, 0]. Cannot apply '
62                       'BatchToSpaceNDToUpsample transformation.'.format(transpose_order, transpose.name))
63             return
64
65         expand_axis = match['expand_dims']
66         expand_axis_value = _input_node_value(expand_axis, 1)
67         if expand_axis_value != 0:
68             log.debug('The expand axis {} for node {} is not equal to 0. Cannot apply BatchToSpaceNDToUpsample '
69                       'transformation.'.format(expand_axis_value, expand_axis.name))
70             return
71
72         tile = match['tile']
73         tile_value = _input_node_value(tile, 1)
74         if tile_value is None:
75             log.debug('The tile value is not defined for node {}. Cannot apply BatchToSpaceNDToUpsample '
76                       'transformation.'.format(tile.name))
77             return
78
79         if len(np.where(tile_value != 1)) != 1:
80             log.debug('The number of tiles not equal to 1 not equal to 1. Cannot apply BatchToSpaceNDToUpsample '
81                       'transformation.')
82             return
83         tile_batch = tile_value[0]
84
85         batch_to_space_nd = match['batch_to_space_nd']
86         block_shape = _input_node_value(batch_to_space_nd, 1)
87         if block_shape is None or tile_batch != np.prod(block_shape):
88             log.debug('The block shape {} for node {} is not defined or inconsistent with the tile size. Cannot apply '
89                       'BatchToSpaceNDToUpsample transformation.'.format(block_shape, batch_to_space_nd.name))
90             return
91         if len(block_shape) != 2:
92             log.debug('The block shape len is not equal to 2 for node {}. Cannot apply BatchToSpaceNDToUpsample '
93                       'transformation.'.format(batch_to_space_nd.name))
94             return
95
96         transpose_back = match['transpose_back']
97         transpose_back_order = _input_node_value(transpose_back, 1)
98         if transpose_back_order is None or not np.all(np.equal(transpose_back_order, int64_array([3, 0, 1, 2]))):
99             log.debug('The transpose order {} for node {} is not equal to [3, 0, 1, 2]. Cannot apply '
100                       'BatchToSpaceNDToUpsample transformation.'.format(transpose_back_order, transpose_back.name))
101             return
102
103         upsample_node = UpsampleOp(graph, {'height_scale': block_shape[0], 'width_scale': block_shape[1],
104                                            'mode': 'nearest',
105                                            'name': transpose.name + '/upsample'}).create_node()
106
107         match['transpose'].in_port(0).get_connection().set_destination(upsample_node.in_port(0))
108         match['transpose_back'].out_port(0).get_connection().set_source(upsample_node.out_port(0))