new_nodes_to_remove.remove(node_to_keep)
return new_nodes_to_remove
+ def is_preprocessing_applied_before_resize(self, to_float: Node, mul: Node, sub: Node):
+ """
+ The function checks if the output of 'to_float' operation is consumed by 'mul' or 'sub'. If this is true then
+ the pre-processing (mean/scale) is applied before the image resize. The image resize was applied first in the
+ original version of the TF OD API models, but in the recent versions it is applied after.
+
+ :param to_float: the Cast node which converts the input tensor to Float
+ :param mul: the Mul node (can be None)
+ :param sub: the Sub node
+ :return: the result of the check
+ """
+ assert sub is not None, 'The Sub node should not be None. Check the caller function.'
+ if mul is not None:
+ return any([port.node.id == mul.id for port in to_float.out_port(0).get_destinations()])
+ else:
+ return any([port.node.id == sub.id for port in to_float.out_port(0).get_destinations()])
+
def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
argv = graph.graph['cmd_params']
layout = graph.graph['layout']
pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
sub_node = match.output_node(0)[0]
- if not sub_node.has('op') or sub_node.op != 'Sub':
+ if sub_node.soft_get('op') != 'Sub':
raise Error('The output op of the Preprocessor sub-graph is not of type "Sub". Looks like the topology is '
'not created with TensorFlow Object Detection API.')
mul_node = None
- if sub_node.in_node(0).has('op') and sub_node.in_node(0).op == 'Mul':
+ if sub_node.in_port(0).get_source().node.soft_get('op') == 'Mul':
log.info('There is image scaling node in the Preprocessor block.')
- mul_node = sub_node.in_node(0)
+ mul_node = sub_node.in_port(0).get_source().node
initial_input_node_name = 'image_tensor'
if initial_input_node_name not in graph.nodes():
graph.graph['preprocessed_image_height'] = placeholder_node.shape[get_height_dim(layout, 4)]
graph.graph['preprocessed_image_width'] = placeholder_node.shape[get_width_dim(layout, 4)]
- to_float_node = placeholder_node.out_node(0)
- if not to_float_node.has('op') or to_float_node.op != 'Cast':
- raise Error('The output of the node "{}" is not Cast operation. Cannot apply replacer.'.format(
+ to_float_node = placeholder_node.out_port(0).get_destination().node
+ if to_float_node.soft_get('op') != 'Cast':
+ raise Error('The output of the node "{}" is not Cast operation. Cannot apply transformation.'.format(
initial_input_node_name))
- # connect to_float_node directly with node performing scale on mean value subtraction
- if mul_node is None:
- graph.create_edge(to_float_node, sub_node, 0, 0)
+ if self.is_preprocessing_applied_before_resize(to_float_node, mul_node, sub_node):
+ # connect sub node directly to nodes which consume resized image
+ resize_output_node_id = 'Preprocessor/map/TensorArrayStack/TensorArrayGatherV3'
+ if resize_output_node_id not in graph.nodes:
+ raise Error('There is no expected node "{}" in the graph.'.format(resize_output_node_id))
+ resize_output = Node(graph, resize_output_node_id)
+ for dst_port in resize_output.out_port(0).get_destinations():
+ dst_port.get_connection().set_source(sub_node.out_port(0))
else:
- graph.create_edge(to_float_node, mul_node, 0, 1)
+ # connect to_float_node directly with node performing scale on mean value subtraction
+ if mul_node is None:
+ to_float_node.out_port(0).connect(sub_node.in_port(0))
+ else:
+ to_float_node.out_port(0).connect(mul_node.in_port(1))
print('The Preprocessor block has been removed. Only nodes performing mean value subtraction and scaling (if'
' applicable) are kept.')