Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / fifo_replacer.py
index 576dcf1..9063cf5 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 """
 import logging as log
 
-import networkx as nx
 import numpy as np
 
 from mo.front.common.replacement import FrontReplacementSubgraph
-from mo.graph.graph import create_edge, erase_node, Node
+from mo.graph.graph import Graph, Node
 from mo.ops.input import Input
 
 
 class FIFOQueue(FrontReplacementSubgraph):
     enabled = True
 
+    def run_before(self):
+        from extensions.front.override_batch import OverrideBatch
+        return [OverrideBatch]
+
     @staticmethod
     def pattern(**kwargs):
         return dict(
@@ -43,7 +46,7 @@ class FIFOQueue(FrontReplacementSubgraph):
         )
 
     @staticmethod
-    def replace_sub_graph(graph: nx.MultiDiGraph, match: dict, **kwargs):
+    def replace_sub_graph(graph: Graph, match: dict, **kwargs):
         """
         Usually graph looks like:
 
@@ -70,16 +73,16 @@ class FIFOQueue(FrontReplacementSubgraph):
                 ''.format(match['placeholder'].id, true_placeholder_shape, placeholder_shape))
             placeholder_shape = true_placeholder_shape
         placeholder_name = match['fifo_queue'].name
-        erase_node(match['fifo_queue'])
-        erase_node(match['placeholder'])
+        graph.erase_node(match['fifo_queue'])
+        graph.erase_node(match['placeholder'])
         for _, out in match['batch_join'].out_nodes().items():
             if out.id != match['image_batch'].id:
                 if out.out_node().op == 'OpOutput':
-                    erase_node(out.out_node())
-                erase_node(out)
-        erase_node(match['batch_join'])
+                    graph.remove_node(out.out_node().id)
+                graph.remove_node(out.id)
+        graph.remove_node(match['batch_join'].id)
         placeholder = Input(graph, {'name': placeholder_name, 'shape': placeholder_shape}).create_node()
-        create_edge(placeholder, match['image_batch'])
+        graph.create_edge(placeholder, match['image_batch'])
         log.info("FIFOQueueV2 pattern was detected. New shape of placeholder {} is {}. Use -b to set batch size if "
                  "needed".format(placeholder.id, placeholder['shape']))
 
@@ -90,6 +93,10 @@ class QueueDequeueManyV2(FrontReplacementSubgraph):
     """
     enabled = True
 
+    def run_before(self):
+        from extensions.front.override_batch import OverrideBatch
+        return [OverrideBatch]
+
     @staticmethod
     def pattern(**kwargs):
         return dict(
@@ -103,7 +110,7 @@ class QueueDequeueManyV2(FrontReplacementSubgraph):
         )
 
     @staticmethod
-    def replace_sub_graph(graph: nx.MultiDiGraph, match: dict, **kwargs):
+    def replace_sub_graph(graph: Graph, match: dict, **kwargs):
         inputs_dict = {}
         for u, v, edge_attrs in graph.out_edges(match['queue_deque'].id, data=True):
             out_port = edge_attrs['out']
@@ -111,7 +118,7 @@ class QueueDequeueManyV2(FrontReplacementSubgraph):
             if out_port not in inputs_dict:
                 input_op = Input(graph, {'shape': shape.copy()})
                 inputs_dict[out_port] = input_op.create_node([])
-            create_edge(inputs_dict[out_port], Node(graph, v), edge_attrs['out'], edge_attrs['in'], edge_attrs)
+            graph.create_edge(inputs_dict[out_port], Node(graph, v), edge_attrs['out'], edge_attrs['in'], edge_attrs)
 
         graph.remove_node(match['queue_deque'].id)
         graph.remove_node(match['fifo_queue'].id)