Added support for a new version of the TF OD API pre-processing part (#3063)
authorEvgeny Lazarev <evgeny.lazarev@intel.com>
Wed, 11 Nov 2020 08:53:10 +0000 (11:53 +0300)
committerGitHub <noreply@github.com>
Wed, 11 Nov 2020 08:53:10 +0000 (11:53 +0300)
* Added support for a new version of the TF OD API pre-processing part of the mode

* Get rid of legacy API usage

* Fix comment and added assert

* Wording

model-optimizer/extensions/front/tf/ObjectDetectionAPI.py

index f90a754..125aa8f 100644 (file)
@@ -486,6 +486,23 @@ class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSu
                 new_nodes_to_remove.remove(node_to_keep)
         return new_nodes_to_remove
 
+    def is_preprocessing_applied_before_resize(self, to_float: Node, mul: Node, sub: Node):
+        """
+        The function checks if the output of 'to_float' operation is consumed by 'mul' or 'sub'. If this is true then
+        the pre-processing (mean/scale) is applied before the image resize. The image resize was applied first in the
+        original version of the TF OD API models, but in the recent versions it is applied after.
+
+        :param to_float: the Cast node which converts the input tensor to Float
+        :param mul: the Mul node (can be None)
+        :param sub: the Sub node
+        :return: the result of the check
+        """
+        assert sub is not None, 'The Sub node should not be None. Check the caller function.'
+        if mul is not None:
+            return any([port.node.id == mul.id for port in to_float.out_port(0).get_destinations()])
+        else:
+            return any([port.node.id == sub.id for port in to_float.out_port(0).get_destinations()])
+
     def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
         argv = graph.graph['cmd_params']
         layout = graph.graph['layout']
@@ -494,14 +511,14 @@ class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSu
         pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
 
         sub_node = match.output_node(0)[0]
-        if not sub_node.has('op') or sub_node.op != 'Sub':
+        if sub_node.soft_get('op') != 'Sub':
             raise Error('The output op of the Preprocessor sub-graph is not of type "Sub". Looks like the topology is '
                         'not created with TensorFlow Object Detection API.')
 
         mul_node = None
-        if sub_node.in_node(0).has('op') and sub_node.in_node(0).op == 'Mul':
+        if sub_node.in_port(0).get_source().node.soft_get('op') == 'Mul':
             log.info('There is image scaling node in the Preprocessor block.')
-            mul_node = sub_node.in_node(0)
+            mul_node = sub_node.in_port(0).get_source().node
 
         initial_input_node_name = 'image_tensor'
         if initial_input_node_name not in graph.nodes():
@@ -521,16 +538,25 @@ class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSu
         graph.graph['preprocessed_image_height'] = placeholder_node.shape[get_height_dim(layout, 4)]
         graph.graph['preprocessed_image_width'] = placeholder_node.shape[get_width_dim(layout, 4)]
 
-        to_float_node = placeholder_node.out_node(0)
-        if not to_float_node.has('op') or to_float_node.op != 'Cast':
-            raise Error('The output of the node "{}" is not Cast operation. Cannot apply replacer.'.format(
+        to_float_node = placeholder_node.out_port(0).get_destination().node
+        if to_float_node.soft_get('op') != 'Cast':
+            raise Error('The output of the node "{}" is not Cast operation. Cannot apply transformation.'.format(
                 initial_input_node_name))
 
-        # connect to_float_node directly with node performing scale on mean value subtraction
-        if mul_node is None:
-            graph.create_edge(to_float_node, sub_node, 0, 0)
+        if self.is_preprocessing_applied_before_resize(to_float_node, mul_node, sub_node):
+            # connect sub node directly to nodes which consume resized image
+            resize_output_node_id = 'Preprocessor/map/TensorArrayStack/TensorArrayGatherV3'
+            if resize_output_node_id not in graph.nodes:
+                raise Error('There is no expected node "{}" in the graph.'.format(resize_output_node_id))
+            resize_output = Node(graph, resize_output_node_id)
+            for dst_port in resize_output.out_port(0).get_destinations():
+                dst_port.get_connection().set_source(sub_node.out_port(0))
         else:
-            graph.create_edge(to_float_node, mul_node, 0, 1)
+            # connect to_float_node directly with node performing scale on mean value subtraction
+            if mul_node is None:
+                to_float_node.out_port(0).connect(sub_node.in_port(0))
+            else:
+                to_float_node.out_port(0).connect(mul_node.in_port(1))
 
         print('The Preprocessor block has been removed. Only nodes performing mean value subtraction and scaling (if'
               ' applicable) are kept.')