Publishing 2019 R3.1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / fifo_replacer.py
index 9063cf5..8374a58 100644 (file)
@@ -17,9 +17,9 @@ import logging as log
 
 import numpy as np
 
+from extensions.ops.parameter import Parameter
 from mo.front.common.replacement import FrontReplacementSubgraph
 from mo.graph.graph import Graph, Node
-from mo.ops.input import Input
 
 
 class FIFOQueue(FrontReplacementSubgraph):
@@ -33,7 +33,7 @@ class FIFOQueue(FrontReplacementSubgraph):
     def pattern(**kwargs):
         return dict(
             nodes=[
-                ('placeholder', dict(op='Placeholder', data_type=np.int32)),
+                ('placeholder', dict(op='Parameter', data_type=np.int32)),
                 ('fifo_queue', dict(op='FIFOQueueV2')),
                 ('batch_join', dict(op='QueueDequeueUpToV2')),
                 ('image_batch', dict(op='Identity', data_type=np.float32))
@@ -51,7 +51,7 @@ class FIFOQueue(FrontReplacementSubgraph):
         Usually graph looks like:
 
           main_graph
-            ...             OpOutput
+            ...             Result
              |                 |
         image_batch      label_batch
                 \        /
@@ -65,6 +65,7 @@ class FIFOQueue(FrontReplacementSubgraph):
         """
         true_placeholder_shape = match['placeholder'].shape
         placeholder_shape = match['fifo_queue'].shapes[0]
+        placeholder_data_type = match['fifo_queue'].types[0]
         assert true_placeholder_shape.ndim <= 1
         if true_placeholder_shape.ndim == 1 and len(true_placeholder_shape) > 1:
             log.warning(
@@ -77,11 +78,12 @@ class FIFOQueue(FrontReplacementSubgraph):
         graph.erase_node(match['placeholder'])
         for _, out in match['batch_join'].out_nodes().items():
             if out.id != match['image_batch'].id:
-                if out.out_node().op == 'OpOutput':
+                if out.out_node().op == 'Result':
                     graph.remove_node(out.out_node().id)
                 graph.remove_node(out.id)
         graph.remove_node(match['batch_join'].id)
-        placeholder = Input(graph, {'name': placeholder_name, 'shape': placeholder_shape}).create_node()
+        placeholder = Parameter(graph, {'name': placeholder_name, 'shape': placeholder_shape,
+                                        'data_type': placeholder_data_type}).create_node()
         graph.create_edge(placeholder, match['image_batch'])
         log.info("FIFOQueueV2 pattern was detected. New shape of placeholder {} is {}. Use -b to set batch size if "
                  "needed".format(placeholder.id, placeholder['shape']))
@@ -116,7 +118,7 @@ class QueueDequeueManyV2(FrontReplacementSubgraph):
             out_port = edge_attrs['out']
             shape = match['fifo_queue'].shapes[out_port]
             if out_port not in inputs_dict:
-                input_op = Input(graph, {'shape': shape.copy()})
+                input_op = Parameter(graph, {'shape': shape.copy()})
                 inputs_dict[out_port] = input_op.create_node([])
             graph.create_edge(inputs_dict[out_port], Node(graph, v), edge_attrs['out'], edge_attrs['in'], edge_attrs)