Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / pattern_match.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import networkx as nx
20 from networkx.algorithms import isomorphism as ism
21
22 from mo.graph.graph import Node, dict_includes, Graph
23
24
25 def inverse_dict(d: dict):
26     return {v: k for k, v in d.items()}
27
28
29 def for_each_sub_graph(graph: Graph, func: callable):
30     """ Run a given function `func` for each sub-graph in a given graph not recursively.
31
32         It doesn't search for sub-graphs in found sub-graphs recursively. If the recursion is required,
33         a given function `func` should be implemented in a special way to enable fully recursive traversal.
34     """
35     for node in graph.nodes():
36         node = Node(graph, node)
37         if node.has_valid('sub_graphs'):
38             for sub_graph_name in node.sub_graphs:
39                 func(node[sub_graph_name])
40
41
42 def for_each_sub_graph_recursively(graph: Graph, func: callable):
43     """ Run a given function `func` for each sub-graph in a given graph `graph` recursively.
44
45         A given function `func` shouldn't contain a recursion for sub-graphs of the second level.
46     """
47     def recursive_helper(sub_graph):
48         # user action
49         func(sub_graph)
50         # recursion
51         for_each_sub_graph(sub_graph, recursive_helper)
52
53     for_each_sub_graph(graph, recursive_helper)
54
55
56 def for_graph_and_each_sub_graph_recursively(graph: Graph, func: callable):
57     """ Run a given function `func` for a given graph `graph` and each sub-graph recursively. """
58     func(graph)
59     for_each_sub_graph_recursively(graph, func)
60
61
62 def all_edges_in_nodes(nodes: list, edges: list):
63     return all([edge[0] in nodes and edge[1] in nodes for edge in edges])
64
65
66 def apply_pattern(graph: Graph, nodes: list, edges: list, action: callable, node_attrs: list = None,
67                   edge_attrs: list = None):
68     """
69     Search for all matches of a given subgraph defined by [nodes, edges] in graph,
70     then apply action for each such match.
71     """
72     if not all_edges_in_nodes([node[0] for node in nodes], edges):
73         log.warning("Incorrect pattern attributes: not all nodes from edges are in nodes. "
74                     "Please, mention all nodes you need in pattern in nodes attribute. ")
75
76     matches = []
77     for match in find_pattern_matches(graph, nodes, edges, node_attrs, edge_attrs):
78         matches.append(match)
79
80     for match in matches:
81         match = inverse_dict(match)
82         still_valid = True
83         for k in match:
84             if not graph.has_node(match[k]):
85                 # Graph changed significantly
86                 still_valid = False
87                 log.warning("The graph has changed significantly during applying pattern:\n"
88                             "nodes: {}\n"
89                             "edges: {}\n"
90                             "node_attrs: {}\n"
91                             "edge_attrs: {}".format(nodes, edges, node_attrs, edge_attrs))
92                 break
93             match[k] = Node(graph, match[k])
94         if still_valid:
95             action(graph, match)
96
97     # Find all sub-graphs and apply_pattern recursively
98     for_each_sub_graph(graph, lambda graph: apply_pattern(graph, nodes, edges, action, node_attrs, edge_attrs))
99
100
101 def check_node_usages_out_of_match(match: dict, node_name_in_match_group: str):
102     """
103     Checks if node is consumed by nodes out of match
104     :param match: dictionary with pattern match
105     :param node_name_in_match_group: string
106     :return:
107     """
108     assert node_name_in_match_group in match
109     graph = match[node_name_in_match_group].graph
110     all_node_ids = [match[name].id for name in match]
111     in_out_node_ids = [u for u, _ in graph.in_edges(match[node_name_in_match_group].id)]
112     in_out_node_ids.extend([v for _, v in graph.out_edges(match[node_name_in_match_group].id)])
113     return all([n in all_node_ids for n in in_out_node_ids])
114
115
116 def node_match(data1: dict, data2: dict):
117     # We have to skip _in_ports/_out_ports attributes for comparision as they are not comparable
118     return dict_includes(data1, data2, skip_attr_names=['_in_ports', '_out_ports'])
119
120
121 def edge_match(datasets1, datasets2):
122     attrs = list(datasets2[0].keys())
123     values1 = set([])
124     for data1 in datasets1.values():
125         x = tuple(data1.get(attr, None) for attr in attrs)
126         values1.add(x)
127     values2 = set([])
128     for data2 in datasets2.values():
129         x = tuple(data2.get(attr, None) for attr in attrs)
130         values2.add(x)
131     return values1 == values2
132
133
134 def build_matcher(graph: Graph, nodes: list, edges: list, node_attrs: list = None,
135                          edge_attrs: list = None):
136     if node_attrs is not None or edge_attrs is not None:
137         log.warning('\'edge_attrs\' or `\'node_attrs\'` parameter was passed to function \'find_pattern_matches\', '
138                     'but they are not used anymore. Pattern matching proceeds according to \'nodes\' and \'edges\' '
139                     'parameters. Please avoid passing \'edge_attrs\' and \'node_attrs\' parameters to any pattern '
140                     'matching function like \'find_pattern_matches\', \'apply_pattern\' and \'pattern\' because it '
141                     'will be deprecated in the next release.')
142
143     subgraph = Graph(name='pattern')
144     subgraph.add_nodes_from(nodes)
145     subgraph.add_edges_from(edges)
146     return ism.MultiDiGraphMatcher(graph, subgraph, node_match, edge_match)
147
148
149 def find_pattern_matches(graph: Graph, nodes: list, edges: list, node_attrs: list = None,
150                          edge_attrs: list = None):
151     """
152     Find all matches of a given sub-graph defined by [nodes, edges] in graph.
153     """
154     matcher = build_matcher(graph, nodes, edges, node_attrs, edge_attrs)
155     return matcher.subgraph_isomorphisms_iter()
156
157
158 def find_isomorphisms(graph: Graph, nodes: list, edges: list):
159     ''' Find for isomorphism between a given graph and a pattern specified by a given nodes and edges.
160         Applies the same rules as apply_pattern.
161     '''
162     matcher = build_matcher(graph, nodes, edges)
163     result = []
164     for match in matcher.isomorphisms_iter():
165         match = inverse_dict(match)
166         match = {k: Node(graph, match[k]) for k in match.keys()}
167         result.append(match)
168     return result