2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from networkx.algorithms import isomorphism as ism
22 from mo.graph.graph import Node, dict_includes, Graph
25 def inverse_dict(d: dict):
26 return {v: k for k, v in d.items()}
29 def for_each_sub_graph(graph: Graph, func: callable):
30 """ Run a given function `func` for each sub-graph in a given graph not recursively.
32 It doesn't search for sub-graphs in found sub-graphs recursively. If the recursion is required,
33 a given function `func` should be implemented in a special way to enable fully recursive traversal.
35 for node in graph.nodes():
36 node = Node(graph, node)
37 if node.has_valid('sub_graphs'):
38 for sub_graph_name in node.sub_graphs:
39 func(node[sub_graph_name])
42 def for_each_sub_graph_recursively(graph: Graph, func: callable):
43 """ Run a given function `func` for each sub-graph in a given graph `graph` recursively.
45 A given function `func` shouldn't contain a recursion for sub-graphs of the second level.
47 def recursive_helper(sub_graph):
51 for_each_sub_graph(sub_graph, recursive_helper)
53 for_each_sub_graph(graph, recursive_helper)
56 def for_graph_and_each_sub_graph_recursively(graph: Graph, func: callable):
57 """ Run a given function `func` for a given graph `graph` and each sub-graph recursively. """
59 for_each_sub_graph_recursively(graph, func)
62 def all_edges_in_nodes(nodes: list, edges: list):
63 return all([edge[0] in nodes and edge[1] in nodes for edge in edges])
66 def apply_pattern(graph: Graph, nodes: list, edges: list, action: callable, node_attrs: list = None,
67 edge_attrs: list = None):
69 Search for all matches of a given subgraph defined by [nodes, edges] in graph,
70 then apply action for each such match.
72 if not all_edges_in_nodes([node[0] for node in nodes], edges):
73 log.warning("Incorrect pattern attributes: not all nodes from edges are in nodes. "
74 "Please, mention all nodes you need in pattern in nodes attribute. ")
77 for match in find_pattern_matches(graph, nodes, edges, node_attrs, edge_attrs):
81 match = inverse_dict(match)
84 if not graph.has_node(match[k]):
85 # Graph changed significantly
87 log.warning("The graph has changed significantly during applying pattern:\n"
91 "edge_attrs: {}".format(nodes, edges, node_attrs, edge_attrs))
93 match[k] = Node(graph, match[k])
97 # Find all sub-graphs and apply_pattern recursively
98 for_each_sub_graph(graph, lambda graph: apply_pattern(graph, nodes, edges, action, node_attrs, edge_attrs))
101 def check_node_usages_out_of_match(match: dict, node_name_in_match_group: str):
103 Checks if node is consumed by nodes out of match
104 :param match: dictionary with pattern match
105 :param node_name_in_match_group: string
108 assert node_name_in_match_group in match
109 graph = match[node_name_in_match_group].graph
110 all_node_ids = [match[name].id for name in match]
111 in_out_node_ids = [u for u, _ in graph.in_edges(match[node_name_in_match_group].id)]
112 in_out_node_ids.extend([v for _, v in graph.out_edges(match[node_name_in_match_group].id)])
113 return all([n in all_node_ids for n in in_out_node_ids])
116 def node_match(data1: dict, data2: dict):
117 # We have to skip _in_ports/_out_ports attributes for comparision as they are not comparable
118 return dict_includes(data1, data2, skip_attr_names=['_in_ports', '_out_ports'])
121 def edge_match(datasets1, datasets2):
122 attrs = list(datasets2[0].keys())
124 for data1 in datasets1.values():
125 x = tuple(data1.get(attr, None) for attr in attrs)
128 for data2 in datasets2.values():
129 x = tuple(data2.get(attr, None) for attr in attrs)
131 return values1 == values2
134 def build_matcher(graph: Graph, nodes: list, edges: list, node_attrs: list = None,
135 edge_attrs: list = None):
136 if node_attrs is not None or edge_attrs is not None:
137 log.warning('\'edge_attrs\' or `\'node_attrs\'` parameter was passed to function \'find_pattern_matches\', '
138 'but they are not used anymore. Pattern matching proceeds according to \'nodes\' and \'edges\' '
139 'parameters. Please avoid passing \'edge_attrs\' and \'node_attrs\' parameters to any pattern '
140 'matching function like \'find_pattern_matches\', \'apply_pattern\' and \'pattern\' because it '
141 'will be deprecated in the next release.')
143 subgraph = Graph(name='pattern')
144 subgraph.add_nodes_from(nodes)
145 subgraph.add_edges_from(edges)
146 return ism.MultiDiGraphMatcher(graph, subgraph, node_match, edge_match)
149 def find_pattern_matches(graph: Graph, nodes: list, edges: list, node_attrs: list = None,
150 edge_attrs: list = None):
152 Find all matches of a given sub-graph defined by [nodes, edges] in graph.
154 matcher = build_matcher(graph, nodes, edges, node_attrs, edge_attrs)
155 return matcher.subgraph_isomorphisms_iter()
158 def find_isomorphisms(graph: Graph, nodes: list, edges: list):
159 ''' Find for isomorphism between a given graph and a pattern specified by a given nodes and edges.
160 Applies the same rules as apply_pattern.
162 matcher = build_matcher(graph, nodes, edges)
164 for match in matcher.isomorphisms_iter():
165 match = inverse_dict(match)
166 match = {k: Node(graph, match[k]) for k in match.keys()}