"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
"""
import logging as log
-import networkx as nx
import numpy as np
from copy import deepcopy
-from extensions.middle.AddReshapeAfterStridedSlice import AddReshapeAfterStridedSlice
+from extensions.middle.ConvertGroupedStridedSlice import ConvertGroupedStridedSlice
from extensions.middle.FusePermutesSequence import FusePermutesSequence
from extensions.middle.ShufflenetReshape import ReshapeSoftmaxReshape
+from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.op import Op
from mo.ops.permute import Permute
class PixelLinkReshape(MiddleReplacementPattern):
"""
- Transform adds Permutes around Reshapes that pack 4 dimensions in 2, than
- do Softmax and then unpack it back to 5 dims.
+ Transform adds Permutes around Reshapes that pack 4 dimensions in 2, than
+ do Softmax and then unpack it back to 5 dims.
"""
enabled = True
def run_before(self):
- return [FusePermutesSequence, ReshapeSoftmaxReshape, AddReshapeAfterStridedSlice]
+ return [FusePermutesSequence, ReshapeSoftmaxReshape, ConvertGroupedStridedSlice]
def run_after(self):
- return []
+ from extensions.middle.pass_separator import MiddleStart
+ return [MiddleStart]
def pattern(self):
return dict(nodes=[('reshape_split', dict(kind='op', type='Reshape')),
('reshape_unpack', dict(kind='op', type='Reshape')),
('reshape_unpack_data', dict(kind='data')),
('strided_slice', dict(kind='op', op='StridedSlice')),
- ],
+ ],
edges=[('reshape_split', 'reshape_split_data'),
('reshape_split_data', 'reshape_pack'),
('reshape_pack', 'reshape_data'),
else:
return False
- def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(self, graph: Graph, match: dict):
if graph.graph['layout'] != 'NHWC':
return
attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
graph.remove_edge(node.id, out_node.id)
- permute_after_node = permute_after.create_node_with_data([data_node], permute_after.attrs,
- data_nodes=[out_node])
+ permute_after.create_node_with_data([data_node], permute_after.attrs,
+ data_nodes=[out_node])
graph.add_edge(node.id, data_node.id, **attrs)
# update softmax shape
node_softmax = match['softmax']
node_softmax.out_node(0).shape = out_node.shape
- # revert strided slice and reshape
- node_ss = match['strided_slice']
- node_unpack = match['reshape_unpack']
-
- unpack_out = node_unpack.out_node(0).id
- ss_out = node_ss.out_node(0).id
-
- #gather edge attributes
- soft_reshape_attrs = deepcopy(graph.get_edge_data(node_softmax.out_node(0).id, node_unpack.id)[0])
- reshape_data_attrs = deepcopy(graph.get_edge_data(node_unpack.id, unpack_out)[0])
- reshape_ss_attrs = deepcopy(graph.get_edge_data(unpack_out, node_ss.id)[0])
- ss_out_attrs = deepcopy(graph.get_edge_data(node_ss.id, ss_out)[0])
-
- #remove all edges in Softmax->Reshape->StridedSlice chain
- graph.remove_edge(node_softmax.out_node(0).id, node_unpack.id)
- graph.remove_edge(node_unpack.id, unpack_out)
- graph.remove_edge(unpack_out, node_ss.id)
- graph.remove_edge(node_ss.id, ss_out)
-
- #add new edges to get chain Softmax->StridedSlice->Reshape
- graph.add_edge(node_softmax.out_node(0).id, node_ss.id, **soft_reshape_attrs)
- graph.add_edge(node_ss.id, unpack_out, **reshape_data_attrs)
- graph.add_edge(unpack_out, node_unpack.id, **reshape_ss_attrs)
- graph.add_edge(node_unpack.id, ss_out, **ss_out_attrs)
-
- #update output shape and parameters for StridedSlice
- node_ss.out_node(0).shape = np.zeros(3)
- node_ss.out_node(0).shape[0] = out_node.shape[0]
- node_ss.out_node(0).shape[1] = 1
- node_ss.out_node(0).shape[2] = out_node.shape[2]
-
- old_slices = node_ss.slices.copy()
- node_ss.slices = []
- node_ss.slices.append(old_slices[0])
- node_ss.slices.append(old_slices[-1])
- node_ss.slices.append(slice(0, out_node.shape[2], 1))
- node_ss.shrink_axis_mask = [False, False, False]
- node_ss.new_axis_mask = [False, False, False]
-
- #update Reshape attribute
- node_unpack.dim = np.delete(node_unpack.dim, 4)
- #prevent permute for reshape because it gives wrong result
- node_unpack['nchw_layout'] = True
- node_unpack.out_node(0)['nchw_layout'] = True
+ if ConvertGroupedStridedSlice.enabled is True:
+ # revert strided slice and reshape
+ node_ss = match['strided_slice']
+ node_unpack = match['reshape_unpack']
+
+ unpack_out = node_unpack.out_node(0).id
+ ss_out = node_ss.out_node(0).id
+
+ # gather edge attributes
+ soft_reshape_attrs = deepcopy(graph.get_edge_data(node_softmax.out_node(0).id, node_unpack.id)[0])
+ reshape_data_attrs = deepcopy(graph.get_edge_data(node_unpack.id, unpack_out)[0])
+ reshape_ss_attrs = deepcopy(graph.get_edge_data(unpack_out, node_ss.id)[0])
+ ss_out_attrs = deepcopy(graph.get_edge_data(node_ss.id, ss_out)[0])
+
+ # remove all edges in Softmax->Reshape->StridedSlice chain
+ graph.remove_edge(node_softmax.out_node(0).id, node_unpack.id)
+ graph.remove_edge(node_unpack.id, unpack_out)
+ graph.remove_edge(unpack_out, node_ss.id)
+ graph.remove_edge(node_ss.id, ss_out)
+
+ # add new edges to get chain Softmax->StridedSlice->Reshape
+ graph.add_edge(node_softmax.out_node(0).id, node_ss.id, **soft_reshape_attrs)
+ graph.add_edge(node_ss.id, unpack_out, **reshape_data_attrs)
+ graph.add_edge(unpack_out, node_unpack.id, **reshape_ss_attrs)
+ graph.add_edge(node_unpack.id, ss_out, **ss_out_attrs)
+
+ # update output shape and parameters for StridedSlice
+ node_ss.out_node(0).shape = np.zeros(3)
+ node_ss.out_node(0).shape[0] = out_node.shape[0]
+ node_ss.out_node(0).shape[1] = 1
+ node_ss.out_node(0).shape[2] = out_node.shape[2]
+
+ old_slices = node_ss.slices.copy()
+ node_ss.slices = []
+ node_ss.slices.append(old_slices[0])
+ node_ss.slices.append(old_slices[-1])
+ node_ss.slices.append(slice(0, out_node.shape[2], 1))
+ node_ss.shrink_axis_mask = np.array([0, 0, 0], dtype=np.int64)
+ node_ss.new_axis_mask = np.array([0, 0, 0], dtype=np.int64)
+ node_ss.ellipsis_mask = np.array([0, 0, 0], dtype=np.int64)
+ node_ss.begin_mask = np.array([0, 1, 0], dtype=np.int64)
+ node_ss.end_mask = np.array([0, 1, 0], dtype=np.int64)
+
+ # update Reshape attribute
+ node_unpack.dim = np.delete(node_unpack.dim, 4)
+ # prevent permute for reshape because it gives wrong result
+ node_unpack['nchw_layout'] = True
+ node_unpack.out_node(0)['nchw_layout'] = True
+ else:
+ # reshape unpack: permute correctly
+ node_unpack = match['reshape_unpack']
+ data_node = Op._create_data_node(graph, node.name + "/Permute_after_unpack_data", {'shape': node_unpack.out_node().shape})
+ permute_after_unpack = Permute(graph, dict(name=node.name + "/Permute_after_unpack",
+ order=np.array([0, 3, 1, 2, 4])))
+ out_node = node_unpack.out_node(0)
+ out_node.shape = out_node.shape[np.array([0, 3, 1, 2, 4], dtype=np.int)]
+ attrs = deepcopy(graph.get_edge_data(node_unpack.id, out_node.id)[0])
+ graph.remove_edge(node_unpack.id, out_node.id)
+ permute_after.create_node_with_data([data_node], permute_after_unpack.attrs,
+ data_nodes=[out_node])
+ graph.add_edge(node_unpack.id, data_node.id, **attrs)