Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / partial_infer / slice.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import numpy as np
18
19 from mo.utils.error import Error
20
21
22 def tf_strided_slice_infer(node):
23     if node.in_node(1).value is None or node.in_node(2).value is None:
24         raise Error('Strided slice layer supports only constant begin and end inputs')
25     begin_id = node.in_node(1).value
26     end_id = node.in_node(2).value
27     if len(node.in_nodes()) > 3:
28         if node.in_node(3).value is None:
29             raise Error('Strided slice layer supports only constant stride input')
30         stride = node.in_node(3).value
31     else:
32         stride = []
33
34     shape = node.in_node(0).shape
35
36     if shape is None or any([x < 0 for x in shape]):
37         return
38
39     convert_negative_indices(begin_id, shape)
40     convert_negative_indices(end_id, shape)
41
42     slice_idx = []
43     dims = np.amax(np.array([len(begin_id), len(end_id), len(stride),
44                              len(node.shrink_axis_mask), len(node.new_axis_mask), len(node.ellipsis_mask),
45                              len(node.begin_mask), len(node.end_mask)]))
46
47     # make mask correct length
48     def extend_mask(in_mask, fin_len, zeros=True):
49         mask = list(in_mask)
50         if len(mask) < fin_len:
51             if zeros:
52                 mask.extend(np.zeros(dims-len(mask), dtype=np.int32))
53             else:
54                 mask.extend(np.ones(dims-len(mask), dtype=np.int32))
55         return np.array(mask, dtype=np.int32)
56
57     for mask in {'new_axis_mask', 'shrink_axis_mask', 'ellipsis_mask'}:
58         node[mask] = extend_mask(node[mask], dims)
59     node.begin_mask = extend_mask(node.begin_mask, dims, False)
60     node.end_mask = extend_mask(node.end_mask, dims, False)
61
62     old_idx = 0
63     ellips_ext = 0
64     id_em = 0
65     for idx in range(dims):
66         if node.new_axis_mask[idx]:
67             slice_idx.append(np.newaxis)
68         elif node.ellipsis_mask[idx]:
69             ellips_ext = len(shape) - (dims - np.count_nonzero(node.new_axis_mask) - 1)
70             id_em = idx
71             for i in range(0, ellips_ext):
72                 slice_idx.append(slice(0, shape[old_idx], 1))
73                 old_idx = old_idx + 1
74         else:
75             s = stride[idx] if len(stride) > idx else 1
76             def_beg = 0 if s > 0 else -1
77             def_end = shape[old_idx] if s > 0 else -shape[old_idx]-1
78             l = begin_id[idx] if node.begin_mask[idx] and idx < len(begin_id) else def_beg
79             r = end_id[idx] if node.end_mask[idx] and idx < len(end_id) else def_end
80
81             # Check shrink_axis_mask
82             if node.shrink_axis_mask[idx] and idx < len(shape):
83                 slice_idx.append(slice(l, l+1, s))
84             else:
85                 slice_idx.append(slice(l, r, s))
86             old_idx = old_idx + 1
87
88     value = node.in_node(0).value if node.in_node(0).value is not None else np.zeros(shape)
89     # fix for the warning: "FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated use
90     # `arr[tuple(seq)]` instead of `arr[seq]`"
91     value = value[tuple(slice_idx)]
92
93     for idx, flag in reversed(list(enumerate(node.shrink_axis_mask))):
94         if flag:
95             if ellips_ext > 0 and idx > id_em:
96                 idx = idx + ellips_ext - 1
97             try:
98                 value = np.squeeze(value, idx)
99             except ValueError:
100                 # ignore this error
101                 continue
102
103     node['slices'] = np.array(slice_idx)
104     for attr in ('shrink_axis_mask', 'new_axis_mask', 'ellipsis_mask', 'begin_mask', 'end_mask'):
105         node[attr] = np.array(node[attr], dtype=np.int32)
106
107     node.out_node().value = np.array(value) if node.in_node(0).value is not None else None
108     node.out_node().shape = np.array(value.shape, dtype=np.int64)
109
110     # change precision to I32 for begin, end, stride inputs
111     for i in range(1, len(node.in_nodes())):
112         inp = node.in_node(i)
113         inp["force_precision"] = "I32"
114
115
116 def convert_negative_indices(indices: np.array, shape: np.array):
117     for ind, value in enumerate(indices):
118         if value < 0:
119             indices[ind] += shape[ind]
120
121
122 def caffe_slice_infer(node):
123     """
124     Slices an input layer to multiple output layers along a given dimension
125     with given slice indices
126     Parameters
127     ----------
128     node
129
130     """
131     top_shape = node.in_node(0).shape
132     slice_axis = node.axis
133     bottom_slice_axis = node.in_node(0).shape[node.axis]
134     if len(node.slice_point) == 0:
135         new_shape = np.array(top_shape, dtype=np.int64)
136         new_shape[slice_axis] = bottom_slice_axis / len(node.out_nodes())
137         for i in range(0, len(node.out_nodes())):
138             node.out_node(i).shape = new_shape
139         return
140
141     assert (len(node.slice_point) == len(node.out_nodes()) - 1)
142     prev = 0
143     slices = []
144     for slice_point in node.slice_point:
145         if slice_point <= prev:
146             raise Error(
147                 'Check failed for the layer {}. Slice points should be ordered in increasing manner. '.format(node.id) +
148                 'Current slice point {} is not greater than the previous slice point {}. '.format(slice_point, prev) +
149                 'Please verify your model correctness')
150         slices.append(slice_point - prev)
151         prev = slice_point
152
153     slices.append(bottom_slice_axis - prev)
154     if sum(slices) != bottom_slice_axis:
155         raise Error(
156             'Check failed for the layer {}. Sum of slices points {} does not equal '.format(node.id, sum(slices)) +
157             'to the value of input blob shape by the given slice axis {}'.format(bottom_slice_axis))
158     for i in range(len(node.out_nodes())):
159         new_shape = np.array(top_shape, dtype=np.int64)
160         new_shape[slice_axis] = slices[i]
161         node.out_node(i).shape = new_shape
162
163
164 def mxnet_slice_axis_infer(node):
165     in_shape = node.in_node(0).shape
166     slice_axis = node.axis
167
168     new_shape = np.array(in_shape, dtype=np.int64)
169     new_shape[slice_axis] = new_shape[slice_axis] / len(node.out_nodes())
170
171     axis_size = in_shape[slice_axis]
172     if node.offset < 0:
173         node.offset += axis_size
174
175     if not node.dim:
176         node.dim = axis_size
177     elif node.dim < 0:
178         node.dim += axis_size
179
180     input_dim = in_shape.size
181     node.dim = (node.dim - node.offset)
182     if node.dim > in_shape[slice_axis]:
183         raise Error(
184             '{0} node dimension value is bigger than the corresponding value in the input shape {1}. ' +
185             '\nIn particular {2} is bigger than {3}. The Model Optimizer does not support this case. ' +
186             '\nTo overcome, try to edit the original model "end" property of the {0} layer.',
187             node.name, ','.join(str(i) for i in in_shape), str(node.dim), str(in_shape[slice_axis])
188         )
189
190     for i in range(0, input_dim):
191         if i == slice_axis:
192             new_shape[i] = node.dim
193         else:
194             new_shape[i] = in_shape[i]
195
196     for i in range(0, len(node.out_nodes())):
197         node.out_node(i)['shape'] = new_shape