Add thrust support for nms (#5116)
authorLeyuan Wang <laurawly@gmail.com>
Mon, 23 Mar 2020 23:52:33 +0000 (16:52 -0700)
committerGitHub <noreply@github.com>
Mon, 23 Mar 2020 23:52:33 +0000 (08:52 +0900)
* add argsort_nms_thrust

* consider valid count in thrust nms sort

* make thrust optional

* typo

* typo

* fix pylint

* address some of the comments

* address more comments

* fix lint

* address more comments

* address more comments

cmake/config.cmake
src/runtime/contrib/thrust/thrust.cu
topi/python/topi/cuda/nms.py
topi/python/topi/cuda/sort.py

index fd295aa..6ab362c 100644 (file)
@@ -148,7 +148,7 @@ set(USE_NNPACK OFF)
 # Possible values:
 # - ON: enable tflite with cmake's find search
 # - OFF: disable tflite
-# - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite library 
+# - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite library
 set(USE_TFLITE OFF)
 
 # /path/to/tensorflow: tensorflow root path when use tflite library
index fc9deac..c40235d 100644 (file)
@@ -28,6 +28,7 @@
 #include <dlpack/dlpack.h>
 #include <algorithm>
 #include <vector>
+#include <functional>
 
 namespace tvm {
 namespace contrib {
@@ -39,7 +40,8 @@ template<typename DataType, typename IndicesType>
 void thrust_sort(DLTensor* input,
                  DLTensor* out_values,
                  DLTensor* out_indices,
-                 bool is_ascend) {
+                 bool is_ascend,
+                 const std::function<int(int)> &get_sort_len) {
   thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
   thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data));
   thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data));
@@ -53,6 +55,7 @@ void thrust_sort(DLTensor* input,
   thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr);
 
   for (int i = 0 ; i < n_iter; ++i) {
+    n_values = get_sort_len(i);
     thrust::sequence(indices_ptr, indices_ptr + n_values);
     if (is_ascend) {
       thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
@@ -65,69 +68,100 @@ void thrust_sort(DLTensor* input,
   }
 }
 
-TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  CHECK_GE(args.num_args, 4);
-  DLTensor* input = args[0];
-  DLTensor* values_out = args[1];
-  DLTensor* indices_out = args[2];
-  bool is_ascend = args[3];
-
-  auto data_dtype = DLDataType2String(input->dtype);
-  auto out_dtype = DLDataType2String(indices_out->dtype);
-
+void thrust_sort_common(DLTensor* input,
+                        DLTensor* values_out,
+                        DLTensor* indices_out,
+                        bool is_ascend,
+                        const std::function<int(int)> &get_sort_len,
+                        std::string data_dtype,
+                        std::string out_dtype) {
   if (data_dtype == "float32") {
     if (out_dtype == "int32") {
-      thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend);
+      thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
     } else if (out_dtype == "int64") {
-      thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend);
+      thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
     } else if (out_dtype == "float32") {
-      thrust_sort<float, float>(input, values_out, indices_out, is_ascend);
+      thrust_sort<float, float>(input, values_out, indices_out, is_ascend, get_sort_len);
     } else if (out_dtype == "float64") {
-      thrust_sort<float, double>(input, values_out, indices_out, is_ascend);
+      thrust_sort<float, double>(input, values_out, indices_out, is_ascend, get_sort_len);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
     }
   } else if (data_dtype == "float64") {
     if (out_dtype == "int32") {
-      thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend);
+      thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
     } else if (out_dtype == "int64") {
-      thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend);
+      thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
     } else if (out_dtype == "float32") {
-      thrust_sort<double, float>(input, values_out, indices_out, is_ascend);
+      thrust_sort<double, float>(input, values_out, indices_out, is_ascend, get_sort_len);
     } else if (out_dtype == "float64") {
-      thrust_sort<double, double>(input, values_out, indices_out, is_ascend);
+      thrust_sort<double, double>(input, values_out, indices_out, is_ascend, get_sort_len);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
     }
   } else if (data_dtype == "int32") {
     if (out_dtype == "int32") {
-      thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend);
+      thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
     } else if (out_dtype == "int64") {
-      thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend);
+      thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
     } else if (out_dtype == "float32") {
-      thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend);
+      thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend, get_sort_len);
     } else if (out_dtype == "float64") {
-      thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend);
+      thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, get_sort_len);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
     }
   }  else if (data_dtype == "int64") {
     if (out_dtype == "int32") {
-      thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend);
+      thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
     } else if (out_dtype == "int64") {
-      thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend);
+      thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
     } else if (out_dtype == "float32") {
-      thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend);
+      thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend, get_sort_len);
     } else if (out_dtype == "float64") {
-      thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend);
+      thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, get_sort_len);
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
     }
   } else {
     LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
   }
+}
+
+TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_nms")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+  CHECK_GE(args.num_args, 5);
+  DLTensor* input = args[0];
+  DLTensor* valid_count = args[1];
+  DLTensor* values_out = args[2];
+  DLTensor* indices_out = args[3];
+  bool is_ascend = args[4];
+
+  auto data_dtype = DLDataType2String(input->dtype);
+  auto out_dtype = DLDataType2String(indices_out->dtype);
+
+  thrust::device_ptr<int> valid_count_ptr(static_cast<int *>(valid_count->data));
+  auto get_sort_len = [&valid_count_ptr](int i) { return valid_count_ptr[i]; };
+  thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
+                     data_dtype, out_dtype);
 });
 
