Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / pattern_match.py
index f1ea8cf..0e260f4 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -19,14 +19,14 @@ import logging as log
 import networkx as nx
 from networkx.algorithms import isomorphism as ism
 
-from mo.graph.graph import Node, dict_includes
+from mo.graph.graph import Node, dict_includes, Graph
 
 
 def inverse_dict(d: dict):
     return {v: k for k, v in d.items()}
 
 
-def for_each_sub_graph(graph: nx.MultiDiGraph, func: callable):
+def for_each_sub_graph(graph: Graph, func: callable):
     """ Run a given function `func` for each sub-graph in a given graph not recursively.
 
         It doesn't search for sub-graphs in found sub-graphs recursively. If the recursion is required,
@@ -39,7 +39,7 @@ def for_each_sub_graph(graph: nx.MultiDiGraph, func: callable):
                 func(node[sub_graph_name])
 
 
-def for_each_sub_graph_recursively(graph: nx.MultiDiGraph, func: callable):
+def for_each_sub_graph_recursively(graph: Graph, func: callable):
     """ Run a given function `func` for each sub-graph in a given graph `graph` recursively.
 
         A given function `func` shouldn't contain a recursion for sub-graphs of the second level.
@@ -53,7 +53,7 @@ def for_each_sub_graph_recursively(graph: nx.MultiDiGraph, func: callable):
     for_each_sub_graph(graph, recursive_helper)
 
 
-def for_graph_and_each_sub_graph_recursively(graph: nx.MultiDiGraph, func: callable):
+def for_graph_and_each_sub_graph_recursively(graph: Graph, func: callable):
     """ Run a given function `func` for a given graph `graph` and each sub-graph recursively. """
     func(graph)
     for_each_sub_graph_recursively(graph, func)
@@ -63,7 +63,7 @@ def all_edges_in_nodes(nodes: list, edges: list):
     return all([edge[0] in nodes and edge[1] in nodes for edge in edges])
 
 
-def apply_pattern(graph: nx.MultiDiGraph, nodes: list, edges: list, action: callable, node_attrs: list = None,
+def apply_pattern(graph: Graph, nodes: list, edges: list, action: callable, node_attrs: list = None,
                   edge_attrs: list = None):
     """
     Search for all matches of a given subgraph defined by [nodes, edges] in graph,
@@ -114,7 +114,8 @@ def check_node_usages_out_of_match(match: dict, node_name_in_match_group: str):
 
 
 def node_match(data1: dict, data2: dict):
-    return dict_includes(data1, data2)
+    # We have to skip _in_ports/_out_ports attributes for comparision as they are not comparable
+    return dict_includes(data1, data2, skip_attr_names=['_in_ports', '_out_ports'])
 
 
 def edge_match(datasets1, datasets2):
@@ -130,7 +131,7 @@ def edge_match(datasets1, datasets2):
     return values1 == values2
 
 
-def build_matcher(graph: nx.MultiDiGraph, nodes: list, edges: list, node_attrs: list = None,
+def build_matcher(graph: Graph, nodes: list, edges: list, node_attrs: list = None,
                          edge_attrs: list = None):
     if node_attrs is not None or edge_attrs is not None:
         log.warning('\'edge_attrs\' or `\'node_attrs\'` parameter was passed to function \'find_pattern_matches\', '
@@ -139,13 +140,13 @@ def build_matcher(graph: nx.MultiDiGraph, nodes: list, edges: list, node_attrs:
                     'matching function like \'find_pattern_matches\', \'apply_pattern\' and \'pattern\' because it '
                     'will be deprecated in the next release.')
 
-    subgraph = nx.MultiDiGraph(name='pattern')
+    subgraph = Graph(name='pattern')
     subgraph.add_nodes_from(nodes)
     subgraph.add_edges_from(edges)
     return ism.MultiDiGraphMatcher(graph, subgraph, node_match, edge_match)
 
 
-def find_pattern_matches(graph: nx.MultiDiGraph, nodes: list, edges: list, node_attrs: list = None,
+def find_pattern_matches(graph: Graph, nodes: list, edges: list, node_attrs: list = None,
                          edge_attrs: list = None):
     """
     Find all matches of a given sub-graph defined by [nodes, edges] in graph.
@@ -154,7 +155,7 @@ def find_pattern_matches(graph: nx.MultiDiGraph, nodes: list, edges: list, node_
     return matcher.subgraph_isomorphisms_iter()
 
 
-def find_isomorphisms(graph: nx.MultiDiGraph, nodes: list, edges: list):
+def find_isomorphisms(graph: Graph, nodes: list, edges: list):
     ''' Find for isomorphism between a given graph and a pattern specified by a given nodes and edges.
         Applies the same rules as apply_pattern.
     '''