Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / fifo_replacer.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import logging as log
17
18 import numpy as np
19
20 from mo.front.common.replacement import FrontReplacementSubgraph
21 from mo.graph.graph import Graph, Node
22 from mo.ops.input import Input
23
24
25 class FIFOQueue(FrontReplacementSubgraph):
26     enabled = True
27
28     def run_before(self):
29         from extensions.front.override_batch import OverrideBatch
30         return [OverrideBatch]
31
32     @staticmethod
33     def pattern(**kwargs):
34         return dict(
35             nodes=[
36                 ('placeholder', dict(op='Placeholder', data_type=np.int32)),
37                 ('fifo_queue', dict(op='FIFOQueueV2')),
38                 ('batch_join', dict(op='QueueDequeueUpToV2')),
39                 ('image_batch', dict(op='Identity', data_type=np.float32))
40             ],
41             edges=[
42                 ('placeholder', 'batch_join', {'out': 0}),
43                 ('fifo_queue', 'batch_join', {'out': 0}),
44                 ('batch_join', 'image_batch', {'out': 0})
45             ]
46         )
47
48     @staticmethod
49     def replace_sub_graph(graph: Graph, match: dict, **kwargs):
50         """
51         Usually graph looks like:
52
53           main_graph
54             ...             OpOutput
55              |                 |
56         image_batch      label_batch
57                 \        /
58                 batch_join
59                 /        \
60         placeholder      fifo_queue
61
62         Replacer works for both cases (that's why we have loop - 68 line):
63             label_batch was marked as output
64             there is no label_batch node
65         """
66         true_placeholder_shape = match['placeholder'].shape
67         placeholder_shape = match['fifo_queue'].shapes[0]
68         assert true_placeholder_shape.ndim <= 1
69         if true_placeholder_shape.ndim == 1 and len(true_placeholder_shape) > 1:
70             log.warning(
71                 'Placeholder \'{}\' got non 0-dimensional shape {} in FIFOQueue pattern. Placeholder will have the '
72                 'same shape after folding the pattern instead of {} shape which is original for the network.'
73                 ''.format(match['placeholder'].id, true_placeholder_shape, placeholder_shape))
74             placeholder_shape = true_placeholder_shape
75         placeholder_name = match['fifo_queue'].name
76         graph.erase_node(match['fifo_queue'])
77         graph.erase_node(match['placeholder'])
78         for _, out in match['batch_join'].out_nodes().items():
79             if out.id != match['image_batch'].id:
80                 if out.out_node().op == 'OpOutput':
81                     graph.remove_node(out.out_node().id)
82                 graph.remove_node(out.id)
83         graph.remove_node(match['batch_join'].id)
84         placeholder = Input(graph, {'name': placeholder_name, 'shape': placeholder_shape}).create_node()
85         graph.create_edge(placeholder, match['image_batch'])
86         log.info("FIFOQueueV2 pattern was detected. New shape of placeholder {} is {}. Use -b to set batch size if "
87                  "needed".format(placeholder.id, placeholder['shape']))
88
89
90 class QueueDequeueManyV2(FrontReplacementSubgraph):
91     """
92     Replaces the combination of the FIFOQueueV2 + QueueDequeueManyV2 operations with a number of Placeholders.
93     """
94     enabled = True
95
96     def run_before(self):
97         from extensions.front.override_batch import OverrideBatch
98         return [OverrideBatch]
99
100     @staticmethod
101     def pattern(**kwargs):
102         return dict(
103             nodes=[
104                 ('fifo_queue', dict(op='FIFOQueueV2')),
105                 ('queue_deque', dict(op='QueueDequeueManyV2')),
106             ],
107             edges=[
108                 ('fifo_queue', 'queue_deque', {'out': 0}),
109             ]
110         )
111
112     @staticmethod
113     def replace_sub_graph(graph: Graph, match: dict, **kwargs):
114         inputs_dict = {}
115         for u, v, edge_attrs in graph.out_edges(match['queue_deque'].id, data=True):
116             out_port = edge_attrs['out']
117             shape = match['fifo_queue'].shapes[out_port]
118             if out_port not in inputs_dict:
119                 input_op = Input(graph, {'shape': shape.copy()})
120                 inputs_dict[out_port] = input_op.create_node([])
121             graph.create_edge(inputs_dict[out_port], Node(graph, v), edge_attrs['out'], edge_attrs['in'], edge_attrs)
122
123         graph.remove_node(match['queue_deque'].id)
124         graph.remove_node(match['fifo_queue'].id)
125