+
+TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+  CHECK_GE(args.num_args, 4);
+  DLTensor* input = args[0];
+  DLTensor* values_out = args[1];
+  DLTensor* indices_out = args[2];
+  bool is_ascend = args[3];
+
+  auto data_dtype = DLDataType2String(input->dtype);
+  auto out_dtype = DLDataType2String(indices_out->dtype);
+
+  int n_values = input->shape[input->ndim - 1];
+  auto get_sort_len = [=](int i) { return n_values; };
+  thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
+                     data_dtype, out_dtype);
+});
 }  // namespace contrib
 }  // namespace tvm
index e008dcd..d295116 100644 (file)
@@ -22,7 +22,7 @@ import tvm
 from tvm import te
 
 from tvm.tir import if_then_else
-from .sort import argsort
+from .sort import argsort, argsort_thrust
 from .. import tag
 
 
@@ -668,8 +668,12 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
     score_shape = (batch_size, num_anchors)
     score_tensor = te.compute(
         score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
-    sort_tensor = argsort(
-        score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
+    if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True):
+        sort_tensor = argsort_thrust(
+            score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
+    else:
+        sort_tensor = argsort(
+            score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
 
     sort_tensor_buf = tvm.tir.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
                                           "sort_tensor_buf", data_alignment=8)
index 5499683..a1c70c4 100644 (file)
@@ -24,6 +24,10 @@ from ..math import identity
 from ..transform import strided_slice, transpose
 from .. import tag
 
+def swap(arr, axis):
+    """ swap arr[axis] and arr[-1] """
+    return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]]
+
 def _schedule_sort(outs):
     """Schedule for argsort operator.
 
@@ -237,6 +241,64 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
 
     return ib.get()
 
+def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32"):
+    """Performs sorting along the given axis and returns an array of indicies
+    having same shape as an input array that index data in sorted order.
+
+    Parameters
+    ----------
+    data: tvm.te.Tensor
+        The input array.
+
+    valid_count : tvm.te.Tensor, optional
+        The number of valid elements to be sorted.
+
+    axis : int, optional
+        Axis long which to sort the input tensor.
+
+    is_ascend : boolean, optional
+        Whether to sort in ascending or descending order.
+
+    dtype : string, optional
+        DType of the output indices.
+
+    Returns
+    -------
+    out : tvm.te.Tensor
+        The output of this function.
+    """
+    ndim = len(data.shape)
+    if axis < 0:
+        axis = ndim + axis
+    if axis != ndim - 1:
+        # Prepare for sorting along axis -1.
+        axes = swap(list(range(ndim)), axis)
+        data = transpose(data, axes)
+
+    data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf",
+                                   data_alignment=8)
+    valid_count_buf = tvm.tir.decl_buffer(valid_count.shape, valid_count.dtype,
+                                          "valid_count_buf", data_alignment=4)
+    out_bufs = [
+        tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8),
+        tvm.tir.decl_buffer(data.shape, "int32", "indices_buf", data_alignment=8)
+    ]
+    out = te.extern([data.shape, data.shape],
+                    [data, valid_count],
+                    lambda ins, outs: tvm.tir.call_packed(
+                        "tvm.contrib.thrust.sort_nms", ins[0], ins[1], outs[0], outs[1], is_ascend),
+                    in_buffers=[data_buf, valid_count_buf],
+                    out_buffers=out_bufs,
+                    dtype=[data.dtype, "int32"],
+                    name="nms_argsort_gpu",
+                    tag="nms_argsort_gpu")
+
+    if axis != ndim - 1:
+        axes = swap(list(range(ndim)), axis)
+        out = [transpose(o, axes) for o in out]
+
+    return out[1]
+
 def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
     """Performs sorting along the given axis and returns an array of indicies
     having same shape as an input array that index data in sorted order.
@@ -318,8 +380,7 @@ def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"
         The output of this function.
     """
     if valid_count is not None:
-        # TODO: implement argsort_nms with Thrust
-        out = argsort(data, valid_count, axis, is_ascend, dtype)
+        out = argsort_nms_thrust(data, valid_count, axis, is_ascend, dtype)
     else:
         out = topk_thrust(data, 0, axis, "indices", is_ascend, dtype)
     return out
@@ -453,13 +514,9 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int
     ndim = len(data.shape)
     axis = ndim + axis if axis < 0 else axis
 
-    def swap(arr):
-        """ swap arr[axis] and arr[-1] """
-        return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]]
-
     if axis != ndim - 1:
         # Prepare for sorting along axis -1.
-        axes = swap(list(range(ndim)))
+        axes = swap(list(range(ndim)), axis)
         data = transpose(data, axes)
 
     data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
@@ -483,7 +540,7 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int
         out = [strided_slice(o, beg, end) for o in out]
 
     if axis != ndim - 1:
-        axes = swap(list(range(ndim)))
+        axes = swap(list(range(ndim)), axis)
         out = [transpose(o, axes) for o in out]
 
     if ret_type == "values":