2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from collections import deque
20 from mo.graph.graph import Node
23 def get_value_id(node: Node):
24 assert node.has_valid('op')
26 for port, in_node in node.in_nodes().items():
27 if in_node.has_valid('value'):
34 def get_tensor_id(node: Node):
35 assert node.has_valid('op')
37 for port, in_node in node.in_nodes().items():
38 if not in_node.has_valid('value'):
45 def common_bfs(start_node: Node, allowed_ops: list, op_name: list, is_backward: bool = True, allowed_all: bool = False):
47 The purpose of this algorithm is to find layers with 'op_name' located in given direction.
48 In case of branching algorithm goes into each branch, but if it can't find layer in one of them it returns
51 :param start_node: Start node for BFS algorithm
52 :param allowed_ops: List of operations that we can jump over
53 :param op_name: The list with names of operations for searching
54 :param is_backward: The direction of BFS algorithm
55 :param allowed_all: Bool flag meaning we can jump over all operations
58 q = deque([start_node])
63 log.debug("[BFS:ERROR] Graph contains cycle! BFS starts from {} node".format(start_node.id))
66 in_nodes_size = len(node.in_nodes()) if is_backward else len(node.out_nodes())
67 for id in range(in_nodes_size): # in_nodes() can return either list or dict
68 pnode = node.in_node(id) if is_backward else node.out_node(id)
69 if pnode.has_valid('type'):
70 if pnode.type in op_name:
71 if pnode.id not in ret:
73 elif allowed_all or pnode.op in allowed_ops:
77 elif pnode.kind == 'data' and pnode.value is None:
78 # If we go backward we don't use data node that have more than one consumer
79 if not is_backward or (is_backward and len(pnode.out_nodes()) == 1):
81 return [Node(start_node.graph, x) for x in ret]
84 def forward_bfs(start_node: Node, allowed_ops: list, op_name: list, allowed_all: bool = False):
85 return common_bfs(start_node, allowed_ops, op_name, False, allowed_all=allowed_all)
88 def backward_bfs(start_node: Node, allowed_ops: list, op_name: list, allowed_all: bool = False):
89 return common_bfs(start_node, allowed_ops, op_name, allowed_all=allowed_all)
92 def get_next_operation(node: Node):
94 This function returns next op node, so node should be an operation
96 assert node.kind == 'op'
98 out_nodes = node.out_nodes()
100 for port, out_node in out_nodes.items():
101 op_nodes = out_node.out_nodes()
102 for op_node in op_nodes:
103 if op_node.id not in [n.id for n in res]: