"""
- Copyright (c) 2017-2018 Intel Corporation
+ Copyright (c) 2017-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
"""
import logging as log
-import networkx as nx
-
from mo.front.common.custom_replacement_registry import CustomReplacementRegistry
from mo.front.common.replacement import FrontReplacementSubgraph, FrontReplacementPattern
from mo.front.subgraph_matcher import SubgraphMatcher, SubgraphMatch
from mo.front.tf.custom_subgraph_call import merge_nodes
-from mo.graph.graph import dump_graph_for_graphviz, unique_id
+from mo.graph.graph import Graph
from mo.ops.op import Op
from mo.utils import class_registration
from mo.utils.graph import is_connected_component
def transform_graph(self, graph, replacement_descriptions):
raise Exception('Function "transform_graph" must be overridden in the sub-class')
- def find_and_replace_pattern(self, graph: nx.MultiDiGraph):
+ def find_and_replace_pattern(self, graph: Graph):
replacement_descriptions = CustomReplacementRegistry().get_custom_replacement_description(self.replacement_id)
if replacement_descriptions is None or len(replacement_descriptions) < 1:
log.info("Failed to find custom replacement description with id '{}'".format(self.replacement_id))
def __init__(self):
super().__init__()
- def nodes_to_remove(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+ def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
return match.matched_nodes_names()
- def find_and_replace_pattern(self, graph: nx.MultiDiGraph):
+ def find_and_replace_pattern(self, graph: Graph):
replacement_descriptions = CustomReplacementRegistry().get_custom_replacement_description(self.replacement_id)
if replacement_descriptions is None:
log.info("Failed to find custom replacement description with id '{}'".format(self.replacement_id))
if not is_connected_component(graph, match.matched_nodes_names()):
log.warning("The following nodes don't form connected sub-graph: {}".format(
match.matched_nodes_names()))
- dump_graph_for_graphviz(graph, match.matched_nodes_names())
+ graph.dump_graph_for_graphviz(match.matched_nodes_names())
self.replace_sub_graph(graph, match)
registered_ops = {}
super().__init__()
def input_edges_match(self, # pylint: disable=method-hidden
- graph: nx.DiGraph,
+ graph: Graph,
match: SubgraphMatch,
new_sub_graph: dict):
"""
return input_edges_match
def output_edges_match(self, # pylint: disable=method-hidden
- graph: nx.DiGraph,
+ graph: Graph,
match: SubgraphMatch,
new_sub_graph: dict):
"""
output_edges_match[(output_node.id, output_port)] = (new_sub_graph['new_node'].id, sub_graph_output_port)
return output_edges_match
- def generate_sub_graph(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+ def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
replacement_desc = match.custom_replacement_desc
op = Op.get_op_class_by_name(replacement_desc.op)(graph, match.custom_replacement_desc.custom_attributes)
op.default_backend_attrs = list(match.custom_replacement_desc.custom_attributes.keys())
op.substitute_ie_attrs(op.attrs)
node = merge_nodes(graph, match.matched_nodes_names(), replacement_desc.get_inputs_description(),
replacement_desc.get_outputs_description())
- node.name = unique_id(graph, op.attrs['type'])
+ node.name = graph.unique_id(op.attrs['type'])
node_attrs = graph.node[node.id]
# copy attributes which are defined in the custom operation
for key in op.attrs.keys():