"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
-import networkx as nx
-import numpy as np
from copy import deepcopy
+
+import numpy as np
+
from mo.front.common.layout import get_features_dim, shape_for_layout
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
from mo.middle.passes.fusing.helpers import get_value_id
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.op import Op
def run_after(self):
return [EltwiseInputReshape]
- def find_and_replace_pattern(self, graph: nx.MultiDiGraph):
+ def find_and_replace_pattern(self, graph: Graph):
layout = graph.graph['layout']
for n in list(graph.nodes()):
if 'type' in graph.node[n] and graph.node[n]['type'] == 'Eltwise' and get_value_id(Node(graph, n)) is None:
class EltwiseInputReshape(MiddleReplacementPattern):
enabled = True
- def find_and_replace_pattern(self, graph: nx.MultiDiGraph):
+ def run_after(self):
+ from extensions.middle.pass_separator import MiddleStart
+ return [MiddleStart]
+
+ def find_and_replace_pattern(self, graph: Graph):
data_nodes = [Node(graph, node) for node in graph.node if Node(graph, node).kind == 'data']
for node in data_nodes:
# Get all requested shapes for current node
# Reconnect edge from original data node to Reshape output datanode
graph.remove_edge(node.id, consumer.id)
- graph.add_edge(reshape_data.id, consumer.id, **edge_attrs)
\ No newline at end of file
+ graph.add_edge(reshape_data.id, consumer.id, **edge_attrs)