Publishing R5 content (#72)
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / partial_infer / slice.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import numpy as np
18
19 from mo.graph.graph import erase_node
20 from mo.utils.error import Error
21
22 def tf_strided_slice_infer(node):
23     begin_id = node.in_node(1).value
24     end_id = node.in_node(2).value
25     stride = node.in_node(3).value
26
27     shape = node.in_node(0).shape
28
29     if shape is None or any([x < 0 for x in shape]):
30         return
31
32     convert_negative_indices(begin_id, shape)
33     convert_negative_indices(end_id, shape)
34
35     test_bit = lambda val, offset: ((1 << offset) & val != 0)
36
37     slice_idx = []
38     shrink_axis_mask = []
39     ellipsis_mask = []
40     new_axis_mask = []
41     dims = len(begin_id)
42
43     for idx in range(dims):
44         def_beg = 0 if stride[idx] > 0 else -1
45         def_end = shape[idx] if stride[idx] > 0 else -shape[idx]-1
46         l = begin_id[idx] if not test_bit(node.begin_mask, idx) else def_beg
47         r = end_id[idx] if not test_bit(node.end_mask, idx) else def_end
48
49         # Check shrink_axis_mask
50         shrink_axis_mask.append(test_bit(node.shrink_axis_mask, idx))
51         if shrink_axis_mask[idx]:
52             l, r = l, l + 1
53
54         # Check new_axis_mask
55         new_axis_mask.append(test_bit(node.new_axis_mask, idx))
56         if new_axis_mask[idx]:
57             slice_idx.append(np.newaxis)
58
59         # Check ellipsis_mask
60         ellipsis_mask.append(test_bit(node.ellipsis_mask, idx))
61         if ellipsis_mask[idx]:
62             shrink_axis_mask[idx] = False
63             l, r = 0, shape[idx]
64
65         slice_idx.append(slice(l, r, stride[idx]))
66     
67     # if masks length are less than input dims length than add slices and masks for such dims
68     for idx in range(dims, len(shape)):
69         slice_idx.append(slice(0, shape[idx], 1))
70         shrink_axis_mask.append(False)
71         new_axis_mask.append(False)
72
73     value = node.in_node(0).value if node.in_node(0).value is not None else np.zeros(shape)
74     # fix for the warning: "FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated use
75     # `arr[tuple(seq)]` instead of `arr[seq]`"
76     value = value[tuple(slice_idx)]
77
78     for idx, flag in reversed(list(enumerate(shrink_axis_mask))):
79         if flag:
80             value = np.squeeze(value, idx)
81
82     node['slices'] = np.array(slice_idx)
83     node['shrink_axis_mask'] = np.array(shrink_axis_mask)
84     node['new_axis_mask'] = np.array(new_axis_mask)
85
86     node.out_node().value = np.array(value) if node.in_node(0).value is not None else None
87     node.out_node().shape = np.array(value.shape)
88
89     #remove inputs converted in attributes
90     #for i in range(1,4):
91     #    node.graph.remove_edge(node.in_node(i).id, node.id)
92
93 def convert_negative_indices(indices: np.array, shape: np.array):
94     for ind, value in enumerate(indices):
95         if value < 0:
96             indices[ind] += shape[ind]
97
98
99 def caffe_slice_infer(node):
100     """
101     Slices an input layer to multiple output layers along a given dimension
102     with given slice indices
103     Parameters
104     ----------
105     node
106
107     """
108     top_shape = node.in_node(0).shape
109     slice_axis = node.axis
110     bottom_slice_axis = node.in_node(0).shape[node.axis]
111     if len(node.slice_point) == 0:
112         new_shape = np.array(top_shape, dtype=np.int64)
113         new_shape[slice_axis] = bottom_slice_axis / len(node.out_nodes())
114         for i in range(0, len(node.out_nodes())):
115             node.out_node(i).shape = new_shape
116         return
117
118     assert (len(node.slice_point) == len(node.out_nodes()) - 1)
119     prev = 0
120     slices = []
121     for slice_point in node.slice_point:
122         if slice_point <= prev:
123             raise Error(
124                 'Check failed for the layer {}. Slice points should be ordered in increasing manner. '.format(node.id) +
125                 'Current slice point {} is not greater than the previous slice point {}. '.format(slice_point, prev) +
126                 'Please verify your model correctness')
127         slices.append(slice_point - prev)
128         prev = slice_point
129
130     slices.append(bottom_slice_axis - prev)
131     if sum(slices) != bottom_slice_axis:
132         raise Error(
133             'Check failed for the layer {}. Sum of slices points {} does not equal '.format(node.id, sum(slices)) +
134             'to the value of input blob shape by the given slice axis {}'.format(bottom_slice_axis))
135     for i in range(len(node.out_nodes())):
136         new_shape = np.array(top_shape, dtype=np.int64)
137         new_shape[slice_axis] = slices[i]
138         node.out_node(i).shape = new_shape
139
140
141 def mxnet_slice_axis_infer(node):
142     in_shape = node.in_node(0).shape
143     slice_axis = node.axis
144
145     new_shape = np.array(in_shape, dtype=np.int64)
146     new_shape[slice_axis] = new_shape[slice_axis] / len(node.out_nodes())
147
148     axis_size = in_shape[slice_axis]
149     if node.offset < 0:
150         node.offset += axis_size
151
152     if not node.dim:
153         node.dim = axis_size
154     elif node.dim < 0:
155         node.dim += axis_size
156
157     input_dim = in_shape.size
158     node.dim = (node.dim - node.offset)
159     if node.dim > in_shape[slice_axis]:
160         raise Error(
161             '{0} node dimension value is bigger than the corresponding value in the input shape {1}. ' +
162             '\nIn particular {2} is bigger than {3}. The Model Optimizer does not support this case. ' +
163             '\nTo overcome, try to edit the original model "end" property of the {0} layer.',
164             node.name, ','.join(str(i) for i in in_shape), str(node.dim), str(in_shape[slice_axis])
165         )
166
167     for i in range(0, input_dim):
168         if i == slice_axis:
169             new_shape[i] = node.dim
170         else:
171             new_shape[i] = in_shape[i]
172
173     for i in range(0, len(node.out_nodes())):
174         node.out_node(i)['shape'] = new_shape