Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / ObjectDetectionAPI.py
index c62f9f6..c729051 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 import logging as log
 from math import sqrt
 
-import networkx as nx
 import numpy as np
 
+from extensions.front.Pack import Pack
+from extensions.front.div import Div
 from extensions.front.standalone_const_eraser import StandaloneConstEraser
 from extensions.front.sub import Sub
 from extensions.front.tf.CropAndResizeReplacement import CropAndResizeReplacement
-from extensions.front.Pack import Pack
 from extensions.front.tf.Unpack import Unpack
 from extensions.ops.DetectionOutput import DetectionOutput
 from extensions.ops.priorbox_clustered import PriorBoxClusteredOp
 from extensions.ops.proposal import ProposalOp
+from extensions.ops.psroipooling import PSROIPoolingOp
 from mo.front.common.layout import get_batch_dim, get_height_dim, get_width_dim
+from mo.front.common.partial_infer.utils import int64_array
 from mo.front.common.weights import swap_weights_xy
-from mo.front.extractor import output_user_data_repack, add_output_ops
+from mo.front.extractor import output_user_data_repack, add_output_ops, update_attrs
 from mo.front.subgraph_matcher import SubgraphMatch
 from mo.front.tf.graph_utils import add_activation_function_after_node, add_convolution_to_swap_xy_coordinates, \
-    squeeze_reshape_and_concat
+    squeeze_reshape_and_concat, add_fake_background_loc
 from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph, FrontReplacementFromConfigFileGeneral
-from mo.graph.graph import create_edge, insert_node_after, Node, replace_node
+from mo.graph.graph import Graph, Node
 from mo.ops.activation import Activation
 from mo.ops.concat import Concat
 from mo.ops.const import Const
 from mo.ops.crop import Crop
-from mo.ops.div import Div
 from mo.ops.eltwise import Eltwise
+from mo.ops.input import Input
 from mo.ops.op import PermuteAttrs
 from mo.ops.output import Output
 from mo.ops.permute import Permute
+from mo.ops.reduce import Reduce
 from mo.ops.reshape import Reshape
 from mo.ops.roipooling import ROIPooling
+from mo.ops.shape import Shape
 from mo.ops.softmax import Softmax
 from mo.utils.error import Error
-from mo.utils.graph import backward_bfs_for_operation
+from mo.utils.graph import backward_bfs_for_operation, bfs_search
 from mo.utils.pipeline_config import PipelineConfig
 
 missing_param_error = 'To convert the model specify path to the pipeline configuration file which was used to ' \
