From: Leyuan Wang Date: Mon, 23 Mar 2020 23:52:33 +0000 (-0700) Subject: Add thrust support for nms (#5116) X-Git-Tag: upstream/0.7.0~1065 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=25d354218a72c0df6b26796f3574f45b5be99ede;p=platform%2Fupstream%2Ftvm.git Add thrust support for nms (#5116) * add argsort_nms_thrust * consider valid count in thrust nms sort * make thrust optional * typo * typo * fix pylint * address some of the comments * address more comments * fix lint * address more comments * address more comments --- diff --git a/cmake/config.cmake b/cmake/config.cmake index fd295aa..6ab362c 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -148,7 +148,7 @@ set(USE_NNPACK OFF) # Possible values: # - ON: enable tflite with cmake's find search # - OFF: disable tflite -# - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite library +# - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite library set(USE_TFLITE OFF) # /path/to/tensorflow: tensorflow root path when use tflite library diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index fc9deac..c40235d 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -28,6 +28,7 @@ #include #include #include +#include namespace tvm { namespace contrib { @@ -39,7 +40,8 @@ template void thrust_sort(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, - bool is_ascend) { + bool is_ascend, + const std::function &get_sort_len) { thrust::device_ptr data_ptr(static_cast(input->data)); thrust::device_ptr values_ptr(static_cast(out_values->data)); thrust::device_ptr indices_ptr(static_cast(out_indices->data)); @@ -53,6 +55,7 @@ void thrust_sort(DLTensor* input, thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr); for (int i = 0 ; i < n_iter; ++i) { + n_values = get_sort_len(i); thrust::sequence(indices_ptr, indices_ptr + n_values); if (is_ascend) { thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr); @@ -65,69 +68,100 @@ void thrust_sort(DLTensor* input, } } -TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_GE(args.num_args, 4); - DLTensor* input = args[0]; - DLTensor* values_out = args[1]; - DLTensor* indices_out = args[2]; - bool is_ascend = args[3]; - - auto data_dtype = DLDataType2String(input->dtype); - auto out_dtype = DLDataType2String(indices_out->dtype); - +void thrust_sort_common(DLTensor* input, + DLTensor* values_out, + DLTensor* indices_out, + bool is_ascend, + const std::function &get_sort_len, + std::string data_dtype, + std::string out_dtype) { if (data_dtype == "float32") { if (out_dtype == "int32") { - thrust_sort(input, values_out, indices_out, is_ascend); + thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); } else if (out_dtype == "int64") { - thrust_sort(input, values_out, indices_out, is_ascend); + thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); } else if (out_dtype == "float32") { - thrust_sort(input, values_out, indices_out, is_ascend); + thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); } else if (out_dtype == "float64") { - thrust_sort(input, values_out, indices_out, is_ascend); + thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "float64") { if (out_dtype == "int32") { - thrust_sort(input, values_out, indices_out, is_ascend); + thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); } else if (out_dtype == "int64") { - thrust_sort(input, values_out, indices_out, is_ascend); + thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); } else if (out_dtype == "float32") { - thrust_sort(input, values_out, indices_out, is_ascend); + thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); } else if (out_dtype == "float64") { - thrust_sort(input, values_out, indices_out, is_ascend); + thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "int32") { if (out_dtype == "int32") { - thrust_sort(input, values_out, indices_out, is_ascend); + thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); } else if (out_dtype == "int64") { - thrust_sort(input, values_out, indices_out, is_ascend); + thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); } else if (out_dtype == "float32") { - thrust_sort(input, values_out, indices_out, is_ascend); + thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); } else if (out_dtype == "float64") { - thrust_sort(input, values_out, indices_out, is_ascend); + thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "int64") { if (out_dtype == "int32") { - thrust_sort(input, values_out, indices_out, is_ascend); + thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); } else if (out_dtype == "int64") { - thrust_sort(input, values_out, indices_out, is_ascend); + thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); } else if (out_dtype == "float32") { - thrust_sort(input, values_out, indices_out, is_ascend); + thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); } else if (out_dtype == "float64") { - thrust_sort(input, values_out, indices_out, is_ascend); + thrust_sort(input, values_out, indices_out, is_ascend, get_sort_len); } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } } else { LOG(FATAL) << "Unsupported input dtype: " << data_dtype; } +} + +TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_nms") +.set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_GE(args.num_args, 5); + DLTensor* input = args[0]; + DLTensor* valid_count = args[1]; + DLTensor* values_out = args[2]; + DLTensor* indices_out = args[3]; + bool is_ascend = args[4]; + + auto data_dtype = DLDataType2String(input->dtype); + auto out_dtype = DLDataType2String(indices_out->dtype); + + thrust::device_ptr valid_count_ptr(static_cast(valid_count->data)); + auto get_sort_len = [&valid_count_ptr](int i) { return valid_count_ptr[i]; }; + thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len, + data_dtype, out_dtype); }); + +TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort") +.set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_GE(args.num_args, 4); + DLTensor* input = args[0]; + DLTensor* values_out = args[1]; + DLTensor* indices_out = args[2]; + bool is_ascend = args[3]; + + auto data_dtype = DLDataType2String(input->dtype); + auto out_dtype = DLDataType2String(indices_out->dtype); + + int n_values = input->shape[input->ndim - 1]; + auto get_sort_len = [=](int i) { return n_values; }; + thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len, + data_dtype, out_dtype); +}); } // namespace contrib } // namespace tvm diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index e008dcd..d295116 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -22,7 +22,7 @@ import tvm from tvm import te from tvm.tir import if_then_else -from .sort import argsort +from .sort import argsort, argsort_thrust from .. import tag @@ -668,8 +668,12 @@ def non_max_suppression(data, valid_count, max_output_size=-1, score_shape = (batch_size, num_anchors) score_tensor = te.compute( score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) - sort_tensor = argsort( - score_tensor, valid_count=valid_count, axis=1, is_ascend=False) + if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True): + sort_tensor = argsort_thrust( + score_tensor, valid_count=valid_count, axis=1, is_ascend=False) + else: + sort_tensor = argsort( + score_tensor, valid_count=valid_count, axis=1, is_ascend=False) sort_tensor_buf = tvm.tir.decl_buffer(sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8) diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py index 5499683..a1c70c4 100644 --- a/topi/python/topi/cuda/sort.py +++ b/topi/python/topi/cuda/sort.py @@ -24,6 +24,10 @@ from ..math import identity from ..transform import strided_slice, transpose from .. import tag +def swap(arr, axis): + """ swap arr[axis] and arr[-1] """ + return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]] + def _schedule_sort(outs): """Schedule for argsort operator. @@ -237,6 +241,64 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): return ib.get() +def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32"): + """Performs sorting along the given axis and returns an array of indicies + having same shape as an input array that index data in sorted order. + + Parameters + ---------- + data: tvm.te.Tensor + The input array. + + valid_count : tvm.te.Tensor, optional + The number of valid elements to be sorted. + + axis : int, optional + Axis long which to sort the input tensor. + + is_ascend : boolean, optional + Whether to sort in ascending or descending order. + + dtype : string, optional + DType of the output indices. + + Returns + ------- + out : tvm.te.Tensor + The output of this function. + """ + ndim = len(data.shape) + if axis < 0: + axis = ndim + axis + if axis != ndim - 1: + # Prepare for sorting along axis -1. + axes = swap(list(range(ndim)), axis) + data = transpose(data, axes) + + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", + data_alignment=8) + valid_count_buf = tvm.tir.decl_buffer(valid_count.shape, valid_count.dtype, + "valid_count_buf", data_alignment=4) + out_bufs = [ + tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8), + tvm.tir.decl_buffer(data.shape, "int32", "indices_buf", data_alignment=8) + ] + out = te.extern([data.shape, data.shape], + [data, valid_count], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.thrust.sort_nms", ins[0], ins[1], outs[0], outs[1], is_ascend), + in_buffers=[data_buf, valid_count_buf], + out_buffers=out_bufs, + dtype=[data.dtype, "int32"], + name="nms_argsort_gpu", + tag="nms_argsort_gpu") + + if axis != ndim - 1: + axes = swap(list(range(ndim)), axis) + out = [transpose(o, axes) for o in out] + + return out[1] + def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. @@ -318,8 +380,7 @@ def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32" The output of this function. """ if valid_count is not None: - # TODO: implement argsort_nms with Thrust - out = argsort(data, valid_count, axis, is_ascend, dtype) + out = argsort_nms_thrust(data, valid_count, axis, is_ascend, dtype) else: out = topk_thrust(data, 0, axis, "indices", is_ascend, dtype) return out @@ -453,13 +514,9 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int ndim = len(data.shape) axis = ndim + axis if axis < 0 else axis - def swap(arr): - """ swap arr[axis] and arr[-1] """ - return arr[:axis] + [arr[-1]] + arr[axis+1:-1] + [arr[axis]] - if axis != ndim - 1: # Prepare for sorting along axis -1. - axes = swap(list(range(ndim))) + axes = swap(list(range(ndim)), axis) data = transpose(data, axes) data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) @@ -483,7 +540,7 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int out = [strided_slice(o, beg, end) for o in out] if axis != ndim - 1: - axes = swap(list(range(ndim))) + axes = swap(list(range(ndim)), axis) out = [transpose(o, axes) for o in out] if ret_type == "values":