2 Copyright (c) 2017-2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
22 from extensions.ops.resize_factor_utils import factor_update
23 from mo.front.common.layout import get_batch_dim, get_features_dim, get_height_dim, get_width_dim, shape_for_layout
24 from mo.graph.graph import Node
25 from mo.ops.op import Op
26 from mo.utils.utils import refer_to_faq_msg
32 def __init__(self, graph: nx.MultiDiGraph, attrs: dict):
38 'infer': InterpOp.interp_infer
40 super().__init__(graph, mandatory_props, attrs)
42 def supported_attrs(self):
48 'factor', # float factor required by IE shape inference
55 def interp_infer(node: Node):
56 layout = node.graph.graph['layout']
57 assert len(layout) == 4
58 if len(node.in_nodes()) == 2:
59 src_shape = node.in_node(0).shape
60 dst_shape = node.in_node(1).value
61 if src_shape is None or dst_shape is None or len(src_shape) != 4 or len(dst_shape) != 2:
63 'Node {} with op {} cannot be converted to Resample layer because there is no enough info about '
64 'src/dst shapes: src_shape = {}, dst_shape = {}'.format(node.name, node.op, src_shape, dst_shape))
65 node.type = None # prevent translation to a valid IE layer
67 in_height = src_shape[get_height_dim(layout, 4)]
68 in_width = src_shape[get_width_dim(layout, 4)]
69 out_height = dst_shape[0]
70 out_width = dst_shape[1]
72 node.factor = factor_update(
74 [float(out_height) / in_height, float(out_width) / in_width],
75 [in_height, in_width],
76 [out_height, out_width],
80 if node.factor is None:
81 node['width'] = out_width
82 node['height'] = out_height
84 node.out_node().shape = shape_for_layout(layout,
85 batch=src_shape[get_batch_dim(layout, 4)],
86 features=src_shape[get_features_dim(layout, 4)],
89 node.graph.remove_edge(node.in_node(1).id, node.id)
91 outn = node.out_node(0)
93 in_shape = node.in_node(0)
94 num_ = in_shape.shape[get_batch_dim(layout, 4)]
95 channels_ = in_shape.shape[get_features_dim(layout, 4)]
96 height_in_ = in_shape.shape[get_height_dim(layout, 4)]
97 width_in_ = in_shape.shape[get_width_dim(layout, 4)]
99 height_out_ = height_in_ + node.pad_beg + node.pad_end
100 width_out_ = width_in_ + node.pad_beg + node.pad_end
102 if node.shrink_factor != 1 and node.zoom_factor == 1:
103 shrink_factor = node.shrink_factor
104 if shrink_factor < 1:
105 log.error('Shrink factor should be positive in node {}'.format(node.id))
107 height_out_ = (height_out_ - 1) / shrink_factor + 1
108 width_out_ = (width_out_ - 1) / shrink_factor + 1
109 elif node.shrink_factor == 1 and node.zoom_factor != 1:
110 zoom_factor = node.zoom_factor
112 log.error('Zoom factor should be positive in node {}'.format(node.id))
115 node['debug_message'] = 'Interp layer shape inference function may be wrong, please, try to update ' \
116 'layer shape inference function in the file (extensions/ops/interp.op at the ' \
117 'line {}).'.format(inspect.currentframe().f_lineno) + refer_to_faq_msg(100)
118 # Reshape methods can be different in some cases
119 # Commented out section represents reshape that used in deeplab-caffe
120 # Uncomment the following lines, if your model was trained with deeplab-caffe
121 # or have the same reshape method
122 # height_out_ = height_out_ + (height_out_ - 1) * (zoom_factor - 1)
123 # width_out_ = width_out_ + (width_out_ - 1) * (zoom_factor - 1)
125 # Comment out the following lines if you use the reshape method from previous section
126 height_out_ = height_out_ * zoom_factor
127 width_out_ = width_out_ * zoom_factor
128 elif node.width != 0 and node.height != 0:
129 height_out_ = node.height
130 width_out_ = node.width
131 elif node.shrink_factor != 1 and node.zoom_factor != 1:
132 shrink_factor = node.shrink_factor
133 zoom_factor = node.zoom_factor
134 if shrink_factor < 1:
135 log.error('Shrink factor should be positive in node {}'.format(node.id))
138 log.error('Zoom factor should be positive in node {}'.format(node.id))
140 height_out_ = (height_out_ - 1) / shrink_factor + 1
141 width_out_ = (width_out_ - 1) / shrink_factor + 1
142 height_out_ = height_out_ + (height_out_ - 1) * (zoom_factor - 1)
143 width_out_ = width_out_ + (width_out_ - 1) * (zoom_factor - 1)
145 outn.shape = shape_for_layout(layout,