TVM_ATTR_FIELD(mode).set_default("clip")
.describe("Specify how out-of-bound indices will behave."
"clip - clip to the range (default)"
- "wrap - wrap around the indices");
+ "wrap - wrap around the indices"
+ "fast - no clip or wrap around (user must make sure indices are in-bound)");
}
};
the flattened input array is used.
mode : str, optional
- Specifies how out-of-bound indices will behave [clip, wrap].
+ Specifies how out-of-bound indices will behave [clip, wrap, fast].
clip: clip to the range (default).
wrap: wrap around the indices.
+ fast: no clip or wrap around (user must make sure indices are in-bound).
Returns
-------
func = relay.Function([x, indices], z)
x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype)
- ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode)
+ np_mode = "raise" if mode == "fast" else mode
+ ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=np_mode)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
verify_take((3,4), [-1, 2], axis=1)
verify_take((3,4), [-1, 2], axis=1, mode="wrap")
+ verify_take((3,3,3), [[11,25]], mode="fast")
+ verify_take((3,4), [0, 2], axis=0, mode="fast")
+ verify_take((3,4), [0, 2], axis=1, mode="fast")
def test_split_infer_type():
auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
return a(UnravelIndex(idx, a_shape));
}, name, tag);
+ } else if (mode == "fast") {
+ LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
+ "Make sure input indices are in bound";
+ return compute(
+ out_shape, [&](const Array<Var>& out_index) {
+ return a(UnravelIndex(indices(out_index), a_shape));
+ }, name, tag);
} else { // mode == "wrap"
return compute(
out_shape, [&](const Array<Var>& out_index) {
}
return a(real_indices);
}, name, tag);
+ } else if (mode == "fast") {
+ LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
+ "Make sure input indices are in bound";
+ return compute(
+ out_shape, [&](const Array<Var>& out_index) {
+ Array<Expr> indices_position;
+ for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
+ indices_position.push_back(out_index[j]);
+ }
+ Array<Expr> real_indices;
+ for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
+ real_indices.push_back(out_index[j]);
+ }
+ real_indices.push_back(indices(indices_position));
+ for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
+ real_indices.push_back(out_index[j]);
+ }
+ return a(real_indices);
+ }, name, tag);
} else { // mode == "wrap"
return compute(
out_shape, [&](const Array<Var>& out_index) {
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
wrap - wrap around the indices
+ fast - no clip or wrap around (user must make sure indices are in-bound)
Returns
-------
data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
if axis is None:
- out_npys = np.take(data_npy, indices_src, mode=mode)
+ np_mode = "raise" if mode == "fast" else mode
+ out_npys = np.take(data_npy, indices_src, mode=np_mode)
else:
- out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode)
+ np_mode = "raise" if mode == "fast" else mode
+ out_npys = np.take(data_npy, indices_src, axis=axis, mode=np_mode)
data_nd = tvm.nd.array(data_npy, ctx)
indices_nd = tvm.nd.array(indices_src, ctx)
out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
verify_take((3,4), [-1, 2], axis=1)
verify_take((3,4), [-1, 2], axis=1, mode="wrap")
+ verify_take((3,3,3), [[11,25]], mode="fast")
+ verify_take((3,4), [0, 2], axis=0, mode="fast")
+ verify_take((3,4), [0, 2], axis=1, mode="fast")
def test_gather_nd():
for indices_dtype in ['int32', 'float32']: