Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / AddReshapeAfterStridedSlice.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18 import networkx as nx
19 import numpy as np
20
21 from copy import deepcopy
22 from extensions.middle.UselessStridedSlice import UselessStridedSliceEraser
23
24 from mo.middle.replacement import MiddleReplacementPattern
25 from mo.ops.op import Op
26 from mo.ops.reshape import Reshape
27
28
29 class AddReshapeAfterStridedSlice(MiddleReplacementPattern):
30     """
31       Transform adds Reshape after StridedSlice layers if new_axis_mask or/and
32       shrink_axis_mask contains True. After this transform StridedSlice layer 
33       does not change shape dims and new_axis_mask/shrink_axis_mask fulfilled by 
34       False
35     """
36     enabled = True
37
38     # Run before passes that will convert/remove StridedSlice
39     def run_before(self):
40         return [UselessStridedSliceEraser]
41
42     def pattern(self):
43         return dict(nodes=[('strided_slice', dict(kind='op', op='StridedSlice'))],
44                     edges=[])
45
46     def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
47         # add Reshape for shrink_axis_mask
48         if True in match['strided_slice']['shrink_axis_mask']:
49             log.info("StridedSlice op with shrink mask '{}' has been detected".format(match['strided_slice'].id))
50             node = match['strided_slice']
51
52             if len(node.in_nodes()) != 4 or len(node.out_nodes()) != 1:
53                 return
54
55             shape_in = node.in_node().shape
56             shape_out = node.out_node().shape
57             dim = shape_out.copy()
58             ss_shape = []
59             k = 0
60
61             # Don't permute reshape if channels were squeezed
62             dont_permute = False
63             if graph.graph['layout'] == 'NHWC' and node['shrink_axis_mask'][-1] == True:
64                 dont_permute = True
65
66             for i in range(0, len(node['shrink_axis_mask'])):
67                 if not node['shrink_axis_mask'][i]:
68                     ss_shape.append(shape_out[k])
69                     k = k + 1
70                 else:
71                     node['shrink_axis_mask'][i] = False
72                     ss_shape.append(1)
73
74             out_node = node.out_node(0)
75
76             # insert data node for StridedSlice
77             data_node = Op._create_data_node(graph, node.name + "/Reshape_shrink_data", {'shape': ss_shape})
78             attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
79             graph.remove_edge(node.id, out_node.id)
80             graph.add_edge(node.id, data_node.id, **attrs)
81
82             # insert Reshape
83             if dont_permute:
84                 reshape = Reshape(graph, dict(name=node.name + "/Reshape_shrink",
85                                               dim=np.array(dim, dtype=np.int64), nchw_layout=True))
86                 reshape_data_node = reshape.create_node_with_data([data_node], reshape.attrs,
87                                                                   data_nodes=[out_node])
88                 reshape_data_node['nchw_layout'] = True
89             else:
90                 reshape = Reshape(graph, dict(name=node.name + "/Reshape_shrink",
91                                               dim=np.array(dim, dtype=np.int64)))
92                 reshape_data_node = reshape.create_node_with_data([data_node], reshape.attrs,
93                                                                   data_nodes=[out_node])
94
95         # add Reshape for new_axis_mask
96         if True in match['strided_slice']['new_axis_mask']:
97             log.info("StridedSlice op with new axis mask '{}' has been detected".format(match['strided_slice'].id))
98             node = match['strided_slice']
99
100             if len(node.in_nodes()) != 4 or len(node.out_nodes()) != 1:
101                 return
102
103             shape_in = node.in_node().shape
104             shape_out = node.out_node().shape
105             dim = shape_out.copy()
106             ss_shape = []
107             for i in range(0, len(node['new_axis_mask'])):
108                 if not node['new_axis_mask'][i]:
109                     ss_shape.append(shape_out[i])
110                 else:
111                     node['new_axis_mask'][i] = False
112
113             out_node = node.out_node(0)
114             # insert data node for StridedSlice
115             data_node = Op._create_data_node(graph, node.name + "/Reshape_new_data", {'shape': ss_shape})
116             attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
117             graph.remove_edge(node.id, out_node.id)
118             graph.add_edge(node.id, data_node.id, **attrs)
119
120             # insert Reshape
121             reshape = Reshape(graph, dict(name=node.name + "/Reshape_new",
122                                           dim=np.array(dim, dtype=np.int64)))
123             reshape_data_node = reshape.create_node_with_data([data_node], reshape.attrs,
124                                                               data_nodes=[out_node])