2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.front.common.replacement import FrontReplacementOp
22 from mo.front.tf.graph_utils import add_convolution_to_swap_xy_coordinates
23 from mo.graph.graph import Node, Graph
24 from mo.ops.concat import Concat
25 from mo.ops.reshape import Reshape
26 from mo.ops.unsqueeze import Unsqueeze
29 class CropAndResizeReplacement(FrontReplacementOp):
31 The CropAndResize operation from TF gets separate input with boxes coordinates and image batch indices. But
32 ROIPooling operation in the Inference Engine receives them as a single concatenated input. This replacer
33 concatenates two inputs into a new one.
38 def nodes_to_remove(self, graph: Graph, match: dict):
39 # do not remove matched node
42 def replace_op(self, graph: Graph, node: Node):
43 if node.has_and_set('inputs_preprocessed'):
44 log.debug('Node "{}" has already been preprocessed'.format(node.soft_get('name')))
46 # reshape tensor with batch indices to 2d
47 unsqueeze_op = Unsqueeze(graph, {'unsqueeze_dims': np.array([1], dtype=np.int64)})
48 unsqueeze_node = unsqueeze_op.create_node([node.in_node(2)])
50 concat_op = Concat(graph, {'axis': 1, 'name': node.name + '/concat_batch_indices_and_boxes', 'in_ports_count': 2})
51 concat_node = concat_op.create_node([unsqueeze_node, node.in_node(1)])
53 # do not remove edge with crop_size because it is needed in the partial infer
54 graph.remove_edge(node.in_node(1).id, node.id)
56 # input to the CropAndResize contains boxes coordinates in YXYX layout. But IE layer ROIPooling expects
57 # coordinates in the XYXY layout, so convolution is added here to swap coordinates
58 swapped_box_coordinates_node = add_convolution_to_swap_xy_coordinates(graph, concat_node, 5)
60 # reshape locations tensor to 2D so it could be passed to Eltwise which will be converted to ScaleShift
61 reshape_2d_op = Reshape(graph, dict(dim=np.array([-1, 5])))
63 reshape_2d_node = reshape_2d_op.create_node([swapped_box_coordinates_node],
64 dict(name=swapped_box_coordinates_node.id + '/reshape_2d_',
66 graph.create_edge(reshape_2d_node, node, 0, 1)
68 # do not replace any output edge