Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / ObjectDetectionAPI.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18 from math import sqrt
19
20 import numpy as np
21
22 from extensions.front.Pack import Pack
23 from extensions.front.div import Div
24 from extensions.front.standalone_const_eraser import StandaloneConstEraser
25 from extensions.front.sub import Sub
26 from extensions.front.tf.CropAndResizeReplacement import CropAndResizeReplacement
27 from extensions.front.tf.Unpack import Unpack
28 from extensions.ops.DetectionOutput import DetectionOutput
29 from extensions.ops.priorbox_clustered import PriorBoxClusteredOp
30 from extensions.ops.proposal import ProposalOp
31 from extensions.ops.psroipooling import PSROIPoolingOp
32 from mo.front.common.layout import get_batch_dim, get_height_dim, get_width_dim
33 from mo.front.common.partial_infer.utils import int64_array
34 from mo.front.common.weights import swap_weights_xy
35 from mo.front.extractor import output_user_data_repack, add_output_ops, update_attrs
36 from mo.front.subgraph_matcher import SubgraphMatch
37 from mo.front.tf.graph_utils import add_activation_function_after_node, add_convolution_to_swap_xy_coordinates, \
38     squeeze_reshape_and_concat, add_fake_background_loc
39 from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph, FrontReplacementFromConfigFileGeneral
40 from mo.graph.graph import Graph, Node
41 from mo.ops.activation import Activation
42 from mo.ops.concat import Concat
43 from mo.ops.const import Const
44 from mo.ops.crop import Crop
45 from mo.ops.eltwise import Eltwise
46 from mo.ops.input import Input
47 from mo.ops.op import PermuteAttrs
48 from mo.ops.output import Output
49 from mo.ops.permute import Permute
50 from mo.ops.reduce import Reduce
51 from mo.ops.reshape import Reshape
52 from mo.ops.roipooling import ROIPooling
53 from mo.ops.shape import Shape
54 from mo.ops.softmax import Softmax
55 from mo.utils.error import Error
56 from mo.utils.graph import backward_bfs_for_operation, bfs_search
57 from mo.utils.pipeline_config import PipelineConfig
58
59 missing_param_error = 'To convert the model specify path to the pipeline configuration file which was used to ' \
60                       'generate the model. Please use "--tensorflow_object_detection_api_pipeline_config" option:\n' \
61                       '--tensorflow_object_detection_api_pipeline_config "<path_to_pipeline.config>"\nIf you have ' \
62                       'downloaded the model file from the Object Detection Model zoo repository then this file is ' \
63                       'located in the archive with frozen model and called "pipeline.config".\nIf you did not use ' \
64                       'this command line parameter before that means that you are using currently deprecated ' \
65                       'TensorFlow* Object Detection API models conversion mechanism.'
66
67
68 def _value_or_raise(match: SubgraphMatch, pipeline_config: PipelineConfig, key: str):
69     """
70     Returns value from the 'custom_attributes' of the 'match' object or pipeline_config associated with a key 'key'.
71     If the value doesn't exist then raise error.
72     :param match: SubgraphMatch object containing 'custom_attributes'.
73     :param pipeline_config: PipelineConfig object with parsed values.
74     :param key: key to search for.
75     :return: the requested value.
76     """
77     if match and key in match.custom_replacement_desc.custom_attributes:
78         return match.custom_replacement_desc.custom_attributes[key]
79     value = pipeline_config.get_param(key)
80     if value is None:
81         raise Error('The sub-graph replacer "[REPLACEMENT_ID]" was not able to find the value for key "{}" in the '
82                     'pipeline configuration file specified with the --tensorflow_object_detection_api_pipeline_config '
83                     'command line parameter. Update the sub-graph replacement configuration file specified with the '
84                     '--tensorflow_use_custom_operations_config command line parameter by adding key "{}" with required '
85                     'value to the "custom_attributes" dictionary of the "[REPLACEMENT_ID]" replacer.'.format(key, key))
86     return value
87
88
89 def _find_ssd_head_node(graph: Graph, ssd_head_index: int, head_type: str):
90     """
91     Finds the SSD head node with index 'ssd_head_index' in the topology. The parameter 'head_type' specifies what type
92     of the head is requested: with box predictions or class predictions.
93     :param graph: graph with the topology.
94     :param ssd_head_index: index of the SSD head.
95     :param head_type: either 'box' or 'class' string specifying type of the SSD head node.
96     :return: the requested Node or None if node is not found.
97     """
98     if head_type == 'box':
99         possible_node_names = ['BoxPredictor_%d/BoxEncodingPredictor/BiasAdd' % ssd_head_index,
100                                'WeightSharedConvolutionalBoxPredictor/BoxPredictor/BiasAdd' if ssd_head_index == 0 else
101                                'WeightSharedConvolutionalBoxPredictor_%d/BoxPredictor/BiasAdd' % ssd_head_index]
102     elif head_type == 'class':
103         possible_node_names = ['BoxPredictor_%d/ClassPredictor/BiasAdd' % ssd_head_index,
104                                'WeightSharedConvolutionalBoxPredictor/ClassPredictor/BiasAdd' if ssd_head_index == 0
105                                else 'WeightSharedConvolutionalBoxPredictor_%d/ClassPredictor/BiasAdd' % ssd_head_index]
106     else:
107         raise Error('SSD heads can be of type "box" and "class" only.')
108
109     head_node = None
110     for head_node_name in possible_node_names:
111         if head_node_name in graph.nodes():
112             assert (head_node is None)  # only one of the possible node names should exist in the graph
113             head_node = Node(graph, head_node_name)
114     return head_node
115
116
117 def _variance_from_pipeline_config(pipeline_config: PipelineConfig):
118     """
119     Generates a numpy array with variances values from the pipeline_config object. The order of the elements is the
120     following: variance x, variance y, variance box width, variance box height.
121     :param pipeline_config: pipeline_config object to get variances from.
122     :return: the numpy array with variances.
123     """
124     return 1.0 / np.array([pipeline_config.get_param('frcnn_variance_x'),
125                            pipeline_config.get_param('frcnn_variance_y'),
126                            pipeline_config.get_param('frcnn_variance_width'),
127                            pipeline_config.get_param('frcnn_variance_height')])
128
129
130 def _skip_node_of_type(node: Node, node_ops_to_skip: list):
131     """
132     Skips nodes of specified ops starting from node 'node'.
133     :param node: node to start skipping Identity nodes.
134     :return: node of the op
135     """
136     # skip the Identity node
137     while len(node.out_edges()) == 1 and node.op in node_ops_to_skip:
138         node = node.out_node()
139     return node
140
141
142 def _relax_reshape_nodes(graph: Graph, pipeline_config: PipelineConfig):
143     """
144     Finds the 'Reshape' operations following the SSD head nodes which have hard-coded output dimensions and replaces
145     them with new ones with one of the dimensions sizes equal to -1. This function is used to make TF OD API SSD models
146     reshapable.
147     :param graph: graph with the topology.
148     :param pipeline_config: PipelineConfig object with parsed values.
149     :return: None
150     """
151     num_classes = pipeline_config.get_param('num_classes')
152     num_layers = pipeline_config.get_param('ssd_anchor_generator_num_layers')
153     if num_layers is None:
154         num_layers = pipeline_config.get_param('multiscale_anchor_generator_max_level') - \
155                      pipeline_config.get_param('multiscale_anchor_generator_min_level') + 1
156     for ssd_head_ind in range(num_layers):
157         # fix hard-coded value for the number of items in tensor produced by the convolution to make topology reshapable
158         input_node = _find_ssd_head_node(graph, ssd_head_ind, 'box')
159         assert (input_node is not None)
160         old_reshape_node = _skip_node_of_type(input_node.out_node(), ['Identity'])
161         assert (old_reshape_node.op == 'Reshape')
162         reshape_size_node = Const(graph, {'value': int64_array([0, -1, 1, 4])}).create_node([])
163         new_reshape_op = Reshape(graph, {'name': input_node.id + '/Reshape', 'correct_data_layout': True})
164         new_reshape_node = new_reshape_op.create_node([input_node, reshape_size_node])
165         old_reshape_node.replace_node(new_reshape_node)
166
167         # fix hard-coded value for the number of items in tensor produced by the convolution to make topology reshapable
168         input_node = _find_ssd_head_node(graph, ssd_head_ind, 'class')
169         assert (input_node is not None)
170         old_reshape_node = _skip_node_of_type(input_node.out_node(), ['Identity'])
171         assert (old_reshape_node.op == 'Reshape')
172         reshape_size_node_2 = Const(graph, {'value': int64_array([0, -1, num_classes + 1])}).create_node([])
173         new_reshape_op_2 = Reshape(graph, {'name': input_node.id + '/Reshape', 'correct_data_layout': True})
174         new_reshape_node_2 = new_reshape_op_2.create_node([input_node, reshape_size_node_2])
175         old_reshape_node.replace_node(new_reshape_node_2)
176
177
178 def _create_prior_boxes_node(graph: Graph, pipeline_config: PipelineConfig):
179     """
180     The function creates one or several PriorBoxClustered nodes based on information from the pipeline configuration
181     files. The PriorBoxClustered nodes get input data from SSD 'heads' and from the placeholder node (just to get
182     input image size).
183     :param graph: graph with the topology.
184     :param pipeline_config: PipelineConfig object with parsed values.
185     :return: node generating prior boxes.
186     """
187     min_scale = pipeline_config.get_param('ssd_anchor_generator_min_scale')
188     max_scale = pipeline_config.get_param('ssd_anchor_generator_max_scale')
189     num_layers = pipeline_config.get_param('ssd_anchor_generator_num_layers')
190     aspect_ratios = pipeline_config.get_param('ssd_anchor_generator_aspect_ratios')
191     # prior boxes have to be generated using the image size used for training
192     image_height = pipeline_config.get_param('resizer_image_height')
193     image_width = pipeline_config.get_param('resizer_image_width')
194     min_im_shape = min(image_height, image_width)
195     _base_anchor_height = pipeline_config.get_param('ssd_anchor_generator_base_anchor_height')
196     _base_anchor_width = pipeline_config.get_param('ssd_anchor_generator_base_anchor_width')
197     base_anchor_size = [min_im_shape / image_height * _base_anchor_height,
198                         min_im_shape / image_width * _base_anchor_width]
199     reduce_boxes_in_lowest_layer = True
200     if pipeline_config.get_param('ssd_anchor_generator_reduce_lowest') is not None:
201         reduce_boxes_in_lowest_layer = pipeline_config.get_param('ssd_anchor_generator_reduce_lowest')
202
203     scales = [min_scale + (max_scale - min_scale) * i / (num_layers - 1) for i in range(num_layers)] + [1.0]
204     prior_box_nodes = []
205     for ssd_head_ind in range(num_layers):
206         ssd_head_node = _find_ssd_head_node(graph, ssd_head_ind, 'box')
207         assert (ssd_head_node is not None)
208
209         if ssd_head_ind == 0 and reduce_boxes_in_lowest_layer:
210             widths = [0.1, min_scale * sqrt(2.0), min_scale * sqrt(0.5)]
211             heights = [0.1, min_scale / sqrt(2.0), min_scale / sqrt(0.5)]
212         else:
213             widths = [scales[ssd_head_ind] * sqrt(ar) for ar in aspect_ratios]
214             heights = [scales[ssd_head_ind] / sqrt(ar) for ar in aspect_ratios]
215
216             widths += [sqrt(scales[ssd_head_ind] * scales[ssd_head_ind + 1])]
217             heights += [sqrt(scales[ssd_head_ind] * scales[ssd_head_ind + 1])]
218         widths = [w * image_width * base_anchor_size[1] for w in widths]
219         heights = [h * image_height * base_anchor_size[0] for h in heights]
220
221         variance = _variance_from_pipeline_config(pipeline_config)
222         prior_box_op = PriorBoxClusteredOp(graph, {'width': np.array(widths), 'height': np.array(heights),
223                                                    'clip': 0, 'flip': 0, 'variance': variance, 'offset': 0.5,
224                                                    })
225         # connect the PriorBoxClustered node with the "Cast" node of the Placeholder node because the pass that removes
226         # Cast operations is executed in the middle phase and it will fail when there are several consumers of the
227         # Placeholder
228         prior_box_node = prior_box_op.create_node([ssd_head_node, Node(graph, 'image_tensor').out_node(0)],
229                                                   {'name': 'PriorBoxClustered_{}'.format(ssd_head_ind)})
230         prior_box_nodes.append(prior_box_node)
231     if len(prior_box_nodes) == 1:
232         return prior_box_nodes[0]
233     else:
234         concat_prior_boxes_op = Concat(graph, {'axis': -1, 'in_ports_count': len(prior_box_nodes)})
235         return concat_prior_boxes_op.create_node(prior_box_nodes, {'name': 'ConcatPriorBoxesClustered'})
236
237
238 def _create_multiscale_prior_boxes_node(graph: Graph, pipeline_config: PipelineConfig):
239     """
240     The function creates one or several PriorBoxClustered nodes based on information from the pipeline configuration
241     files. The PriorBoxClustered nodes get input data from SSD 'heads' and from the placeholder node (just to get
242     input image size).
243     :param graph: graph with the topology.
244     :param pipeline_config: PipelineConfig object with parsed values.
245     :return: node generating prior boxes.
246     """
247     min_level = pipeline_config.get_param('multiscale_anchor_generator_min_level')
248     max_level = pipeline_config.get_param('multiscale_anchor_generator_max_level')
249     anchor_scale = pipeline_config.get_param('multiscale_anchor_generator_anchor_scale')
250     aspect_ratios = pipeline_config.get_param('multiscale_anchor_generator_aspect_ratios')
251     scales_per_octave = pipeline_config.get_param('multiscale_anchor_generator_scales_per_octave')
252
253     prior_box_nodes = []
254     scales = [2 ** (float(scale) / scales_per_octave) for scale in range(scales_per_octave)]
255     for level in range(min_level, max_level + 1):
256         base_anchor_size = 2 ** level * anchor_scale
257
258         ssd_head_ind = level - min_level
259         ssd_head_node = _find_ssd_head_node(graph, ssd_head_ind, 'box')
260         assert (ssd_head_node is not None)
261
262         widths = [base_anchor_size * scale * sqrt(ar) for ar in aspect_ratios for scale in scales]
263         heights = [base_anchor_size * scale / sqrt(ar) for ar in aspect_ratios for scale in scales]
264
265         variance = _variance_from_pipeline_config(pipeline_config)
266         prior_box_op = PriorBoxClusteredOp(graph, {'width': np.array(widths), 'height': np.array(heights),
267                                                    'clip': 0, 'flip': 0, 'variance': variance,
268                                                    'offset': 0.5,
269                                                    })
270         # connect the PriorBoxClustered node with the "Cast" node of the Placeholder node because the pass that removes
271         # Cast operations is executed in the middle phase and it will fail when there are several consumers of the
272         # Placeholder
273         prior_box_node = prior_box_op.create_node([ssd_head_node, Node(graph, 'image_tensor').out_node(0)],
274                                                   {'name': 'PriorBoxClustered_{}'.format(ssd_head_ind)})
275         prior_box_nodes.append(prior_box_node)
276     if len(prior_box_nodes) == 1:
277         return prior_box_nodes[0]
278     else:
279         concat_prior_boxes_op = Concat(graph, {'axis': -1, 'in_ports_count': len(prior_box_nodes)})
280         return concat_prior_boxes_op.create_node(prior_box_nodes, {'name': 'ConcatPriorBoxesClustered'})
281
282
283 def calculate_shape_keeping_aspect_ratio(height: int, width: int, min_size: int, max_size: int):
284     """
285     The function scales spatial sizes of the image keeping aspect ratio to satisfy provided requirements.
286     The behavior of this function is equivalent to the output shape calculation of the Preprocessor block of TensorFlow
287     Object Detection API models with keep aspect ratio resizer.
288     :param height: input height.
289     :param width: input width.
290     :param min_size: size limit.
291     :param max_size: size limit.
292     :return: the tuple with scaled image height, width.
293     """
294     ratio_min = min_size / min(height, width)
295     ratio_max = max_size / max(height, width)
296     ratio = min(ratio_min, ratio_max)
297     return int(round(height * ratio)), int(round(width * ratio))
298
299
300 def calculate_placeholder_spatial_shape(graph: Graph, match: SubgraphMatch, pipeline_config: PipelineConfig):
301     """
302     The function calculates the preprocessed shape of the input image for a TensorFlow Object Detection API model.
303     It uses various sources to calculate it:
304     1. The shape passed using the '--input_shape' command line parameter.
305     2. The values from the pipeline configuration file describing Preprocessor block of the topology:
306         a. If the fixed size resizer is used then size passed via '--input_shape' can override them, but Model Optimizer
307            prints warning. If the '--input_shape' is not defined then use values from the pipeline configuration file.
308         b. If the keep aspect ratio resizer is used then scale the size passed via '--input_shape' using the provided
309            limits. If the '--input_shape' is not defined then use shape as (min_dimension_size, min_dimension_size)
310            defined in the pipeline configuration file.
311     :param graph: graph with the topology.
312     :param match: the object containing matching sub-graph and custom attributes from the sub-graph replacement file.
313     :param pipeline_config: the object contain information from the pipeline configuration file.
314     :return: tuple (height, width) of the placeholder shape.
315     """
316     height = None
317     width = None
318     user_shapes = graph.graph['user_shapes']
319
320     if 'preprocessed_image_height' in match.custom_replacement_desc.custom_attributes or 'preprocessed_image_width' in \
321             match.custom_replacement_desc.custom_attributes:
322         log.error('The "preprocessed_image_height" or "preprocessed_image_width" is specified in the sub-graph '
323                   'replacement configuration file but they are ignored. Please, specify desired input shape using the '
324                   '"--input_shape" command line parameter.', extra={'is_warning': True})
325
326     user_defined_height = None
327     user_defined_width = None
328     if user_shapes and 'image_tensor' in user_shapes and user_shapes['image_tensor']:
329         user_defined_shape = user_shapes['image_tensor'][0]['shape']
330         if user_defined_shape is not None:
331             user_defined_height = user_defined_shape[1]
332             user_defined_width = user_defined_shape[2]
333
334     resizer_height = pipeline_config.get_param('resizer_image_height')
335     resizer_width = pipeline_config.get_param('resizer_image_width')
336     if resizer_height and resizer_width:
337         log.debug('The model resizes image to a fixed shape: ({}, {})'.format(resizer_height, resizer_width))
338
339     resizer_min_dimension = pipeline_config.get_param('resizer_min_dimension')
340     resizer_max_dimension = pipeline_config.get_param('resizer_max_dimension')
341     if resizer_min_dimension and resizer_max_dimension:
342         log.debug('The model resizes image using keep aspect ratio with minimum size {}, maximum size {}.'.format(
343             resizer_min_dimension, resizer_max_dimension))
344
345     # if the model is created with an input image resizer to a fixed shape
346     if resizer_width and resizer_height:
347         if user_defined_height and user_defined_width:
348             if user_defined_width != resizer_width or user_defined_width != resizer_width:
349                 log.error('The model expects that the input image is resized to a fixed shape ({}, {}), but the shape '
350                           'provided with the "--input_shape" command line parameter is different ({}, {}).'.format(
351                     resizer_height, resizer_width, user_defined_height, user_defined_width), extra={'is_warning': True})
352             height = user_defined_height
353             width = user_defined_width
354         else:
355             height = resizer_height
356             width = resizer_width
357
358     # if the model is created with an input image resizer keeping aspect ratio
359     if resizer_min_dimension and resizer_max_dimension:
360         print('[ WARNING ] Model Optimizer removes pre-processing block of the model which resizes image keeping '
361               'aspect ratio. The Inference Engine does not support dynamic image size so the Intermediate '
362               'Representation file is generated with the input image size of a fixed size.')
363         if user_defined_height and user_defined_width:
364             scaled_height, scaled_width = calculate_shape_keeping_aspect_ratio(user_defined_height,
365                                                                                user_defined_width,
366                                                                                resizer_min_dimension,
367                                                                                resizer_max_dimension)
368             if scaled_height != user_defined_height or scaled_width != scaled_width:
369                 log.error('The model resizes the input image keeping aspect ratio with min dimension {}, max '
370                           'dimension {}. The provided input height {}, width {} is transformed to height {}, width '
371                           '{}.'.format(resizer_min_dimension, resizer_max_dimension, user_defined_height,
372                                        user_defined_width, scaled_height, scaled_width), extra={'is_warning': True})
373             height = scaled_height
374             width = scaled_width
375         else:
376             height = width = resizer_min_dimension
377             print('Specify the "--input_shape" command line parameter to override the default shape which is equal to '
378                   '({}, {}).'.format(height, width))
379
380     if height is None or width is None:
381         raise Error('Failed to determine the placeholder shape.')
382     return height, width
383
384
385 class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSubGraph):
386     """
387     The class replaces the "Preprocessor" block resizing input image and applying mean/scale values. Only nodes related
388     to applying mean/scaling values are kept.
389     """
390     replacement_id = 'ObjectDetectionAPIPreprocessorReplacement'
391
392     def run_before(self):
393         return [Pack, Sub]
394
395     def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
396         new_nodes_to_remove = match.matched_nodes_names()
397         # do not remove nodes that perform input image scaling and mean value subtraction
398         for node_to_keep in ('Preprocessor/sub', 'Preprocessor/sub/y', 'Preprocessor/mul', 'Preprocessor/mul/x'):
399             if node_to_keep in new_nodes_to_remove:
400                 new_nodes_to_remove.remove(node_to_keep)
401         return new_nodes_to_remove
402
403     def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
404         argv = graph.graph['cmd_params']
405         layout = graph.graph['layout']
406         if argv.tensorflow_object_detection_api_pipeline_config is None:
407             raise Error(missing_param_error)
408         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
409
410         sub_node = match.output_node(0)[0]
411         if not sub_node.has('op') or sub_node.op != 'Sub':
412             raise Error('The output op of the Preprocessor sub-graph is not of type "Sub". Looks like the topology is '
413                         'not created with TensorFlow Object Detection API.')
414
415         mul_node = None
416         if sub_node.in_node(0).has('op') and sub_node.in_node(0).op == 'Mul':
417             log.info('There is image scaling node in the Preprocessor block.')
418             mul_node = sub_node.in_node(0)
419
420         initial_input_node_name = 'image_tensor'
421         if initial_input_node_name not in graph.nodes():
422             raise Error('Input node "{}" of the graph is not found. Do not run the Model Optimizer with '
423                         '"--input" command line parameter.'.format(initial_input_node_name))
424         placeholder_node = Node(graph, initial_input_node_name)
425
426         # set default value of the batch size to 1 if user didn't specify batch size and input shape
427         batch_dim = get_batch_dim(layout, 4)
428         if argv.batch is None and placeholder_node.shape[batch_dim] == -1:
429             placeholder_node.shape[batch_dim] = 1
430         height, width = calculate_placeholder_spatial_shape(graph, match, pipeline_config)
431         placeholder_node.shape[get_height_dim(layout, 4)] = height
432         placeholder_node.shape[get_width_dim(layout, 4)] = width
433
434         # save the pre-processed image spatial sizes to be used in the other replacers
435         graph.graph['preprocessed_image_height'] = placeholder_node.shape[get_height_dim(layout, 4)]
436         graph.graph['preprocessed_image_width'] = placeholder_node.shape[get_width_dim(layout, 4)]
437
438         to_float_node = placeholder_node.out_node(0)
439         if not to_float_node.has('op') or to_float_node.op != 'Cast':
440             raise Error('The output of the node "{}" is not Cast operation. Cannot apply replacer.'.format(
441                 initial_input_node_name))
442
443         # connect to_float_node directly with node performing scale on mean value subtraction
444         if mul_node is None:
445             graph.create_edge(to_float_node, sub_node, 0, 0)
446         else:
447             graph.create_edge(to_float_node, mul_node, 0, 1)
448
449         print('The Preprocessor block has been removed. Only nodes performing mean value subtraction and scaling (if'
450               ' applicable) are kept.')
451         return {}
452
453
454 class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFileSubGraph):
455     """
456     Replaces the sub-graph that is equal to the DetectionOutput layer from Inference Engine. This replacer is used for
457     Faster R-CNN, R-FCN and Mask R-CNN topologies conversion.
458     The replacer uses a value of the custom attribute 'coordinates_swap_method' from the sub-graph replacement
459     configuration file to choose how to swap box coordinates of the 0-th input of the generated DetectionOutput layer.
460     Refer to the code for more details.
461     """
462     replacement_id = 'ObjectDetectionAPIDetectionOutputReplacement'
463
464     def run_before(self):
465         return [ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement, Unpack, Sub]
466
467     def run_after(self):
468         return [ObjectDetectionAPIProposalReplacement, CropAndResizeReplacement]
469
470     def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
471         new_nodes_to_remove = match.matched_nodes_names().copy()
472         outputs = ['detection_boxes', 'detection_scores', 'num_detections']
473         for output in outputs:
474             children = Node(graph, output).out_nodes()
475             if len(children) != 1:
476                 log.warning('Output {} has {} children. It should have only one output: with op==`OpOutput`'
477                             ''.format(output, len(children)))
478             elif children[list(children.keys())[0]].op == 'OpOutput':
479                 new_nodes_to_remove.append(children[list(children.keys())[0]].id)
480             else:
481                 continue
482         new_nodes_to_remove.extend(outputs)
483         return new_nodes_to_remove
484
485     def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
486         # the DetectionOutput in IE produces single tensor, but in TF it produces four tensors, so we need to create
487         # only one output edge match
488         return {match.output_node(0)[0].id: new_sub_graph['detection_output_node'].id}
489
490     @staticmethod
491     def skip_nodes_by_condition(current_node: Node, condition: callable):
492         while condition(current_node):
493             current_node = current_node.in_node()
494         return current_node
495
496     def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
497         argv = graph.graph['cmd_params']
498         if argv.tensorflow_object_detection_api_pipeline_config is None:
499             raise Error(missing_param_error)
500         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
501
502         num_classes = _value_or_raise(match, pipeline_config, 'num_classes')
503         max_proposals = _value_or_raise(match, pipeline_config, 'first_stage_max_proposals')
504         activation_function = _value_or_raise(match, pipeline_config, 'postprocessing_score_converter')
505
506         activation_conf_node = add_activation_function_after_node(graph, match.single_input_node(1)[0].in_node(0),
507                                                                   activation_function)
508
509         # IE DetectionOutput layer consumes flattened tensors so need add a Reshape layer.
510         # The batch value of the input tensor is not equal to the batch of the topology, so it is not possible to use
511         # "0" value in the Reshape layer attribute to refer to the batch size, but we know how to
512         # calculate the second dimension so the batch value will be deduced from it with help of "-1".
513         reshape_conf_op = Reshape(graph, dict(dim=int64_array([-1, (num_classes + 1) * max_proposals])))
514         reshape_conf_node = reshape_conf_op.create_node([activation_conf_node], dict(name='do_reshape_conf'))
515
516         # Workaround for PermuteForReshape pass.
517         # We looking for first not Reshape-typed node before match.single_input_node(0)[0].in_node(0).
518         # And add  reshape_loc node after this first not Reshape-typed node.
519         current_node = self.skip_nodes_by_condition(match.single_input_node(0)[0].in_node(0),
520                                                     lambda x: x['kind'] == 'op' and x.soft_get('type') == 'Reshape')
521
522         reshape_loc_op = Reshape(graph, dict(dim=int64_array([-1, num_classes, 1, 4])))
523         reshape_loc_node = reshape_loc_op.create_node([current_node], dict(name='reshape_loc', nchw_layout=True))
524         update_attrs(reshape_loc_node, 'shape_attrs', 'dim')
525
526         # constant node with variances
527         variances_const_op = Const(graph, dict(value=_variance_from_pipeline_config(pipeline_config)))
528         variances_const_node = variances_const_op.create_node([])
529
530         # TF produces locations tensor without boxes for background.
531         # Inference Engine DetectionOutput layer requires background boxes so we generate them
532         loc_node = add_fake_background_loc(graph, reshape_loc_node)
533         PermuteAttrs.set_permutation(reshape_loc_node, loc_node, None)
534
535         # reshape locations tensor to 2D so it could be passed to Eltwise which will be converted to ScaleShift
536         reshape_loc_2d_op = Reshape(graph, dict(dim=int64_array([-1, 4])))
537         reshape_loc_2d_node = reshape_loc_2d_op.create_node([loc_node], dict(name='reshape_locs_2d', nchw_layout=True))
538         PermuteAttrs.set_permutation(loc_node, reshape_loc_2d_node, None)
539
540         # element-wise multiply locations with variances
541         eltwise_locs_op = Eltwise(graph, dict(operation='mul'))
542         eltwise_locs_node = eltwise_locs_op.create_node([reshape_loc_2d_node, variances_const_node],
543                                                         dict(name='scale_locs'))
544
545         # IE DetectionOutput layer consumes flattened tensors so need add a Reshape layer.
546         # The batch value of the input tensor is not equal to the batch of the topology, so it is not possible to use
547         # "0" value in the Reshape layer attribute to refer to the batch size, but we know how to
548         # calculate the second dimension so the batch value will be deduced from it with help of "-1".
549         reshape_loc_do_op = Reshape(graph, dict(dim=int64_array([-1, (num_classes + 1) * max_proposals * 4])))
550
551         custom_attributes = match.custom_replacement_desc.custom_attributes
552         coordinates_swap_method = 'add_convolution'
553         if 'coordinates_swap_method' not in custom_attributes:
554             log.error('The ObjectDetectionAPIDetectionOutputReplacement sub-graph replacement configuration file '
555                       'must contain "coordinates_swap_method" in the "custom_attributes" dictionary. Two values are '
556                       'supported: "swap_weights" and "add_convolution". The first one should be used when there is '
557                       'a MatMul or Conv2D node before the "SecondStagePostprocessor" block in the topology. With this '
558                       'solution the weights of the MatMul or Conv2D nodes are permutted, simulating the swap of XY '
559                       'coordinates in the tensor. The second could be used in any other cases but it is worse in terms '
560                       'of performance because it adds the Conv2D node which performs permutting of data. Since the '
561                       'attribute is not defined the second approach is used by default.')
562         else:
563             coordinates_swap_method = custom_attributes['coordinates_swap_method']
564         supported_swap_methods = ['swap_weights', 'add_convolution']
565         if coordinates_swap_method not in supported_swap_methods:
566             raise Error('Unsupported "coordinates_swap_method" defined in the sub-graph replacement configuration '
567                         'file. Supported methods are: {}'.format(', '.join(supported_swap_methods)))
568
569         if coordinates_swap_method == 'add_convolution':
570             swapped_locs_node = add_convolution_to_swap_xy_coordinates(graph, eltwise_locs_node, 4)
571             reshape_loc_do_node = reshape_loc_do_op.create_node([swapped_locs_node], dict(name='do_reshape_locs'))
572         else:
573             reshape_loc_do_node = reshape_loc_do_op.create_node([eltwise_locs_node], dict(name='do_reshape_locs'))
574
575         # find Proposal output which has the data layout as in TF: YXYX coordinates without batch indices.
576         proposal_nodes_ids = [node_id for node_id, attrs in graph.nodes(data=True)
577                               if 'name' in attrs and attrs['name'] == 'crop_proposals']
578         if len(proposal_nodes_ids) != 1:
579             raise Error("Found the following nodes '{}' with name 'crop_proposals' but there should be exactly 1. "
580                         "Looks like ObjectDetectionAPIProposalReplacement replacement didn't work.".
581                         format(proposal_nodes_ids))
582         proposal_node = Node(graph, proposal_nodes_ids[0])
583
584         # check whether it is necessary to permute proposals coordinates before passing them to the DetectionOutput
585         # currently this parameter is set for the RFCN topologies
586         if 'swap_proposals' in custom_attributes and custom_attributes['swap_proposals']:
587             proposal_node = add_convolution_to_swap_xy_coordinates(graph, proposal_node, 4)
588
589         # reshape priors boxes as Detection Output expects
590         reshape_priors_op = Reshape(graph, dict(dim=int64_array([-1, 1, max_proposals * 4])))
591         reshape_priors_node = reshape_priors_op.create_node([proposal_node],
592                                                             dict(name='DetectionOutput_reshape_priors_'))
593
594         detection_output_op = DetectionOutput(graph, {})
595         if coordinates_swap_method == 'swap_weights':
596             # update infer function to re-pack weights
597             detection_output_op.attrs['old_infer'] = detection_output_op.attrs['infer']
598             detection_output_op.attrs['infer'] = __class__.do_infer
599         for key in ('clip_before_nms', 'clip_after_nms'):
600             if key in match.custom_replacement_desc.custom_attributes:
601                 detection_output_op.attrs[key] = int(match.custom_replacement_desc.custom_attributes[key])
602
603         detection_output_node = detection_output_op.create_node(
604             [reshape_loc_do_node, reshape_conf_node, reshape_priors_node],
605             dict(name=detection_output_op.attrs['type'], share_location=0, variance_encoded_in_target=1,
606                  code_type='caffe.PriorBoxParameter.CENTER_SIZE', pad_mode='caffe.ResizeParameter.CONSTANT',
607                  resize_mode='caffe.ResizeParameter.WARP',
608                  num_classes=num_classes,
609                  confidence_threshold=_value_or_raise(match, pipeline_config, 'postprocessing_score_threshold'),
610                  top_k=_value_or_raise(match, pipeline_config, 'postprocessing_max_detections_per_class'),
611                  keep_top_k=_value_or_raise(match, pipeline_config, 'postprocessing_max_total_detections'),
612                  nms_threshold=_value_or_raise(match, pipeline_config, 'postprocessing_iou_threshold')))
613         # sets specific name to the node so we can find it in other replacers
614         detection_output_node.name = 'detection_output'
615
616         output_op = Output(graph, dict(name='do_OutputOp'))
617         output_op.create_node([detection_output_node])
618
619         print('The graph output nodes "num_detections", "detection_boxes", "detection_classes", "detection_scores" '
620               'have been replaced with a single layer of type "Detection Output". Refer to IR catalogue in the '
621               'documentation for information about this layer.')
622
623         return {'detection_output_node': detection_output_node}
624
625     @staticmethod
626     def do_infer(node):
627         node.old_infer(node)
628         # compared to the IE's DetectionOutput, the TF keeps the locations in YXYX, need to get back to the XYXY
629         # for last matmul/Conv2D that operate the locations need to swap the X and Y for output feature weights & biases
630         swap_weights_xy(backward_bfs_for_operation(node.in_node(0), ['MatMul', 'Conv2D']))
631
632
633 class ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement(FrontReplacementFromConfigFileSubGraph):
634     replacement_id = 'ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement'
635
636     def run_after(self):
637         return [ObjectDetectionAPIProposalReplacement]
638
639     def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
640         return {match.output_node(0)[0].id: new_sub_graph['roi_pooling_node'].id}
641
642     def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
643         argv = graph.graph['cmd_params']
644         if argv.tensorflow_object_detection_api_pipeline_config is None:
645             raise Error(missing_param_error)
646         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
647         roi_pool_size = _value_or_raise(match, pipeline_config, 'initial_crop_size')
648
649         detection_output_nodes_ids = [node_id for node_id, attrs in graph.nodes(data=True)
650                                       if 'name' in attrs and attrs['name'] == 'detection_output']
651         if len(detection_output_nodes_ids) != 1:
652             raise Error("Found the following nodes '{}' with 'detection_output' but there should be exactly 1.".
653                         format(detection_output_nodes_ids))
654         detection_output_node = Node(graph, detection_output_nodes_ids[0])
655
656         # add reshape of Detection Output so it can be an output of the topology
657         reshape_detection_output_2d_op = Reshape(graph, dict(dim=int64_array([-1, 7])))
658         reshape_detection_output_2d_node = reshape_detection_output_2d_op.create_node(
659             [detection_output_node], dict(name='reshape_do_2d'))
660
661         # adds special node of type "Output" that is a marker for the output nodes of the topology
662         output_op = Output(graph, dict(name='do_reshaped_OutputOp'))
663         output_node = output_op.create_node([reshape_detection_output_2d_node])
664
665         # add attribute 'output_sort_order' so it will be used as a key to sort output nodes before generation of IR
666         output_node.in_edge()['data_attrs'].append('output_sort_order')
667         output_node.in_edge()['output_sort_order'] = [('detection_boxes', 0)]
668
669         # creates two Crop operations which get input from the DetectionOutput layer, cuts of slices of data with class
670         # ids and probabilities and produce a tensor with batch ids and bounding boxes only (as it is expected by the
671         # ROIPooling layer)
672         crop_batch_op = Crop(graph, dict(axis=int64_array([3]), offset=int64_array([0]), dim=int64_array([1]),
673                                          nchw_layout=True))
674         crop_batch_node = crop_batch_op.create_node([detection_output_node], dict(name='crop_do_batch_ids'))
675
676         crop_coordinates_op = Crop(graph, dict(axis=int64_array([3]), offset=int64_array([3]), dim=int64_array([4]),
677                                                nchw_layout=True))
678         crop_coordinates_node = crop_coordinates_op.create_node([detection_output_node], dict(name='crop_do_coords'))
679
680         concat_op = Concat(graph, dict(axis=3))
681         concat_node = concat_op.create_node([crop_batch_node, crop_coordinates_node], dict(name='batch_and_coords',
682                                                                                            nchw_layout=True))
683
684         # reshape bounding boxes as required by ROIPooling
685         reshape_do_op = Reshape(graph, dict(dim=int64_array([-1, 5])))
686         reshape_do_node = reshape_do_op.create_node([concat_node], dict(name='reshape_do'))
687
688         roi_pooling_op = ROIPooling(graph, dict(method="bilinear", spatial_scale=1,
689                                                 pooled_h=roi_pool_size, pooled_w=roi_pool_size))
690         roi_pooling_node = roi_pooling_op.create_node([match.single_input_node(0)[0].in_node(), reshape_do_node],
691                                                       dict(name='ROI_pooling_2'))
692         return {'roi_pooling_node': roi_pooling_node}
693
694
695 class ObjectDetectionAPIMaskRCNNSigmoidReplacement(FrontReplacementFromConfigFileGeneral):
696     """
697     This replacer is used to convert Mask R-CNN topologies only.
698     Adds activation with sigmoid function to the end of the network producing masks tensors.
699     """
700     replacement_id = 'ObjectDetectionAPIMaskRCNNSigmoidReplacement'
701
702     def run_after(self):
703         return [ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement]
704
705     def transform_graph(self, graph: Graph, replacement_descriptions):
706         output_node = None
707         op_outputs = [n for n, d in graph.nodes(data=True) if 'op' in d and d['op'] == 'OpOutput']
708         for op_output in op_outputs:
709             last_node = Node(graph, op_output).in_node(0)
710             if last_node.name.startswith('SecondStageBoxPredictor'):
711                 sigmoid_op = Activation(graph, dict(operation='sigmoid'))
712                 sigmoid_node = sigmoid_op.create_node([last_node], dict(name=last_node.id + '/sigmoid'))
713                 sigmoid_node.name = 'masks'
714
715                 if output_node is not None:
716                     raise Error('Identified two possible outputs from the topology. Cannot proceed.')
717                 # add special node of type "Output" that is a marker for the output nodes of the topology
718                 output_op = Output(graph, dict(name=sigmoid_node.name + '/OutputOp'))
719                 output_node = output_op.create_node([sigmoid_node])
720
721         print('The predicted masks are produced by the "masks" layer for each bounding box generated with a '
722               '"detection_output" layer.\n Refer to IR catalogue in the documentation for information '
723               'about the DetectionOutput layer and Inference Engine documentation about output data interpretation.\n'
724               'The topology can be inferred using dedicated demo "mask_rcnn_demo".')
725
726
727 class ObjectDetectionAPIProposalReplacement(FrontReplacementFromConfigFileSubGraph):
728     """
729     This class replaces sub-graph of operations with Proposal layer and additional layers transforming
730     tensors from layout of TensorFlow to layout required by Inference Engine.
731     Refer to comments inside the function for more information about performed actions.
732     """
733     replacement_id = 'ObjectDetectionAPIProposalReplacement'
734
735     def run_after(self):
736         return [ObjectDetectionAPIPreprocessorReplacement]
737
738     def run_before(self):
739         return [Sub, CropAndResizeReplacement]
740
741     def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
742         return {match.output_node(0)[0].id: new_sub_graph['proposal_node'].id}
743
744     def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
745         new_list = match.matched_nodes_names().copy()
746         # do not remove nodes that produce box predictions and class predictions
747         new_list.remove(match.single_input_node(0)[0].id)
748         new_list.remove(match.single_input_node(1)[0].id)
749         return new_list
750
751     def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
752         argv = graph.graph['cmd_params']
753         if argv.tensorflow_object_detection_api_pipeline_config is None:
754             raise Error(missing_param_error)
755         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
756
757         max_proposals = _value_or_raise(match, pipeline_config, 'first_stage_max_proposals')
758         proposal_ratios = _value_or_raise(match, pipeline_config, 'anchor_generator_aspect_ratios')
759         proposal_scales = _value_or_raise(match, pipeline_config, 'anchor_generator_scales')
760         anchors_count = len(proposal_ratios) * len(proposal_scales)
761
762         # Convolution/matmul node that produces classes predictions
763         # Permute result of the tensor with classes permissions so it will be in a correct layout for Softmax
764         predictions_node = backward_bfs_for_operation(match.single_input_node(1)[0], ['Add'])[0]
765
766         reshape_classes_op = Reshape(graph, dict(dim=int64_array([0, anchors_count, 2, -1])))
767         reshape_classes_node = reshape_classes_op.create_node([], dict(name='predictions/Reshape', nchw_layout=True))
768         predictions_node.insert_node_after(reshape_classes_node, 0)
769
770         softmax_conf_op = Softmax(graph, dict(axis=2, nchw_layout=True, name=reshape_classes_node.id + '/Softmax'))
771         softmax_conf_node = softmax_conf_op.create_node([reshape_classes_node])
772         permute_reshape_softmax_op = Permute(graph, dict(order=int64_array([0, 2, 1, 3]), nchw_layout=True))
773         permute_reshape_softmax_node = permute_reshape_softmax_op.create_node([softmax_conf_node], dict(
774             name=softmax_conf_node.name + '/Permute'))
775
776         initial_shape_op = Shape(graph, dict(name=predictions_node.id + '/Shape'))
777         initial_shape_node = initial_shape_op.create_node([predictions_node])
778
779         # implement custom reshape infer function because we need to know the input convolution node output dimension
780         # sizes but we can know it only after partial infer
781         reshape_permute_op = Reshape(graph, dict())
782         reshape_permute_node = reshape_permute_op.create_node([permute_reshape_softmax_node, initial_shape_node],
783                                                               dict(name='Reshape_Permute_Class'))
784
785         variance_height = pipeline_config.get_param('frcnn_variance_height')
786         variance_width = pipeline_config.get_param('frcnn_variance_width')
787         variance_x = pipeline_config.get_param('frcnn_variance_x')
788         variance_y = pipeline_config.get_param('frcnn_variance_y')
789         anchor_generator_height_stride = pipeline_config.get_param('anchor_generator_height_stride')
790         anchor_generator_width_stride = pipeline_config.get_param('anchor_generator_width_stride')
791         anchor_generator_height = pipeline_config.get_param('anchor_generator_height')
792         anchor_generator_width = pipeline_config.get_param('anchor_generator_width')
793
794         if variance_height != variance_width:
795             log.error('The values for variance for height "{}" is not equal to variance for width "{}". The detection '
796                       'results will be inaccurate.'.format(variance_height, variance_width))
797         if variance_x != variance_y:
798             log.error('The values for variance for x "{}" is not equal to variance for y "{}". The detection '
799                       'results will be inaccurate.'.format(variance_x, variance_y))
800         if anchor_generator_height_stride != anchor_generator_width_stride:
801             log.error('The values for the anchor generator height stride "{}" is not equal to the anchor generator '
802                       'width stride "{}". The detection results will be inaccurate.'.format(
803                 anchor_generator_height_stride, anchor_generator_width_stride))
804         if anchor_generator_height != anchor_generator_width:
805             log.error('The values for the anchor generator height "{}" is not equal to the anchor generator width '
806                       'stride "{}". The detection results will be inaccurate.'.format(anchor_generator_height,
807                                                                                       anchor_generator_width))
808
809         proposal_op = ProposalOp(graph, dict(min_size=1,
810                                              framework='tensorflow',
811                                              pre_nms_topn=2 ** 31 - 1,
812                                              box_size_scale=variance_height,
813                                              box_coordinate_scale=variance_x,
814                                              post_nms_topn=max_proposals,
815                                              feat_stride=anchor_generator_height_stride,
816                                              ratio=proposal_ratios,
817                                              scale=proposal_scales,
818                                              normalize=1,
819                                              base_size=anchor_generator_height,
820                                              nms_thresh=_value_or_raise(match, pipeline_config,
821                                                                         'first_stage_nms_iou_threshold')))
822         for key in ('clip_before_nms', 'clip_after_nms'):
823             if key in match.custom_replacement_desc.custom_attributes:
824                 proposal_op.attrs[key] = int(match.custom_replacement_desc.custom_attributes[key])
825
826         anchors_node = backward_bfs_for_operation(match.single_input_node(0)[0], ['Add'])[0]
827
828         # creates input to store input image height, width and scales (usually 1.0s)
829         # the batch size for this input is fixed because it is allowed to pass images of the same size only as input
830         input_op_with_image_size = Input(graph, dict(shape=int64_array([1, 3]), fixed_batch=True))
831         input_with_image_size_node = input_op_with_image_size.create_node([], dict(name='image_info'))
832
833         proposal_node = proposal_op.create_node([reshape_permute_node, anchors_node, input_with_image_size_node],
834                                                 dict(name='proposals'))
835
836         if 'do_not_swap_proposals' in match.custom_replacement_desc.custom_attributes and \
837                 match.custom_replacement_desc.custom_attributes['do_not_swap_proposals']:
838             swapped_proposals_node = proposal_node
839         else:
840             swapped_proposals_node = add_convolution_to_swap_xy_coordinates(graph, proposal_node, 5)
841
842         proposal_reshape_2d_op = Reshape(graph, dict(dim=int64_array([-1, 5]), nchw_layout=True))
843         proposal_reshape_2d_node = proposal_reshape_2d_op.create_node([swapped_proposals_node],
844                                                                       dict(name="reshape_swap_proposals_2d"))
845
846         # feed the CropAndResize node with a correct boxes information produced with the Proposal layer
847         # find the first CropAndResize node in the BFS order
848         crop_and_resize_nodes_ids = [node_id for node_id in bfs_search(graph, [match.single_input_node(0)[0].id]) if
849                                      graph.node[node_id]['op'] == 'CropAndResize']
850         assert len(crop_and_resize_nodes_ids) != 0, "Didn't find any CropAndResize nodes in the graph."
851         if 'do_not_swap_proposals' not in match.custom_replacement_desc.custom_attributes or not \
852                 match.custom_replacement_desc.custom_attributes['do_not_swap_proposals']:
853             crop_and_resize_node = Node(graph, crop_and_resize_nodes_ids[0])
854             # set a marker that the input with box coordinates has been pre-processed so the CropAndResizeReplacement
855             # transform doesn't try to merge the second and the third inputs
856             crop_and_resize_node['inputs_preprocessed'] = True
857             graph.remove_edge(crop_and_resize_node.in_node(1).id, crop_and_resize_node.id)
858             graph.create_edge(proposal_reshape_2d_node, crop_and_resize_node, out_port=0, in_port=1)
859
860         tf_proposal_reshape_4d_op = Reshape(graph, dict(dim=int64_array([-1, 1, max_proposals, 5]), nchw_layout=True))
861         tf_proposal_reshape_4d_node = tf_proposal_reshape_4d_op.create_node([swapped_proposals_node],
862                                                                             dict(name="reshape_proposal_4d"))
863
864         crop_op = Crop(graph, dict(axis=int64_array([3]), offset=int64_array([1]), dim=int64_array([4]),
865                                    nchw_layout=True))
866         crop_node = crop_op.create_node([tf_proposal_reshape_4d_node], dict(name='crop_proposals'))
867
868         tf_proposals_crop_reshape_3d_op = Reshape(graph, dict(dim=int64_array([0, -1, 4]), nchw_layout=True))
869         tf_proposals_crop_reshape_3d_node = tf_proposals_crop_reshape_3d_op.create_node([crop_node],
870                                                                                         dict(name="reshape_crop_3d"))
871
872         return {'proposal_node': tf_proposals_crop_reshape_3d_node}
873
874
875 class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFileSubGraph):
876     replacement_id = 'ObjectDetectionAPISSDPostprocessorReplacement'
877
878     def run_after(self):
879         return [ObjectDetectionAPIPreprocessorReplacement]
880
881     def run_before(self):
882         # the replacer uses node of type "RealDiv" as one of the start points, but Model Optimizer replaces nodes of
883         # type "RealDiv" with a new ones, so it is necessary to replace the sub-graph before replacing the "RealDiv"
884         # nodes
885         return [Div, StandaloneConstEraser]
886
887     def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
888         # the DetectionOutput in IE produces single tensor, but in TF it produces two tensors, so create only one output
889         # edge match
890         return {match.output_node(0)[0].id: new_sub_graph['detection_output_node'].id}
891
892     def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
893         argv = graph.graph['cmd_params']
894         if argv.tensorflow_object_detection_api_pipeline_config is None:
895             raise Error(missing_param_error)
896         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
897         num_classes = _value_or_raise(match, pipeline_config, 'num_classes')
898
899         # reshapes confidences to 4D before applying activation function
900         expand_dims_op = Reshape(graph, {'dim': int64_array([0, 1, -1, num_classes + 1])})
901         # do not convert from NHWC to NCHW this node shape
902         expand_dims_node = expand_dims_op.create_node([match.input_nodes(1)[0][0].in_node(0)],
903                                                       dict(name='do_ExpandDims_conf'))
904
905         activation_function = _value_or_raise(match, pipeline_config, 'postprocessing_score_converter')
906         activation_conf_node = add_activation_function_after_node(graph, expand_dims_node, activation_function)
907         PermuteAttrs.set_permutation(expand_dims_node, expand_dims_node.out_node(), None)
908
909         # IE DetectionOutput layer consumes flattened tensors
910         # reshape operation to flatten locations tensor
911         reshape_loc_op = Reshape(graph, {'dim': int64_array([0, -1])})
912         reshape_loc_node = reshape_loc_op.create_node([match.input_nodes(0)[0][0].in_node(0)],
913                                                       dict(name='do_reshape_loc'))
914
915         # IE DetectionOutput layer consumes flattened tensors
916         # reshape operation to flatten confidence tensor
917         reshape_conf_op = Reshape(graph, {'dim': int64_array([0, -1])})
918         reshape_conf_node = reshape_conf_op.create_node([activation_conf_node], dict(name='do_reshape_conf'))
919
920         if pipeline_config.get_param('ssd_anchor_generator_num_layers') is not None or \
921                         pipeline_config.get_param('multiscale_anchor_generator_min_level') is not None:
922             # change the Reshape operations with hardcoded number of output elements of the convolution nodes to be
923             # reshapable
924             _relax_reshape_nodes(graph, pipeline_config)
925
926             # create PriorBoxClustered nodes instead of a constant value with prior boxes so the model could be reshaped
927             if pipeline_config.get_param('ssd_anchor_generator_num_layers') is not None:
928                 priors_node = _create_prior_boxes_node(graph, pipeline_config)
929             elif pipeline_config.get_param('multiscale_anchor_generator_min_level') is not None:
930                 priors_node = _create_multiscale_prior_boxes_node(graph, pipeline_config)
931         else:
932             log.info('The anchor generator is not known. Save constant with prior-boxes to IR.')
933             priors_node = match.input_nodes(2)[0][0].in_node(0)
934
935         # creates DetectionOutput Node object from Op class
936         detection_output_op = DetectionOutput(graph, match.custom_replacement_desc.custom_attributes)
937         detection_output_op.attrs['old_infer'] = detection_output_op.attrs['infer']
938         detection_output_op.attrs['infer'] = __class__.do_infer
939         detection_output_node = detection_output_op.create_node(
940             [reshape_loc_node, reshape_conf_node, priors_node],
941             dict(name=detection_output_op.attrs['type'],
942                  clip=1,
943                  confidence_threshold=_value_or_raise(match, pipeline_config, 'postprocessing_score_threshold'),
944                  top_k=_value_or_raise(match, pipeline_config, 'postprocessing_max_detections_per_class'),
945                  keep_top_k=_value_or_raise(match, pipeline_config, 'postprocessing_max_total_detections'),
946                  nms_threshold=_value_or_raise(match, pipeline_config, 'postprocessing_iou_threshold')))
947
948         return {'detection_output_node': detection_output_node}
949
950     @staticmethod
951     def do_infer(node: Node):
952         prior_boxes = node.in_node(2).value
953         if prior_boxes is not None:
954             argv = node.graph.graph['cmd_params']
955             if argv.tensorflow_object_detection_api_pipeline_config is None:
956                 raise Error(missing_param_error)
957             pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
958             variance = _variance_from_pipeline_config(pipeline_config)
959             # replicating the variance values for all prior-boxes
960             variances = np.tile(variance, [prior_boxes.shape[-2], 1])
961             # DetectionOutput Inference Engine expects the prior-boxes in the following layout: (values, variances)
962             prior_boxes = prior_boxes.reshape([-1, 4])
963             prior_boxes = np.concatenate((prior_boxes, variances), 0)
964             # compared to the IE's DetectionOutput, the TF keeps the prior-boxes in YXYX, need to get back to the XYXY
965             prior_boxes = np.concatenate((prior_boxes[:, 1:2], prior_boxes[:, 0:1],
966                                           prior_boxes[:, 3:4], prior_boxes[:, 2:3]), 1)
967             #  adding another dimensions, as the prior-boxes are expected as 3d tensors
968             prior_boxes = prior_boxes.reshape((1, 2, -1))
969             node.in_node(2).shape = int64_array(prior_boxes.shape)
970             node.in_node(2).value = prior_boxes
971
972         node.old_infer(node)
973         # compared to the IE's DetectionOutput, the TF keeps the locations in YXYX, need to get back to the XYXY
974         # for last convolutions that operate the locations need to swap the X and Y for output feature weights & biases
975         conv_nodes = backward_bfs_for_operation(node.in_node(0), ['Conv2D'])
976         swap_weights_xy(conv_nodes)
977         squeeze_reshape_and_concat(conv_nodes)
978
979         for node_name in node.graph.nodes():
980             node = Node(node.graph, node_name)
981             if node.has_and_set('swap_xy_count') and len(node.out_nodes()) != node['swap_xy_count']:
982                 raise Error('The weights were swapped for node "{}", but this weight was used in other nodes.'.format(
983                     node.name))
984
985
986 class ObjectDetectionAPIOutputReplacement(FrontReplacementFromConfigFileGeneral):
987     """
988     This replacer is used to cut-off the network by specified nodes for models generated with Object Detection API.
989     The custom attribute for the replacer contains one value for key "outputs". This string is a comma separated list
990     of outputs alternatives. Each output alternative is a '|' separated list of node name which could be outputs. The
991     first node from each alternative that exits in the graph is chosen. Others are ignored.
992     For example, if the "outputs" is equal to the following string:
993
994         "Reshape_16,SecondStageBoxPredictor_1/Conv_3/BiasAdd|SecondStageBoxPredictor_1/Conv_1/BiasAdd"
995
996     then the "Reshape_16" will be an output if it exists in the graph. The second output will be
997     SecondStageBoxPredictor_1/Conv_3/BiasAdd if it exist in the graph, if not then
998     SecondStageBoxPredictor_1/Conv_1/BiasAdd will be output if it exists in the graph.
999     """
1000     replacement_id = 'ObjectDetectionAPIOutputReplacement'
1001
1002     def run_before(self):
1003         return [ObjectDetectionAPIPreprocessorReplacement]
1004
1005     def transform_graph(self, graph: Graph, replacement_descriptions: dict):
1006         if graph.graph['cmd_params'].output is not None:
1007             log.warning('User defined output nodes are specified. Skip the graph cut-off by the '
1008                         'ObjectDetectionAPIOutputReplacement.')
1009             return
1010         outputs = []
1011         outputs_string = replacement_descriptions['outputs']
1012         for alternatives in outputs_string.split(','):
1013             for out_node_name in alternatives.split('|'):
1014                 if graph.has_node(out_node_name):
1015                     outputs.append(out_node_name)
1016                     break
1017                 else:
1018                     log.debug('A node "{}" does not exist in the graph. Do not add it as output'.format(out_node_name))
1019         _outputs = output_user_data_repack(graph, outputs)
1020         add_output_ops(graph, _outputs, graph.graph['inputs'])
1021
1022
1023 class ObjectDetectionAPIPSROIPoolingReplacement(FrontReplacementFromConfigFileSubGraph):
1024     replacement_id = 'ObjectDetectionAPIPSROIPoolingReplacement'
1025
1026     def run_after(self):
1027         return [ObjectDetectionAPIProposalReplacement]
1028
1029     def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
1030         return {match.output_node(0)[0].id: new_sub_graph['output_node'].id}
1031
1032     def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
1033         argv = graph.graph['cmd_params']
1034         if argv.tensorflow_object_detection_api_pipeline_config is None:
1035             raise Error(missing_param_error)
1036         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
1037         num_classes = _value_or_raise(match, pipeline_config, 'num_classes')
1038
1039         input_node = match.input_nodes(0)[0][0].in_node(0)
1040         if 'class_predictions' in input_node.id:
1041             psroipooling_output_dim = num_classes + 1
1042         else:
1043             psroipooling_output_dim = num_classes * 4
1044
1045         num_spatial_bins_height = pipeline_config.get_param('num_spatial_bins_height')
1046         num_spatial_bins_width = pipeline_config.get_param('num_spatial_bins_width')
1047         crop_height = pipeline_config.get_param('crop_height')
1048         crop_width = pipeline_config.get_param('crop_width')
1049         if crop_height != crop_width:
1050             raise Error('Different "crop_height" and "crop_width" parameters from the pipeline config are not '
1051                         'supported: {} vs {}'.format(crop_height, crop_width))
1052         psroipooling_op = PSROIPoolingOp(graph, {'name': input_node.soft_get('name') + '/PSROIPooling',
1053                                                  'output_dim': psroipooling_output_dim,
1054                                                  'group_size': crop_width / num_spatial_bins_width,
1055                                                  'spatial_bins_x': num_spatial_bins_width,
1056                                                  'spatial_bins_y': num_spatial_bins_height,
1057                                                  'mode': 'bilinear',
1058                                                  'spatial_scale': 1,
1059                                                  })
1060
1061         if 'reshape_swap_proposals_2d' in graph.nodes():
1062             reshape_swap_proposals_node = Node(graph, 'reshape_swap_proposals_2d')
1063         else:
1064             swap_proposals_node = add_convolution_to_swap_xy_coordinates(graph, Node(graph, 'proposals'), 5)
1065             reshape_swap_proposals_node = Reshape(graph, {'dim': [-1, 5], 'nchw_layout': True,
1066                                                           'name': 'reshape_swap_proposals_2d'}).create_node(
1067                 [swap_proposals_node])
1068         psroipooling_node = psroipooling_op.create_node([input_node, reshape_swap_proposals_node])
1069
1070         reduce_op = Reduce(graph, {'name': 'mean',
1071                                    'reduce_type': 'mean',
1072                                    'axis': int64_array([1, 2]),
1073                                    'keep_dims': True
1074                                    })
1075         reduce_node = reduce_op.create_node([psroipooling_node])
1076
1077         graph.erase_node(match.output_node(0)[0].out_node())
1078
1079         return {'output_node': reduce_node}
1080
1081
1082 class ObjectDetectionAPIConstValueOverride(FrontReplacementFromConfigFileGeneral):
1083     """
1084     Transforms allows to override specific constant values in the topology. The replacement description configuration
1085     file contains list of tuples describing the desired replacements specified in the "replacements" key of the
1086     "custom_attributes". The first element in the tuple is the initial node name of the graph with constant value. The
1087     second element is the name of the parameter from the pipeline configuration file which stores new value.
1088
1089     Usage example. The Faster-RCNNs topologies has constant node with the number specifying maximum generated proposals.
1090     This value is specified in the pipeline configuration file in the parameter 'first_stage_max_proposals' and is
1091     saved as a constant node in the generated topology. If the parameter is modified from it's original value then the
1092     topology will be incorrect because the number 'first_stage_max_proposals' is used in the transforms of this file is
1093     no more equal to the 'first_stage_max_proposals' saved as a constant.
1094     """
1095     replacement_id = 'ObjectDetectionAPIConstValueOverride'
1096
1097     def run_before(self):
1098         return [ObjectDetectionAPIPreprocessorReplacement]
1099
1100     def transform_graph(self, graph: Graph, replacement_descriptions: dict):
1101         argv = graph.graph['cmd_params']
1102         if argv.tensorflow_object_detection_api_pipeline_config is None:
1103             raise Error(missing_param_error)
1104         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
1105         for (node_id, pipeline_config_name) in replacement_descriptions['replacements']:
1106             if node_id not in graph.nodes():
1107                 log.debug('Node with id {} does not exist in the graph'.format(node_id))
1108                 continue
1109             node = Node(graph, node_id)
1110             if not node.has_valid('value'):
1111                 log.debug('Node with id {} does not have value'.format(node_id))
1112                 continue
1113             node.value = np.array(pipeline_config.get_param(pipeline_config_name))
1114             node.value = node.value.reshape(node.shape)