Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / CropAndResizeReplacement.py
index d02f109..15c1103 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  limitations under the License.
 """
 
-import networkx as nx
+import logging as log
+
 import numpy as np
 
-from mo.front.tf.graph_utils import add_convolution_to_swap_xy_coordinates
 from mo.front.common.replacement import FrontReplacementOp
-from mo.graph.graph import Node, create_edge
+from mo.front.tf.graph_utils import add_convolution_to_swap_xy_coordinates
+from mo.graph.graph import Node, Graph
 from mo.ops.concat import Concat
 from mo.ops.reshape import Reshape
 from mo.ops.unsqueeze import Unsqueeze
@@ -34,16 +35,19 @@ class CropAndResizeReplacement(FrontReplacementOp):
     op = "CropAndResize"
     enabled = True
 
-    def nodes_to_remove(self, graph: nx.MultiDiGraph, match: dict):
+    def nodes_to_remove(self, graph: Graph, match: dict):
         # do not remove matched node
         return []
 
-    def replace_op(self, graph: nx.MultiDiGraph, node: Node):
+    def replace_op(self, graph: Graph, node: Node):
+        if node.has_and_set('inputs_preprocessed'):
+            log.debug('Node "{}" has already been preprocessed'.format(node.soft_get('name')))
+            return []
         # reshape tensor with batch indices to 2d
         unsqueeze_op = Unsqueeze(graph, {'unsqueeze_dims': np.array([1], dtype=np.int64)})
         unsqueeze_node = unsqueeze_op.create_node([node.in_node(2)])
 
-        concat_op = Concat(graph, {'axis': 1, 'name': node.name + '/concat_batch_indices_and_boxes'})
+        concat_op = Concat(graph, {'axis': 1, 'name': node.name + '/concat_batch_indices_and_boxes', 'in_ports_count': 2})
         concat_node = concat_op.create_node([unsqueeze_node, node.in_node(1)])
 
         # do not remove edge with crop_size because it is needed in the partial infer
@@ -55,9 +59,11 @@ class CropAndResizeReplacement(FrontReplacementOp):
 
         # reshape locations tensor to 2D so it could be passed to Eltwise which will be converted to ScaleShift
         reshape_2d_op = Reshape(graph, dict(dim=np.array([-1, 5])))
-        reshape_2d_node = reshape_2d_op.create_node([swapped_box_coordinates_node], dict(name='reshape_2d_'))
-        create_edge(reshape_2d_node, node, 0, 1)
+
+        reshape_2d_node = reshape_2d_op.create_node([swapped_box_coordinates_node],
+                                                    dict(name=swapped_box_coordinates_node.id + '/reshape_2d_',
+                                                         nchw_layout=True))
+        graph.create_edge(reshape_2d_node, node, 0, 1)
 
         # do not replace any output edge
         return []
-