[Topi] Fast mode in take op (#3325)
authorhlu1 <14827759+hlu1@users.noreply.github.com>
Tue, 11 Jun 2019 23:32:12 +0000 (16:32 -0700)
committerHaichen Shen <shenhaichen@gmail.com>
Tue, 11 Jun 2019 23:32:12 +0000 (16:32 -0700)
include/tvm/relay/attrs/transform.h
python/tvm/relay/op/transform.py
tests/python/relay/test_op_level3.py
topi/include/topi/transform.h
topi/python/topi/transform.py
topi/tests/python/test_topi_transform.py

index 1b82412..65febae 100644 (file)
@@ -101,7 +101,8 @@ struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
     TVM_ATTR_FIELD(mode).set_default("clip")
         .describe("Specify how out-of-bound indices will behave."
                   "clip - clip to the range (default)"
-                  "wrap - wrap around the indices");
+                  "wrap - wrap around the indices"
+                  "fast - no clip or wrap around (user must make sure indices are in-bound)");
   }
 };
 
index 02fd492..dce2258 100644 (file)
@@ -218,9 +218,10 @@ def take(data, indices, axis=None, mode="clip"):
         the flattened input array is used.
 
     mode : str, optional
-        Specifies how out-of-bound indices will behave [clip, wrap].
+        Specifies how out-of-bound indices will behave [clip, wrap, fast].
         clip: clip to the range (default).
         wrap: wrap around the indices.
+        fast: no clip or wrap around (user must make sure indices are in-bound).
 
     Returns
     -------
index 15cb326..a878d79 100644 (file)
@@ -269,7 +269,8 @@ def test_take():
 
         func = relay.Function([x, indices], z)
         x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype)
-        ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode)
+        np_mode = "raise" if mode == "fast" else mode
+        ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=np_mode)
 
         for target, ctx in ctx_list():
             for kind in ["graph", "debug"]:
@@ -291,6 +292,9 @@ def test_take():
     verify_take((3,4), [-1, 2], axis=0, mode="wrap")
     verify_take((3,4), [-1, 2], axis=1)
     verify_take((3,4), [-1, 2], axis=1, mode="wrap")
+    verify_take((3,3,3), [[11,25]], mode="fast")
+    verify_take((3,4), [0, 2], axis=0, mode="fast")
+    verify_take((3,4), [0, 2], axis=1, mode="fast")
 
 
 def test_split_infer_type():
index 4dba4ea..c992be6 100644 (file)
@@ -641,6 +641,13 @@ inline Tensor take(const Tensor& a,
           auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
           return a(UnravelIndex(idx, a_shape));
         }, name, tag);
+  } else if (mode == "fast") {
+    LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
+                    "Make sure input indices are in bound";
+    return compute(
+        out_shape, [&](const Array<Var>& out_index) {
+          return a(UnravelIndex(indices(out_index), a_shape));
+        }, name, tag);
   } else {  // mode == "wrap"
     return compute(
         out_shape, [&](const Array<Var>& out_index) {
@@ -706,6 +713,25 @@ inline Tensor take(const Tensor& a,
           }
           return a(real_indices);
         }, name, tag);
+  } else if (mode == "fast") {
+    LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
+                    "Make sure input indices are in bound";
+    return compute(
+        out_shape, [&](const Array<Var>& out_index) {
+          Array<Expr> indices_position;
+          for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
+            indices_position.push_back(out_index[j]);
+          }
+          Array<Expr> real_indices;
+          for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
+            real_indices.push_back(out_index[j]);
+          }
+          real_indices.push_back(indices(indices_position));
+          for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
+            real_indices.push_back(out_index[j]);
+          }
+          return a(real_indices);
+        }, name, tag);
   } else {  // mode == "wrap"
     return compute(
         out_shape, [&](const Array<Var>& out_index) {
index 04af151..3d7293e 100644 (file)
@@ -265,6 +265,7 @@ def take(a, indices, axis=None, mode="clip"):
         Specifies how out-of-bound indices will behave.
         clip - clip to the range (default)
         wrap - wrap around the indices
+        fast - no clip or wrap around (user must make sure indices are in-bound)
 
     Returns
     -------
index d29fb64..5682fde 100644 (file)
@@ -275,9 +275,11 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"):
         data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
 
         if axis is None:
-            out_npys = np.take(data_npy, indices_src, mode=mode)
+            np_mode = "raise" if mode == "fast" else mode
+            out_npys = np.take(data_npy, indices_src, mode=np_mode)
         else:
-            out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode)
+            np_mode = "raise" if mode == "fast" else mode
+            out_npys = np.take(data_npy, indices_src, axis=axis, mode=np_mode)
         data_nd = tvm.nd.array(data_npy, ctx)
         indices_nd = tvm.nd.array(indices_src, ctx)
         out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
@@ -521,6 +523,9 @@ def test_take():
     verify_take((3,4), [-1, 2], axis=0, mode="wrap")
     verify_take((3,4), [-1, 2], axis=1)
     verify_take((3,4), [-1, 2], axis=1, mode="wrap")
+    verify_take((3,3,3), [[11,25]], mode="fast")
+    verify_take((3,4), [0, 2], axis=0, mode="fast")
+    verify_take((3,4), [0, 2], axis=1, mode="fast")
 
 def test_gather_nd():
     for indices_dtype in ['int32', 'float32']: