Implement kthvalue in ATen (#17544)
authorThomas Viehmann <tv.code@beamnet.de>
Sat, 2 Mar 2019 02:57:02 +0000 (18:57 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 2 Mar 2019 03:00:10 +0000 (19:00 -0800)
Summary:
The CPU version is based on the TH version.
The GPU version is based on #8406 by Pararth Shah (thank you).

CPU quickselect based on that in TH's THTensorMoreMath.cpp, but with C++ (quickselectnoindex will be achieved by a different swap)
CPU kthvalue is based on the THTensor function in the same file.
The dim_apply function is a C++ replacement for TH_TENSOR_DIM_APPLYx macros.
The CUDA kernel uses functions adapted from the THCTensorSortK implementation.
In particular radixSelect is from THCTensorTopK.cuh.
The CUDA launcher code replaces a bunch of macros with C++. It will be re-used in one of the following patches.

Plan for further PRs:
- This
- Sort
- TopK + Mode + Median in any order
- Rip out THC stuff.

There may be utility functions / structs in the SortingCommon.cuh that come into
relevance only with sort.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17544

Differential Revision: D14286934

Pulled By: ezyang

fbshipit-source-id: 35dbea050b097e88777ac5fa5c0f499d5e23c738

aten/src/ATen/Declarations.cwrap
aten/src/ATen/native/Sorting.cpp [new file with mode: 0644]
aten/src/ATen/native/SortingUtils.h [new file with mode: 0644]
aten/src/ATen/native/TensorCompare.cpp
aten/src/ATen/native/cuda/SortingCommon.cuh [new file with mode: 0644]
aten/src/ATen/native/cuda/SortingKthValue.cu [new file with mode: 0644]
aten/src/ATen/native/cuda/SortingRadixSelect.cuh [new file with mode: 0644]
aten/src/ATen/native/native_functions.yaml
test/data/test_cuda_ignores.txt

index ad2a0e2..2118d8f 100644 (file)
           default: "false"
 ]]
 [[
-  name: _th_kthvalue
-  backends:
-    - CPU
-  variants: function
-  cname: kthvalue
-  return: argument 0,1
-  scalar_check: self_->dim() == 0 || (keepdim == false && self_->dim() == 1)
-  arguments:
-    - arg: THTensor* values
-      output: True
-    - arg: THIndexTensor* indices
-      output: True
-    - THTensor* self
-    - long k
-    - arg: long dim
-      wrap_dim: self
-      default: __last_dim
-    - arg: bool keepdim
-      default: "false"
-]]
-[[
   name: _th_mode
   variants: function
   cname: mode
diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp
new file mode 100644 (file)
index 0000000..4bc3c2c
--- /dev/null
@@ -0,0 +1,195 @@
+#include <ATen/ATen.h>
+#include <ATen/Parallel.h>
+#include <ATen/WrapDimUtils.h>
+#include <ATen/native/SortingUtils.h>
+
+namespace at {
+namespace native {
+
+namespace {
+
+// maybe these days, one should define a random access iterator and use
+// std::sort...
+/* Note from TH:
+
+   I cut and pasted (slightly adapted) the quicksort code from
+   Sedgewick's 1978 "Implementing Quicksort Programs" article
+   http://www.csie.ntu.edu.tw/~b93076/p847-sedgewick.pdf
+
+   It is the state of the art existing implementation. The macros
+   are here to make as close a match as possible to the pseudocode of
+   Program 2 p.851
+
+   Note that other partition schemes exist, and are typically presented
+   in textbook, but those are less efficient. See e.g.
+   http://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto
+
+   Julien, November 12th 2013
+*/
+
+constexpr int64_t MAX_LEVELS = 300;
+constexpr int64_t M_SMALL = 10; // Limit for small subfiles
+
+template <typename Fn>
+void dim_apply(TensorList tensors, int64_t dim, Fn f) {
+  AT_ASSERT(tensors.size() > 0);
+  auto t = tensors[0];
+  auto sizes = t.sizes();
+  int64_t ndim = t.dim();
+  int64_t itersize = 1;
+  for (int64_t i = 0; i < ndim; i++) {
+    if (i != dim) {
+      itersize *= t.size(i);
+    }
+  }
+  parallel_for(0, itersize, 1, [&](int64_t i_begin, int64_t i_end) {
+    std::vector<Tensor> narrowed_tensors;
+    narrowed_tensors.reserve(tensors.size());
+    for (int64_t it = i_begin; it < i_end; it++) {
+      narrowed_tensors.clear();
+      for (auto ti : tensors) {
+        int64_t i = it;
+        Tensor nt = ti;
+        for (size_t d = 0; d < ndim; d++) {
+          if (d != dim) {
+            // this could be avoided for slower-changing dimensions if done
+            // better
+            nt = nt.select((d > dim ? 1 : 0), i % sizes[d]);
+            i = i / sizes[d];
+          }
+        }
+        narrowed_tensors.emplace_back(nt);
+      }
+      f(it, narrowed_tensors);
+    }
+  });
+}
+
+template <typename scalar_t, typename Fn>
+void quick_select_template(
+    TensorAccessor<scalar_t, 1> arr,
+    int64_t k,
+    Fn swap_fn) {
+  int64_t P, L, R, i, j, swap;
+  scalar_t rswap, piv;
+  L = 0;
+  R = arr.size(0) - 1;
+
+  do {
+    if (R <= L) // One element only
+      return;
+
+    if (R == L + 1) { // Two elements only
+      if (arr[L] > arr[R]) {
+        swap_fn(L, R);
+      }
+      return;
+    }
+
+    // Use median of three for pivot choice
+    P = (L + R) >> 1;
+    swap_fn(P, L + 1);
+    if (arr[L + 1] > arr[R]) {
+      swap_fn(L + 1, R);
+    }
+    if (arr[L] > arr[R]) {
+      swap_fn(L, R);
+    }
+    if (arr[L + 1] > arr[L]) {
+      swap_fn(L + 1, L);
+    }
+
+    i = L + 1;
+    j = R;
+    piv = arr[L];
+    do {
+      do
+        i++;
+      while (arr[i] < piv);
+      do
+        j--;
+      while (arr[j] > piv);
+      if (j < i)
+        break;
+      swap_fn(i, j);
+    } while (1);
+    swap_fn(L, j);
+
+    // Re-set active partition
+    if (j <= k)
+      L = i;
+    if (j >= k)
+      R = j - 1;
+  } while (1);
+}
+
+} // namespace
+
+std::tuple<Tensor&, Tensor&> kthvalue_out_cpu(
+    Tensor& values,
+    Tensor& indices,
+    const Tensor& self,
+    int64_t k,
+    int64_t dim_,
+    bool keepdim) {
+  int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
+  // FIXME: This seems bogus, I only do this because it was the old behaviour.
+  //        The reductions are fine, as long as the axis being reduced along
+  //        isn't of 0 elements (and the output has elements).
+  AT_CHECK(
+      self.numel() > 0,
+      "cannot perform reduction function kthvalue",
+      " on tensor with no elements because the operation does not have an identity");
+  AT_CHECK(
+      k > 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
+      "selected index k out of range");
+
+  _reduction_with_indices_allocate_or_resize_output(
+      values, indices, self, dim_, keepdim);
+  if (self.dim() == 0 && self.numel() == 1) {
+    values.copy_(self);
+    indices.zero_();
+    return std::forward_as_tuple(values, indices);
+  }
+  auto tmp_values = self.clone();
+  auto tmp_indices = at::empty(self.sizes(), self.options().dtype(kLong));
+  AT_DISPATCH_ALL_TYPES(self.type(), "kthvalue", [&] {
+    dim_apply(
+        {tmp_values, tmp_indices, values, indices},
+        dim,
+        [&](int64_t i, TensorList tl) {
+          auto tmp_values = tl[0].accessor<scalar_t, 1>();
+          auto tmp_indices = tl[1].accessor<int64_t, 1>();
+          scalar_t* mode_value = tl[2].data<scalar_t>();
+          int64_t* mode_index = tl[3].data<int64_t>();
+          for (int64_t j = 0; j < tmp_indices.size(0); j++) {
+            tmp_indices[j] = j;
+          }
+          quick_select_template(tmp_values, k - 1, [&](int64_t i, int64_t j) {
+            std::swap(tmp_values[i], tmp_values[j]);
+            std::swap(tmp_indices[i], tmp_indices[j]);
+          });
+          *mode_value = tmp_values[k - 1];
+          *mode_index = tmp_indices[k - 1];
+        });
+  });
+  if (!keepdim) {
+    values.squeeze_(dim);
+    indices.squeeze_(dim);
+  }
+  return std::forward_as_tuple(values, indices);
+}
+
+std::tuple<Tensor, Tensor> kthvalue(
+    const Tensor& self,
+    int64_t k,
+    int64_t dim,
+    bool keepdim) {
+  Tensor values = at::empty({0}, self.options());
+  Tensor indices = at::empty({0}, self.options().dtype(kLong));
+  at::kthvalue_out(values, indices, self, k, dim, keepdim);
+  return std::make_tuple(values, indices);
+}
+
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/SortingUtils.h b/aten/src/ATen/native/SortingUtils.h
new file mode 100644 (file)
index 0000000..9cc7afb
--- /dev/null
@@ -0,0 +1,48 @@
+#pragma once
+
+namespace at {
+namespace native {
+
+// ensure we get good values and indices for kthvalue, mode, median
+// this will always be with the reducing dim as 1-d
+static void _reduction_with_indices_allocate_or_resize_output(
+    Tensor& values,
+    Tensor& indices,
+    const Tensor& self,
+    int64_t dim_,
+    bool keepdim) {
+  int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
+  auto result_sizes = self.sizes().vec();
+  if (result_sizes.size() > 0) {
+    result_sizes[dim] = 1;
+  }
+  if (values.defined()) {
+    AT_CHECK(
+        self.type() == values.type(),
+        "output values must be of same type as input");
+    if (!keepdim && values.dim() == self.dim() - 1) {
+      // unsqueeze to preserve passed in noncontiguous tensor in resize
+      values.unsqueeze_(dim);
+    }
+    values.resize_(result_sizes);
+  } else {
+    values = at::empty(result_sizes, self.options());
+  }
+  if (indices.defined()) {
+    AT_CHECK(
+        indices.dtype() == kLong, "output indices must be of scalar type Long");
+    AT_CHECK(
+        indices.device() == self.device(),
+        "output indices must be on same device as input");
+    if (!keepdim && indices.dim() == self.dim() - 1) {
+      // unsqueeze to preserve passed in noncontiguous tensor in resize
+      indices.unsqueeze_(dim);
+    }
+    indices.resize_(result_sizes);
+  } else {
+    indices = at::empty(result_sizes, self.options().dtype(kLong));
+  }
+}
+
+} // namespace native
+} // namespace at
index 0f58f4f..eaac69b 100644 (file)
@@ -97,26 +97,6 @@ Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& o
   return ret;
 }
 
-std::tuple<Tensor, Tensor> kthvalue(const Tensor& self, int64_t k, int64_t dim, bool keepdim) {
-  Tensor values = at::empty({0}, self.options());
-  Tensor indices = at::empty({0}, self.options().dtype(kLong));
-  return at::native::kthvalue_out(values, indices, self, k, dim, keepdim);
-}
-
-std::tuple<Tensor &,Tensor &> kthvalue_out(Tensor& values, Tensor& indices,
-                                           const Tensor& self, int64_t k, int64_t dim, bool keepdim) {
-  AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
-           "kthvalue only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
-  dim = maybe_wrap_dim(dim, self.dim());
-  if (_dimreduce_return_trivial_no_ident(values, self, dim, keepdim, "kthvalue")) {
-    AT_ASSERT(values.dim() == 0);
-    indices.resize_({}).fill_(0);
-    return std::forward_as_tuple(values, indices);
-  } else {
-    return at::legacy::th::_th_kthvalue_out(values, indices, self, k, dim, keepdim);
-  }
-}
-
 std::tuple<Tensor, Tensor> median(const Tensor& self, int64_t dim, bool keepdim) {
   Tensor values = at::empty({0}, self.options());
   Tensor indices = at::empty({0}, self.options().dtype(kLong));
diff --git a/aten/src/ATen/native/cuda/SortingCommon.cuh b/aten/src/ATen/native/cuda/SortingCommon.cuh
new file mode 100644 (file)
index 0000000..1f7988d
--- /dev/null
@@ -0,0 +1,226 @@
+#include <ATen/ATen.h>
+#include <ATen/native/SortingUtils.h>
+#include <assert.h>
+#include <c10/macros/Macros.h>
+#include <stdlib.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <ATen/cuda/detail/TensorInfo.cuh>
+#include <THC/THCDeviceUtils.cuh> // only for THCRoundUp?
+#include <THC/THCNumerics.cuh>
+#include <THC/THCScanUtils.cuh>
+#include <THC/THCTensorMathReduce.cuh> // AddOp
+
+namespace at {
+namespace native {
+
+#if defined(__HIP_PLATFORM_HCC__)
+constexpr int WARP_SIZE = 64;
+constexpr int MAX_BLOCK_SIZE = 256;
+
+#else
+constexpr int WARP_SIZE = 32;
+constexpr int MAX_BLOCK_SIZE = 1024;
+#endif
+
+// Maximum size per grid dimension that we assume (compute capability >= 2.0)
+constexpr int64_t MAX_GRID_SIZE = 65535LL;
+
+static bool getGridFromTiles(int64_t gridTiles, dim3& grid) {
+  if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) {
+    return false;
+  }
+
+  int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
+  int64_t gridY = 1;
+  int64_t gridZ = 1;
+
+  if (gridTiles > MAX_GRID_SIZE) {
+    gridTiles = cuda::ATenCeilDiv(gridTiles, MAX_GRID_SIZE);
+    gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
+
+    if (gridTiles > MAX_GRID_SIZE) {
+      gridTiles = cuda::ATenCeilDiv(gridTiles, MAX_GRID_SIZE);
+      gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
+    }
+  }
+
+  grid = dim3(gridX, gridY, gridZ);
+  return true;
+}
+
+template <typename scalar_t, bool handleNaN = false>
+struct ThrustGTOp {
+  __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const {
+    return (handleNaN && THCNumerics<scalar_t>::isnan(lhs) &&
+            !THCNumerics<scalar_t>::isnan(rhs)) ||
+        THCNumerics<scalar_t>::gt(lhs, rhs);
+  }
+};
+
+template <typename scalar_t, bool handleNaN = false>
+struct ThrustLTOp {
+  __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const {
+    return (handleNaN && THCNumerics<scalar_t>::isnan(rhs) &&
+            !THCNumerics<scalar_t>::isnan(lhs)) ||
+        THCNumerics<scalar_t>::lt(lhs, rhs);
+  }
+};
+
+template <typename index_t>
+__device__ __forceinline__ index_t getLinearBlockId() {
+  return blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x +
+      blockIdx.x;
+}
+
+// `base` is the base address of a tensor
+// For each slice (defined as a linear point of `out`, from 0 ->
+// (sliceSize - 1) * sliceStride, we fill that slice from `0` to
+// `sliceSize - 1`.
+template <typename index_t, int Dim>
+__global__ void fillSliceWithIndex_kernel(
+    cuda::detail::TensorInfo<int64_t, index_t> out,
+    index_t totalSlices,
+    index_t sliceSize,
+    index_t sliceStride) {
+  index_t slice = getLinearBlockId<index_t>();
+
+  if (slice >= totalSlices) {
+    return;
+  }
+
+  const uint64_t offset =
+      cuda::detail::IndexToOffset<int64_t, index_t, Dim>::get(slice, out);
+  int64_t* base = &out.data[offset];
+
+  for (int64_t i = threadIdx.x; i < sliceSize; i += blockDim.x) {
+    // Torch indices are 1-based (hence the +1)
+    base[i * sliceStride] = i;
+  }
+}
+
+// For slice sorting in Thrust; extracts a slice index from a linear
+// index and uses that for comparison
+struct SliceComp {
+  SliceComp(int64_t size) : sliceSize(size) {}
+
+  __device__ bool operator()(const int64_t& a, const int64_t& b) const {
+    // Since the slices are guaranteed to be innermost,
+    // the segment is just via int64_t division
+    int64_t segA = a / sliceSize;
+    int64_t segB = b / sliceSize;
+    return segA < segB;
+  }
+
+  const int64_t sliceSize;
+};
+
+// For sorting in Thurst; extracts a within-slice index from a linear index
+struct GlobalIndexToPerSliceIndex {
+  GlobalIndexToPerSliceIndex(int64_t size) : sliceSize(size) {}
+
+  __device__ inline void operator()(int64_t& v) const {
+    v = v % sliceSize;
+  }
+
+  const int64_t sliceSize;
+};
+
+// Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks
+static uint64_t nextHighestPowerOf2(uint64_t n) {
+  n--;
+  n |= n >> 1;
+  n |= n >> 2;
+  n |= n >> 4;
+  n |= n >> 8;
+  n |= n >> 16;
+#ifndef _MSC_VER
+  n |= n >> 32;
+#endif
+  n++;
+
+  return n;
+}
+
+
+template <typename scalar_t, typename index_t, typename Launcher>
+void run_launcher(
+    Tensor& values,
+    Tensor& indices,
+    const Tensor& self,
+    int64_t dim,
+    Launcher l) {
+  auto self_info = cuda::detail::getTensorInfo<scalar_t, index_t>(self);
+  auto values_info = cuda::detail::getTensorInfo<scalar_t, index_t>(values);
+  auto indices_info = cuda::detail::getTensorInfo<int64_t, index_t>(indices);
+
+  int64_t slice_size = self.size(dim);
+  /* We use these structures solely to find the offset to */
+  /* each slice we are operating on */
+  self_info.reduceDim(dim);
+  values_info.reduceDim(dim);
+  indices_info.reduceDim(dim);
+
+  /* Collapse all other dims */
+  int collapse_self_dim = self_info.collapseDims(dim);
+  int collapse_values_dim = values_info.collapseDims(dim);
+  int collapse_indices_dim = indices_info.collapseDims(dim);
+
+  int64_t num_slices = 1;
+  for (int i = 0; i < self_info.dims; ++i) {
+    num_slices *= self_info.sizes[i];
+  }
+
+  /* This is used as a template parameter to calculate indices. */
+  /* We only specialize it if all collapsed dim sizes are the */
+  /* same; otherwise, we use -1 which is the specialization */
+  /* parameter for arbitrary dimensions */
+  int all_dims = self_info.dims;
+  if (values_info.dims != all_dims || indices_info.dims != all_dims) {
+    all_dims = -1;
+  }
+
+  if (all_dims == 1) {
+    l.template launch<scalar_t, index_t, 1>(
+        values_info,
+        collapse_values_dim,
+        indices_info,
+        collapse_indices_dim,
+        self_info,
+        collapse_self_dim,
+        num_slices,
+        slice_size);
+  } else if (all_dims == 2) {
+    l.template launch<scalar_t, index_t, 2>(
+        values_info,
+        collapse_values_dim,
+        indices_info,
+        collapse_indices_dim,
+        self_info,
+        collapse_self_dim,
+        num_slices,
+        slice_size);
+  } else if (all_dims == 3) {
+    l.template launch<scalar_t, index_t, 3>(
+        values_info,
+        collapse_values_dim,
+        indices_info,
+        collapse_indices_dim,
+        self_info,
+        collapse_self_dim,
+        num_slices,
+        slice_size);
+  } else {
+    l.template launch<scalar_t, index_t, -1>(
+        values_info,
+        collapse_values_dim,
+        indices_info,
+        collapse_indices_dim,
+        self_info,
+        collapse_self_dim,
+        num_slices,
+        slice_size);
+  }
+}
+
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/cuda/SortingKthValue.cu b/aten/src/ATen/native/cuda/SortingKthValue.cu
new file mode 100644 (file)
index 0000000..d3c488e
--- /dev/null
@@ -0,0 +1,249 @@
+#include <ATen/ATen.h>
+#include <ATen/native/SortingUtils.h>
+#include <assert.h>
+#include <c10/macros/Macros.h>
+#include <stdlib.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <ATen/cuda/detail/TensorInfo.cuh>
+#include <THC/THCDeviceUtils.cuh> // only for THCRoundUp?
+#include <THC/THCNumerics.cuh>
+#include <THC/THCScanUtils.cuh>
+#include <THC/THCTensorMathReduce.cuh> // AddOp
+
+#include <thrust/device_ptr.h>
+#include <thrust/sort.h>
+
+#include <thrust/device_vector.h>
+#include <thrust/execution_policy.h>
+#include <thrust/extrema.h>
+#include <thrust/inner_product.h>
+#include <thrust/sequence.h>
+#include <THC/THCThrustAllocator.cuh>
+#include <ATen/native/cuda/SortingCommon.cuh>
+#include <ATen/native/cuda/SortingRadixSelect.cuh>
+
+namespace at {
+namespace native {
+
+namespace {
+
+
+template <typename scalar_t, typename index_t, int Dim>
+__global__ void gatherKthValue(
+    cuda::detail::TensorInfo<scalar_t, index_t> input,
+    index_t inputSliceSize,
+    index_t k,
+
+    index_t numInputSlices,
+    index_t inputWithinSliceStride,
+
+    cuda::detail::TensorInfo<scalar_t, index_t> kthValue,
+    cuda::detail::TensorInfo<int64_t, index_t> indices) {
+  // Indices are limited to integer fp precision, so counts can fit in
+  // int32, regardless of index_t
+  __shared__ int smem[WARP_SIZE]; // one per each warp, up to warp limit
+
+  index_t slice = getLinearBlockId<index_t>();
+  if (slice >= numInputSlices) {
+    return;
+  }
+
+  // Find the start offset for our slice
+  index_t sliceStartIndex =
+      cuda::detail::IndexToOffset<scalar_t, index_t, Dim>::get(slice, input);
+  index_t kthValueSliceStartIndex =
+      cuda::detail::IndexToOffset<scalar_t, index_t, Dim>::get(slice, kthValue);
+  index_t indicesSliceStartIndex =
+      cuda::detail::IndexToOffset<int64_t, index_t, Dim>::get(slice, indices);
+
+  scalar_t* inputSliceStart = &input.data[sliceStartIndex];
+  scalar_t* kthValueSliceStart = &kthValue.data[kthValueSliceStartIndex];
+  int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex];
+
+  // Find the k-th highest element in our input
+  scalar_t kValue = static_cast<scalar_t>(0);
+  radixSelect<
+      scalar_t,
+      typename TopKTypeConfig<scalar_t>::RadixType,
+      index_t,
+      false>(
+      inputSliceStart,
+      k,
+      inputSliceSize,
+      inputWithinSliceStride,
+      smem,
+      &kValue);
+
+  // Find the index of the k-th highest element
+  index_t kValueIndex = 0;
+  bool foundKValue = false;
+
+  for (index_t i = threadIdx.x; i < inputSliceSize; i += blockDim.x) {
+    bool inRange = (i < inputSliceSize);
+    scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
+                         : static_cast<scalar_t>(0);
+    bool isKValue = inRange && (THCNumerics<scalar_t>::eq(v, kValue));
+
+    if (isKValue) {
+      kValueIndex = i;
+      foundKValue = true;
+      break;
+    }
+  }
+
+  if (foundKValue) {
+    kthValueSliceStart[0] = kValue;
+    indicesSliceStart[0] = kValueIndex;
+  }
+}
+
+struct KthValueLauncher {
+  int64_t k;
+
+  KthValueLauncher(int64_t k) : k(k) {}
+
+  template <typename scalar_t, typename index_t, int all_dims>
+  inline void launch(
+      cuda::detail::TensorInfo<scalar_t, index_t> values_info,
+      int collapse_values_dim,
+      cuda::detail::TensorInfo<int64_t, index_t> indices_info,
+      int collapse_indices_dim,
+      cuda::detail::TensorInfo<scalar_t, index_t> self_info,
+      int collapse_self_dim,
+      int64_t num_slices,
+      int64_t slice_size) {
+    dim3 grid;
+    if (!getGridFromTiles(num_slices, grid)) {
+      AT_ERROR("slices are too many");
+    }
+
+    dim3 block(
+        std::min(THCRoundUp(slice_size, (int64_t)WARP_SIZE), (int64_t)1024));
+    auto stream = at::cuda::getCurrentCUDAStream();
+    gatherKthValue<scalar_t, index_t, all_dims><<<grid, block, 0, stream>>>(
+        self_info,
+        slice_size,
+        k,
+        num_slices,
+        /* The actual dimension that the k-selection is running in */
+        /* may have changed from collapseDims() */
+        self_info.strides[collapse_self_dim],
+        values_info,
+        indices_info);
+  }
+};
+
+template <typename scalar_t>
+void kthvalue_cuda_template(
+    Tensor& values,
+    Tensor& indices,
+    const Tensor& self,
+    int64_t k,
+    int64_t dim_,
+    bool keepdim) {
+  int64_t dim = maybe_wrap_dim(dim_, self.dim());
+  int64_t slicesize = self.size(dim);
+  // FIXME: This seems bogus, I only do this because it was the old behaviour.
+  //        The reductions are fine, as long as the axis being reduced along
+  //        isn't of 0 elements (and the output has elements).
+  AT_CHECK(
+      self.numel() > 0,
+      "cannot perform reduction function kthvalue",
+      " on tensor with no elements because the operation does not have an identity");
+  AT_CHECK(k >= 1 && k <= slicesize, "selected number k out of range");
+
+  _reduction_with_indices_allocate_or_resize_output(
+      values, indices, self, dim, keepdim);
+  if (self.dim() == 0 && self.numel() == 1) {
+    values.copy_(self);
+    indices.zero_();
+    return;
+  }
+
+  AT_CHECK(
+      self.dim() <= MAX_TENSORINFO_DIMS,
+      "cannot operate on more than ",
+      MAX_TENSORINFO_DIMS,
+      " dimensions");
+
+  // Based on required index size, run the algorithm with the
+  // appropriate index type
+  if (cuda::detail::canUse32BitIndexMath(self) &&
+      cuda::detail::canUse32BitIndexMath(values) &&
+      cuda::detail::canUse32BitIndexMath(indices)) {
+    run_launcher<scalar_t, uint32_t>(
+        values, indices, self, dim, KthValueLauncher(k));
+  } else {
+    run_launcher<scalar_t, uint64_t>(
+        values, indices, self, dim, KthValueLauncher(k));
+  }
+
+  if (!keepdim) {
+    values.squeeze_(dim);
+    indices.squeeze_(dim);
+  }
+
+  AT_CUDA_CHECK(cudaGetLastError());
+}
+
+// this does not reduce to median with dim beause we don't want to copy twice
+template <typename scalar_t>
+Tensor median_cuda_template(const Tensor& self) {
+  AT_CHECK(self.numel() > 0, "median cannot be called with empty tensor");
+  if (self.dim() == 0 && self.numel() == 1) {
+    return self.clone();
+  }
+  auto self_copy = self.clone().view(-1);
+  auto values = at::empty({1}, self.options());
+  auto indices = at::empty({1}, self.options().dtype(kLong));
+  AT_CHECK(
+      self.dim() <= MAX_TENSORINFO_DIMS,
+      "cannot operate on more than ",
+      MAX_TENSORINFO_DIMS,
+      " dimensions");
+
+  // Based on required index size, run the algorithm with the
+  // appropriate index type
+  if (cuda::detail::canUse32BitIndexMath(self) &&
+      cuda::detail::canUse32BitIndexMath(values) &&
+      cuda::detail::canUse32BitIndexMath(indices)) {
+    run_launcher<scalar_t, uint32_t>(
+        values,
+        indices,
+        self_copy,
+        0,
+        KthValueLauncher((self_copy.size(0) + 1) / 2)); // KthValue is 1-based
+  } else {
+    run_launcher<scalar_t, uint64_t>(
+        values,
+        indices,
+        self_copy,
+        0,
+        KthValueLauncher((self_copy.size(0) + 1) / 2)); // KthValue is 1-based
+  }
+  return values.view({});
+}
+
+} // namespace
+
+std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
+    Tensor& values,
+    Tensor& indices,
+    const Tensor& self,
+    int64_t k,
+    int64_t dim,
+    bool keepdim) {
+  AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "kthvalue", [&] {
+    kthvalue_cuda_template<scalar_t>(values, indices, self, k, dim, keepdim);
+  });
+  return std::forward_as_tuple(values, indices);
+}
+
+Tensor median_cuda(const Tensor& self) {
+  return AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "median", [&] {
+    return median_cuda_template<scalar_t>(self);
+  });
+}
+
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/cuda/SortingRadixSelect.cuh b/aten/src/ATen/native/cuda/SortingRadixSelect.cuh
new file mode 100644 (file)
index 0000000..2c5d81a
--- /dev/null
@@ -0,0 +1,392 @@
+namespace at {
+namespace native {
+
+template <typename scalar_t>
+struct TopKTypeConfig {};
+
+template <>
+struct TopKTypeConfig<float> {
+  typedef uint32_t RadixType;
+
+  // Converts a float to an integer representation with the same
+  // sorting; i.e., for floats f1, f2:
+  // if f1 < f2 then convert(f1) < convert(f2)
+  // We use this to enable radix selection of floating-point values.
+  // This also gives a relative order for NaNs, but that's ok, as they
+  // will all be adjacent
+  static inline __device__ RadixType convert(float v) {
+    RadixType x = __float_as_int(v);
+    RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
+
+    return (x ^ mask);
+  }
+
+  static inline __device__ float deconvert(RadixType v) {
+    RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
+
+    return __int_as_float(v ^ mask);
+  }
+};
+
+template <>
+struct TopKTypeConfig<uint8_t> {
+  typedef uint32_t RadixType;
+
+  static inline __device__ RadixType convert(uint8_t v) {
+    return v;
+  }
+
+  static inline __device__ uint8_t deconvert(RadixType v) {
+    return v;
+  }
+};
+
+template <>
+struct TopKTypeConfig<int8_t> {
+  typedef uint32_t RadixType;
+
+  static inline __device__ RadixType convert(int8_t v) {
+    return 128u + v;
+  }
+
+  static inline __device__ int8_t deconvert(RadixType v) {
+    return v - 128;
+  }
+};
+
+template <>
+struct TopKTypeConfig<int16_t> {
+  typedef uint32_t RadixType;
+
+  static inline __device__ RadixType convert(int16_t v) {
+    assert(sizeof(short) == 2);
+    return 32768u + v;
+  }
+
+  static inline __device__ int16_t deconvert(RadixType v) {
+    return v - 32768;
+  }
+};
+
+template <>
+struct TopKTypeConfig<int32_t> {
+  typedef uint32_t RadixType;
+
+  static inline __device__ RadixType convert(int32_t v) {
+    assert(sizeof(int) == 4);
+    return 2147483648u + v;
+  }
+
+  static inline __device__ int32_t deconvert(RadixType v) {
+    return v - 2147483648u;
+  }
+};
+
+template <>
+struct TopKTypeConfig<int64_t> {
+  typedef uint64_t RadixType;
+
+  static inline __device__ RadixType convert(int64_t v) {
+    assert(sizeof(int64_t) == 8);
+    return 9223372036854775808ull + v;
+  }
+
+  static inline __device__ int64_t deconvert(RadixType v) {
+    return v - 9223372036854775808ull;
+  }
+};
+
+template <>
+struct TopKTypeConfig<double> {
+  typedef uint64_t RadixType;
+
+  static inline __device__ RadixType convert(double v) {
+    RadixType x = __double_as_longlong(v);
+    RadixType mask = -((x >> 63)) | 0x8000000000000000;
+    return (x ^ mask);
+  }
+
+  static inline __device__ double deconvert(RadixType v) {
+    RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
+    return __longlong_as_double(v ^ mask);
+  }
+};
+
+template <>
+struct TopKTypeConfig<at::Half> {
+  typedef uint32_t RadixType;
+
+  static inline __device__ RadixType convert(at::Half v) {
+#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
+    RadixType x = __half_as_ushort(v);
+    RadixType mask = -((x >> 15)) | 0x8000;
+    return (x ^ mask);
+#else
+    assert(false);
+    return 0u;
+#endif
+  }
+
+  static inline __device__ at::Half deconvert(RadixType v) {
+#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
+    RadixType mask = ((v >> 15) - 1) | 0x8000;
+    return __ushort_as_half(v ^ mask);
+#else
+    assert(false);
+    return static_cast<at::Half>(0);
+#endif
+  }
+};
+
+// This function counts the distribution of all input values in a
+// slice we are selecting by radix digit at `radixDigitPos`, but only
+// those that pass the filter `((v & desiredMask) == desired)`.
+// This produces and broadcasts the seen counts for a single block only.
+// `smem` must have at least `RadixSize` elements.
+template <
+    typename scalar_t,
+    typename bitwise_t,
+    typename index_t,
+    typename CountType,
+    int RadixSize,
+    int RadixBits>
+__device__ void countRadixUsingMask(
+    CountType counts[RadixSize],
+    CountType* smem,
+    bitwise_t desired,
+    bitwise_t desiredMask,
+    int radixDigitPos,
+    index_t sliceSize,
+    index_t withinSliceStride,
+    scalar_t* data) {
+  // Clear out per-thread counts from a previous round
+#pragma unroll
+  for (int i = 0; i < RadixSize; ++i) {
+    counts[i] = 0;
+  }
+
+  if (threadIdx.x < RadixSize) {
+    smem[threadIdx.x] = 0;
+  }
+  __syncthreads();
+
+  // Scan over all the data. Upon a read, the warp will accumulate
+  // counts per each digit in the radix using warp voting.
+  for (index_t i = threadIdx.x; i < sliceSize; i += blockDim.x) {
+    bitwise_t val =
+        TopKTypeConfig<scalar_t>::convert(doLdg(&data[i * withinSliceStride]));
+
+    bool hasVal = ((val & desiredMask) == desired);
+    bitwise_t digitInRadix =
+        Bitfield<bitwise_t>::getBitfield(val, radixDigitPos, RadixBits);
+
+#pragma unroll
+    for (uint32_t j = 0; j < RadixSize; ++j) {
+      bool vote = hasVal && (digitInRadix == j);
+#if defined(__HIP_PLATFORM_HCC__)
+      counts[j] += __popcll(WARP_BALLOT(vote));
+#else
+      counts[j] += __popc(WARP_BALLOT(vote, ACTIVE_MASK()));
+#endif
+    }
+  }
+
+  // Now, for each warp, sum values
+  if (getLaneId() == 0) {
+#pragma unroll
+    for (uint32_t i = 0; i < RadixSize; ++i) {
+      atomicAdd(&smem[i], counts[i]);
+    }
+  }
+
+  __syncthreads();
+
+  // For each thread, read in the total counts
+#pragma unroll
+  for (uint32_t i = 0; i < RadixSize; ++i) {
+    counts[i] = smem[i];
+  }
+
+  __syncthreads();
+}
+
+// Over what radix we are selecting values
+constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS)
+constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS
+constexpr int RADIX_MASK = (RADIX_SIZE - 1);
+
+// This finds the unique value `v` that matches the pattern
+// ((v & desired) == desiredMask) in our sorted int format
+template <typename scalar_t, typename bitwise_t, typename index_t>
+__device__ scalar_t findPattern(
+    scalar_t* smem,
+    scalar_t* data,
+    index_t sliceSize,
+    index_t withinSliceStride,
+    bitwise_t desired,
+    bitwise_t desiredMask) {
+  if (threadIdx.x < WARP_SIZE) {
+    smem[threadIdx.x] = static_cast<scalar_t>(0);
+  }
+  __syncthreads();
+
+  // All threads participate in the loop, in order to sync on the flag
+  index_t numIterations =
+      THCRoundUp(sliceSize, static_cast<index_t>(blockDim.x));
+  for (index_t i = threadIdx.x; i < numIterations; i += blockDim.x) {
+    bool inRange = (i < sliceSize);
+    scalar_t v = inRange ? doLdg(&data[i * withinSliceStride])
+                         : static_cast<scalar_t>(0);
+
+    if (inRange &&
+        ((TopKTypeConfig<scalar_t>::convert(v) & desiredMask) == desired)) {
+      // There should not be conflicts if we are using findPattern,
+      // since the result is unique
+      smem[0] = static_cast<scalar_t>(1);
+      smem[1] = v; // can't use val as the flag, since it could be 0
+    }
+
+    __syncthreads();
+
+    scalar_t found = smem[0];
+    scalar_t val = smem[1];
+
+    __syncthreads();
+
+    // Check to see if a thread found the value
+    if (THCNumerics<scalar_t>::ne(found, static_cast<scalar_t>(0))) {
+      // all threads return this value
+      return val;
+    }
+  }
+
+  // should not get here
+  assert(false);
+  return static_cast<scalar_t>(0);
+}
+
+// Returns the top-Kth element found in the data using radix selection
+template <typename scalar_t, typename bitwise_t, typename index_t, bool Order>
+__device__ void radixSelect(
+    scalar_t* data,
+    index_t k,
+    index_t sliceSize,
+    index_t withinSliceStride,
+    int* smem,
+    scalar_t* topK) {
+  // Per-thread buckets into which we accumulate digit counts in our
+  // radix
+  int counts[RADIX_SIZE];
+
+  // We only consider elements x such that (x & desiredMask) == desired
+  // Initially, we consider all elements of the array, so the above
+  // statement is true regardless of input.
+  bitwise_t desired = 0;
+  bitwise_t desiredMask = 0;
+
+  // We are looking for the top kToFind-th element when iterating over
+  // digits; this count gets reduced by elimination when counting
+  // successive digits
+  int kToFind = k;
+
+  // We start at the most significant digit in our radix, scanning
+  // through to the least significant digit
+#pragma unroll
+  for (int digitPos = sizeof(scalar_t) * 8 - RADIX_BITS; digitPos >= 0;
+       digitPos -= RADIX_BITS) {
+    // Count radix distribution for the current position and reduce
+    // across all threads
+    countRadixUsingMask<
+        scalar_t,
+        bitwise_t,
+        index_t,
+        int,
+        RADIX_SIZE,
+        RADIX_BITS>(
+        counts,
+        smem,
+        desired,
+        desiredMask,
+        digitPos,
+        sliceSize,
+        withinSliceStride,
+        data);
+
+    auto found_unique = [&](int i, int count) -> bool {
+      /* All threads have the same value in counts here, so all */
+      /* threads will return from the function. */
+      if (count == 1 && kToFind == 1) {
+        /* There is a unique answer. */
+        desired =
+            Bitfield<bitwise_t>::setBitfield(desired, i, digitPos, RADIX_BITS);
+        desiredMask = Bitfield<bitwise_t>::setBitfield(
+            desiredMask, RADIX_MASK, digitPos, RADIX_BITS);
+
+        /* The answer is now the unique element v such that: */
+        /* (v & desiredMask) == desired */
+        /* However, we do not yet know what the actual element is. We */
+        /* need to perform a search through the data to find the */
+        /* element that matches this pattern. */
+        *topK = findPattern<scalar_t, bitwise_t, index_t>(
+            (scalar_t*)smem,
+            data,
+            sliceSize,
+            withinSliceStride,
+            desired,
+            desiredMask);
+        return true;
+      }
+      return false;
+    };
+    auto found_non_unique = [&](int i, int count) -> bool {
+      if (count >= kToFind) {
+        desired =
+            Bitfield<bitwise_t>::setBitfield(desired, i, digitPos, RADIX_BITS);
+        desiredMask = Bitfield<bitwise_t>::setBitfield(
+            desiredMask, RADIX_MASK, digitPos, RADIX_BITS);
+
+        /* The top-Kth element v must now be one such that: */
+        /* (v & desiredMask == desired) */
+        /* but we haven't narrowed it down; we must check the next */
+        /* least-significant digit */
+        return true;
+      }
+      kToFind -= count;
+      return false; // continue the loop
+    };
+
+    // All threads participate in the comparisons below to know the
+    // final result
+    if (Order) {
+      // Process in descending order
+#pragma unroll
+      for (int i = RADIX_SIZE - 1; i >= 0; --i) {
+        int count = counts[i];
+        if (found_unique(i, count)) {
+          return;
+        }
+        if (found_non_unique(i, count)) {
+          break;
+        }
+      }
+    } else {
+      // Process in ascending order
+#pragma unroll
+      for (int i = 0; i < RADIX_SIZE; ++i) {
+        int count = counts[i];
+        if (found_unique(i, count)) {
+          return;
+        }
+        if (found_non_unique(i, count)) {
+          break;
+        }
+      }
+    }
+  } // end digitPos for
+
+  // There is no unique result, but there is a non-unique result
+  // matching `desired` exactly
+  *topK = TopKTypeConfig<scalar_t>::deconvert(desired);
+}
+} // namespace native
+} // namespace at
index 11d50a0..bb6f8cc 100644 (file)
   variants: function, method
 
 - func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) ->(Tensor(a!) values, Tensor(b!) indices)
+  dispatch:
+    CPU: kthvalue_out_cpu
+    CUDA: kthvalue_out_cuda
 
 - func: layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor
   matches_jit_signature: True
index 46016e9..afad1c5 100644 (file)
@@ -2,7 +2,6 @@
 # These are skipped by test_cuda.py
 torch.ByteTensor.dist
 torch.ByteTensor.dot
-torch.ByteTensor.kthvalue
 torch.ByteTensor.lerp
 torch.ByteTensor.lerp_
 torch.ByteTensor.mean
@@ -13,7 +12,6 @@ torch.ByteTensor.std
 torch.ByteTensor.var
 torch.CharTensor.dist
 torch.CharTensor.dot
-torch.CharTensor.kthvalue
 torch.CharTensor.lerp
 torch.CharTensor.lerp_
 torch.CharTensor.mean
@@ -22,8 +20,6 @@ torch.CharTensor.renorm
 torch.CharTensor.renorm_
 torch.CharTensor.std
 torch.CharTensor.var
-torch.DoubleTensor.kthvalue
-torch.FloatTensor.kthvalue
 torch.HalfTensor.chunk_
 torch.HalfTensor.clone_
 torch.HalfTensor.contiguous_
@@ -47,7 +43,6 @@ torch.HalfTensor.inverse_
 torch.HalfTensor.is_contiguous_
 torch.HalfTensor.is_same_size_
 torch.HalfTensor.is_set_to_
-torch.HalfTensor.kthvalue
 torch.HalfTensor.kthvalue_
 torch.HalfTensor.max_
 torch.HalfTensor.mean_
@@ -87,7 +82,6 @@ torch.HalfTensor.zeros
 torch.HalfTensor.zeros_
 torch.IntTensor.dist
 torch.IntTensor.dot
-torch.IntTensor.kthvalue
 torch.IntTensor.lerp
 torch.IntTensor.lerp_
 torch.IntTensor.mean
@@ -98,7 +92,6 @@ torch.IntTensor.std
 torch.IntTensor.var
 torch.LongTensor.dist
 torch.LongTensor.dot
-torch.LongTensor.kthvalue
 torch.LongTensor.lerp
 torch.LongTensor.lerp_
 torch.LongTensor.mean
@@ -109,7 +102,6 @@ torch.LongTensor.std
 torch.LongTensor.var
 torch.ShortTensor.dist
 torch.ShortTensor.dot
-torch.ShortTensor.kthvalue
 torch.ShortTensor.lerp
 torch.ShortTensor.lerp_
 torch.ShortTensor.mean