Publishing 2019 R3.1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / fifo_replacer.py
index 6eebe8c..8374a58 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 """
 import logging as log
 
-import networkx as nx
 import numpy as np
 
+from extensions.ops.parameter import Parameter
 from mo.front.common.replacement import FrontReplacementSubgraph
-from mo.graph.graph import create_edge, erase_node
-from mo.ops.input import Input
+from mo.graph.graph import Graph, Node
 
 
 class FIFOQueue(FrontReplacementSubgraph):
     enabled = True
 
+    def run_before(self):
+        from extensions.front.override_batch import OverrideBatch
+        return [OverrideBatch]
+
     @staticmethod
     def pattern(**kwargs):
         return dict(
             nodes=[
-                ('placeholder', dict(op='Placeholder', data_type=np.int32)),
+                ('placeholder', dict(op='Parameter', data_type=np.int32)),
                 ('fifo_queue', dict(op='FIFOQueueV2')),
                 ('batch_join', dict(op='QueueDequeueUpToV2')),
                 ('image_batch', dict(op='Identity', data_type=np.float32))
@@ -43,12 +46,12 @@ class FIFOQueue(FrontReplacementSubgraph):
         )
 
     @staticmethod
-    def replace_sub_graph(graph: nx.MultiDiGraph, match: dict, **kwargs):
+    def replace_sub_graph(graph: Graph, match: dict, **kwargs):
         """
         Usually graph looks like:
 
           main_graph
-            ...             OpOutput
+            ...             Result
              |                 |
         image_batch      label_batch
                 \        /
@@ -61,7 +64,8 @@ class FIFOQueue(FrontReplacementSubgraph):
             there is no label_batch node
         """
         true_placeholder_shape = match['placeholder'].shape
-        placeholder_shape = match['fifo_queue'].shape
+        placeholder_shape = match['fifo_queue'].shapes[0]
+        placeholder_data_type = match['fifo_queue'].types[0]
         assert true_placeholder_shape.ndim <= 1
         if true_placeholder_shape.ndim == 1 and len(true_placeholder_shape) > 1:
             log.warning(
@@ -70,15 +74,54 @@ class FIFOQueue(FrontReplacementSubgraph):
                 ''.format(match['placeholder'].id, true_placeholder_shape, placeholder_shape))
             placeholder_shape = true_placeholder_shape
         placeholder_name = match['fifo_queue'].name
-        erase_node(match['fifo_queue'])
-        erase_node(match['placeholder'])
+        graph.erase_node(match['fifo_queue'])
+        graph.erase_node(match['placeholder'])
         for _, out in match['batch_join'].out_nodes().items():
             if out.id != match['image_batch'].id:
-                if out.out_node().op == 'OpOutput':
-                    erase_node(out.out_node())
-                erase_node(out)
-        erase_node(match['batch_join'])
-        placeholder = Input(graph, {'name': placeholder_name, 'shape': placeholder_shape}).create_node()
-        create_edge(placeholder, match['image_batch'])
+                if out.out_node().op == 'Result':
+                    graph.remove_node(out.out_node().id)
+                graph.remove_node(out.id)
+        graph.remove_node(match['batch_join'].id)
+        placeholder = Parameter(graph, {'name': placeholder_name, 'shape': placeholder_shape,
+                                        'data_type': placeholder_data_type}).create_node()
+        graph.create_edge(placeholder, match['image_batch'])
         log.info("FIFOQueueV2 pattern was detected. New shape of placeholder {} is {}. Use -b to set batch size if "
                  "needed".format(placeholder.id, placeholder['shape']))
+
+
+class QueueDequeueManyV2(FrontReplacementSubgraph):
+    """
+    Replaces the combination of the FIFOQueueV2 + QueueDequeueManyV2 operations with a number of Placeholders.
+    """
+    enabled = True
+
+    def run_before(self):
+        from extensions.front.override_batch import OverrideBatch
+        return [OverrideBatch]
+
+    @staticmethod
+    def pattern(**kwargs):
+        return dict(
+            nodes=[
+                ('fifo_queue', dict(op='FIFOQueueV2')),
+                ('queue_deque', dict(op='QueueDequeueManyV2')),
+            ],
+            edges=[
+                ('fifo_queue', 'queue_deque', {'out': 0}),
+            ]
+        )
+
+    @staticmethod
+    def replace_sub_graph(graph: Graph, match: dict, **kwargs):
+        inputs_dict = {}
+        for u, v, edge_attrs in graph.out_edges(match['queue_deque'].id, data=True):
+            out_port = edge_attrs['out']
+            shape = match['fifo_queue'].shapes[out_port]
+            if out_port not in inputs_dict:
+                input_op = Parameter(graph, {'shape': shape.copy()})
+                inputs_dict[out_port] = input_op.create_node([])
+            graph.create_edge(inputs_dict[out_port], Node(graph, v), edge_attrs['out'], edge_attrs['in'], edge_attrs)
+
+        graph.remove_node(match['queue_deque'].id)
+        graph.remove_node(match['fifo_queue'].id)
+