"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
-import networkx as nx
+import logging as log
+
import numpy as np
-from mo.front.tf.graph_utils import add_convolution_to_swap_xy_coordinates
from mo.front.common.replacement import FrontReplacementOp
-from mo.graph.graph import Node, create_edge
+from mo.front.tf.graph_utils import add_convolution_to_swap_xy_coordinates
+from mo.graph.graph import Node, Graph
from mo.ops.concat import Concat
from mo.ops.reshape import Reshape
from mo.ops.unsqueeze import Unsqueeze
op = "CropAndResize"
enabled = True
- def nodes_to_remove(self, graph: nx.MultiDiGraph, match: dict):
+ def nodes_to_remove(self, graph: Graph, match: dict):
# do not remove matched node
return []
- def replace_op(self, graph: nx.MultiDiGraph, node: Node):
+ def replace_op(self, graph: Graph, node: Node):
+ if node.has_and_set('inputs_preprocessed'):
+ log.debug('Node "{}" has already been preprocessed'.format(node.soft_get('name')))
+ return []
# reshape tensor with batch indices to 2d
unsqueeze_op = Unsqueeze(graph, {'unsqueeze_dims': np.array([1], dtype=np.int64)})
unsqueeze_node = unsqueeze_op.create_node([node.in_node(2)])
- concat_op = Concat(graph, {'axis': 1, 'name': node.name + '/concat_batch_indices_and_boxes'})
+ concat_op = Concat(graph, {'axis': 1, 'name': node.name + '/concat_batch_indices_and_boxes', 'in_ports_count': 2})
concat_node = concat_op.create_node([unsqueeze_node, node.in_node(1)])
# do not remove edge with crop_size because it is needed in the partial infer
# reshape locations tensor to 2D so it could be passed to Eltwise which will be converted to ScaleShift
reshape_2d_op = Reshape(graph, dict(dim=np.array([-1, 5])))
- reshape_2d_node = reshape_2d_op.create_node([swapped_box_coordinates_node], dict(name='reshape_2d_'))
- create_edge(reshape_2d_node, node, 0, 1)
+
+ reshape_2d_node = reshape_2d_op.create_node([swapped_box_coordinates_node],
+ dict(name=swapped_box_coordinates_node.id + '/reshape_2d_',
+ nchw_layout=True))
+ graph.create_edge(reshape_2d_node, node, 0, 1)
# do not replace any output edge
return []
-