"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
"""
import logging as log
-import networkx as nx
import numpy as np
from mo.front.common.replacement import FrontReplacementSubgraph
-from mo.graph.graph import create_edge, erase_node, Node
+from mo.graph.graph import Graph, Node
from mo.ops.input import Input
class FIFOQueue(FrontReplacementSubgraph):
enabled = True
+ def run_before(self):
+ from extensions.front.override_batch import OverrideBatch
+ return [OverrideBatch]
+
@staticmethod
def pattern(**kwargs):
return dict(
)
@staticmethod
- def replace_sub_graph(graph: nx.MultiDiGraph, match: dict, **kwargs):
+ def replace_sub_graph(graph: Graph, match: dict, **kwargs):
"""
Usually graph looks like:
''.format(match['placeholder'].id, true_placeholder_shape, placeholder_shape))
placeholder_shape = true_placeholder_shape
placeholder_name = match['fifo_queue'].name
- erase_node(match['fifo_queue'])
- erase_node(match['placeholder'])
+ graph.erase_node(match['fifo_queue'])
+ graph.erase_node(match['placeholder'])
for _, out in match['batch_join'].out_nodes().items():
if out.id != match['image_batch'].id:
if out.out_node().op == 'OpOutput':
- erase_node(out.out_node())
- erase_node(out)
- erase_node(match['batch_join'])
+ graph.remove_node(out.out_node().id)
+ graph.remove_node(out.id)
+ graph.remove_node(match['batch_join'].id)
placeholder = Input(graph, {'name': placeholder_name, 'shape': placeholder_shape}).create_node()
- create_edge(placeholder, match['image_batch'])
+ graph.create_edge(placeholder, match['image_batch'])
log.info("FIFOQueueV2 pattern was detected. New shape of placeholder {} is {}. Use -b to set batch size if "
"needed".format(placeholder.id, placeholder['shape']))
"""
enabled = True
+ def run_before(self):
+ from extensions.front.override_batch import OverrideBatch
+ return [OverrideBatch]
+
@staticmethod
def pattern(**kwargs):
return dict(
)
@staticmethod
- def replace_sub_graph(graph: nx.MultiDiGraph, match: dict, **kwargs):
+ def replace_sub_graph(graph: Graph, match: dict, **kwargs):
inputs_dict = {}
for u, v, edge_attrs in graph.out_edges(match['queue_deque'].id, data=True):
out_port = edge_attrs['out']
if out_port not in inputs_dict:
input_op = Input(graph, {'shape': shape.copy()})
inputs_dict[out_port] = input_op.create_node([])
- create_edge(inputs_dict[out_port], Node(graph, v), edge_attrs['out'], edge_attrs['in'], edge_attrs)
+ graph.create_edge(inputs_dict[out_port], Node(graph, v), edge_attrs['out'], edge_attrs['in'], edge_attrs)
graph.remove_node(match['queue_deque'].id)
graph.remove_node(match['fifo_queue'].id)