"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
-import networkx as nx
import numpy as np
from mo.front.subgraph_matcher import SubgraphMatch
from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph
-from mo.graph.graph import insert_node_after
+from mo.graph.graph import Graph
from mo.ops.permute import Permute
class ConvFlattenReplacement(FrontReplacementFromConfigFileSubGraph):
replacement_id = 'ConvFlatten'
- def output_edges_match(self, graph: nx.DiGraph, match: SubgraphMatch, new_sub_graph: dict):
+ def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
return {}
- def input_edges_match(self, graph: nx.DiGraph, match: SubgraphMatch, new_sub_graph: dict):
+ def input_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
return {}
- def nodes_to_remove(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+ def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
# no need to remove any of matched nodes. We just insert 'Permute' node before the matched sub-graph.
return []
- def generate_sub_graph(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+ def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
permute_op = Permute(graph, {'order': np.array([0, 2, 3, 1])})
permute_node = permute_op.add_node({'name': match.scope + '_permute_'})
# reshape_in_node is the node after which we should insert Permute
reshape_in_node = reshape_node.in_nodes()[0]
- insert_node_after(reshape_in_node, permute_node, 0)
+ reshape_in_node.insert_node_after(permute_node, 0)
return {}