From 2c41fd2f038e90539479ab08370916c1ecd95d2b Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Tue, 11 Jun 2019 16:32:12 -0700 Subject: [PATCH] [Topi] Fast mode in take op (#3325) --- include/tvm/relay/attrs/transform.h | 3 ++- python/tvm/relay/op/transform.py | 3 ++- tests/python/relay/test_op_level3.py | 6 +++++- topi/include/topi/transform.h | 26 ++++++++++++++++++++++++++ topi/python/topi/transform.py | 1 + topi/tests/python/test_topi_transform.py | 9 +++++++-- 6 files changed, 43 insertions(+), 5 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 1b82412..65febae 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -101,7 +101,8 @@ struct TakeAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(mode).set_default("clip") .describe("Specify how out-of-bound indices will behave." "clip - clip to the range (default)" - "wrap - wrap around the indices"); + "wrap - wrap around the indices" + "fast - no clip or wrap around (user must make sure indices are in-bound)"); } }; diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 02fd492..dce2258 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -218,9 +218,10 @@ def take(data, indices, axis=None, mode="clip"): the flattened input array is used. mode : str, optional - Specifies how out-of-bound indices will behave [clip, wrap]. + Specifies how out-of-bound indices will behave [clip, wrap, fast]. clip: clip to the range (default). wrap: wrap around the indices. + fast: no clip or wrap around (user must make sure indices are in-bound). Returns ------- diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 15cb326..a878d79 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -269,7 +269,8 @@ def test_take(): func = relay.Function([x, indices], z) x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype) - ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode) + np_mode = "raise" if mode == "fast" else mode + ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=np_mode) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: @@ -291,6 +292,9 @@ def test_take(): verify_take((3,4), [-1, 2], axis=0, mode="wrap") verify_take((3,4), [-1, 2], axis=1) verify_take((3,4), [-1, 2], axis=1, mode="wrap") + verify_take((3,3,3), [[11,25]], mode="fast") + verify_take((3,4), [0, 2], axis=0, mode="fast") + verify_take((3,4), [0, 2], axis=1, mode="fast") def test_split_infer_type(): diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 4dba4ea..c992be6 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -641,6 +641,13 @@ inline Tensor take(const Tensor& a, auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1); return a(UnravelIndex(idx, a_shape)); }, name, tag); + } else if (mode == "fast") { + LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. " + "Make sure input indices are in bound"; + return compute( + out_shape, [&](const Array& out_index) { + return a(UnravelIndex(indices(out_index), a_shape)); + }, name, tag); } else { // mode == "wrap" return compute( out_shape, [&](const Array& out_index) { @@ -706,6 +713,25 @@ inline Tensor take(const Tensor& a, } return a(real_indices); }, name, tag); + } else if (mode == "fast") { + LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. " + "Make sure input indices are in bound"; + return compute( + out_shape, [&](const Array& out_index) { + Array indices_position; + for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { + indices_position.push_back(out_index[j]); + } + Array real_indices; + for (size_t j = 0; j < static_cast(axis); ++j) { + real_indices.push_back(out_index[j]); + } + real_indices.push_back(indices(indices_position)); + for (size_t j = axis + indices_len; j < out_index.size(); ++j) { + real_indices.push_back(out_index[j]); + } + return a(real_indices); + }, name, tag); } else { // mode == "wrap" return compute( out_shape, [&](const Array& out_index) { diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 04af151..3d7293e 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -265,6 +265,7 @@ def take(a, indices, axis=None, mode="clip"): Specifies how out-of-bound indices will behave. clip - clip to the range (default) wrap - wrap around the indices + fast - no clip or wrap around (user must make sure indices are in-bound) Returns ------- diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index d29fb64..5682fde 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -275,9 +275,11 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"): data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape)) if axis is None: - out_npys = np.take(data_npy, indices_src, mode=mode) + np_mode = "raise" if mode == "fast" else mode + out_npys = np.take(data_npy, indices_src, mode=np_mode) else: - out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode) + np_mode = "raise" if mode == "fast" else mode + out_npys = np.take(data_npy, indices_src, axis=axis, mode=np_mode) data_nd = tvm.nd.array(data_npy, ctx) indices_nd = tvm.nd.array(indices_src, ctx) out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype) @@ -521,6 +523,9 @@ def test_take(): verify_take((3,4), [-1, 2], axis=0, mode="wrap") verify_take((3,4), [-1, 2], axis=1) verify_take((3,4), [-1, 2], axis=1, mode="wrap") + verify_take((3,3,3), [[11,25]], mode="fast") + verify_take((3,4), [0, 2], axis=0, mode="fast") + verify_take((3,4), [0, 2], axis=1, mode="fast") def test_gather_nd(): for indices_dtype in ['int32', 'float32']: -- 2.7.4