"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import numpy as np
-from mo.graph.graph import erase_node
from mo.utils.error import Error
+
def tf_strided_slice_infer(node):
+ if node.in_node(1).value is None or node.in_node(2).value is None:
+ raise Error('Strided slice layer supports only constant begin and end inputs')
begin_id = node.in_node(1).value
end_id = node.in_node(2).value
- stride = node.in_node(3).value
+ if len(node.in_nodes()) > 3:
+ if node.in_node(3).value is None:
+ raise Error('Strided slice layer supports only constant stride input')
+ stride = node.in_node(3).value
+ else:
+ stride = []
shape = node.in_node(0).shape
convert_negative_indices(begin_id, shape)
convert_negative_indices(end_id, shape)
- test_bit = lambda val, offset: ((1 << offset) & val != 0)
-
slice_idx = []
- shrink_axis_mask = []
- ellipsis_mask = []
- new_axis_mask = []
- dims = len(begin_id)
-
+ dims = np.amax(np.array([len(begin_id), len(end_id), len(stride),
+ len(node.shrink_axis_mask), len(node.new_axis_mask), len(node.ellipsis_mask),
+ len(node.begin_mask), len(node.end_mask)]))
+
+ # make mask correct length
+ def extend_mask(in_mask, fin_len, zeros=True):
+ mask = list(in_mask)
+ if len(mask) < fin_len:
+ if zeros:
+ mask.extend(np.zeros(dims-len(mask), dtype=np.int32))
+ else:
+ mask.extend(np.ones(dims-len(mask), dtype=np.int32))
+ return np.array(mask, dtype=np.int32)
+
+ for mask in {'new_axis_mask', 'shrink_axis_mask', 'ellipsis_mask'}:
+ node[mask] = extend_mask(node[mask], dims)
+ node.begin_mask = extend_mask(node.begin_mask, dims, False)
+ node.end_mask = extend_mask(node.end_mask, dims, False)
+
+ old_idx = 0
+ ellips_ext = 0
+ id_em = 0
for idx in range(dims):
- def_beg = 0 if stride[idx] > 0 else -1
- def_end = shape[idx] if stride[idx] > 0 else -shape[idx]-1
- l = begin_id[idx] if not test_bit(node.begin_mask, idx) else def_beg
- r = end_id[idx] if not test_bit(node.end_mask, idx) else def_end
-
- # Check shrink_axis_mask
- shrink_axis_mask.append(test_bit(node.shrink_axis_mask, idx))
- if shrink_axis_mask[idx]:
- l, r = l, l + 1
-
- # Check new_axis_mask
- new_axis_mask.append(test_bit(node.new_axis_mask, idx))
- if new_axis_mask[idx]:
+ if node.new_axis_mask[idx]:
slice_idx.append(np.newaxis)
-
- # Check ellipsis_mask
- ellipsis_mask.append(test_bit(node.ellipsis_mask, idx))
- if ellipsis_mask[idx]:
- shrink_axis_mask[idx] = False
- l, r = 0, shape[idx]
-
- slice_idx.append(slice(l, r, stride[idx]))
-
- # if masks length are less than input dims length than add slices and masks for such dims
- for idx in range(dims, len(shape)):
- slice_idx.append(slice(0, shape[idx], 1))
- shrink_axis_mask.append(False)
- new_axis_mask.append(False)
+ elif node.ellipsis_mask[idx]:
+ ellips_ext = len(shape) - (dims - np.count_nonzero(node.new_axis_mask) - 1)
+ id_em = idx
+ for i in range(0, ellips_ext):
+ slice_idx.append(slice(0, shape[old_idx], 1))
+ old_idx = old_idx + 1
+ else:
+ s = stride[idx] if len(stride) > idx else 1
+ def_beg = 0 if s > 0 else -1
+ def_end = shape[old_idx] if s > 0 else -shape[old_idx]-1
+ l = begin_id[idx] if node.begin_mask[idx] and idx < len(begin_id) else def_beg
+ r = end_id[idx] if node.end_mask[idx] and idx < len(end_id) else def_end
+
+ # Check shrink_axis_mask
+ if node.shrink_axis_mask[idx] and idx < len(shape):
+ slice_idx.append(slice(l, l+1, s))
+ else:
+ slice_idx.append(slice(l, r, s))
+ old_idx = old_idx + 1
value = node.in_node(0).value if node.in_node(0).value is not None else np.zeros(shape)
# fix for the warning: "FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated use
# `arr[tuple(seq)]` instead of `arr[seq]`"
value = value[tuple(slice_idx)]
- for idx, flag in reversed(list(enumerate(shrink_axis_mask))):
+ for idx, flag in reversed(list(enumerate(node.shrink_axis_mask))):
if flag:
- value = np.squeeze(value, idx)
+ if ellips_ext > 0 and idx > id_em:
+ idx = idx + ellips_ext - 1
+ try:
+ value = np.squeeze(value, idx)
+ except ValueError:
+ # ignore this error
+ continue
node['slices'] = np.array(slice_idx)
- node['shrink_axis_mask'] = np.array(shrink_axis_mask)
- node['new_axis_mask'] = np.array(new_axis_mask)
+ for attr in ('shrink_axis_mask', 'new_axis_mask', 'ellipsis_mask', 'begin_mask', 'end_mask'):
+ node[attr] = np.array(node[attr], dtype=np.int32)
node.out_node().value = np.array(value) if node.in_node(0).value is not None else None
- node.out_node().shape = np.array(value.shape)
+ node.out_node().shape = np.array(value.shape, dtype=np.int64)
+
+ # change precision to I32 for begin, end, stride inputs
+ for i in range(1, len(node.in_nodes())):
+ inp = node.in_node(i)
+ inp["force_precision"] = "I32"
- #remove inputs converted in attributes
- #for i in range(1,4):
- # node.graph.remove_edge(node.in_node(i).id, node.id)
def convert_negative_indices(indices: np.array, shape: np.array):
for ind, value in enumerate(indices):