2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from mo.graph.graph import erase_node
20 from mo.utils.error import Error
22 def tf_strided_slice_infer(node):
23 begin_id = node.in_node(1).value
24 end_id = node.in_node(2).value
25 stride = node.in_node(3).value
27 shape = node.in_node(0).shape
29 if shape is None or any([x < 0 for x in shape]):
32 convert_negative_indices(begin_id, shape)
33 convert_negative_indices(end_id, shape)
35 test_bit = lambda val, offset: ((1 << offset) & val != 0)
43 for idx in range(dims):
44 def_beg = 0 if stride[idx] > 0 else -1
45 def_end = shape[idx] if stride[idx] > 0 else -shape[idx]-1
46 l = begin_id[idx] if not test_bit(node.begin_mask, idx) else def_beg
47 r = end_id[idx] if not test_bit(node.end_mask, idx) else def_end
49 # Check shrink_axis_mask
50 shrink_axis_mask.append(test_bit(node.shrink_axis_mask, idx))
51 if shrink_axis_mask[idx]:
55 new_axis_mask.append(test_bit(node.new_axis_mask, idx))
56 if new_axis_mask[idx]:
57 slice_idx.append(np.newaxis)
60 ellipsis_mask.append(test_bit(node.ellipsis_mask, idx))
61 if ellipsis_mask[idx]:
62 shrink_axis_mask[idx] = False
65 slice_idx.append(slice(l, r, stride[idx]))
67 # if masks length are less than input dims length than add slices and masks for such dims
68 for idx in range(dims, len(shape)):
69 slice_idx.append(slice(0, shape[idx], 1))
70 shrink_axis_mask.append(False)
71 new_axis_mask.append(False)
73 value = node.in_node(0).value if node.in_node(0).value is not None else np.zeros(shape)
74 # fix for the warning: "FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated use
75 # `arr[tuple(seq)]` instead of `arr[seq]`"
76 value = value[tuple(slice_idx)]
78 for idx, flag in reversed(list(enumerate(shrink_axis_mask))):
80 value = np.squeeze(value, idx)
82 node['slices'] = np.array(slice_idx)
83 node['shrink_axis_mask'] = np.array(shrink_axis_mask)
84 node['new_axis_mask'] = np.array(new_axis_mask)
86 node.out_node().value = np.array(value) if node.in_node(0).value is not None else None
87 node.out_node().shape = np.array(value.shape)
89 #remove inputs converted in attributes
91 # node.graph.remove_edge(node.in_node(i).id, node.id)
93 def convert_negative_indices(indices: np.array, shape: np.array):
94 for ind, value in enumerate(indices):
96 indices[ind] += shape[ind]
99 def caffe_slice_infer(node):
101 Slices an input layer to multiple output layers along a given dimension
102 with given slice indices
108 top_shape = node.in_node(0).shape
109 slice_axis = node.axis
110 bottom_slice_axis = node.in_node(0).shape[node.axis]
111 if len(node.slice_point) == 0:
112 new_shape = np.array(top_shape, dtype=np.int64)
113 new_shape[slice_axis] = bottom_slice_axis / len(node.out_nodes())
114 for i in range(0, len(node.out_nodes())):
115 node.out_node(i).shape = new_shape
118 assert (len(node.slice_point) == len(node.out_nodes()) - 1)
121 for slice_point in node.slice_point:
122 if slice_point <= prev:
124 'Check failed for the layer {}. Slice points should be ordered in increasing manner. '.format(node.id) +
125 'Current slice point {} is not greater than the previous slice point {}. '.format(slice_point, prev) +
126 'Please verify your model correctness')
127 slices.append(slice_point - prev)
130 slices.append(bottom_slice_axis - prev)
131 if sum(slices) != bottom_slice_axis:
133 'Check failed for the layer {}. Sum of slices points {} does not equal '.format(node.id, sum(slices)) +
134 'to the value of input blob shape by the given slice axis {}'.format(bottom_slice_axis))
135 for i in range(len(node.out_nodes())):
136 new_shape = np.array(top_shape, dtype=np.int64)
137 new_shape[slice_axis] = slices[i]
138 node.out_node(i).shape = new_shape
141 def mxnet_slice_axis_infer(node):
142 in_shape = node.in_node(0).shape
143 slice_axis = node.axis
145 new_shape = np.array(in_shape, dtype=np.int64)
146 new_shape[slice_axis] = new_shape[slice_axis] / len(node.out_nodes())
148 axis_size = in_shape[slice_axis]
150 node.offset += axis_size
155 node.dim += axis_size
157 input_dim = in_shape.size
158 node.dim = (node.dim - node.offset)
159 if node.dim > in_shape[slice_axis]:
161 '{0} node dimension value is bigger than the corresponding value in the input shape {1}. ' +
162 '\nIn particular {2} is bigger than {3}. The Model Optimizer does not support this case. ' +
163 '\nTo overcome, try to edit the original model "end" property of the {0} layer.',
164 node.name, ','.join(str(i) for i in in_shape), str(node.dim), str(in_shape[slice_axis])
167 for i in range(0, input_dim):
169 new_shape[i] = node.dim
171 new_shape[i] = in_shape[i]
173 for i in range(0, len(node.out_nodes())):
174 node.out_node(i)['shape'] = new_shape