2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from mo.utils.error import Error
22 def tf_strided_slice_infer(node):
23 if node.in_node(1).value is None or node.in_node(2).value is None:
24 raise Error('Strided slice layer supports only constant begin and end inputs')
25 begin_id = node.in_node(1).value
26 end_id = node.in_node(2).value
27 if len(node.in_nodes()) > 3:
28 if node.in_node(3).value is None:
29 raise Error('Strided slice layer supports only constant stride input')
30 stride = node.in_node(3).value
34 shape = node.in_node(0).shape
36 if shape is None or any([x < 0 for x in shape]):
39 convert_negative_indices(begin_id, shape)
40 convert_negative_indices(end_id, shape)
43 dims = np.amax(np.array([len(begin_id), len(end_id), len(stride),
44 len(node.shrink_axis_mask), len(node.new_axis_mask), len(node.ellipsis_mask),
45 len(node.begin_mask), len(node.end_mask)]))
47 # make mask correct length
48 def extend_mask(in_mask, fin_len, zeros=True):
50 if len(mask) < fin_len:
52 mask.extend(np.zeros(dims-len(mask), dtype=np.int32))
54 mask.extend(np.ones(dims-len(mask), dtype=np.int32))
55 return np.array(mask, dtype=np.int32)
57 for mask in {'new_axis_mask', 'shrink_axis_mask', 'ellipsis_mask'}:
58 node[mask] = extend_mask(node[mask], dims)
59 node.begin_mask = extend_mask(node.begin_mask, dims, False)
60 node.end_mask = extend_mask(node.end_mask, dims, False)
65 for idx in range(dims):
66 if node.new_axis_mask[idx]:
67 slice_idx.append(np.newaxis)
68 elif node.ellipsis_mask[idx]:
69 ellips_ext = len(shape) - (dims - np.count_nonzero(node.new_axis_mask) - 1)
71 for i in range(0, ellips_ext):
72 slice_idx.append(slice(0, shape[old_idx], 1))
75 s = stride[idx] if len(stride) > idx else 1
76 def_beg = 0 if s > 0 else -1
77 def_end = shape[old_idx] if s > 0 else -shape[old_idx]-1
78 l = begin_id[idx] if node.begin_mask[idx] and idx < len(begin_id) else def_beg
79 r = end_id[idx] if node.end_mask[idx] and idx < len(end_id) else def_end
81 # Check shrink_axis_mask
82 if node.shrink_axis_mask[idx] and idx < len(shape):
83 slice_idx.append(slice(l, l+1, s))
85 slice_idx.append(slice(l, r, s))
88 value = node.in_node(0).value if node.in_node(0).value is not None else np.zeros(shape)
89 # fix for the warning: "FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated use
90 # `arr[tuple(seq)]` instead of `arr[seq]`"
91 value = value[tuple(slice_idx)]
93 for idx, flag in reversed(list(enumerate(node.shrink_axis_mask))):
95 if ellips_ext > 0 and idx > id_em:
96 idx = idx + ellips_ext - 1
98 value = np.squeeze(value, idx)
103 node['slices'] = np.array(slice_idx)
104 for attr in ('shrink_axis_mask', 'new_axis_mask', 'ellipsis_mask', 'begin_mask', 'end_mask'):
105 node[attr] = np.array(node[attr], dtype=np.int32)
107 node.out_node().value = np.array(value) if node.in_node(0).value is not None else None
108 node.out_node().shape = np.array(value.shape, dtype=np.int64)
110 # change precision to I32 for begin, end, stride inputs
111 for i in range(1, len(node.in_nodes())):
112 inp = node.in_node(i)
113 inp["force_precision"] = "I32"
116 def convert_negative_indices(indices: np.array, shape: np.array):
117 for ind, value in enumerate(indices):
119 indices[ind] += shape[ind]
122 def caffe_slice_infer(node):
124 Slices an input layer to multiple output layers along a given dimension
125 with given slice indices
131 top_shape = node.in_node(0).shape
132 slice_axis = node.axis
133 bottom_slice_axis = node.in_node(0).shape[node.axis]
134 if len(node.slice_point) == 0:
135 new_shape = np.array(top_shape, dtype=np.int64)
136 new_shape[slice_axis] = bottom_slice_axis / len(node.out_nodes())
137 for i in range(0, len(node.out_nodes())):
138 node.out_node(i).shape = new_shape
141 assert (len(node.slice_point) == len(node.out_nodes()) - 1)
144 for slice_point in node.slice_point:
145 if slice_point <= prev:
147 'Check failed for the layer {}. Slice points should be ordered in increasing manner. '.format(node.id) +
148 'Current slice point {} is not greater than the previous slice point {}. '.format(slice_point, prev) +
149 'Please verify your model correctness')
150 slices.append(slice_point - prev)
153 slices.append(bottom_slice_axis - prev)
154 if sum(slices) != bottom_slice_axis:
156 'Check failed for the layer {}. Sum of slices points {} does not equal '.format(node.id, sum(slices)) +
157 'to the value of input blob shape by the given slice axis {}'.format(bottom_slice_axis))
158 for i in range(len(node.out_nodes())):
159 new_shape = np.array(top_shape, dtype=np.int64)
160 new_shape[slice_axis] = slices[i]
161 node.out_node(i).shape = new_shape
164 def mxnet_slice_axis_infer(node):
165 in_shape = node.in_node(0).shape
166 slice_axis = node.axis
168 new_shape = np.array(in_shape, dtype=np.int64)
169 new_shape[slice_axis] = new_shape[slice_axis] / len(node.out_nodes())
171 axis_size = in_shape[slice_axis]
173 node.offset += axis_size
178 node.dim += axis_size
180 input_dim = in_shape.size
181 node.dim = (node.dim - node.offset)
182 if node.dim > in_shape[slice_axis]:
184 '{0} node dimension value is bigger than the corresponding value in the input shape {1}. ' +
185 '\nIn particular {2} is bigger than {3}. The Model Optimizer does not support this case. ' +
186 '\nTo overcome, try to edit the original model "end" property of the {0} layer.',
187 node.name, ','.join(str(i) for i in in_shape), str(node.dim), str(in_shape[slice_axis])
190 for i in range(0, input_dim):
192 new_shape[i] = node.dim
194 new_shape[i] = in_shape[i]
196 for i in range(0, len(node.out_nodes())):
197 node.out_node(i)['shape'] = new_shape