Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / fusing / helpers.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18 from collections import deque
19
20 from mo.graph.graph import Node
21
22
23 def get_value_id(node: Node):
24     assert node.has_valid('op')
25     value_id = None
26     for port, in_node in node.in_nodes().items():
27         if in_node.has_valid('value'):
28             if value_id:
29                 return None
30             value_id = port
31     return value_id
32
33
34 def get_tensor_id(node: Node):
35     assert node.has_valid('op')
36     tensor_id = None
37     for port, in_node in node.in_nodes().items():
38         if not in_node.has_valid('value'):
39             if tensor_id:
40                 return None
41             tensor_id = port
42     return tensor_id
43
44
45 def common_bfs(start_node: Node, allowed_ops: list, op_name: list, is_backward: bool = True, allowed_all: bool = False):
46     """
47     The purpose of this algorithm is to find layers with 'op_name' located in given direction.
48     In case of branching algorithm goes into each branch, but if it can't find layer in one of them it returns
49     empty list.
50
51     :param start_node: Start node for BFS algorithm
52     :param allowed_ops: List of operations that we can jump over
53     :param op_name: The list with names of operations for searching
54     :param is_backward: The direction of BFS algorithm
55     :param allowed_all: Bool flag meaning we can jump over all operations
56     """
57     ret = []
58     q = deque([start_node])
59     used = []
60     while len(q) != 0:
61         node = q.popleft()
62         if node.id in used:
63             log.debug("[BFS:ERROR] Graph contains cycle! BFS starts from {} node".format(start_node.id))
64             return []
65         used.append(node.id)
66         in_nodes_size = len(node.in_nodes()) if is_backward else len(node.out_nodes())
67         for id in range(in_nodes_size):  # in_nodes() can return either list or dict
68             pnode = node.in_node(id) if is_backward else node.out_node(id)
69             if pnode.has_valid('type'):
70                 if pnode.type in op_name:
71                     if pnode.id not in ret:
72                         ret.append(pnode.id)
73                 elif allowed_all or pnode.op in allowed_ops:
74                     q.append(pnode)
75                 else:
76                     return []
77             elif pnode.kind == 'data' and pnode.value is None:
78                 # If we go backward we don't use data node that have more than one consumer
79                 if not is_backward or (is_backward and len(pnode.out_nodes()) == 1):
80                     q.append(pnode)
81     return [Node(start_node.graph, x) for x in ret]
82
83
84 def forward_bfs(start_node: Node, allowed_ops: list, op_name: list, allowed_all: bool = False):
85     return common_bfs(start_node, allowed_ops, op_name, False, allowed_all=allowed_all)
86
87
88 def backward_bfs(start_node: Node, allowed_ops: list, op_name: list, allowed_all: bool = False):
89     return common_bfs(start_node, allowed_ops, op_name, allowed_all=allowed_all)
90
91
92 def get_next_operation(node: Node):
93     """
94     This function returns next op node, so node should be an operation
95     """
96     assert node.kind == 'op'
97
98     out_nodes = node.out_nodes()
99     res = []
100     for port, out_node in out_nodes.items():
101         op_nodes = out_node.out_nodes()
102         for op_node in op_nodes:
103             if op_node.id not in [n.id for n in res]:
104                 res.append(op_node)
105     return res