2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from copy import deepcopy
22 from extensions.middle.UselessStridedSlice import UselessStridedSliceEraser
24 from mo.middle.replacement import MiddleReplacementPattern
25 from mo.ops.op import Op
26 from mo.ops.reshape import Reshape
29 class AddReshapeAfterStridedSlice(MiddleReplacementPattern):
31 Transform adds Reshape after StridedSlice layers if new_axis_mask or/and
32 shrink_axis_mask contains True. After this transform StridedSlice layer
33 does not change shape dims and new_axis_mask/shrink_axis_mask fulfilled by
38 # Run before passes that will convert/remove StridedSlice
40 return [UselessStridedSliceEraser]
43 return dict(nodes=[('strided_slice', dict(kind='op', op='StridedSlice'))],
46 def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
47 # add Reshape for shrink_axis_mask
48 if True in match['strided_slice']['shrink_axis_mask']:
49 log.info("StridedSlice op with shrink mask '{}' has been detected".format(match['strided_slice'].id))
50 node = match['strided_slice']
52 if len(node.in_nodes()) != 4 or len(node.out_nodes()) != 1:
55 shape_in = node.in_node().shape
56 shape_out = node.out_node().shape
57 dim = shape_out.copy()
61 # Don't permute reshape if channels were squeezed
63 if graph.graph['layout'] == 'NHWC' and node['shrink_axis_mask'][-1] == True:
66 for i in range(0, len(node['shrink_axis_mask'])):
67 if not node['shrink_axis_mask'][i]:
68 ss_shape.append(shape_out[k])
71 node['shrink_axis_mask'][i] = False
74 out_node = node.out_node(0)
76 # insert data node for StridedSlice
77 data_node = Op._create_data_node(graph, node.name + "/Reshape_shrink_data", {'shape': ss_shape})
78 attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
79 graph.remove_edge(node.id, out_node.id)
80 graph.add_edge(node.id, data_node.id, **attrs)
84 reshape = Reshape(graph, dict(name=node.name + "/Reshape_shrink",
85 dim=np.array(dim, dtype=np.int64), nchw_layout=True))
86 reshape_data_node = reshape.create_node_with_data([data_node], reshape.attrs,
87 data_nodes=[out_node])
88 reshape_data_node['nchw_layout'] = True
90 reshape = Reshape(graph, dict(name=node.name + "/Reshape_shrink",
91 dim=np.array(dim, dtype=np.int64)))
92 reshape_data_node = reshape.create_node_with_data([data_node], reshape.attrs,
93 data_nodes=[out_node])
95 # add Reshape for new_axis_mask
96 if True in match['strided_slice']['new_axis_mask']:
97 log.info("StridedSlice op with new axis mask '{}' has been detected".format(match['strided_slice'].id))
98 node = match['strided_slice']
100 if len(node.in_nodes()) != 4 or len(node.out_nodes()) != 1:
103 shape_in = node.in_node().shape
104 shape_out = node.out_node().shape
105 dim = shape_out.copy()
107 for i in range(0, len(node['new_axis_mask'])):
108 if not node['new_axis_mask'][i]:
109 ss_shape.append(shape_out[i])
111 node['new_axis_mask'][i] = False
113 out_node = node.out_node(0)
114 # insert data node for StridedSlice
115 data_node = Op._create_data_node(graph, node.name + "/Reshape_new_data", {'shape': ss_shape})
116 attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
117 graph.remove_edge(node.id, out_node.id)
118 graph.add_edge(node.id, data_node.id, **attrs)
121 reshape = Reshape(graph, dict(name=node.name + "/Reshape_new",
122 dim=np.array(dim, dtype=np.int64)))
123 reshape_data_node = reshape.create_node_with_data([data_node], reshape.attrs,
124 data_nodes=[out_node])