@@ -82,7 +86,7 @@ def _value_or_raise(match: SubgraphMatch, pipeline_config: PipelineConfig, key:
     return value
 
 
-def _find_ssd_head_node(graph: nx.MultiDiGraph, ssd_head_index: int, head_type: str):
+def _find_ssd_head_node(graph: Graph, ssd_head_index: int, head_type: str):
     """
     Finds the SSD head node with index 'ssd_head_index' in the topology. The parameter 'head_type' specifies what type
     of the head is requested: with box predictions or class predictions.
@@ -135,7 +139,7 @@ def _skip_node_of_type(node: Node, node_ops_to_skip: list):
     return node
 
 
-def _relax_reshape_nodes(graph: nx.MultiDiGraph, pipeline_config: PipelineConfig):
+def _relax_reshape_nodes(graph: Graph, pipeline_config: PipelineConfig):
     """
     Finds the 'Reshape' operations following the SSD head nodes which have hard-coded output dimensions and replaces
     them with new ones with one of the dimensions sizes equal to -1. This function is used to make TF OD API SSD models
@@ -155,23 +159,23 @@ def _relax_reshape_nodes(graph: nx.MultiDiGraph, pipeline_config: PipelineConfig
         assert (input_node is not None)
         old_reshape_node = _skip_node_of_type(input_node.out_node(), ['Identity'])
         assert (old_reshape_node.op == 'Reshape')
-        reshape_size_node = Const(graph, {'value': np.array([0, -1, 1, 4])}).create_node([])
+        reshape_size_node = Const(graph, {'value': int64_array([0, -1, 1, 4])}).create_node([])
         new_reshape_op = Reshape(graph, {'name': input_node.id + '/Reshape', 'correct_data_layout': True})
         new_reshape_node = new_reshape_op.create_node([input_node, reshape_size_node])
-        replace_node(old_reshape_node, new_reshape_node)
+        old_reshape_node.replace_node(new_reshape_node)
 
         # fix hard-coded value for the number of items in tensor produced by the convolution to make topology reshapable
         input_node = _find_ssd_head_node(graph, ssd_head_ind, 'class')
         assert (input_node is not None)
         old_reshape_node = _skip_node_of_type(input_node.out_node(), ['Identity'])
         assert (old_reshape_node.op == 'Reshape')
-        reshape_size_node_2 = Const(graph, {'value': np.array([0, -1, num_classes + 1])}).create_node([])
+        reshape_size_node_2 = Const(graph, {'value': int64_array([0, -1, num_classes + 1])}).create_node([])
         new_reshape_op_2 = Reshape(graph, {'name': input_node.id + '/Reshape', 'correct_data_layout': True})
         new_reshape_node_2 = new_reshape_op_2.create_node([input_node, reshape_size_node_2])
-        replace_node(old_reshape_node, new_reshape_node_2)
+        old_reshape_node.replace_node(new_reshape_node_2)
 
 
-def _create_prior_boxes_node(graph: nx.MultiDiGraph, pipeline_config: PipelineConfig):
+def _create_prior_boxes_node(graph: Graph, pipeline_config: PipelineConfig):
     """
     The function creates one or several PriorBoxClustered nodes based on information from the pipeline configuration
     files. The PriorBoxClustered nodes get input data from SSD 'heads' and from the placeholder node (just to get
@@ -227,11 +231,11 @@ def _create_prior_boxes_node(graph: nx.MultiDiGraph, pipeline_config: PipelineCo
     if len(prior_box_nodes) == 1:
         return prior_box_nodes[0]
     else:
-        concat_prior_boxes_op = Concat(graph, {'axis': -1})
+        concat_prior_boxes_op = Concat(graph, {'axis': -1, 'in_ports_count': len(prior_box_nodes)})
         return concat_prior_boxes_op.create_node(prior_box_nodes, {'name': 'ConcatPriorBoxesClustered'})
 
 
-def _create_multiscale_prior_boxes_node(graph: nx.MultiDiGraph, pipeline_config: PipelineConfig):
+def _create_multiscale_prior_boxes_node(graph: Graph, pipeline_config: PipelineConfig):
     """
     The function creates one or several PriorBoxClustered nodes based on information from the pipeline configuration
     files. The PriorBoxClustered nodes get input data from SSD 'heads' and from the placeholder node (just to get
@@ -272,7 +276,7 @@ def _create_multiscale_prior_boxes_node(graph: nx.MultiDiGraph, pipeline_config:
     if len(prior_box_nodes) == 1:
         return prior_box_nodes[0]
     else:
-        concat_prior_boxes_op = Concat(graph, {'axis': -1})
+        concat_prior_boxes_op = Concat(graph, {'axis': -1, 'in_ports_count': len(prior_box_nodes)})
         return concat_prior_boxes_op.create_node(prior_box_nodes, {'name': 'ConcatPriorBoxesClustered'})
 
 
@@ -293,7 +297,7 @@ def calculate_shape_keeping_aspect_ratio(height: int, width: int, min_size: int,
     return int(round(height * ratio)), int(round(width * ratio))
 
 
-def calculate_placeholder_spatial_shape(graph: nx.MultiDiGraph, match: SubgraphMatch, pipeline_config: PipelineConfig):
+def calculate_placeholder_spatial_shape(graph: Graph, match: SubgraphMatch, pipeline_config: PipelineConfig):
     """
     The function calculates the preprocessed shape of the input image for a TensorFlow Object Detection API model.
     It uses various sources to calculate it:
@@ -388,7 +392,7 @@ class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSu
     def run_before(self):
         return [Pack, Sub]
 
-    def nodes_to_remove(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+    def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
         new_nodes_to_remove = match.matched_nodes_names()
         # do not remove nodes that perform input image scaling and mean value subtraction
         for node_to_keep in ('Preprocessor/sub', 'Preprocessor/sub/y', 'Preprocessor/mul', 'Preprocessor/mul/x'):
@@ -396,7 +400,7 @@ class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSu
                 new_nodes_to_remove.remove(node_to_keep)
         return new_nodes_to_remove
 
-    def generate_sub_graph(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
         argv = graph.graph['cmd_params']
         layout = graph.graph['layout']
         if argv.tensorflow_object_detection_api_pipeline_config is None:
@@ -423,8 +427,6 @@ class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSu
         batch_dim = get_batch_dim(layout, 4)
         if argv.batch is None and placeholder_node.shape[batch_dim] == -1:
             placeholder_node.shape[batch_dim] = 1
-        if placeholder_node.shape[batch_dim] > 1:
-            print("[ WARNING ] The batch size more than 1 is supported for SSD topologies only.")
         height, width = calculate_placeholder_spatial_shape(graph, match, pipeline_config)
         placeholder_node.shape[get_height_dim(layout, 4)] = height
         placeholder_node.shape[get_width_dim(layout, 4)] = width
@@ -440,9 +442,9 @@ class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSu
 
         # connect to_float_node directly with node performing scale on mean value subtraction
         if mul_node is None:
-            create_edge(to_float_node, sub_node, 0, 0)
+            graph.create_edge(to_float_node, sub_node, 0, 0)
         else:
-            create_edge(to_float_node, mul_node, 0, 1)
+            graph.create_edge(to_float_node, mul_node, 0, 1)
 
         print('The Preprocessor block has been removed. Only nodes performing mean value subtraction and scaling (if'
               ' applicable) are kept.')
@@ -465,12 +467,22 @@ class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFil
     def run_after(self):
         return [ObjectDetectionAPIProposalReplacement, CropAndResizeReplacement]
 
-    def nodes_to_remove(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+    def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
         new_nodes_to_remove = match.matched_nodes_names().copy()
-        new_nodes_to_remove.extend(['detection_boxes', 'detection_scores', 'num_detections'])
+        outputs = ['detection_boxes', 'detection_scores', 'num_detections']
+        for output in outputs:
+            children = Node(graph, output).out_nodes()
+            if len(children) != 1:
+                log.warning('Output {} has {} children. It should have only one output: with op==`OpOutput`'
+                            ''.format(output, len(children)))
+            elif children[list(children.keys())[0]].op == 'OpOutput':
+                new_nodes_to_remove.append(children[list(children.keys())[0]].id)
+            else:
+                continue
+        new_nodes_to_remove.extend(outputs)
         return new_nodes_to_remove
 
-    def output_edges_match(self, graph: nx.DiGraph, match: SubgraphMatch, new_sub_graph: dict):
+    def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
         # the DetectionOutput in IE produces single tensor, but in TF it produces four tensors, so we need to create
         # only one output edge match
         return {match.output_node(0)[0].id: new_sub_graph['detection_output_node'].id}
@@ -481,62 +493,60 @@ class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFil
             current_node = current_node.in_node()
         return current_node
 
-    def generate_sub_graph(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
         argv = graph.graph['cmd_params']
         if argv.tensorflow_object_detection_api_pipeline_config is None:
             raise Error(missing_param_error)
         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
 
         num_classes = _value_or_raise(match, pipeline_config, 'num_classes')
-        first_stage_max_proposals = _value_or_raise(match, pipeline_config, 'first_stage_max_proposals')
+        max_proposals = _value_or_raise(match, pipeline_config, 'first_stage_max_proposals')
         activation_function = _value_or_raise(match, pipeline_config, 'postprocessing_score_converter')
 
         activation_conf_node = add_activation_function_after_node(graph, match.single_input_node(1)[0].in_node(0),
                                                                   activation_function)
 
-        # IE DetectionOutput layer consumes flattened tensors
-        # reshape operation to flatten confidence tensor
-        reshape_conf_op = Reshape(graph, dict(dim=np.array([1, -1])))
+        # IE DetectionOutput layer consumes flattened tensors so need add a Reshape layer.
+        # The batch value of the input tensor is not equal to the batch of the topology, so it is not possible to use
+        # "0" value in the Reshape layer attribute to refer to the batch size, but we know how to
+        # calculate the second dimension so the batch value will be deduced from it with help of "-1".
+        reshape_conf_op = Reshape(graph, dict(dim=int64_array([-1, (num_classes + 1) * max_proposals])))
         reshape_conf_node = reshape_conf_op.create_node([activation_conf_node], dict(name='do_reshape_conf'))
 
-        # TF produces locations tensor without boxes for background.
-        # Inference Engine DetectionOutput layer requires background boxes so we generate them with some values
-        # and concatenate with locations tensor
-        fake_background_locs_blob = np.tile([[[1, 1, 2, 2]]], [first_stage_max_proposals, 1, 1])
-        fake_background_locs_const_op = Const(graph, dict(value=fake_background_locs_blob))
-        fake_background_locs_const_node = fake_background_locs_const_op.create_node([])
-
         # Workaround for PermuteForReshape pass.
         # We looking for first not Reshape-typed node before match.single_input_node(0)[0].in_node(0).
         # And add  reshape_loc node after this first not Reshape-typed node.
         current_node = self.skip_nodes_by_condition(match.single_input_node(0)[0].in_node(0),
                                                     lambda x: x['kind'] == 'op' and x.soft_get('type') == 'Reshape')
 
-        reshape_loc_op = Reshape(graph, dict(dim=np.array([first_stage_max_proposals, num_classes, 4])))
-        reshape_loc_node = reshape_loc_op.create_node([current_node], dict(name='reshape_loc'))
-
-        concat_loc_op = Concat(graph, dict(axis=1))
-        concat_loc_node = concat_loc_op.create_node([fake_background_locs_const_node, reshape_loc_node],
-                                                    dict(name='concat_fake_loc'))
-        PermuteAttrs.set_permutation(reshape_loc_node, concat_loc_node, None)
-        PermuteAttrs.set_permutation(fake_background_locs_const_node, concat_loc_node, None)
+        reshape_loc_op = Reshape(graph, dict(dim=int64_array([-1, num_classes, 1, 4])))
+        reshape_loc_node = reshape_loc_op.create_node([current_node], dict(name='reshape_loc', nchw_layout=True))
+        update_attrs(reshape_loc_node, 'shape_attrs', 'dim')
 
         # constant node with variances
         variances_const_op = Const(graph, dict(value=_variance_from_pipeline_config(pipeline_config)))
         variances_const_node = variances_const_op.create_node([])
 
+        # TF produces locations tensor without boxes for background.
+        # Inference Engine DetectionOutput layer requires background boxes so we generate them
+        loc_node = add_fake_background_loc(graph, reshape_loc_node)
+        PermuteAttrs.set_permutation(reshape_loc_node, loc_node, None)
+
         # reshape locations tensor to 2D so it could be passed to Eltwise which will be converted to ScaleShift
-        reshape_loc_2d_op = Reshape(graph, dict(dim=np.array([-1, 4])))
-        reshape_loc_2d_node = reshape_loc_2d_op.create_node([concat_loc_node], dict(name='reshape_locs_2'))
-        PermuteAttrs.set_permutation(concat_loc_node, reshape_loc_2d_node, None)
+        reshape_loc_2d_op = Reshape(graph, dict(dim=int64_array([-1, 4])))
+        reshape_loc_2d_node = reshape_loc_2d_op.create_node([loc_node], dict(name='reshape_locs_2d', nchw_layout=True))
+        PermuteAttrs.set_permutation(loc_node, reshape_loc_2d_node, None)
 
         # element-wise multiply locations with variances
         eltwise_locs_op = Eltwise(graph, dict(operation='mul'))
         eltwise_locs_node = eltwise_locs_op.create_node([reshape_loc_2d_node, variances_const_node],
                                                         dict(name='scale_locs'))
 
-        # IE DetectionOutput layer consumes flattened tensors
-        reshape_loc_do_op = Reshape(graph, dict(dim=np.array([1, -1])))
+        # IE DetectionOutput layer consumes flattened tensors so need add a Reshape layer.
+        # The batch value of the input tensor is not equal to the batch of the topology, so it is not possible to use
+        # "0" value in the Reshape layer attribute to refer to the batch size, but we know how to
+        # calculate the second dimension so the batch value will be deduced from it with help of "-1".
+        reshape_loc_do_op = Reshape(graph, dict(dim=int64_array([-1, (num_classes + 1) * max_proposals * 4])))
 
         custom_attributes = match.custom_replacement_desc.custom_attributes
         coordinates_swap_method = 'add_convolution'
@@ -564,18 +574,21 @@ class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFil
 
         # find Proposal output which has the data layout as in TF: YXYX coordinates without batch indices.
         proposal_nodes_ids = [node_id for node_id, attrs in graph.nodes(data=True)
-                              if 'name' in attrs and attrs['name'] == 'proposals']
+                              if 'name' in attrs and attrs['name'] == 'crop_proposals']
         if len(proposal_nodes_ids) != 1:
-            raise Error("Found the following nodes '{}' with name 'proposals' but there should be exactly 1. "
+            raise Error("Found the following nodes '{}' with name 'crop_proposals' but there should be exactly 1. "
                         "Looks like ObjectDetectionAPIProposalReplacement replacement didn't work.".
                         format(proposal_nodes_ids))
         proposal_node = Node(graph, proposal_nodes_ids[0])
 
-        swapped_proposals_node = add_convolution_to_swap_xy_coordinates(graph, proposal_node, 5)
+        # check whether it is necessary to permute proposals coordinates before passing them to the DetectionOutput
+        # currently this parameter is set for the RFCN topologies
+        if 'swap_proposals' in custom_attributes and custom_attributes['swap_proposals']:
+            proposal_node = add_convolution_to_swap_xy_coordinates(graph, proposal_node, 4)
 
         # reshape priors boxes as Detection Output expects
-        reshape_priors_op = Reshape(graph, dict(dim=np.array([1, 1, -1])))
-        reshape_priors_node = reshape_priors_op.create_node([swapped_proposals_node],
+        reshape_priors_op = Reshape(graph, dict(dim=int64_array([-1, 1, max_proposals * 4])))
+        reshape_priors_node = reshape_priors_op.create_node([proposal_node],
                                                             dict(name='DetectionOutput_reshape_priors_'))
 
         detection_output_op = DetectionOutput(graph, {})
@@ -583,14 +596,16 @@ class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFil
             # update infer function to re-pack weights
             detection_output_op.attrs['old_infer'] = detection_output_op.attrs['infer']
             detection_output_op.attrs['infer'] = __class__.do_infer
+        for key in ('clip_before_nms', 'clip_after_nms'):
+            if key in match.custom_replacement_desc.custom_attributes:
+                detection_output_op.attrs[key] = int(match.custom_replacement_desc.custom_attributes[key])
+
         detection_output_node = detection_output_op.create_node(
             [reshape_loc_do_node, reshape_conf_node, reshape_priors_node],
-            dict(name=detection_output_op.attrs['type'], share_location=0, normalized=0, variance_encoded_in_target=1,
-                 clip=1, code_type='caffe.PriorBoxParameter.CENTER_SIZE', pad_mode='caffe.ResizeParameter.CONSTANT',
+            dict(name=detection_output_op.attrs['type'], share_location=0, variance_encoded_in_target=1,
+                 code_type='caffe.PriorBoxParameter.CENTER_SIZE', pad_mode='caffe.ResizeParameter.CONSTANT',
                  resize_mode='caffe.ResizeParameter.WARP',
                  num_classes=num_classes,
-                 input_height=graph.graph['preprocessed_image_height'],
-                 input_width=graph.graph['preprocessed_image_width'],
                  confidence_threshold=_value_or_raise(match, pipeline_config, 'postprocessing_score_threshold'),
                  top_k=_value_or_raise(match, pipeline_config, 'postprocessing_max_detections_per_class'),
                  keep_top_k=_value_or_raise(match, pipeline_config, 'postprocessing_max_total_detections'),
@@ -618,10 +633,13 @@ class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFil
 class ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement(FrontReplacementFromConfigFileSubGraph):
     replacement_id = 'ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement'
 
-    def output_edges_match(self, graph: nx.DiGraph, match: SubgraphMatch, new_sub_graph: dict):
+    def run_after(self):
+        return [ObjectDetectionAPIProposalReplacement]
+
+    def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
         return {match.output_node(0)[0].id: new_sub_graph['roi_pooling_node'].id}
 
-    def generate_sub_graph(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
         argv = graph.graph['cmd_params']
         if argv.tensorflow_object_detection_api_pipeline_config is None:
             raise Error(missing_param_error)
@@ -636,7 +654,7 @@ class ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement(FrontReplacementFrom
         detection_output_node = Node(graph, detection_output_nodes_ids[0])
 
         # add reshape of Detection Output so it can be an output of the topology
-        reshape_detection_output_2d_op = Reshape(graph, dict(dim=np.array([-1, 7])))
+        reshape_detection_output_2d_op = Reshape(graph, dict(dim=int64_array([-1, 7])))
         reshape_detection_output_2d_node = reshape_detection_output_2d_op.create_node(
             [detection_output_node], dict(name='reshape_do_2d'))
 
@@ -648,15 +666,24 @@ class ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement(FrontReplacementFrom
         output_node.in_edge()['data_attrs'].append('output_sort_order')
         output_node.in_edge()['output_sort_order'] = [('detection_boxes', 0)]
 
-        # creates the Crop operation that gets input from the DetectionOutput layer, cuts of slices of data with batch
-        # indices and class labels producing a tensor with classes probabilities and bounding boxes only as it is
-        # expected by the ROIPooling layer
-        crop_op = Crop(graph, dict(axis=np.array([3]), offset=np.array([2]), dim=np.array([5]), nchw_layout=True))
-        crop_node = crop_op.create_node([detection_output_node], dict(name='crop_do'))
+        # creates two Crop operations which get input from the DetectionOutput layer, cuts of slices of data with class
+        # ids and probabilities and produce a tensor with batch ids and bounding boxes only (as it is expected by the
+        # ROIPooling layer)
+        crop_batch_op = Crop(graph, dict(axis=int64_array([3]), offset=int64_array([0]), dim=int64_array([1]),
+                                         nchw_layout=True))
+        crop_batch_node = crop_batch_op.create_node([detection_output_node], dict(name='crop_do_batch_ids'))
+
+        crop_coordinates_op = Crop(graph, dict(axis=int64_array([3]), offset=int64_array([3]), dim=int64_array([4]),
+                                               nchw_layout=True))
+        crop_coordinates_node = crop_coordinates_op.create_node([detection_output_node], dict(name='crop_do_coords'))
+
+        concat_op = Concat(graph, dict(axis=3))
+        concat_node = concat_op.create_node([crop_batch_node, crop_coordinates_node], dict(name='batch_and_coords',
+                                                                                           nchw_layout=True))
 
         # reshape bounding boxes as required by ROIPooling
-        reshape_do_op = Reshape(graph, dict(dim=np.array([-1, 5])))
-        reshape_do_node = reshape_do_op.create_node([crop_node], dict(name='reshape_do'))
+        reshape_do_op = Reshape(graph, dict(dim=int64_array([-1, 5])))
+        reshape_do_node = reshape_do_op.create_node([concat_node], dict(name='reshape_do'))
 
         roi_pooling_op = ROIPooling(graph, dict(method="bilinear", spatial_scale=1,
                                                 pooled_h=roi_pool_size, pooled_w=roi_pool_size))
@@ -675,7 +702,7 @@ class ObjectDetectionAPIMaskRCNNSigmoidReplacement(FrontReplacementFromConfigFil
     def run_after(self):
         return [ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement]
 
-    def transform_graph(self, graph: nx.MultiDiGraph, replacement_descriptions):
+    def transform_graph(self, graph: Graph, replacement_descriptions):
         output_node = None
         op_outputs = [n for n, d in graph.nodes(data=True) if 'op' in d and d['op'] == 'OpOutput']
         for op_output in op_outputs:
@@ -711,24 +738,22 @@ class ObjectDetectionAPIProposalReplacement(FrontReplacementFromConfigFileSubGra
     def run_before(self):
         return [Sub, CropAndResizeReplacement]
 
-    def output_edges_match(self, graph: nx.DiGraph, match: SubgraphMatch, new_sub_graph: dict):
+    def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
         return {match.output_node(0)[0].id: new_sub_graph['proposal_node'].id}
 
-    def nodes_to_remove(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+    def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
         new_list = match.matched_nodes_names().copy()
         # do not remove nodes that produce box predictions and class predictions
         new_list.remove(match.single_input_node(0)[0].id)
         new_list.remove(match.single_input_node(1)[0].id)
         return new_list
 
-    def generate_sub_graph(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
         argv = graph.graph['cmd_params']
         if argv.tensorflow_object_detection_api_pipeline_config is None:
             raise Error(missing_param_error)
         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
 
-        input_height = graph.graph['preprocessed_image_height']
-        input_width = graph.graph['preprocessed_image_width']
         max_proposals = _value_or_raise(match, pipeline_config, 'first_stage_max_proposals')
         proposal_ratios = _value_or_raise(match, pipeline_config, 'anchor_generator_aspect_ratios')
         proposal_scales = _value_or_raise(match, pipeline_config, 'anchor_generator_scales')
@@ -737,39 +762,24 @@ class ObjectDetectionAPIProposalReplacement(FrontReplacementFromConfigFileSubGra
         # Convolution/matmul node that produces classes predictions
         # Permute result of the tensor with classes permissions so it will be in a correct layout for Softmax
         predictions_node = backward_bfs_for_operation(match.single_input_node(1)[0], ['Add'])[0]
-        permute_predictions_op = Permute(graph, dict(order=np.array([0, 2, 3, 1])))
-        permute_predictions_node = permute_predictions_op.create_node([], dict(name=predictions_node.name + '/Permute'))
-        insert_node_after(predictions_node, permute_predictions_node, 0)
-
-        # creates constant input with the image height, width and scale H and scale W (if present) required for Proposal
-        const_op = Const(graph, dict(value=np.array([[input_height, input_width, 1]], dtype=np.float32)))
-        const_node = const_op.create_node([], dict(name='proposal_const_image_size'))
-
-        reshape_classes_op = Reshape(graph, dict(dim=np.array([0, -1, 2])))
-        reshape_classes_node = reshape_classes_op.create_node([permute_predictions_node],
-                                                              dict(name='reshape_FirstStageBoxPredictor_class',
-                                                                   nchw_layout=True))
 
-        softmax_conf_op = Softmax(graph, dict(axis=2))
-        softmax_conf_node = softmax_conf_op.create_node([reshape_classes_node],
-                                                        dict(name='FirstStageBoxPredictor_softMax_class'))
-        PermuteAttrs.set_permutation(reshape_classes_node, softmax_conf_node, None)
+        reshape_classes_op = Reshape(graph, dict(dim=int64_array([0, anchors_count, 2, -1])))
+        reshape_classes_node = reshape_classes_op.create_node([], dict(name='predictions/Reshape', nchw_layout=True))
+        predictions_node.insert_node_after(reshape_classes_node, 0)
 
-        reshape_softmax_op = Reshape(graph, dict(dim=np.array([1, anchors_count, 2, -1])))
-        reshape_softmax_node = reshape_softmax_op.create_node([softmax_conf_node], dict(name='reshape_softmax_class'))
-        PermuteAttrs.set_permutation(softmax_conf_node, reshape_softmax_node, None)
+        softmax_conf_op = Softmax(graph, dict(axis=2, nchw_layout=True, name=reshape_classes_node.id + '/Softmax'))
+        softmax_conf_node = softmax_conf_op.create_node([reshape_classes_node])
+        permute_reshape_softmax_op = Permute(graph, dict(order=int64_array([0, 2, 1, 3]), nchw_layout=True))
+        permute_reshape_softmax_node = permute_reshape_softmax_op.create_node([softmax_conf_node], dict(
+            name=softmax_conf_node.name + '/Permute'))
 
-        permute_reshape_softmax_op = Permute(graph, dict(order=np.array([0, 1, 3, 2])))
-        permute_reshape_softmax_node = permute_reshape_softmax_op.create_node([reshape_softmax_node], dict(
-            name=reshape_softmax_node.name + '/Permute'))
+        initial_shape_op = Shape(graph, dict(name=predictions_node.id + '/Shape'))
+        initial_shape_node = initial_shape_op.create_node([predictions_node])
 
         # implement custom reshape infer function because we need to know the input convolution node output dimension
         # sizes but we can know it only after partial infer
-        reshape_permute_op = Reshape(graph,
-                                     dict(dim=np.ones([4]), anchors_count=anchors_count, conv_node=predictions_node))
-        reshape_permute_op.attrs['old_infer'] = reshape_permute_op.attrs['infer']
-        reshape_permute_op.attrs['infer'] = __class__.classes_probabilities_reshape_shape_infer
-        reshape_permute_node = reshape_permute_op.create_node([permute_reshape_softmax_node],
+        reshape_permute_op = Reshape(graph, dict())
+        reshape_permute_node = reshape_permute_op.create_node([permute_reshape_softmax_node, initial_shape_node],
                                                               dict(name='Reshape_Permute_Class'))
 
         variance_height = pipeline_config.get_param('frcnn_variance_height')
@@ -805,46 +815,61 @@ class ObjectDetectionAPIProposalReplacement(FrontReplacementFromConfigFileSubGra
                                              feat_stride=anchor_generator_height_stride,
                                              ratio=proposal_ratios,
                                              scale=proposal_scales,
+                                             normalize=1,
                                              base_size=anchor_generator_height,
                                              nms_thresh=_value_or_raise(match, pipeline_config,
                                                                         'first_stage_nms_iou_threshold')))
+        for key in ('clip_before_nms', 'clip_after_nms'):
+            if key in match.custom_replacement_desc.custom_attributes:
+                proposal_op.attrs[key] = int(match.custom_replacement_desc.custom_attributes[key])
 
         anchors_node = backward_bfs_for_operation(match.single_input_node(0)[0], ['Add'])[0]
-        proposal_node = proposal_op.create_node([reshape_permute_node, anchors_node, const_node],
-                                                dict(name='proposals'))
 
-        # the TF implementation of ROIPooling with bi-linear filtration need proposals scaled by image size
-        proposal_scale_const = np.array([1.0, 1 / input_height, 1 / input_width, 1 / input_height, 1 / input_width],
-                                        dtype=np.float32)
-        proposal_scale_const_op = Const(graph, dict(value=proposal_scale_const))
-        proposal_scale_const_node = proposal_scale_const_op.create_node([], dict(name='Proposal_scale_const'))
+        # creates input to store input image height, width and scales (usually 1.0s)
+        # the batch size for this input is fixed because it is allowed to pass images of the same size only as input
+        input_op_with_image_size = Input(graph, dict(shape=int64_array([1, 3]), fixed_batch=True))
+        input_with_image_size_node = input_op_with_image_size.create_node([], dict(name='image_info'))
 
-        scale_proposals_op = Eltwise(graph, dict(operation='mul'))
-        scale_proposals_node = scale_proposals_op.create_node([proposal_node, proposal_scale_const_node],
-                                                              dict(name='scaled_proposals'))
+        proposal_node = proposal_op.create_node([reshape_permute_node, anchors_node, input_with_image_size_node],
+                                                dict(name='proposals'))
 
-        proposal_reshape_4d_op = Reshape(graph, dict(dim=np.array([1, 1, max_proposals, 5]), nchw_layout=True))
-        proposal_reshape_4d_node = proposal_reshape_4d_op.create_node([scale_proposals_node],
-                                                                      dict(name="reshape_proposals_4d"))
+        if 'do_not_swap_proposals' in match.custom_replacement_desc.custom_attributes and \
+                match.custom_replacement_desc.custom_attributes['do_not_swap_proposals']:
+            swapped_proposals_node = proposal_node
+        else:
+            swapped_proposals_node = add_convolution_to_swap_xy_coordinates(graph, proposal_node, 5)
 
-        # creates the Crop operation that gets input from the Proposal layer and gets tensor with bounding boxes only
-        crop_op = Crop(graph, dict(axis=np.array([3]), offset=np.array([1]), dim=np.array([4]), nchw_layout=True))
-        crop_node = crop_op.create_node([proposal_reshape_4d_node], dict(name='crop_proposals'))
+        proposal_reshape_2d_op = Reshape(graph, dict(dim=int64_array([-1, 5]), nchw_layout=True))
+        proposal_reshape_2d_node = proposal_reshape_2d_op.create_node([swapped_proposals_node],
+                                                                      dict(name="reshape_swap_proposals_2d"))
 
-        proposal_reshape_3d_op = Reshape(graph, dict(dim=np.array([0, -1, 4]), nchw_layout=True))
-        proposal_reshape_3d_node = proposal_reshape_3d_op.create_node([crop_node], dict(name="tf_proposals"))
+        # feed the CropAndResize node with a correct boxes information produced with the Proposal layer
+        # find the first CropAndResize node in the BFS order
+        crop_and_resize_nodes_ids = [node_id for node_id in bfs_search(graph, [match.single_input_node(0)[0].id]) if
+                                     graph.node[node_id]['op'] == 'CropAndResize']
+        assert len(crop_and_resize_nodes_ids) != 0, "Didn't find any CropAndResize nodes in the graph."
+        if 'do_not_swap_proposals' not in match.custom_replacement_desc.custom_attributes or not \
+                match.custom_replacement_desc.custom_attributes['do_not_swap_proposals']:
+            crop_and_resize_node = Node(graph, crop_and_resize_nodes_ids[0])
+            # set a marker that the input with box coordinates has been pre-processed so the CropAndResizeReplacement
+            # transform doesn't try to merge the second and the third inputs
+            crop_and_resize_node['inputs_preprocessed'] = True
+            graph.remove_edge(crop_and_resize_node.in_node(1).id, crop_and_resize_node.id)
+            graph.create_edge(proposal_reshape_2d_node, crop_and_resize_node, out_port=0, in_port=1)
 
-        return {'proposal_node': proposal_reshape_3d_node}
+        tf_proposal_reshape_4d_op = Reshape(graph, dict(dim=int64_array([-1, 1, max_proposals, 5]), nchw_layout=True))
+        tf_proposal_reshape_4d_node = tf_proposal_reshape_4d_op.create_node([swapped_proposals_node],
+                                                                            dict(name="reshape_proposal_4d"))
 
-    @staticmethod
-    def classes_probabilities_reshape_shape_infer(node: Node):
-        # now we can determine the reshape dimensions from Convolution node
-        conv_node = node.conv_node
-        conv_output_shape = conv_node.out_node().shape
+        crop_op = Crop(graph, dict(axis=int64_array([3]), offset=int64_array([1]), dim=int64_array([4]),
+                                   nchw_layout=True))
+        crop_node = crop_op.create_node([tf_proposal_reshape_4d_node], dict(name='crop_proposals'))
 
-        # update desired shape of the Reshape node
-        node.dim = np.array([0, conv_output_shape[1], conv_output_shape[2], node.anchors_count * 2])
-        node.old_infer(node)
+        tf_proposals_crop_reshape_3d_op = Reshape(graph, dict(dim=int64_array([0, -1, 4]), nchw_layout=True))
+        tf_proposals_crop_reshape_3d_node = tf_proposals_crop_reshape_3d_op.create_node([crop_node],
+                                                                                        dict(name="reshape_crop_3d"))
+
+        return {'proposal_node': tf_proposals_crop_reshape_3d_node}
 
 
 class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFileSubGraph):
@@ -859,12 +884,12 @@ class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFi
         # nodes
         return [Div, StandaloneConstEraser]
 
-    def output_edges_match(self, graph: nx.DiGraph, match: SubgraphMatch, new_sub_graph: dict):
+    def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
         # the DetectionOutput in IE produces single tensor, but in TF it produces two tensors, so create only one output
         # edge match
         return {match.output_node(0)[0].id: new_sub_graph['detection_output_node'].id}
 
-    def generate_sub_graph(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
         argv = graph.graph['cmd_params']
         if argv.tensorflow_object_detection_api_pipeline_config is None:
             raise Error(missing_param_error)
@@ -872,7 +897,7 @@ class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFi
         num_classes = _value_or_raise(match, pipeline_config, 'num_classes')
 
         # reshapes confidences to 4D before applying activation function
-        expand_dims_op = Reshape(graph, {'dim': np.array([0, 1, -1, num_classes + 1])})
+        expand_dims_op = Reshape(graph, {'dim': int64_array([0, 1, -1, num_classes + 1])})
         # do not convert from NHWC to NCHW this node shape
         expand_dims_node = expand_dims_op.create_node([match.input_nodes(1)[0][0].in_node(0)],
                                                       dict(name='do_ExpandDims_conf'))
@@ -883,13 +908,13 @@ class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFi
 
         # IE DetectionOutput layer consumes flattened tensors
         # reshape operation to flatten locations tensor
-        reshape_loc_op = Reshape(graph, {'dim': np.array([0, -1])})
+        reshape_loc_op = Reshape(graph, {'dim': int64_array([0, -1])})
         reshape_loc_node = reshape_loc_op.create_node([match.input_nodes(0)[0][0].in_node(0)],
                                                       dict(name='do_reshape_loc'))
 
         # IE DetectionOutput layer consumes flattened tensors
         # reshape operation to flatten confidence tensor
-        reshape_conf_op = Reshape(graph, {'dim': np.array([0, -1])})
+        reshape_conf_op = Reshape(graph, {'dim': int64_array([0, -1])})
         reshape_conf_node = reshape_conf_op.create_node([activation_conf_node], dict(name='do_reshape_conf'))
 
         if pipeline_config.get_param('ssd_anchor_generator_num_layers') is not None or \
@@ -933,7 +958,7 @@ class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFi
             variance = _variance_from_pipeline_config(pipeline_config)
             # replicating the variance values for all prior-boxes
             variances = np.tile(variance, [prior_boxes.shape[-2], 1])
-            # DetectionOutput in the Inference Engine expects the prior-boxes in the following layout: (values, variances)
+            # DetectionOutput Inference Engine expects the prior-boxes in the following layout: (values, variances)
             prior_boxes = prior_boxes.reshape([-1, 4])
             prior_boxes = np.concatenate((prior_boxes, variances), 0)
             # compared to the IE's DetectionOutput, the TF keeps the prior-boxes in YXYX, need to get back to the XYXY
@@ -941,7 +966,7 @@ class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFi
                                           prior_boxes[:, 3:4], prior_boxes[:, 2:3]), 1)
             #  adding another dimensions, as the prior-boxes are expected as 3d tensors
             prior_boxes = prior_boxes.reshape((1, 2, -1))
-            node.in_node(2).shape = np.array(prior_boxes.shape, dtype=np.int64)
+            node.in_node(2).shape = int64_array(prior_boxes.shape)
             node.in_node(2).value = prior_boxes
 
         node.old_infer(node)
@@ -977,7 +1002,7 @@ class ObjectDetectionAPIOutputReplacement(FrontReplacementFromConfigFileGeneral)
     def run_before(self):
         return [ObjectDetectionAPIPreprocessorReplacement]
 
-    def transform_graph(self, graph: nx.MultiDiGraph, replacement_descriptions: dict):
+    def transform_graph(self, graph: Graph, replacement_descriptions: dict):
         if graph.graph['cmd_params'].output is not None:
             log.warning('User defined output nodes are specified. Skip the graph cut-off by the '
                         'ObjectDetectionAPIOutputReplacement.')
@@ -993,3 +1018,97 @@ class ObjectDetectionAPIOutputReplacement(FrontReplacementFromConfigFileGeneral)
                     log.debug('A node "{}" does not exist in the graph. Do not add it as output'.format(out_node_name))
         _outputs = output_user_data_repack(graph, outputs)
         add_output_ops(graph, _outputs, graph.graph['inputs'])
+
+
+class ObjectDetectionAPIPSROIPoolingReplacement(FrontReplacementFromConfigFileSubGraph):
+    replacement_id = 'ObjectDetectionAPIPSROIPoolingReplacement'
+
+    def run_after(self):
+        return [ObjectDetectionAPIProposalReplacement]
+
+    def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
+        return {match.output_node(0)[0].id: new_sub_graph['output_node'].id}
+
+    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
+        argv = graph.graph['cmd_params']
+        if argv.tensorflow_object_detection_api_pipeline_config is None:
+            raise Error(missing_param_error)
+        pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
+        num_classes = _value_or_raise(match, pipeline_config, 'num_classes')
+
+        input_node = match.input_nodes(0)[0][0].in_node(0)
+        if 'class_predictions' in input_node.id:
+            psroipooling_output_dim = num_classes + 1
+        else:
+            psroipooling_output_dim = num_classes * 4
+
+        num_spatial_bins_height = pipeline_config.get_param('num_spatial_bins_height')
+        num_spatial_bins_width = pipeline_config.get_param('num_spatial_bins_width')
+        crop_height = pipeline_config.get_param('crop_height')
+        crop_width = pipeline_config.get_param('crop_width')
+        if crop_height != crop_width:
+            raise Error('Different "crop_height" and "crop_width" parameters from the pipeline config are not '
+                        'supported: {} vs {}'.format(crop_height, crop_width))
+        psroipooling_op = PSROIPoolingOp(graph, {'name': input_node.soft_get('name') + '/PSROIPooling',
+                                                 'output_dim': psroipooling_output_dim,
+                                                 'group_size': crop_width / num_spatial_bins_width,
+                                                 'spatial_bins_x': num_spatial_bins_width,
+                                                 'spatial_bins_y': num_spatial_bins_height,
+                                                 'mode': 'bilinear',
+                                                 'spatial_scale': 1,
+                                                 })
+
+        if 'reshape_swap_proposals_2d' in graph.nodes():
+            reshape_swap_proposals_node = Node(graph, 'reshape_swap_proposals_2d')
+        else:
+            swap_proposals_node = add_convolution_to_swap_xy_coordinates(graph, Node(graph, 'proposals'), 5)
+            reshape_swap_proposals_node = Reshape(graph, {'dim': [-1, 5], 'nchw_layout': True,
+                                                          'name': 'reshape_swap_proposals_2d'}).create_node(
+                [swap_proposals_node])
+        psroipooling_node = psroipooling_op.create_node([input_node, reshape_swap_proposals_node])
+
+        reduce_op = Reduce(graph, {'name': 'mean',
+                                   'reduce_type': 'mean',
+                                   'axis': int64_array([1, 2]),
+                                   'keep_dims': True
+                                   })
+        reduce_node = reduce_op.create_node([psroipooling_node])
+
+        graph.erase_node(match.output_node(0)[0].out_node())
+
+        return {'output_node': reduce_node}
+
+
+class ObjectDetectionAPIConstValueOverride(FrontReplacementFromConfigFileGeneral):
+    """
+    Transforms allows to override specific constant values in the topology. The replacement description configuration
+    file contains list of tuples describing the desired replacements specified in the "replacements" key of the
+    "custom_attributes". The first element in the tuple is the initial node name of the graph with constant value. The
+    second element is the name of the parameter from the pipeline configuration file which stores new value.
+
+    Usage example. The Faster-RCNNs topologies has constant node with the number specifying maximum generated proposals.
+    This value is specified in the pipeline configuration file in the parameter 'first_stage_max_proposals' and is
+    saved as a constant node in the generated topology. If the parameter is modified from it's original value then the
+    topology will be incorrect because the number 'first_stage_max_proposals' is used in the transforms of this file is
+    no more equal to the 'first_stage_max_proposals' saved as a constant.
+    """
+    replacement_id = 'ObjectDetectionAPIConstValueOverride'
+
+    def run_before(self):
+        return [ObjectDetectionAPIPreprocessorReplacement]
+
+    def transform_graph(self, graph: Graph, replacement_descriptions: dict):
+        argv = graph.graph['cmd_params']
+        if argv.tensorflow_object_detection_api_pipeline_config is None:
+            raise Error(missing_param_error)
+        pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
+        for (node_id, pipeline_config_name) in replacement_descriptions['replacements']:
+            if node_id not in graph.nodes():
+                log.debug('Node with id {} does not exist in the graph'.format(node_id))
+                continue
+            node = Node(graph, node_id)
+            if not node.has_valid('value'):
+                log.debug('Node with id {} does not have value'.format(node_id))
+                continue
+            node.value = np.array(pipeline_config.get_param(pipeline_config_name))
+            node.value = node.value.reshape(node.shape)