380eaf2c38fdbaf5124113a411d69751dc0be78a
[platform/upstream/dldt.git] / model-optimizer / extensions / ops / interp.py
1 """
2  Copyright (c) 2017-2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import inspect
18 import logging as log
19
20 import networkx as nx
21
22 from extensions.ops.resize_factor_utils import factor_update
23 from mo.front.common.layout import get_batch_dim, get_features_dim, get_height_dim, get_width_dim, shape_for_layout
24 from mo.graph.graph import Node
25 from mo.ops.op import Op
26 from mo.utils.utils import refer_to_faq_msg
27
28
29 class InterpOp(Op):
30     op = 'Interp'
31
32     def __init__(self, graph: nx.MultiDiGraph, attrs: dict):
33         mandatory_props = {
34             'type': __class__.op,
35             'op': __class__.op,
36             'factor': None,
37             'align_corners': 1,
38             'infer': InterpOp.interp_infer
39         }
40         super().__init__(graph, mandatory_props, attrs)
41
42     def supported_attrs(self):
43         return [
44             'height',
45             'width',
46             'zoom_factor',
47             'shrink_factor',
48             'factor',  # float factor required by IE shape inference
49             'pad_beg',
50             'pad_end',
51             'align_corners'
52         ]
53
54     @staticmethod
55     def interp_infer(node: Node):
56         layout = node.graph.graph['layout']
57         assert len(layout) == 4
58         if len(node.in_nodes()) == 2:
59             src_shape = node.in_node(0).shape
60             dst_shape = node.in_node(1).value
61             if src_shape is None or dst_shape is None or len(src_shape) != 4 or len(dst_shape) != 2:
62                 log.error(
63                     'Node {} with op {} cannot be converted to Resample layer because there is no enough info about '
64                     'src/dst shapes: src_shape = {}, dst_shape = {}'.format(node.name, node.op, src_shape, dst_shape))
65                 node.type = None  # prevent translation to a valid IE layer
66                 return
67             in_height = src_shape[get_height_dim(layout, 4)]
68             in_width = src_shape[get_width_dim(layout, 4)]
69             out_height = dst_shape[0]
70             out_width = dst_shape[1]
71
72             node.factor = factor_update(
73                 node.factor,
74                 [float(out_height) / in_height, float(out_width) / in_width],
75                 [in_height, in_width],
76                 [out_height, out_width],
77                 node.soft_get('name')
78             )
79
80             if node.factor is None:
81                 node['width'] = out_width
82                 node['height'] = out_height
83
84             node.out_node().shape = shape_for_layout(layout,
85                                                      batch=src_shape[get_batch_dim(layout, 4)],
86                                                      features=src_shape[get_features_dim(layout, 4)],
87                                                      height=out_height,
88                                                      width=out_width)
89             node.graph.remove_edge(node.in_node(1).id, node.id)
90         else:
91             outn = node.out_node(0)
92
93             in_shape = node.in_node(0)
94             num_ = in_shape.shape[get_batch_dim(layout, 4)]
95             channels_ = in_shape.shape[get_features_dim(layout, 4)]
96             height_in_ = in_shape.shape[get_height_dim(layout, 4)]
97             width_in_ = in_shape.shape[get_width_dim(layout, 4)]
98
99             height_out_ = height_in_ + node.pad_beg + node.pad_end
100             width_out_ = width_in_ + node.pad_beg + node.pad_end
101
102             if node.shrink_factor != 1 and node.zoom_factor == 1:
103                 shrink_factor = node.shrink_factor
104                 if shrink_factor < 1:
105                     log.error('Shrink factor should be positive in node {}'.format(node.id))
106                     return None
107                 height_out_ = (height_out_ - 1) / shrink_factor + 1
108                 width_out_ = (width_out_ - 1) / shrink_factor + 1
109             elif node.shrink_factor == 1 and node.zoom_factor != 1:
110                 zoom_factor = node.zoom_factor
111                 if zoom_factor < 1:
112                     log.error('Zoom factor should be positive in node {}'.format(node.id))
113                     return None
114
115                 node['debug_message'] = 'Interp layer shape inference function may be wrong, please, try to update ' \
116                                         'layer shape inference function in the file (extensions/ops/interp.op at the ' \
117                                         'line {}).'.format(inspect.currentframe().f_lineno) + refer_to_faq_msg(100)
118                 # Reshape methods can be different in some cases
119                 # Commented out section represents reshape that used in deeplab-caffe
120                 # Uncomment the following lines, if your model was trained with deeplab-caffe
121                 # or have the same reshape method
122                 # height_out_ = height_out_ + (height_out_ - 1) * (zoom_factor - 1)
123                 # width_out_ = width_out_ + (width_out_ - 1) * (zoom_factor - 1)
124
125                 # Comment out the following lines if you use the reshape method from previous section
126                 height_out_ = height_out_ * zoom_factor
127                 width_out_ = width_out_ * zoom_factor
128             elif node.width != 0 and node.height != 0:
129                 height_out_ = node.height
130                 width_out_ = node.width
131             elif node.shrink_factor != 1 and node.zoom_factor != 1:
132                 shrink_factor = node.shrink_factor
133                 zoom_factor = node.zoom_factor
134                 if shrink_factor < 1:
135                     log.error('Shrink factor should be positive in node {}'.format(node.id))
136                     return None
137                 if zoom_factor < 1:
138                     log.error('Zoom factor should be positive in node {}'.format(node.id))
139                     return None
140                 height_out_ = (height_out_ - 1) / shrink_factor + 1
141                 width_out_ = (width_out_ - 1) / shrink_factor + 1
142                 height_out_ = height_out_ + (height_out_ - 1) * (zoom_factor - 1)
143                 width_out_ = width_out_ + (width_out_ - 1) * (zoom_factor - 1)
144
145             outn.shape = shape_for_layout(layout,
146                                           batch=num_,
147                                           features=channels_,
148                                           height=height_out_,
149                                           width=width_out_)