import numpy as np
+from extensions.ops.parameter import Parameter
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.graph.graph import Graph, Node
-from mo.ops.input import Input
class FIFOQueue(FrontReplacementSubgraph):
def pattern(**kwargs):
return dict(
nodes=[
- ('placeholder', dict(op='Placeholder', data_type=np.int32)),
+ ('placeholder', dict(op='Parameter', data_type=np.int32)),
('fifo_queue', dict(op='FIFOQueueV2')),
('batch_join', dict(op='QueueDequeueUpToV2')),
('image_batch', dict(op='Identity', data_type=np.float32))
Usually graph looks like:
main_graph
- ... OpOutput
+ ... Result
| |
image_batch label_batch
\ /
"""
true_placeholder_shape = match['placeholder'].shape
placeholder_shape = match['fifo_queue'].shapes[0]
+ placeholder_data_type = match['fifo_queue'].types[0]
assert true_placeholder_shape.ndim <= 1
if true_placeholder_shape.ndim == 1 and len(true_placeholder_shape) > 1:
log.warning(
graph.erase_node(match['placeholder'])
for _, out in match['batch_join'].out_nodes().items():
if out.id != match['image_batch'].id:
- if out.out_node().op == 'OpOutput':
+ if out.out_node().op == 'Result':
graph.remove_node(out.out_node().id)
graph.remove_node(out.id)
graph.remove_node(match['batch_join'].id)
- placeholder = Input(graph, {'name': placeholder_name, 'shape': placeholder_shape}).create_node()
+ placeholder = Parameter(graph, {'name': placeholder_name, 'shape': placeholder_shape,
+ 'data_type': placeholder_data_type}).create_node()
graph.create_edge(placeholder, match['image_batch'])
log.info("FIFOQueueV2 pattern was detected. New shape of placeholder {} is {}. Use -b to set batch size if "
"needed".format(placeholder.id, placeholder['shape']))
out_port = edge_attrs['out']
shape = match['fifo_queue'].shapes[out_port]
if out_port not in inputs_dict:
- input_op = Input(graph, {'shape': shape.copy()})
+ input_op = Parameter(graph, {'shape': shape.copy()})
inputs_dict[out_port] = input_op.create_node([])
graph.create_edge(inputs_dict[out_port], Node(graph, v), edge_attrs['out'], edge_attrs['in'], edge_attrs)