"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
See the License for the specific language governing permissions and
limitations under the License.
"""
-
-import logging as log
-from copy import deepcopy
-
-import networkx as nx
-
+from mo.front.common.partial_infer.utils import int64_array
+from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.permute import Permute
from mo.ops.reshape import Reshape
enabled = True
+ def run_after(self):
+ from extensions.middle.pass_separator import MiddleStart
+ return [MiddleStart]
+
+ def run_before(self):
+ from extensions.middle.pass_separator import MiddleFinish
+ return [MiddleFinish]
+
def pattern(self):
return dict(
nodes=[
('op', 'out_data')
])
- def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(self, graph: Graph, match: dict):
node = match['op']
N, H, W, C = match['in_data'].shape
graph.remove_edge(match['in_data'].id, match['op'].id)
graph.remove_edge(match['op'].id, match['out_data'].id)
- dim_6D = [N, block_size, block_size, int(C / (block_size ** 2)), H, W]
- order_6D = [0, 3, 4, 1, 5, 2]
- dim_4D = [N, int(H * block_size), int(W * block_size), int(C / (block_size ** 2))]
-
- reshape_data_node = Reshape(graph=graph, attrs={'name': match['op'].id + '/Reshape_to_6D', 'dim': dim_6D}).create_node_with_data([match['in_data']])
- permute_data_node = Permute(graph=graph, attrs={'name': match['op'].id + '/Permute', 'order': order_6D}).create_node_with_data([reshape_data_node])
- reshape_node = Reshape(graph=graph, attrs={'infer': None, 'name': match['op'].id + '/Reshape_to_4D', 'dim': dim_4D}).create_node_with_data([permute_data_node], data_nodes=[match['out_data']])
+ dim_6D = int64_array([N, block_size, block_size, int(C / (block_size ** 2)), H, W])
+ order_6D = int64_array([0, 3, 4, 1, 5, 2])
+ dim_4D = int64_array([N, int(H * block_size), int(W * block_size), int(C / (block_size ** 2))])
+
+ reshape_data_node = Reshape(graph=graph, attrs={'name': match['op'].id + '/Reshape_to_6D',
+ 'dim': dim_6D}).create_node_with_data([match['in_data']])
+ permute_data_node = Permute(graph=graph, attrs={'name': match['op'].id + '/Permute',
+ 'order': order_6D}).create_node_with_data([reshape_data_node])
+ reshape_node = Reshape(graph=graph, attrs={'name': match['op'].id + '/Reshape_to_4D',
+ 'dim': dim_4D}).create_node_with_data([permute_data_node],
+ data_nodes=[match['out_data']])
reshape_data_node.in_node()['nchw_layout'] = True
reshape_data_node['nchw_layout'] = True