Publishing 2019 R3 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / ObjectDetectionAPI.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18 from math import sqrt
19
20 import numpy as np
21
22 from extensions.back.InsertLayoutPropagationTransposes import mark_as_correct_data_layout, \
23     mark_input_as_in_correct_layout, \
24     mark_output_as_in_correct_layout
25 from extensions.front.Pack import Pack
26 from extensions.front.TransposeOrderNormalizer import TransposeOrderNormalizer
27 from extensions.front.div import Div
28 from extensions.front.standalone_const_eraser import StandaloneConstEraser
29 from extensions.front.sub import Sub
30 from extensions.front.tf.CropAndResizeReplacement import CropAndResizeReplacement
31 from extensions.front.tf.Unpack import Unpack
32 from extensions.ops.DetectionOutput import DetectionOutput
33 from extensions.ops.ReduceOps import ReduceMean
34 from extensions.ops.activation_ops import Sigmoid
35 from extensions.ops.elementwise import Mul
36 from extensions.ops.parameter import Parameter
37 from extensions.ops.priorbox_clustered import PriorBoxClusteredOp
38 from extensions.ops.proposal import ProposalOp
39 from extensions.ops.psroipooling import PSROIPoolingOp
40 from extensions.ops.transpose import Transpose
41 from mo.front.common.layout import get_batch_dim, get_height_dim, get_width_dim
42 from mo.front.common.partial_infer.utils import int64_array
43 from mo.front.common.weights import swap_weights_xy
44 from mo.front.extractor import output_user_data_repack, add_output_ops
45 from mo.front.subgraph_matcher import SubgraphMatch
46 from mo.front.tf.graph_utils import add_activation_function_after_node, add_convolution_to_swap_xy_coordinates, \
47     squeeze_reshape_and_concat, add_fake_background_loc, create_op_node_with_second_input
48 from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph, FrontReplacementFromConfigFileGeneral
49 from mo.graph.graph import Graph, Node
50 from mo.ops.concat import Concat
51 from mo.ops.const import Const
52 from mo.ops.crop import Crop
53 from mo.ops.op import PermuteAttrs
54 from mo.ops.result import Result
55 from mo.ops.reshape import Reshape
56 from mo.ops.roipooling import ROIPooling
57 from mo.ops.shape import Shape
58 from mo.ops.softmax import Softmax
59 from mo.utils.error import Error
60 from mo.utils.graph import backward_bfs_for_operation, bfs_search
61 from mo.utils.pipeline_config import PipelineConfig
62
63 missing_param_error = 'To convert the model specify path to the pipeline configuration file which was used to ' \
64                       'generate the model. Please use "--tensorflow_object_detection_api_pipeline_config" option:\n' \
65                       '--tensorflow_object_detection_api_pipeline_config "<path_to_pipeline.config>"\nIf you have ' \
66                       'downloaded the model file from the Object Detection Model zoo repository then this file is ' \
67                       'located in the archive with frozen model and called "pipeline.config".\nIf you did not use ' \
68                       'this command line parameter before that means that you are using currently deprecated ' \
69                       'TensorFlow* Object Detection API models conversion mechanism.'
70
71
72 def _value_or_raise(match: SubgraphMatch, pipeline_config: PipelineConfig, key: str):
73     """
74     Returns value from the 'custom_attributes' of the 'match' object or pipeline_config associated with a key 'key'.
75     If the value doesn't exist then raise error.
76     :param match: SubgraphMatch object containing 'custom_attributes'.
77     :param pipeline_config: PipelineConfig object with parsed values.
78     :param key: key to search for.
79     :return: the requested value.
80     """
81     if match and key in match.custom_replacement_desc.custom_attributes:
82         return match.custom_replacement_desc.custom_attributes[key]
83     value = pipeline_config.get_param(key)
84     if value is None:
85         raise Error('The sub-graph replacer "[REPLACEMENT_ID]" was not able to find the value for key "{}" in the '
86                     'pipeline configuration file specified with the --tensorflow_object_detection_api_pipeline_config '
87                     'command line parameter. Update the sub-graph replacement configuration file specified with the '
88                     '--tensorflow_use_custom_operations_config command line parameter by adding key "{}" with required '
89                     'value to the "custom_attributes" dictionary of the "[REPLACEMENT_ID]" replacer.'.format(key, key))
90     return value
91
92
93 def _find_ssd_head_node(graph: Graph, ssd_head_index: int, head_type: str):
94     """
95     Finds the SSD head node with index 'ssd_head_index' in the topology. The parameter 'head_type' specifies what type
96     of the head is requested: with box predictions or class predictions.
97     :param graph: graph with the topology.
98     :param ssd_head_index: index of the SSD head.
99     :param head_type: either 'box' or 'class' string specifying type of the SSD head node.
100     :return: the requested Node or None if node is not found.
101     """
102     if head_type == 'box':
103         possible_node_names = ['BoxPredictor_%d/BoxEncodingPredictor/BiasAdd' % ssd_head_index,
104                                'WeightSharedConvolutionalBoxPredictor/BoxPredictor/BiasAdd' if ssd_head_index == 0 else
105                                'WeightSharedConvolutionalBoxPredictor_%d/BoxPredictor/BiasAdd' % ssd_head_index]
106     elif head_type == 'class':
107         possible_node_names = ['BoxPredictor_%d/ClassPredictor/BiasAdd' % ssd_head_index,
108                                'WeightSharedConvolutionalBoxPredictor/ClassPredictor/BiasAdd' if ssd_head_index == 0
109                                else 'WeightSharedConvolutionalBoxPredictor_%d/ClassPredictor/BiasAdd' % ssd_head_index]
110     else:
111         raise Error('SSD heads can be of type "box" and "class" only.')
112
113     head_node = None
114     for head_node_name in possible_node_names:
115         if head_node_name in graph.nodes():
116             assert (head_node is None)  # only one of the possible node names should exist in the graph
117             head_node = Node(graph, head_node_name)
118     return head_node
119
120
121 def _variance_from_pipeline_config(pipeline_config: PipelineConfig):
122     """
123     Generates a numpy array with variances values from the pipeline_config object. The order of the elements is the
124     following: variance x, variance y, variance box width, variance box height.
125     :param pipeline_config: pipeline_config object to get variances from.
126     :return: the numpy array with variances.
127     """
128     return 1.0 / np.array([pipeline_config.get_param('frcnn_variance_x'),
129                            pipeline_config.get_param('frcnn_variance_y'),
130                            pipeline_config.get_param('frcnn_variance_width'),
131                            pipeline_config.get_param('frcnn_variance_height')])
132
133
134 def _skip_node_of_type(node: Node, node_ops_to_skip: list):
135     """
136     Skips nodes of specified ops starting from node 'node'.
137     :param node: node to start skipping Identity nodes.
138     :return: node of the op
139     """
140     # skip the Identity node
141     while len(node.out_edges()) == 1 and node.op in node_ops_to_skip:
142         node = node.out_node()
143     return node
144
145
146 def _relax_reshape_nodes(graph: Graph, pipeline_config: PipelineConfig):
147     """
148     Finds the 'Reshape' operations following the SSD head nodes which have hard-coded output dimensions and replaces
149     them with new ones with one of the dimensions sizes equal to -1. This function is used to make TF OD API SSD models
150     reshape-able.
151     :param graph: graph with the topology.
152     :param pipeline_config: PipelineConfig object with parsed values.
153     :return: None
154     """
155     num_classes = pipeline_config.get_param('num_classes')
156     num_layers = pipeline_config.get_param('ssd_anchor_generator_num_layers')
157     if num_layers is None:
158         num_layers = pipeline_config.get_param('multiscale_anchor_generator_max_level') - \
159                      pipeline_config.get_param('multiscale_anchor_generator_min_level') + 1
160     for ssd_head_ind in range(num_layers):
161         input_node = _find_ssd_head_node(graph, ssd_head_ind, 'box')
162         assert (input_node is not None)
163         old_reshape_node = _skip_node_of_type(input_node.out_node(), ['Identity', 'FakeQuantWithMinMaxVars'])
164         assert old_reshape_node.op == 'Reshape'
165         reshape_size_node = Const(graph, {'value': int64_array([0, -1, 1, 4])}).create_node([])
166         new_reshape_op = Reshape(graph, {'name': input_node.id + '/Reshape'})
167         new_reshape_node = new_reshape_op.create_node([input_node, reshape_size_node])
168         old_reshape_node.replace_node(new_reshape_node)
169
170         # fix hard-coded value for the number of items in tensor produced by the convolution to make topology reshapable
171         input_node = _find_ssd_head_node(graph, ssd_head_ind, 'class')
172         assert (input_node is not None)
173         old_reshape_node = _skip_node_of_type(input_node.out_node(), ['Identity', 'FakeQuantWithMinMaxVars'])
174         assert old_reshape_node.op == 'Reshape'
175         reshape_size_node_2 = Const(graph, {'value': int64_array([0, -1, num_classes + 1])}).create_node([])
176         new_reshape_op_2 = Reshape(graph, {'name': input_node.id + '/Reshape'})
177         new_reshape_node_2 = new_reshape_op_2.create_node([input_node, reshape_size_node_2])
178         old_reshape_node.replace_node(new_reshape_node_2)
179
180
181 def _create_prior_boxes_node(graph: Graph, pipeline_config: PipelineConfig):
182     """
183     The function creates one or several PriorBoxClustered nodes based on information from the pipeline configuration
184     files. The PriorBoxClustered nodes get input data from SSD 'heads' and from the placeholder node (just to get
185     input image size).
186     :param graph: graph with the topology.
187     :param pipeline_config: PipelineConfig object with parsed values.
188     :return: node generating prior boxes.
189     """
190     min_scale = pipeline_config.get_param('ssd_anchor_generator_min_scale')
191     max_scale = pipeline_config.get_param('ssd_anchor_generator_max_scale')
192     num_layers = pipeline_config.get_param('ssd_anchor_generator_num_layers')
193     aspect_ratios = pipeline_config.get_param('ssd_anchor_generator_aspect_ratios')
194     if not isinstance(aspect_ratios, list):
195         aspect_ratios = [aspect_ratios]
196
197     # prior boxes have to be generated using the image size used for training
198     image_height = pipeline_config.get_param('resizer_image_height')
199     image_width = pipeline_config.get_param('resizer_image_width')
200     min_im_shape = min(image_height, image_width)
201     _base_anchor_height = pipeline_config.get_param('ssd_anchor_generator_base_anchor_height')
202     _base_anchor_width = pipeline_config.get_param('ssd_anchor_generator_base_anchor_width')
203     base_anchor_size = [min_im_shape / image_height * _base_anchor_height,
204                         min_im_shape / image_width * _base_anchor_width]
205     reduce_boxes_in_lowest_layer = True
206     if pipeline_config.get_param('ssd_anchor_generator_reduce_lowest') is not None:
207         reduce_boxes_in_lowest_layer = pipeline_config.get_param('ssd_anchor_generator_reduce_lowest')
208
209     if pipeline_config.get_param('ssd_anchor_generator_scales') is not None:
210         scales = pipeline_config.get_param('ssd_anchor_generator_scales') + [1.0]
211     else:
212         scales = [min_scale + (max_scale - min_scale) * i / (num_layers - 1) for i in range(num_layers)] + [1.0]
213
214     prior_box_nodes = []
215     for ssd_head_ind in range(num_layers):
216         ssd_head_node = _find_ssd_head_node(graph, ssd_head_ind, 'box')
217         assert (ssd_head_node is not None)
218
219         if ssd_head_ind == 0 and reduce_boxes_in_lowest_layer:
220             widths = [0.1, min_scale * sqrt(2.0), min_scale * sqrt(0.5)]
221             heights = [0.1, min_scale / sqrt(2.0), min_scale / sqrt(0.5)]
222         else:
223             widths = [scales[ssd_head_ind] * sqrt(ar) for ar in aspect_ratios]
224             heights = [scales[ssd_head_ind] / sqrt(ar) for ar in aspect_ratios]
225
226             interpolated_scale_ar = pipeline_config.get_param('ssd_anchor_generator_interpolated_scale_aspect_ratio')
227             if interpolated_scale_ar > 0.0:
228                 widths += [sqrt(scales[ssd_head_ind] * scales[ssd_head_ind + 1]) * interpolated_scale_ar]
229                 heights += [sqrt(scales[ssd_head_ind] * scales[ssd_head_ind + 1]) / interpolated_scale_ar]
230         widths = [w * image_width * base_anchor_size[1] for w in widths]
231         heights = [h * image_height * base_anchor_size[0] for h in heights]
232
233         variance = _variance_from_pipeline_config(pipeline_config)
234         prior_box_op = PriorBoxClusteredOp(graph, {'width': np.array(widths), 'height': np.array(heights),
235                                                    'clip': 0, 'flip': 0, 'variance': variance, 'offset': 0.5,
236                                                    })
237         # connect the PriorBoxClustered node with the "Cast" node of the Placeholder node because the pass that removes
238         # Cast operations is executed in the middle phase and it will fail when there are several consumers of the
239         # Placeholder
240         prior_box_node = prior_box_op.create_node([ssd_head_node, Node(graph, 'image_tensor').out_node(0)],
241                                                   {'name': 'PriorBoxClustered_{}'.format(ssd_head_ind)})
242         prior_box_nodes.append(prior_box_node)
243     if len(prior_box_nodes) == 1:
244         return prior_box_nodes[0]
245     else:
246         concat_prior_boxes_op = Concat(graph, {'axis': -1, 'in_ports_count': len(prior_box_nodes)})
247         return concat_prior_boxes_op.create_node(prior_box_nodes, {'name': 'ConcatPriorBoxesClustered'})
248
249
250 def _create_multiscale_prior_boxes_node(graph: Graph, pipeline_config: PipelineConfig):
251     """
252     The function creates one or several PriorBoxClustered nodes based on information from the pipeline configuration
253     files. The PriorBoxClustered nodes get input data from SSD 'heads' and from the placeholder node (just to get
254     input image size).
255     :param graph: graph with the topology.
256     :param pipeline_config: PipelineConfig object with parsed values.
257     :return: node generating prior boxes.
258     """
259     min_level = pipeline_config.get_param('multiscale_anchor_generator_min_level')
260     max_level = pipeline_config.get_param('multiscale_anchor_generator_max_level')
261     anchor_scale = pipeline_config.get_param('multiscale_anchor_generator_anchor_scale')
262     aspect_ratios = pipeline_config.get_param('multiscale_anchor_generator_aspect_ratios')
263     scales_per_octave = pipeline_config.get_param('multiscale_anchor_generator_scales_per_octave')
264
265     prior_box_nodes = []
266     scales = [2 ** (float(scale) / scales_per_octave) for scale in range(scales_per_octave)]
267     for level in range(min_level, max_level + 1):
268         base_anchor_size = 2 ** level * anchor_scale
269
270         ssd_head_ind = level - min_level
271         ssd_head_node = _find_ssd_head_node(graph, ssd_head_ind, 'box')
272         assert (ssd_head_node is not None)
273
274         widths = [base_anchor_size * scale * sqrt(ar) for ar in aspect_ratios for scale in scales]
275         heights = [base_anchor_size * scale / sqrt(ar) for ar in aspect_ratios for scale in scales]
276
277         variance = _variance_from_pipeline_config(pipeline_config)
278         prior_box_op = PriorBoxClusteredOp(graph, {'width': np.array(widths), 'height': np.array(heights),
279                                                    'clip': 0, 'flip': 0, 'variance': variance,
280                                                    'offset': 0.5,
281                                                    })
282         # connect the PriorBoxClustered node with the "Cast" node of the Placeholder node because the pass that removes
283         # Cast operations is executed in the middle phase and it will fail when there are several consumers of the
284         # Placeholder
285         prior_box_node = prior_box_op.create_node([ssd_head_node, Node(graph, 'image_tensor').out_node(0)],
286                                                   {'name': 'PriorBoxClustered_{}'.format(ssd_head_ind)})
287         prior_box_nodes.append(prior_box_node)
288     if len(prior_box_nodes) == 1:
289         return prior_box_nodes[0]
290     else:
291         concat_prior_boxes_op = Concat(graph, {'axis': -1, 'in_ports_count': len(prior_box_nodes)})
292         return concat_prior_boxes_op.create_node(prior_box_nodes, {'name': 'ConcatPriorBoxesClustered'})
293
294
295 def skip_nodes_by_condition(current_node: Node, condition: callable):
296     while condition(current_node):
297         current_node = current_node.in_node()
298     return current_node
299
300
301 def calculate_shape_keeping_aspect_ratio(height: int, width: int, min_size: int, max_size: int):
302     """
303     The function scales spatial sizes of the image keeping aspect ratio to satisfy provided requirements.
304     The behavior of this function is equivalent to the output shape calculation of the Preprocessor block of TensorFlow
305     Object Detection API models with keep aspect ratio resizer.
306     :param height: input height.
307     :param width: input width.
308     :param min_size: size limit.
309     :param max_size: size limit.
310     :return: the tuple with scaled image height, width.
311     """
312     ratio_min = min_size / min(height, width)
313     ratio_max = max_size / max(height, width)
314     ratio = min(ratio_min, ratio_max)
315     return int(round(height * ratio)), int(round(width * ratio))
316
317
318 def calculate_placeholder_spatial_shape(graph: Graph, match: SubgraphMatch, pipeline_config: PipelineConfig):
319     """
320     The function calculates the preprocessed shape of the input image for a TensorFlow Object Detection API model.
321     It uses various sources to calculate it:
322     1. The shape passed using the '--input_shape' command line parameter.
323     2. The values from the pipeline configuration file describing Preprocessor block of the topology:
324         a. If the fixed size resizer is used then size passed via '--input_shape' can override them, but Model Optimizer
325            prints warning. If the '--input_shape' is not defined then use values from the pipeline configuration file.
326         b. If the keep aspect ratio resizer is used then scale the size passed via '--input_shape' using the provided
327            limits. If the '--input_shape' is not defined then use shape as (min_dimension_size, min_dimension_size)
328            defined in the pipeline configuration file.
329     :param graph: graph with the topology.
330     :param match: the object containing matching sub-graph and custom attributes from the sub-graph replacement file.
331     :param pipeline_config: the object contain information from the pipeline configuration file.
332     :return: tuple (height, width) of the placeholder shape.
333     """
334     height = None
335     width = None
336     user_shapes = graph.graph['user_shapes']
337
338     if 'preprocessed_image_height' in match.custom_replacement_desc.custom_attributes or 'preprocessed_image_width' in \
339             match.custom_replacement_desc.custom_attributes:
340         log.error('The "preprocessed_image_height" or "preprocessed_image_width" is specified in the sub-graph '
341                   'replacement configuration file but they are ignored. Please, specify desired input shape using the '
342                   '"--input_shape" command line parameter.', extra={'is_warning': True})
343
344     user_defined_height = None
345     user_defined_width = None
346     if user_shapes and 'image_tensor' in user_shapes and user_shapes['image_tensor']:
347         user_defined_shape = user_shapes['image_tensor'][0]['shape']
348         if user_defined_shape is not None:
349             user_defined_height = user_defined_shape[1]
350             user_defined_width = user_defined_shape[2]
351
352     resizer_height = pipeline_config.get_param('resizer_image_height')
353     resizer_width = pipeline_config.get_param('resizer_image_width')
354     if resizer_height and resizer_width:
355         log.debug('The model resizes image to a fixed shape: ({}, {})'.format(resizer_height, resizer_width))
356
357     resizer_min_dimension = pipeline_config.get_param('resizer_min_dimension')
358     resizer_max_dimension = pipeline_config.get_param('resizer_max_dimension')
359     if resizer_min_dimension and resizer_max_dimension:
360         log.debug('The model resizes image using keep aspect ratio with minimum size {}, maximum size {}.'.format(
361             resizer_min_dimension, resizer_max_dimension))
362
363     # if the model is created with an input image resizer to a fixed shape
364     if resizer_width and resizer_height:
365         if user_defined_height and user_defined_width:
366             if user_defined_width != resizer_width or user_defined_width != resizer_width:
367                 log.error('The model expects that the input image is resized to a fixed shape ({}, {}), but the shape '
368                           'provided with the "--input_shape" command line parameter is different ({}, {}).'.format(
369                     resizer_height, resizer_width, user_defined_height, user_defined_width), extra={'is_warning': True})
370             height = user_defined_height
371             width = user_defined_width
372         else:
373             height = resizer_height
374             width = resizer_width
375
376     # if the model is created with an input image resizer keeping aspect ratio
377     if resizer_min_dimension and resizer_max_dimension:
378         print('[ WARNING ] Model Optimizer removes pre-processing block of the model which resizes image keeping '
379               'aspect ratio. The Inference Engine does not support dynamic image size so the Intermediate '
380               'Representation file is generated with the input image size of a fixed size.')
381         if user_defined_height and user_defined_width:
382             scaled_height, scaled_width = calculate_shape_keeping_aspect_ratio(user_defined_height,
383                                                                                user_defined_width,
384                                                                                resizer_min_dimension,
385                                                                                resizer_max_dimension)
386             if scaled_height != user_defined_height or scaled_width != scaled_width:
387                 log.error('The model resizes the input image keeping aspect ratio with min dimension {}, max '
388                           'dimension {}. The provided input height {}, width {} is transformed to height {}, width '
389                           '{}.'.format(resizer_min_dimension, resizer_max_dimension, user_defined_height,
390                                        user_defined_width, scaled_height, scaled_width), extra={'is_warning': True})
391             height = scaled_height
392             width = scaled_width
393         else:
394             height = width = resizer_min_dimension
395             print('Specify the "--input_shape" command line parameter to override the default shape which is equal to '
396                   '({}, {}).'.format(height, width))
397
398     if height is None or width is None:
399         raise Error('Failed to determine the placeholder shape.')
400     return height, width
401
402
403 class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSubGraph):
404     """
405     The class replaces the "Preprocessor" block resizing input image and applying mean/scale values. Only nodes related
406     to applying mean/scaling values are kept.
407     """
408     replacement_id = 'ObjectDetectionAPIPreprocessorReplacement'
409
410     def run_before(self):
411         return [Pack, Sub, TransposeOrderNormalizer]
412
413     def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
414         new_nodes_to_remove = match.matched_nodes_names()
415         # do not remove nodes that perform input image scaling and mean value subtraction
416         for node_to_keep in ('Preprocessor/sub', 'Preprocessor/sub/y', 'Preprocessor/mul', 'Preprocessor/mul/x'):
417             if node_to_keep in new_nodes_to_remove:
418                 new_nodes_to_remove.remove(node_to_keep)
419         return new_nodes_to_remove
420
421     def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
422         argv = graph.graph['cmd_params']
423         layout = graph.graph['layout']
424         if argv.tensorflow_object_detection_api_pipeline_config is None:
425             raise Error(missing_param_error)
426         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
427
428         sub_node = match.output_node(0)[0]
429         if not sub_node.has('op') or sub_node.op != 'Sub':
430             raise Error('The output op of the Preprocessor sub-graph is not of type "Sub". Looks like the topology is '
431                         'not created with TensorFlow Object Detection API.')
432
433         mul_node = None
434         if sub_node.in_node(0).has('op') and sub_node.in_node(0).op == 'Mul':
435             log.info('There is image scaling node in the Preprocessor block.')
436             mul_node = sub_node.in_node(0)
437
438         initial_input_node_name = 'image_tensor'
439         if initial_input_node_name not in graph.nodes():
440             raise Error('Input node "{}" of the graph is not found. Do not run the Model Optimizer with '
441                         '"--input" command line parameter.'.format(initial_input_node_name))
442         placeholder_node = Node(graph, initial_input_node_name)
443
444         # set default value of the batch size to 1 if user didn't specify batch size and input shape
445         batch_dim = get_batch_dim(layout, 4)
446         if argv.batch is None and placeholder_node.shape[batch_dim] == -1:
447             placeholder_node.shape[batch_dim] = 1
448         height, width = calculate_placeholder_spatial_shape(graph, match, pipeline_config)
449         placeholder_node.shape[get_height_dim(layout, 4)] = height
450         placeholder_node.shape[get_width_dim(layout, 4)] = width
451
452         # save the pre-processed image spatial sizes to be used in the other replacers
453         graph.graph['preprocessed_image_height'] = placeholder_node.shape[get_height_dim(layout, 4)]
454         graph.graph['preprocessed_image_width'] = placeholder_node.shape[get_width_dim(layout, 4)]
455
456         to_float_node = placeholder_node.out_node(0)
457         if not to_float_node.has('op') or to_float_node.op != 'Cast':
458             raise Error('The output of the node "{}" is not Cast operation. Cannot apply replacer.'.format(
459                 initial_input_node_name))
460
461         # connect to_float_node directly with node performing scale on mean value subtraction
462         if mul_node is None:
463             graph.create_edge(to_float_node, sub_node, 0, 0)
464         else:
465             graph.create_edge(to_float_node, mul_node, 0, 1)
466
467         print('The Preprocessor block has been removed. Only nodes performing mean value subtraction and scaling (if'
468               ' applicable) are kept.')
469         return {}
470
471
472 class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFileSubGraph):
473     """
474     Replaces the sub-graph that is equal to the DetectionOutput layer from Inference Engine. This replacer is used for
475     Faster R-CNN, R-FCN and Mask R-CNN topologies conversion.
476     The replacer uses a value of the custom attribute 'coordinates_swap_method' from the sub-graph replacement
477     configuration file to choose how to swap box coordinates of the 0-th input of the generated DetectionOutput layer.
478     Refer to the code for more details.
479     """
480     replacement_id = 'ObjectDetectionAPIDetectionOutputReplacement'
481
482     def run_before(self):
483         return [ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement, Unpack, Sub, TransposeOrderNormalizer]
484
485     def run_after(self):
486         return [ObjectDetectionAPIProposalReplacement, CropAndResizeReplacement]
487
488     def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
489         new_nodes_to_remove = match.matched_nodes_names().copy()
490         outputs = ['detection_boxes', 'detection_scores', 'num_detections']
491         for output in outputs:
492             children = Node(graph, output).out_nodes()
493             if len(children) != 1:
494                 log.warning('Output {} has {} children. It should have only one output: with op==`Result`'
495                             ''.format(output, len(children)))
496             elif children[list(children.keys())[0]].op == 'Result':
497                 new_nodes_to_remove.append(children[list(children.keys())[0]].id)
498             else:
499                 continue
500         new_nodes_to_remove.extend(outputs)
501         return new_nodes_to_remove
502
503     def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
504         # the DetectionOutput in IE produces single tensor, but in TF it produces four tensors, so we need to create
505         # only one output edge match
506         return {match.output_node(0)[0].id: new_sub_graph['detection_output_node'].id}
507
508     def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
509         argv = graph.graph['cmd_params']
510         if argv.tensorflow_object_detection_api_pipeline_config is None:
511             raise Error(missing_param_error)
512         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
513
514         num_classes = _value_or_raise(match, pipeline_config, 'num_classes')
515         max_proposals = _value_or_raise(match, pipeline_config, 'first_stage_max_proposals')
516         activation_function = _value_or_raise(match, pipeline_config, 'postprocessing_score_converter')
517
518         activation_conf_node = add_activation_function_after_node(graph, match.single_input_node(1)[0].in_node(0),
519                                                                   activation_function)
520
521         # IE DetectionOutput layer consumes flattened tensors so need add a Reshape layer.
522         # The batch value of the input tensor is not equal to the batch of the topology, so it is not possible to use
523         # "0" value in the Reshape layer attribute to refer to the batch size, but we know how to
524         # calculate the second dimension so the batch value will be deduced from it with help of "-1".
525         reshape_conf_node = create_op_node_with_second_input(graph, Reshape,
526                                                              int64_array([-1, (num_classes + 1) * max_proposals]),
527                                                              dict(name='do_reshape_conf'), activation_conf_node)
528
529         mark_as_correct_data_layout(reshape_conf_node)
530
531         # Workaround for TransposeForReshape pass.
532         # We looking for first not Reshape-typed node before match.single_input_node(0)[0].in_node(0).
533         # And add  reshape_loc node after this first not Reshape-typed node.
534         current_node = skip_nodes_by_condition(match.single_input_node(0)[0].in_node(0),
535                                                lambda x: x['kind'] == 'op' and x.has_and_set('reinterp_shape'))
536
537         reshape_loc_node = create_op_node_with_second_input(graph, Reshape, int64_array([-1, num_classes, 1, 4]),
538                                                             dict(name='reshape_loc'), current_node)
539         mark_as_correct_data_layout(reshape_loc_node)
540
541         # constant node with variances
542         variances_const_op = Const(graph, dict(value=_variance_from_pipeline_config(pipeline_config)))
543         variances_const_node = variances_const_op.create_node([])
544
545         # TF produces locations tensor without boxes for background.
546         # Inference Engine DetectionOutput layer requires background boxes so we generate them
547         loc_node = add_fake_background_loc(graph, reshape_loc_node)
548         PermuteAttrs.set_permutation(reshape_loc_node, loc_node, None)
549
550         # reshape locations tensor to 2D so it could be passed to Eltwise which will be converted to ScaleShift
551         reshape_loc_2d_node = create_op_node_with_second_input(graph, Reshape, int64_array([-1, 4]),
552                                                                dict(name='reshape_locs_2d'), loc_node)
553         mark_as_correct_data_layout(reshape_loc_2d_node)
554
555         # element-wise multiply locations with variances
556         eltwise_locs_op = Mul(graph, dict())
557         eltwise_locs_node = eltwise_locs_op.create_node([reshape_loc_2d_node, variances_const_node],
558                                                         dict(name='scale_locs'))
559
560         # IE DetectionOutput layer consumes flattened tensors so need add a Reshape layer.
561         # The batch value of the input tensor is not equal to the batch of the topology, so it is not possible to use
562         # "0" value in the Reshape layer attribute to refer to the batch size, but we know how to
563         # calculate the second dimension so the batch value will be deduced from it with help of "-1".
564         reshape_loc_do_op = Reshape(graph, dict(name='do_reshape_locs'))
565
566         custom_attributes = match.custom_replacement_desc.custom_attributes
567         coordinates_swap_method = 'add_convolution'
568         if 'coordinates_swap_method' not in custom_attributes:
569             log.error('The ObjectDetectionAPIDetectionOutputReplacement sub-graph replacement configuration file '
570                       'must contain "coordinates_swap_method" in the "custom_attributes" dictionary. Two values are '
571                       'supported: "swap_weights" and "add_convolution". The first one should be used when there is '
572                       'a MatMul or Conv2D node before the "SecondStagePostprocessor" block in the topology. With this '
573                       'solution the weights of the MatMul or Conv2D nodes are permutted, simulating the swap of XY '
574                       'coordinates in the tensor. The second could be used in any other cases but it is worse in terms '
575                       'of performance because it adds the Conv2D node which performs permutting of data. Since the '
576                       'attribute is not defined the second approach is used by default.')
577         else:
578             coordinates_swap_method = custom_attributes['coordinates_swap_method']
579         supported_swap_methods = ['swap_weights', 'add_convolution']
580         if coordinates_swap_method not in supported_swap_methods:
581             raise Error('Unsupported "coordinates_swap_method" defined in the sub-graph replacement configuration '
582                         'file. Supported methods are: {}'.format(', '.join(supported_swap_methods)))
583
584         if coordinates_swap_method == 'add_convolution':
585             swapped_locs_node = add_convolution_to_swap_xy_coordinates(graph, eltwise_locs_node, 4)
586             reshape_loc_do_node = reshape_loc_do_op.create_node([swapped_locs_node])
587         else:
588             reshape_loc_do_node = reshape_loc_do_op.create_node([eltwise_locs_node])
589
590         reshape_loc_do_dims = Const(graph, {'value': int64_array([-1, (num_classes + 1) * max_proposals * 4]),
591                                             'name': reshape_loc_do_node.name + '/Dim'}).create_node()
592         reshape_loc_do_dims.out_port(0).connect(reshape_loc_do_node.in_port(1))
593
594         mark_as_correct_data_layout(reshape_loc_do_node)
595
596         # find Proposal output which has the data layout as in TF: YXYX coordinates without batch indices.
597         proposal_nodes_ids = [node_id for node_id, attrs in graph.nodes(data=True)
598                               if 'name' in attrs and attrs['name'] == 'crop_proposals']
599         if len(proposal_nodes_ids) != 1:
600             raise Error("Found the following nodes '{}' with name 'crop_proposals' but there should be exactly 1. "
601                         "Looks like ObjectDetectionAPIProposalReplacement replacement didn't work.".
602                         format(proposal_nodes_ids))
603         proposal_node = Node(graph, proposal_nodes_ids[0])
604
605         # check whether it is necessary to permute proposals coordinates before passing them to the DetectionOutput
606         # currently this parameter is set for the RFCN topologies
607         if 'swap_proposals' in custom_attributes and custom_attributes['swap_proposals']:
608             proposal_node = add_convolution_to_swap_xy_coordinates(graph, proposal_node, 4)
609
610         # reshape priors boxes as Detection Output expects
611         reshape_priors_node = create_op_node_with_second_input(graph, Reshape, int64_array([-1, 1, max_proposals * 4]),
612                                                                dict(name='DetectionOutput_reshape_priors_'),
613                                                                proposal_node)
614         mark_as_correct_data_layout(reshape_priors_node)
615
616         detection_output_op = DetectionOutput(graph, {})
617         if coordinates_swap_method == 'swap_weights':
618             # update infer function to re-pack weights
619             detection_output_op.attrs['old_infer'] = detection_output_op.attrs['infer']
620             detection_output_op.attrs['infer'] = __class__.do_infer
621         for key in ('clip_before_nms', 'clip_after_nms'):
622             if key in match.custom_replacement_desc.custom_attributes:
623                 detection_output_op.attrs[key] = int(match.custom_replacement_desc.custom_attributes[key])
624
625         detection_output_node = detection_output_op.create_node(
626             [reshape_loc_do_node, reshape_conf_node, reshape_priors_node],
627             dict(name=detection_output_op.attrs['type'], share_location=0, variance_encoded_in_target=1,
628                  code_type='caffe.PriorBoxParameter.CENTER_SIZE', pad_mode='caffe.ResizeParameter.CONSTANT',
629                  resize_mode='caffe.ResizeParameter.WARP',
630                  num_classes=num_classes,
631                  confidence_threshold=_value_or_raise(match, pipeline_config, 'postprocessing_score_threshold'),
632                  top_k=_value_or_raise(match, pipeline_config, 'postprocessing_max_detections_per_class'),
633                  keep_top_k=_value_or_raise(match, pipeline_config, 'postprocessing_max_total_detections'),
634                  nms_threshold=_value_or_raise(match, pipeline_config, 'postprocessing_iou_threshold')))
635         # sets specific name to the node so we can find it in other replacers
636         detection_output_node.name = 'detection_output'
637
638         output_op = Result(graph, dict(name='do_OutputOp'))
639         output_op.create_node([detection_output_node])
640
641         print('The graph output nodes "num_detections", "detection_boxes", "detection_classes", "detection_scores" '
642               'have been replaced with a single layer of type "Detection Output". Refer to IR catalogue in the '
643               'documentation for information about this layer.')
644
645         return {'detection_output_node': detection_output_node}
646
647     @staticmethod
648     def do_infer(node):
649         node.old_infer(node)
650         # compared to the IE's DetectionOutput, the TF keeps the locations in YXYX, need to get back to the XYXY
651         # for last matmul/Conv2D that operate the locations need to swap the X and Y for output feature weights & biases
652         swap_weights_xy(backward_bfs_for_operation(node.in_node(0), ['MatMul', 'Conv2D']))
653
654
655 class ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement(FrontReplacementFromConfigFileSubGraph):
656     replacement_id = 'ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement'
657
658     def run_after(self):
659         return [ObjectDetectionAPIProposalReplacement]
660
661     def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
662         return {match.output_node(0)[0].id: new_sub_graph['roi_pooling_node'].id}
663
664     def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
665         argv = graph.graph['cmd_params']
666         if argv.tensorflow_object_detection_api_pipeline_config is None:
667             raise Error(missing_param_error)
668         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
669         roi_pool_size = _value_or_raise(match, pipeline_config, 'initial_crop_size')
670
671         detection_output_nodes_ids = [node_id for node_id, attrs in graph.nodes(data=True)
672                                       if 'name' in attrs and attrs['name'] == 'detection_output']
673         if len(detection_output_nodes_ids) != 1:
674             raise Error("Found the following nodes '{}' with 'detection_output' but there should be exactly 1.".
675                         format(detection_output_nodes_ids))
676         detection_output_node = Node(graph, detection_output_nodes_ids[0])
677
678         # add reshape of Detection Output so it can be an output of the topology
679         reshape_detection_output_2d_node = create_op_node_with_second_input(graph, Reshape, int64_array([-1, 7]),
680                                                                             dict(name='reshape_do_2d'),
681                                                                             detection_output_node)
682         mark_as_correct_data_layout(reshape_detection_output_2d_node)
683
684         # adds special node of type "Output" that is a marker for the output nodes of the topology
685         output_op = Result(graph, dict(name='do_reshaped_OutputOp'))
686         output_node = output_op.create_node([reshape_detection_output_2d_node])
687
688         # add attribute 'output_sort_order' so it will be used as a key to sort output nodes before generation of IR
689         output_node.in_edge()['data_attrs'].append('output_sort_order')
690         output_node.in_edge()['output_sort_order'] = [('detection_boxes', 0)]
691
692         # creates two Crop operations which get input from the DetectionOutput layer, cuts of slices of data with class
693         # ids and probabilities and produce a tensor with batch ids and bounding boxes only (as it is expected by the
694         # ROIPooling layer)
695         crop_batch_op = Crop(graph, dict(axis=int64_array([3]), offset=int64_array([0]), dim=int64_array([1]),
696                                          nchw_layout=True))
697         crop_batch_node = crop_batch_op.create_node([detection_output_node], dict(name='crop_do_batch_ids'))
698
699         crop_coordinates_op = Crop(graph, dict(axis=int64_array([3]), offset=int64_array([3]), dim=int64_array([4]),
700                                                nchw_layout=True))
701         crop_coordinates_node = crop_coordinates_op.create_node([detection_output_node], dict(name='crop_do_coords'))
702
703         concat_op = Concat(graph, dict(axis=3))
704         concat_node = concat_op.create_node([crop_batch_node, crop_coordinates_node], dict(name='batch_and_coords',
705                                                                                            nchw_layout=True))
706
707         # reshape bounding boxes as required by ROIPooling
708         reshape_do_node = create_op_node_with_second_input(graph, Reshape, int64_array([-1, 5]),
709                                                            dict(name='reshape_do'), concat_node)
710         mark_as_correct_data_layout(reshape_do_node)
711
712         roi_pooling_op = ROIPooling(graph, dict(method="bilinear", spatial_scale=1, pooled_h=roi_pool_size,
713                                                 pooled_w=roi_pool_size))
714         roi_pooling_node = roi_pooling_op.create_node([match.single_input_node(0)[0].in_node(), reshape_do_node],
715                                                       dict(name='ROI_pooling_2'))
716         return {'roi_pooling_node': roi_pooling_node}
717
718
719 class ObjectDetectionAPIMaskRCNNSigmoidReplacement(FrontReplacementFromConfigFileGeneral):
720     """
721     This replacer is used to convert Mask R-CNN topologies only.
722     Adds activation with sigmoid function to the end of the network producing masks tensors.
723     """
724     replacement_id = 'ObjectDetectionAPIMaskRCNNSigmoidReplacement'
725
726     def run_after(self):
727         return [ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement]
728
729     def transform_graph(self, graph: Graph, replacement_descriptions):
730         output_node = None
731         op_outputs = [n for n, d in graph.nodes(data=True) if 'op' in d and d['op'] == 'Result']
732         for op_output in op_outputs:
733             last_node = Node(graph, op_output).in_node(0)
734             if last_node.name.startswith('SecondStageBoxPredictor'):
735                 sigmoid_node = Sigmoid(graph, dict(name=last_node.id + '/sigmoid')).create_node([last_node])
736                 sigmoid_node.name = 'masks'
737
738                 if output_node is not None:
739                     raise Error('Identified two possible outputs from the topology. Cannot proceed.')
740                 # add special node of type "Output" that is a marker for the output nodes of the topology
741                 output_op = Result(graph, dict(name=sigmoid_node.name + '/OutputOp'))
742                 output_node = output_op.create_node([sigmoid_node])
743
744         print('The predicted masks are produced by the "masks" layer for each bounding box generated with a '
745               '"detection_output" layer.\n Refer to IR catalogue in the documentation for information '
746               'about the DetectionOutput layer and Inference Engine documentation about output data interpretation.\n'
747               'The topology can be inferred using dedicated demo "mask_rcnn_demo".')
748
749
750 class ObjectDetectionAPIProposalReplacement(FrontReplacementFromConfigFileSubGraph):
751     """
752     This class replaces sub-graph of operations with Proposal layer and additional layers transforming
753     tensors from layout of TensorFlow to layout required by Inference Engine.
754     Refer to comments inside the function for more information about performed actions.
755     """
756     replacement_id = 'ObjectDetectionAPIProposalReplacement'
757
758     def run_after(self):
759         return [ObjectDetectionAPIPreprocessorReplacement]
760
761     def run_before(self):
762         return [Sub, CropAndResizeReplacement, TransposeOrderNormalizer]
763
764     def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
765         return {match.output_node(0)[0].id: new_sub_graph['proposal_node'].id}
766
767     def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
768         new_list = match.matched_nodes_names().copy()
769         # do not remove nodes that produce box predictions and class predictions
770         new_list.remove(match.single_input_node(0)[0].id)
771         new_list.remove(match.single_input_node(1)[0].id)
772         return new_list
773
774     def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
775         argv = graph.graph['cmd_params']
776         if argv.tensorflow_object_detection_api_pipeline_config is None:
777             raise Error(missing_param_error)
778         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
779
780         max_proposals = _value_or_raise(match, pipeline_config, 'first_stage_max_proposals')
781         proposal_ratios = _value_or_raise(match, pipeline_config, 'anchor_generator_aspect_ratios')
782         proposal_scales = _value_or_raise(match, pipeline_config, 'anchor_generator_scales')
783         anchors_count = len(proposal_ratios) * len(proposal_scales)
784
785         # Convolution/matmul node that produces classes predictions
786         # Transpose result of the tensor with classes permissions so it will be in a correct layout for Softmax
787         predictions_node = backward_bfs_for_operation(match.single_input_node(1)[0], ['Add'])[0]
788
789         reshape_classes_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, anchors_count, 2, -1]),
790                                                                 dict(name='predictions/Reshape'))
791         predictions_node.insert_node_after(reshape_classes_node, 0)
792         mark_as_correct_data_layout(reshape_classes_node)
793
794         softmax_conf_op = Softmax(graph, dict(axis=2, nchw_layout=True, name=reshape_classes_node.id + '/Softmax'))
795         softmax_conf_node = softmax_conf_op.create_node([reshape_classes_node])
796
797         order_const = Const(graph, dict(value=int64_array([0, 2, 1, 3]),
798                                         name=softmax_conf_node.name + '/TransposeOrder')).create_node()
799         permute_reshape_softmax_op = Transpose(graph, dict())
800         permute_reshape_softmax_node = permute_reshape_softmax_op.create_node([softmax_conf_node, order_const], dict(
801             name=softmax_conf_node.name + '/Transpose'))
802         mark_input_as_in_correct_layout(permute_reshape_softmax_node, 1)
803         mark_output_as_in_correct_layout(permute_reshape_softmax_node, 0)
804
805         initial_shape_op = Shape(graph, dict(name=predictions_node.id + '/Shape'))
806         initial_shape_node = initial_shape_op.create_node([predictions_node])
807
808         reshape_permute_op = Reshape(graph, dict(name='Reshape_Transpose_Class'))
809         reshape_permute_node = reshape_permute_op.create_node([permute_reshape_softmax_node, initial_shape_node])
810         mark_input_as_in_correct_layout(reshape_permute_node, 0)
811         mark_output_as_in_correct_layout(reshape_permute_node, 0)
812
813         variance_height = pipeline_config.get_param('frcnn_variance_height')
814         variance_width = pipeline_config.get_param('frcnn_variance_width')
815         variance_x = pipeline_config.get_param('frcnn_variance_x')
816         variance_y = pipeline_config.get_param('frcnn_variance_y')
817         anchor_generator_height_stride = pipeline_config.get_param('anchor_generator_height_stride')
818         anchor_generator_width_stride = pipeline_config.get_param('anchor_generator_width_stride')
819         anchor_generator_height = pipeline_config.get_param('anchor_generator_height')
820         anchor_generator_width = pipeline_config.get_param('anchor_generator_width')
821
822         if variance_height != variance_width:
823             log.error('The values for variance for height "{}" is not equal to variance for width "{}". The detection '
824                       'results will be inaccurate.'.format(variance_height, variance_width))
825         if variance_x != variance_y:
826             log.error('The values for variance for x "{}" is not equal to variance for y "{}". The detection '
827                       'results will be inaccurate.'.format(variance_x, variance_y))
828         if anchor_generator_height_stride != anchor_generator_width_stride:
829             log.error('The values for the anchor generator height stride "{}" is not equal to the anchor generator '
830                       'width stride "{}". The detection results will be inaccurate.'.format(
831                 anchor_generator_height_stride, anchor_generator_width_stride))
832         if anchor_generator_height != anchor_generator_width:
833             log.error('The values for the anchor generator height "{}" is not equal to the anchor generator width '
834                       'stride "{}". The detection results will be inaccurate.'.format(anchor_generator_height,
835                                                                                       anchor_generator_width))
836
837         proposal_op = ProposalOp(graph, dict(min_size=1,
838                                              framework='tensorflow',
839                                              pre_nms_topn=2 ** 31 - 1,
840                                              box_size_scale=variance_height,
841                                              box_coordinate_scale=variance_x,
842                                              post_nms_topn=max_proposals,
843                                              feat_stride=anchor_generator_height_stride,
844                                              ratio=proposal_ratios,
845                                              scale=proposal_scales,
846                                              normalize=1,
847                                              base_size=anchor_generator_height,
848                                              nms_thresh=_value_or_raise(match, pipeline_config,
849                                                                         'first_stage_nms_iou_threshold')))
850         for key in ('clip_before_nms', 'clip_after_nms'):
851             if key in match.custom_replacement_desc.custom_attributes:
852                 proposal_op.attrs[key] = int(match.custom_replacement_desc.custom_attributes[key])
853
854         anchors_node = backward_bfs_for_operation(match.single_input_node(0)[0], ['Add'])[0]
855
856         # creates input to store input image height, width and scales (usually 1.0s)
857         # the batch size for this input is fixed because it is allowed to pass images of the same size only as input
858         input_op_with_image_size = Parameter(graph, dict(shape=int64_array([1, 3]), fixed_batch=True))
859         input_with_image_size_node = input_op_with_image_size.create_node([], dict(name='image_info'))
860
861         proposal_node = proposal_op.create_node([reshape_permute_node, anchors_node, input_with_image_size_node],
862                                                 dict(name='proposals'))
863
864         if 'do_not_swap_proposals' in match.custom_replacement_desc.custom_attributes and \
865                 match.custom_replacement_desc.custom_attributes['do_not_swap_proposals']:
866             swapped_proposals_node = proposal_node
867         else:
868             swapped_proposals_node = add_convolution_to_swap_xy_coordinates(graph, proposal_node, 5)
869
870         proposal_reshape_2d_node = create_op_node_with_second_input(graph, Reshape, int64_array([-1, 5]),
871                                                                     dict(name="reshape_swap_proposals_2d"),
872                                                                     swapped_proposals_node)
873         mark_input_as_in_correct_layout(proposal_reshape_2d_node, 0)
874
875         # feed the CropAndResize node with a correct boxes information produced with the Proposal layer
876         # find the first CropAndResize node in the BFS order
877         crop_and_resize_nodes_ids = [node_id for node_id in bfs_search(graph, [match.single_input_node(0)[0].id]) if
878                                      graph.node[node_id]['op'] == 'CropAndResize']
879         assert len(crop_and_resize_nodes_ids) != 0, "Didn't find any CropAndResize nodes in the graph."
880         if 'do_not_swap_proposals' not in match.custom_replacement_desc.custom_attributes or not \
881                 match.custom_replacement_desc.custom_attributes['do_not_swap_proposals']:
882             crop_and_resize_node = Node(graph, crop_and_resize_nodes_ids[0])
883             # set a marker that the input with box coordinates has been pre-processed so the CropAndResizeReplacement
884             # transform doesn't try to merge the second and the third inputs
885             crop_and_resize_node['inputs_preprocessed'] = True
886             graph.remove_edge(crop_and_resize_node.in_node(1).id, crop_and_resize_node.id)
887             graph.create_edge(proposal_reshape_2d_node, crop_and_resize_node, out_port=0, in_port=1)
888
889         tf_proposal_reshape_4d_node = create_op_node_with_second_input(graph, Reshape,
890                                                                        int64_array([-1, 1, max_proposals, 5]),
891                                                                        dict(name="reshape_proposal_4d"),
892                                                                        swapped_proposals_node)
893
894         crop_op = Crop(graph, dict(axis=int64_array([3]), offset=int64_array([1]), dim=int64_array([4]),
895                                    nchw_layout=True))
896         crop_node = crop_op.create_node([tf_proposal_reshape_4d_node], dict(name='crop_proposals'))
897
898         mark_as_correct_data_layout(tf_proposal_reshape_4d_node)
899
900         tf_proposals_crop_reshape_3d_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1, 4]),
901                                                                              dict(name="reshape_crop_3d"), crop_node)
902         mark_input_as_in_correct_layout(tf_proposals_crop_reshape_3d_node, 0)
903         return {'proposal_node': tf_proposals_crop_reshape_3d_node}
904
905
906 class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFileSubGraph):
907     replacement_id = 'ObjectDetectionAPISSDPostprocessorReplacement'
908
909     def run_after(self):
910         return [ObjectDetectionAPIPreprocessorReplacement]
911
912     def run_before(self):
913         # the replacer uses node of type "RealDiv" as one of the start points, but Model Optimizer replaces nodes of
914         # type "RealDiv" with a new ones, so it is necessary to replace the sub-graph before replacing the "RealDiv"
915         # nodes
916         return [Div, StandaloneConstEraser, TransposeOrderNormalizer]
917
918     def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
919         # the DetectionOutput in IE produces single tensor, but in TF it produces two tensors, so create only one output
920         # edge match
921         return {match.output_node(0)[0].id: new_sub_graph['detection_output_node'].id}
922
923     def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
924         argv = graph.graph['cmd_params']
925         if argv.tensorflow_object_detection_api_pipeline_config is None:
926             raise Error(missing_param_error)
927         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
928         num_classes = _value_or_raise(match, pipeline_config, 'num_classes')
929
930         # reshapes confidences to 4D before applying activation function and do not convert from NHWC to NCHW this node
931         expand_dims_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, 1, -1, num_classes + 1]),
932                                                             {'name': 'do_ExpandDims_conf'})
933         expand_dims_node.in_port(0).connect(match.input_nodes(1)[0][0].in_node(0).out_port(0))
934
935         mark_as_correct_data_layout(expand_dims_node)
936
937         activation_function = _value_or_raise(match, pipeline_config, 'postprocessing_score_converter')
938         activation_conf_node = add_activation_function_after_node(graph, expand_dims_node, activation_function)
939
940         # IE DetectionOutput layer consumes flattened tensors
941         # reshape operation to flatten locations tensor
942         reshape_loc_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
943                                                             {'name': 'do_reshape_loc'})
944
945         current_node = skip_nodes_by_condition(match.input_nodes(0)[0][0].in_node(0),
946                                                lambda x: x.op == 'Identity' or x.has_and_set('reinterp_shape'))
947         reshape_loc_node.in_port(0).connect(current_node.out_port(0))
948         mark_as_correct_data_layout(reshape_loc_node)
949
950         # IE DetectionOutput layer consumes flattened tensors
951         # reshape operation to flatten confidence tensor
952         reshape_conf_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
953                                                              {'name': 'do_reshape_conf'}, activation_conf_node)
954         mark_as_correct_data_layout(reshape_conf_node)
955
956         custom_attributes = match.custom_replacement_desc.custom_attributes
957         if ('disable_prior_boxes_layers_generator' not in custom_attributes or
958             not custom_attributes['disable_prior_boxes_layers_generator']) and \
959             (pipeline_config.get_param('ssd_anchor_generator_num_layers') is not None or
960                 pipeline_config.get_param('multiscale_anchor_generator_min_level') is not None):
961             # change the Reshape operations with hardcoded number of output elements of the convolution nodes to be
962             # reshapable
963             _relax_reshape_nodes(graph, pipeline_config)
964
965             # create PriorBoxClustered nodes instead of a constant value with prior boxes so the model could be reshaped
966             if pipeline_config.get_param('ssd_anchor_generator_num_layers') is not None:
967                 priors_node = _create_prior_boxes_node(graph, pipeline_config)
968             elif pipeline_config.get_param('multiscale_anchor_generator_min_level') is not None:
969                 priors_node = _create_multiscale_prior_boxes_node(graph, pipeline_config)
970         else:
971             log.info('The anchor generator is not known. Save constant with prior-boxes to IR.')
972             priors_node = match.input_nodes(2)[0][0].in_node(0)
973
974         # creates DetectionOutput Node object from Op class
975         detection_output_op = DetectionOutput(graph, match.custom_replacement_desc.custom_attributes)
976         detection_output_op.attrs['old_infer'] = detection_output_op.attrs['infer']
977         detection_output_op.attrs['infer'] = __class__.do_infer
978         detection_output_node = detection_output_op.create_node(
979             [reshape_loc_node, reshape_conf_node, priors_node],
980             dict(name=detection_output_op.attrs['type'],
981                  clip=1,
982                  confidence_threshold=_value_or_raise(match, pipeline_config, 'postprocessing_score_threshold'),
983                  top_k=_value_or_raise(match, pipeline_config, 'postprocessing_max_detections_per_class'),
984                  keep_top_k=_value_or_raise(match, pipeline_config, 'postprocessing_max_total_detections'),
985                  nms_threshold=_value_or_raise(match, pipeline_config, 'postprocessing_iou_threshold')))
986
987         return {'detection_output_node': detection_output_node}
988
989     @staticmethod
990     def do_infer(node: Node):
991         prior_boxes = node.in_node(2).value
992         if prior_boxes is not None:
993             argv = node.graph.graph['cmd_params']
994             if argv.tensorflow_object_detection_api_pipeline_config is None:
995                 raise Error(missing_param_error)
996             pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
997             variance = _variance_from_pipeline_config(pipeline_config)
998             # replicating the variance values for all prior-boxes
999             variances = np.tile(variance, [prior_boxes.shape[-2], 1])
1000             # DetectionOutput Inference Engine expects the prior-boxes in the following layout: (values, variances)
1001             prior_boxes = prior_boxes.reshape([-1, 4])
1002             prior_boxes = np.concatenate((prior_boxes, variances), 0)
1003             # compared to the IE's DetectionOutput, the TF keeps the prior-boxes in YXYX, need to get back to the XYXY
1004             prior_boxes = np.concatenate((prior_boxes[:, 1:2], prior_boxes[:, 0:1],
1005                                           prior_boxes[:, 3:4], prior_boxes[:, 2:3]), 1)
1006             #  adding another dimensions, as the prior-boxes are expected as 3d tensors
1007             prior_boxes = prior_boxes.reshape((1, 2, -1))
1008             node.in_node(2).shape = int64_array(prior_boxes.shape)
1009             node.in_node(2).value = prior_boxes
1010
1011         node.old_infer(node)
1012         # compared to the IE's DetectionOutput, the TF keeps the locations in YXYX, need to get back to the XYXY
1013         # for last convolutions that operate the locations need to swap the X and Y for output feature weights & biases
1014         conv_nodes = backward_bfs_for_operation(node.in_node(0), ['Conv2D'])
1015         swap_weights_xy(conv_nodes)
1016         squeeze_reshape_and_concat(conv_nodes)
1017
1018         for node_name in node.graph.nodes():
1019             node = Node(node.graph, node_name)
1020             if node.has_and_set('swap_xy_count') and len(node.out_nodes()) != node['swap_xy_count']:
1021                 raise Error('The weights were swapped for node "{}", but this weight was used in other nodes.'.format(
1022                     node.name))
1023
1024
1025 class ObjectDetectionAPIOutputReplacement(FrontReplacementFromConfigFileGeneral):
1026     """
1027     This replacer is used to cut-off the network by specified nodes for models generated with Object Detection API.
1028     The custom attribute for the replacer contains one value for key "outputs". This string is a comma separated list
1029     of outputs alternatives. Each output alternative is a '|' separated list of node name which could be outputs. The
1030     first node from each alternative that exits in the graph is chosen. Others are ignored.
1031     For example, if the "outputs" is equal to the following string:
1032
1033         "Reshape_16,SecondStageBoxPredictor_1/Conv_3/BiasAdd|SecondStageBoxPredictor_1/Conv_1/BiasAdd"
1034
1035     then the "Reshape_16" will be an output if it exists in the graph. The second output will be
1036     SecondStageBoxPredictor_1/Conv_3/BiasAdd if it exist in the graph, if not then
1037     SecondStageBoxPredictor_1/Conv_1/BiasAdd will be output if it exists in the graph.
1038     """
1039     replacement_id = 'ObjectDetectionAPIOutputReplacement'
1040
1041     def run_before(self):
1042         return [ObjectDetectionAPIPreprocessorReplacement, TransposeOrderNormalizer]
1043
1044     def transform_graph(self, graph: Graph, replacement_descriptions: dict):
1045         if graph.graph['cmd_params'].output is not None:
1046             log.warning('User defined output nodes are specified. Skip the graph cut-off by the '
1047                         'ObjectDetectionAPIOutputReplacement.')
1048             return
1049         outputs = []
1050         outputs_string = replacement_descriptions['outputs']
1051         for alternatives in outputs_string.split(','):
1052             for out_node_name in alternatives.split('|'):
1053                 if graph.has_node(out_node_name):
1054                     outputs.append(out_node_name)
1055                     break
1056                 else:
1057                     log.debug('A node "{}" does not exist in the graph. Do not add it as output'.format(out_node_name))
1058         _outputs = output_user_data_repack(graph, outputs)
1059         add_output_ops(graph, _outputs, graph.graph['inputs'])
1060
1061
1062 class ObjectDetectionAPIPSROIPoolingReplacement(FrontReplacementFromConfigFileSubGraph):
1063     replacement_id = 'ObjectDetectionAPIPSROIPoolingReplacement'
1064
1065     def run_after(self):
1066         return [ObjectDetectionAPIProposalReplacement, TransposeOrderNormalizer]
1067
1068     def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
1069         return {match.output_node(0)[0].id: new_sub_graph['output_node'].id}
1070
1071     def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
1072         argv = graph.graph['cmd_params']
1073         if argv.tensorflow_object_detection_api_pipeline_config is None:
1074             raise Error(missing_param_error)
1075         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
1076         num_classes = _value_or_raise(match, pipeline_config, 'num_classes')
1077
1078         input_node = match.input_nodes(0)[0][0].in_node(0)
1079         if 'class_predictions' in input_node.id:
1080             psroipooling_output_dim = num_classes + 1
1081         else:
1082             psroipooling_output_dim = num_classes * 4
1083
1084         num_spatial_bins_height = pipeline_config.get_param('num_spatial_bins_height')
1085         num_spatial_bins_width = pipeline_config.get_param('num_spatial_bins_width')
1086         crop_height = pipeline_config.get_param('crop_height')
1087         crop_width = pipeline_config.get_param('crop_width')
1088         if crop_height != crop_width:
1089             raise Error('Different "crop_height" and "crop_width" parameters from the pipeline config are not '
1090                         'supported: {} vs {}'.format(crop_height, crop_width))
1091
1092         if 'reshape_swap_proposals_2d' in graph.nodes():
1093             reshape_swap_proposals_node = Node(graph, 'reshape_swap_proposals_2d')
1094         else:
1095             swap_proposals_node = add_convolution_to_swap_xy_coordinates(graph, Node(graph, 'proposals'), 5)
1096             reshape_swap_proposals_node = create_op_node_with_second_input(graph, Reshape, int64_array([-1, 5]),
1097                                                                            {'name': 'reshape_swap_proposals_2d'},
1098                                                                            swap_proposals_node)
1099             mark_input_as_in_correct_layout(reshape_swap_proposals_node, 0)
1100
1101         psroipooling_node = PSROIPoolingOp(graph, {'name': input_node.soft_get('name') + '/PSROIPooling',
1102                                                    'output_dim': psroipooling_output_dim,
1103                                                    'group_size': crop_width / num_spatial_bins_width,
1104                                                    'spatial_bins_x': num_spatial_bins_width,
1105                                                    'spatial_bins_y': num_spatial_bins_height,
1106                                                    'mode': 'bilinear',
1107                                                    'spatial_scale': 1,
1108                                                    }).create_node([input_node, reshape_swap_proposals_node])
1109
1110         reduce_node = create_op_node_with_second_input(graph, ReduceMean, int64_array([1, 2]),
1111                                                        {'name': 'mean', 'keep_dims': True}, psroipooling_node)
1112
1113         output_node = match.output_node(0)[0].out_node()
1114         if len(output_node.in_ports()) == 2 and not output_node.in_port(1).disconnected():
1115             output_node.in_port(1).disconnect()  # disconnect the second input to make "erase_node" function work
1116         graph.erase_node(match.output_node(0)[0].out_node())
1117
1118         return {'output_node': reduce_node}
1119
1120
1121 class ObjectDetectionAPIConstValueOverride(FrontReplacementFromConfigFileGeneral):
1122     """
1123     Transforms allows to override specific constant values in the topology. The replacement description configuration
1124     file contains list of tuples describing the desired replacements specified in the "replacements" key of the
1125     "custom_attributes". The first element in the tuple is the initial node name of the graph with constant value. The
1126     second element is the name of the parameter from the pipeline configuration file which stores new value.
1127
1128     Usage example. The Faster-RCNNs topologies has constant node with the number specifying maximum generated proposals.
1129     This value is specified in the pipeline configuration file in the parameter 'first_stage_max_proposals' and is
1130     saved as a constant node in the generated topology. If the parameter is modified from it's original value then the
1131     topology will be incorrect because the number 'first_stage_max_proposals' is used in the transforms of this file is
1132     no more equal to the 'first_stage_max_proposals' saved as a constant.
1133     """
1134     replacement_id = 'ObjectDetectionAPIConstValueOverride'
1135
1136     def run_before(self):
1137         return [ObjectDetectionAPIPreprocessorReplacement, TransposeOrderNormalizer]
1138
1139     def transform_graph(self, graph: Graph, replacement_descriptions: dict):
1140         argv = graph.graph['cmd_params']
1141         if argv.tensorflow_object_detection_api_pipeline_config is None:
1142             raise Error(missing_param_error)
1143         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
1144         for (node_id, pipeline_config_name) in replacement_descriptions['replacements']:
1145             if node_id not in graph.nodes():
1146                 log.debug('Node with id {} does not exist in the graph'.format(node_id))
1147                 continue
1148             node = Node(graph, node_id)
1149             if not node.has_valid('value'):
1150                 log.debug('Node with id {} does not have value'.format(node_id))
1151                 continue
1152             node.value = np.array(pipeline_config.get_param(pipeline_config_name))
1153             node.value = node.value.reshape(node.shape)