Revert batched pdist, improve existing kernel, add test (#15901)
authorGregory Chanan <gchanan@fb.com>
Thu, 17 Jan 2019 18:12:47 +0000 (10:12 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 17 Jan 2019 18:44:43 +0000 (10:44 -0800)
Summary:
1) Reverts https://github.com/pytorch/pytorch/pull/12302 which added support for batched pdist. Except I kept the (non-batched) test improvements that came with that PR, because they are nice to have.  Motivation: https://github.com/pytorch/pytorch/issues/15511
2) For the non-batched pdist, improved the existing kernel by forcing fp64 math and properly checking cuda launch errors
3) Added a 'large tensor' test that at least on my machine, fails on the batch pdist implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15901

Reviewed By: ezyang

Differential Revision: D13616730

Pulled By: gchanan

fbshipit-source-id: 620d3f9b9acd492dc131bad9d2ff618d69fc2954

aten/src/ATen/native/Distance.cpp
aten/src/ATen/native/cpu/DistanceOpsKernel.cpp
aten/src/ATen/native/cuda/DistanceKernel.cu
test/common_utils.py
test/test_autograd.py
test/test_nn.py
test/test_torch.py
torch/nn/functional.py

index 8f03d79..b5bf0a8 100644 (file)
@@ -14,12 +14,9 @@ Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double ep
 }
 
 // This is to guarantee that the contiguous memory is passed to the backward pass
-// TODO This currently enforces that the entire array is contiguous, but the
-// batches don't really have to be for efficiency sake, meaning there may be a
-// better way to only force the last two dimensions as contiguous.
 Tensor pdist(const Tensor& self, const double p) {
-  AT_CHECK(self.dim() >= 2,
-      "pdist only supports at least 2D tensors, got: ", self.dim(), "D");
+  AT_CHECK(self.dim() == 2,
+      "pdist only supports 2D tensors, got: ", self.dim(), "D");
   AT_CHECK(at::isFloatingType(self.type().scalarType()), "pdist only supports floating-point dtypes");
   AT_CHECK(p >= 0, "pdist only supports non-negative p values");
   return at::_pdist_forward(self.contiguous(), p);
@@ -29,23 +26,17 @@ Tensor _pdist_forward(const Tensor& self, const double p) {
   AT_CHECK(self.is_contiguous(), "_pdist_forward requires contiguous input");
   auto device = self.type().device_type();
   AT_CHECK(device == kCPU || device == kCUDA, "_pdist_forward only supports CPU and CUDA devices, got: ", device);
-
-  const auto batches = self.sizes().slice(0, self.dim() - 2);
-  int64_t b = at::tensor(batches).prod().item<int64_t>();
-  int64_t n = self.size(-2);
-  int64_t m = self.size(-1);
-  int64_t c = n * (n - 1) / 2;
-
-  std::vector<int64_t> result_sizes(batches.begin(), batches.end());
-  result_sizes.push_back(c);
-  Tensor result = at::empty(result_sizes, self.options());
-
-  if (n > 1) {
-    if (m == 0) {
+  Tensor result = at::empty({0}, self.options());
+  if (self.size(0) <= 1) {
+    result.resize_({0});
+  } else {
+    int64_t n = self.size(0);
+    int64_t c = n * (n - 1) / 2;
+    result.resize_({c});
+    if (self.size(1) == 0) {
       result.fill_(0);
     } else {
-      Tensor result_view = result.view({b, c});
-      pdist_forward_stub(device, result_view, self.view({b, n, m}), p);
+      pdist_forward_stub(device, result, self, p);
     }
   }
   return result;
@@ -57,14 +48,7 @@ Tensor _pdist_backward(const Tensor& grad, const Tensor& self, const double p, c
   auto device = self.type().device_type();
   AT_CHECK(device == kCPU || device == kCUDA, "_pdist_backward only supports CPU and CUDA devices, got: ", device);
   Tensor result = at::empty_like(self);
-
-  int64_t b = at::tensor(self.sizes().slice(0, self.dim() - 2)).prod().item<int64_t>();
-  int64_t n = self.size(-2);
-  int64_t m = self.size(-1);
-  int64_t c = pdist.size(-1);
-
-  Tensor result_view = result.view({b, n, m});
-  pdist_backward_stub(device, result_view, grad.contiguous().view({b, c}), self.view({b, n, m}), p, pdist.view({b, c}));
+  pdist_backward_stub(device, result, grad, self, p, pdist);
   return result;
 }
 
index 3808719..620537f 100644 (file)
@@ -95,30 +95,27 @@ struct PDist {
   template <typename F>
   static void run_parallel(Tensor& result, const Tensor& self, const scalar_t p) {
     const scalar_t * const self_start = self.data<scalar_t>();
-    int64_t b = self.size(0);
-    int64_t n = self.size(1);
-    int64_t m = self.size(2);
+    const scalar_t * const self_end = self_start + self.numel();
+    int64_t n = self.size(0);
+    int64_t m = self.size(1);
 
     scalar_t * const res_start = result.data<scalar_t>();
-    int64_t combs = n * (n - 1) / 2;
+    int64_t combs = result.numel(); // n * (n - 1) / 2
     const Vec pvec(p);
 
     // We conceptually iterate over tuples of (i, j, k) where i is the first
     // vector from the input, j is the second, and k is the result index. This
     // parallelizes over the range of k and infers what i and j are from the
     // value of k.
-    parallel_for(0, combs * b, internal::GRAIN_SIZE / (16 * m), [=, &pvec](int64_t start, int64_t end) {
-      int64_t l = start / combs;
-      int64_t k = start % combs;
+    parallel_for(0, combs, internal::GRAIN_SIZE / (16 * m), [=, &pvec](int64_t k, int64_t end) {
       float n2 = n - .5;
       // The -1 accounts for floating point truncation issues
       int64_t i = static_cast<int64_t>((n2 - std::sqrt(n2 * n2 - 2 * k - 1)));
       int64_t j = k - n * i + i * (i + 1) / 2 + i + 1;
 
-      const scalar_t * self_i = self_start + (l * n + i) * m;
-      const scalar_t * self_j = self_start + (l * n + j) * m;
-      const scalar_t * self_end = self_start + (l + 1) * n * m;
-      scalar_t * res = res_start + start;
+      const scalar_t * self_i = self_start + i * m;
+      const scalar_t * self_j = self_start + j * m;
+      scalar_t * res = res_start + k;
       const scalar_t * const res_end = res_start + end;
 
       while (res != res_end) {
@@ -130,10 +127,6 @@ struct PDist {
         self_j += m;
         if (self_j == self_end) {
           self_i += m;
-          if (self_i + m == self_end) {
-            self_i += m;
-            self_end += n * m;
-          }
           self_j = self_i + m;
         }
       }
@@ -155,9 +148,8 @@ struct PDist {
     }
   }
 
-  // This does a backward pass down a Vec column of the input
   template <typename F>
-  inline static void backward_down_column(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int count) {
+  inline static void backward_down_column(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int64_t gs, int64_t count = Vec::size()) {
     for (const scalar_t * const self_end = self_i + m * n; self_i != self_end - m; self_i += m, res_i += m) {
 
       const Vec self_vec_i = Vec::loadu(self_i, count);
@@ -165,7 +157,7 @@ struct PDist {
 
       const scalar_t * self_j = self_i + m;
       scalar_t * res_j = res_i + m;
-      for (; self_j != self_end; self_j += m, res_j += m, grad_k += 1, dist_k += 1) {
+      for (; self_j != self_end; self_j += m, res_j += m, grad_k += gs, dist_k += 1) {
         const Vec self_vec_j = Vec::loadu(self_j, count);
         Vec res_vec_j = Vec::loadu(res_j, count);
 
@@ -182,11 +174,9 @@ struct PDist {
 
   template <typename F>
   static void run_backward_parallel(Tensor& result, const Tensor & grad, const Tensor & self, const scalar_t p, const Tensor& dist) {
-    const int64_t b = self.size(0);
-    const int64_t n = self.size(1);
-    const int64_t m = self.size(2);
-    const int64_t combs = dist.size(1);
-    const int64_t remainder = m % Vec::size();
+    const int64_t n = self.size(0);
+    const int64_t m = self.size(1);
+    const int64_t gs = grad.stride(0);
     const Vec pvec(p);
 
     const scalar_t * const grad_start = grad.data<scalar_t>();
@@ -197,39 +187,18 @@ struct PDist {
     // The only way to parallelize and avoid locking requires parallelizing
     // over the columns of the input, i.e. we compute the gradient for the
     // first section of each vector independentaly of the second section, etc.
-    int64_t mv = (m + Vec::size() - 1) / Vec::size(); // number of Vecs in a row rounded up
-    at::parallel_for(0, b * mv, internal::GRAIN_SIZE / (8 * n * n), [=, &pvec](int64_t start, int64_t end) {
-      const int64_t l = start / mv;
-      const int64_t v = start % mv;
-
-      const scalar_t * self_l = self_start + l * n * m;
-      const scalar_t * self_v = self_l + v * Vec::size();
-
-      const scalar_t * dist_l = dist_start + l * combs;
-      const scalar_t * grad_l = grad_start + l * combs;
-
-      scalar_t * res_l = res_start + l * n * m;
-      scalar_t * res_v = res_l + v * Vec::size();
-
-      while (start != end) {
-        backward_down_column<F>(self_v, res_v, grad_l, dist_l, pvec, n, m, std::min(int(m - (self_v - self_l)), Vec::size()));
+    at::parallel_for(0, m / Vec::size(), internal::GRAIN_SIZE / (8 * n * n), [=, &pvec](int64_t l, int64_t end) {
+      const scalar_t * self_l = self_start + l * Vec::size();
+      scalar_t * res_l = res_start + l * Vec::size();
 
-        start += 1;
-        self_v += Vec::size();
-        res_v += Vec::size();
-        if (self_v == self_l + mv * Vec::size()) {
-          // Reached the end of the row
-          self_l += n * m;
-          self_v = self_l;
-
-          res_l += n * m;
-          res_v = res_l;
-
-          dist_l += combs;
-          grad_l += combs;
-        }
+      for (const scalar_t * const res_end = res_start + end * Vec::size(); res_l != res_end; self_l += Vec::size(), res_l += Vec::size()) {
+        backward_down_column<F>(self_l, res_l, grad_start, dist_start, pvec, n, m, gs);
       }
     });
+    const int64_t remainder = m % Vec::size();
+    if (remainder) {
+      backward_down_column<F>(self_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, pvec, n, m, gs, remainder);
+    }
   }
 
   // Assumes self is nonempty, contiguous, and 2D and dist is also contiguous
@@ -251,7 +220,7 @@ struct PDist {
 
 };
 
-static void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, const double p) {
+void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, const double p) {
   AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist", [&] {
     PDist<scalar_t>::apply(result, self, p);
   });
index 4b2555a..4536b4a 100644 (file)
@@ -1,4 +1,5 @@
 #include <ATen/ATen.h>
+#include <ATen/cuda/Exceptions.h>
 #include <THC/THCTensorMathReduce.cuh>
 #include <math.h>
 
@@ -78,20 +79,19 @@ struct dists {
 };
 
 template <typename scalar_t, typename F>
-__global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t * self, const int64_t n, const int64_t m, const scalar_t p) {
-  const int l = blockIdx.x;
-  const int k = blockIdx.y;
-  const int stride = blockDim.y;
+__global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t * self, const int64_t n, const int64_t m, const scalar_t p,
+                                              const double n2, const double n2_squared_minus_1) {
+  const int k = blockIdx.x;
+  const int stride = blockDim.x;
 
-  float n2 = n - .5;
   // The -1 accounts for floating point truncation issues
-  int64_t i = static_cast<int64_t>((n2 - device_sqrt<scalar_t>(n2 * n2 - 2 * k - 1)));
+  int64_t i = static_cast<int64_t>((n2 - device_sqrt<double>(n2_squared_minus_1 - 2 * k)));
   int64_t j = k - n * i + i * (i + 1) / 2 + i + 1;
 
-  const scalar_t * const start = self + (l * n + i) * m;
+  const scalar_t * const start = self + i * m;
   const scalar_t * const end = start + m;
-  const scalar_t * a = start + threadIdx.y;
-  const scalar_t * b = self + (l * n + j) * m + threadIdx.y;
+  const scalar_t * a = start + threadIdx.x;
+  const scalar_t * b = self + j * m + threadIdx.x;
   scalar_t agg = 0.0;
   for (; a < end; a += stride, b += stride) {
     F::inc(agg, std::abs(*a - *b), p);
@@ -106,53 +106,50 @@ __global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t
   // This shared memory is significantly larger than necessary, but the
   // assumption is that it's not a bottleneck, and this is simple
   __shared__ scalar_t shared[forward_threads];
-  int lane = threadIdx.y % warpSize;
-  int warp_id = threadIdx.y / warpSize;
+  int lane = threadIdx.x % warpSize;
+  int warp_id = threadIdx.x / warpSize;
   if (lane == 0) {
     shared[warp_id] = agg;
   }
   __syncthreads();
-  agg = (threadIdx.y < blockDim.y / warpSize) ? shared[lane] : 0.0;
+  agg = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0;
   if (warp_id == 0) {
     // Only reduce theads with nonzero data
-    for (int offset = blockDim.y / warpSize / 2; offset > 0; offset /= 2) {
+    for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) {
       F::agg(agg, WARP_SHFL_DOWN(agg, offset));
     }
   }
-  if (threadIdx.y == 0) {
-    const int64_t combs = n * (n - 1) / 2;
-    result[l * combs + k] = F::finish(agg, p);
+  if (threadIdx.x == 0) {
+    result[k] = F::finish(agg, p);
   }
 }
 
 template <typename scalar_t, typename F>
-__global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * self, const scalar_t * dist, const int64_t n, const int64_t m, const scalar_t p) {
-  const int l = blockIdx.x;
-  const int k = blockIdx.z * blockDim.z + threadIdx.z;
-  const int init = blockIdx.y * blockDim.y + threadIdx.y;
-  const int stride = blockDim.y * gridDim.y;
-  const int combs = n * (n - 1) / 2;
+__global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * self, const scalar_t * dist, int64_t gs, const int64_t n, const int64_t m, const int64_t combs, const scalar_t p,
+                                                       const double n2, const double n2_squared_minus_1) {
+  const int k = blockIdx.y * blockDim.y + threadIdx.y;
+  const int init = blockIdx.x * blockDim.x + threadIdx.x;
+  const int stride = blockDim.x * gridDim.x;
 
   if (k >= combs) {
     return;
   }
 
-  float n2 = n - .5;
   // The -1 accounts for floating point truncation issues
   int64_t i = static_cast<int64_t>((n2 - device_sqrt<scalar_t>(n2 * n2 - 2 * k - 1)));
   int64_t j = k - n * i + i * (i + 1) / 2 + i + 1;
   int64_t ib = j - i - 1;
   int64_t jb = n - 2 - i;
 
-  const scalar_t grad_k = grad[l * combs + k];
-  const scalar_t dist_k = dist[l * combs + k];
+  const scalar_t grad_k = grad[k * gs];
+  const scalar_t dist_k = dist[k];
 
-  const scalar_t * const start = self + (l * n + i) * m;
+  const scalar_t * const start = self + i * m;
   const scalar_t * const end = start + m;
   const scalar_t * self_i = start + init;
-  const scalar_t * self_j = self + (l * n + j) * m + init;
-  scalar_t * buff_i = buffer + ((l * (n - 1) + ib) * n + i) * m + init;
-  scalar_t * buff_j = buffer + ((l * (n - 1) + jb) * n + j) * m + init;
+  const scalar_t * self_j = self + j * m + init;
+  scalar_t * buff_i = buffer + (ib * n + i) * m + init;
+  scalar_t * buff_j = buffer + (jb * n + j) * m + init;
   for (; self_i < end; self_i += stride, self_j += stride, buff_i += stride, buff_j += stride) {
     const scalar_t res = F::backward(*self_i - *self_j, grad_k, dist_k, p);
     *buff_i = res;
@@ -161,25 +158,29 @@ __global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const
 }
 
 void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) {
-  int64_t b = self.size(0);
-  int64_t n = self.size(1);
-  int64_t m = self.size(2);
-  const dim3 grid(b, result.size(1));
-  const dim3 block(1, forward_threads);
+  const dim3 grid(result.numel());
+  const dim3 block(forward_threads);
+  int64_t n = self.size(0);
+  int64_t m = self.size(1);
+  // https://github.com/pytorch/pytorch/issues/15511 demonstrated we need to do
+  // some math in fp64 -- this is just minimizing the amount of fp64 math we do on the device.
+  const double n2 = n - .5;
+  const double n2_squared_minus_1 = n2 * n2 - 1;
 
   AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist_cuda", [&] {
     if (p == 0.0) {
-      pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::zero><<<grid, block>>>(result.data<scalar_t>(), self.data<scalar_t>(), n, m, p);
+      pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::zero><<<grid, block>>>(result.data<scalar_t>(), self.data<scalar_t>(), n, m, p, n2, n2_squared_minus_1);
     } else if (p == 1.0) {
-      pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(result.data<scalar_t>(), self.data<scalar_t>(), n, m, p);
+      pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(result.data<scalar_t>(), self.data<scalar_t>(), n, m, p, n2, n2_squared_minus_1);
     } else if (p == 2.0) {
-      pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(result.data<scalar_t>(), self.data<scalar_t>(), n, m, p);
+      pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(result.data<scalar_t>(), self.data<scalar_t>(), n, m, p, n2, n2_squared_minus_1);
     } else if (std::isinf(p)) {
-      pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(result.data<scalar_t>(), self.data<scalar_t>(), n, m, p);
+      pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(result.data<scalar_t>(), self.data<scalar_t>(), n, m, p, n2, n2_squared_minus_1);
     } else {
-      pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(result.data<scalar_t>(), self.data<scalar_t>(), n, m, p);
+      pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(result.data<scalar_t>(), self.data<scalar_t>(), n, m, p, n2, n2_squared_minus_1);
     }
   });
+  AT_CUDA_CHECK(cudaGetLastError());
 }
 
 void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) {
@@ -188,32 +189,39 @@ void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor
     return;
   }
 
-  const int64_t b = self.size(0);
-  const int64_t n = self.size(1);
-  const int64_t m = self.size(2);
-  const int block_y = 64;
-  const int block_z = 4;
-  const int grid_y = (m + block_y * 8 - 1) / (block_y * 8);
-  const int grid_z = (dist.numel() + block_z - 1) / block_z;
-  const dim3 grid(b, grid_y, grid_z);
-  const dim3 block(1, block_y, block_z);
-
-  Tensor buffer = at::empty({b, n - 1, n, m}, result.options());
+  const int64_t n = result.size(0);
+  int64_t m = self.size(1);
+  const int block_x = 64;
+  // NB: be careful with changing block_y; as it's currently written, grid_y is limited to be 2^16.
+  // From binary search, block_y of 16 gives us max pdist dim0 of 1449,
+  //                     block_y of  4 gives us max pdist dim0 of  725.
+  const int block_y = 16;
+  const int grid_x = (m + block_x * 8 - 1) / (block_x * 8);
+  const int grid_y = (dist.numel() + block_y - 1) / block_y;
+  const dim3 grid(grid_x, grid_y);
+  const dim3 block(block_x, block_y);
+  // https://github.com/pytorch/pytorch/issues/15511 demonstrated we need to do
+  // some math in fp64 -- this is just minimizing the amount of fp64 math we do on the device.
+  const double n2 = n - .5;
+  const double n2_squared_minus_1 = n2 * n2 - 1;
+
+  Tensor buffer = at::empty({n - 1, result.size(0), result.size(1)}, result.options());
   AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist_cuda_backward", [&] {
     if (p == 1.0) {
-      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), n, m, p);
+      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1);
     } else if (p < 2.0) {
-      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::lt_two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), n, m, p);
+      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::lt_two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1);
     } else if (p == 2.0) {
-      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), n, m, p);
+      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1);
     } else if (std::isinf(p)) {
-      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), n, m, p);
+      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1);
     } else {
-      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), n, m, p);
+      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1);
     }
   });
+  AT_CUDA_CHECK(cudaGetLastError());
 
-  at::sum_out(result, buffer, 1);
+  at::sum_out(result, buffer, 0);
 }
 
 } // anonymous namespace
index e6aa375..d8940fb 100644 (file)
@@ -738,7 +738,7 @@ def brute_pdist(inp, p=2):
     k = n * (n - 1) // 2
     if k == 0:
         # torch complains about empty indices
-        return torch.empty(inp.shape[:-2] + (0,), device=inp.device)
+        return torch.empty(inp.shape[:-2] + (0,), dtype=inp.dtype, device=inp.device)
     square = torch.norm(inp[..., None, :] - inp[..., None, :, :], p=p, dim=-1)
     unroll = square.view(square.shape[:-2] + (n * n,))
     inds = torch.ones(k, dtype=torch.int)
index 3e7879f..4a8d951 100644 (file)
@@ -2226,6 +2226,23 @@ class TestAutograd(TestCase):
         a = torch.arange(1, 13, dtype=torch.double).view(3, 4).requires_grad_()
         gradcheck(lambda a: torch.pow(2, a), (a,))
 
+    # test for backward in https://github.com/pytorch/pytorch/issues/15511
+    @skipIfRocm
+    def test_pdist_large(self):
+        def func(x):
+            return torch.pdist(x, p=2)
+
+        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
+        for device in devices:
+            # shape[0] should be able to be (roughly) arbitrarily large, but the kernel
+            # is currently limited to smaller sizes (see issue above); this is just testing
+            # a floor.
+            shape = (1000, 1)
+            x = torch.randn(shape, device=device).requires_grad_()
+            output = torch.pdist(x, p=2)
+            # just run a single backward, as gradcheck/gradgradcheck is expensive here
+            output.sum().backward()
+
     @skipIfNoLapack
     def test_pinverse(self):
         # Why is pinverse tested this way, and not ordinarily as other linear algebra methods?
index cb98e59..2722635 100644 (file)
@@ -5564,10 +5564,10 @@ class TestNN(NNTestCase):
 
     @skipIfRocm
     def test_pdist(self):
-        for device, trans, shape in itertools.product(device_(), [False, True], [(4, 5), (2, 3, 4)]):
-            inp = torch.randn(shape, dtype=torch.double, device=device, requires_grad=True)
+        for device, trans in itertools.product(device_(), [False, True]):
+            inp = torch.randn(4, 5, dtype=torch.double, device=device, requires_grad=True)
             if trans:
-                inp = inp.transpose(-2, -1)
+                inp = inp.transpose(0, 1)
             for p in [0, 1, 2, 0.5, 1.5, 2.5, float('inf')]:
                 self.assertTrue(gradcheck(lambda x: F.pdist(x, p), (inp,)))
 
@@ -5579,11 +5579,13 @@ class TestNN(NNTestCase):
             for p in [0, 1, 2, 0.5, 1.5, 2.5, float('inf')]:
                 self.assertTrue(gradcheck(lambda x: F.pdist(x, p), (inp,)))
 
+    @skipIfRocm
     def test_pdist_empty_row(self):
         for device in device_():
             inp = torch.randn(1, 3, dtype=torch.double, device=device, requires_grad=True)
             self.assertTrue(gradcheck(F.pdist, (inp,)))
 
+    @skipIfRocm
     def test_pdist_empty_col(self):
         for device in device_():
             inp = torch.randn(4, 0, dtype=torch.double, device=device, requires_grad=True)
@@ -5594,6 +5596,7 @@ class TestNN(NNTestCase):
         inp = torch.randn(4, 5, requires_grad=True)
         gradgradcheck(F.pdist, (inp,))
 
+    @skipIfRocm
     @unittest.expectedFailure
     def test_pdist_cuda_gradgrad_unimplemented(self):
         inp = torch.randn(4, 5, device='cuda', requires_grad=True)
index 4719532..19d903d 100644 (file)
@@ -1169,18 +1169,27 @@ class _TestTorchMixin(object):
             self.assertEqual(torch.zeros(3, device=device), torch.pdist(x))
 
     def test_pdist_norm(self):
+        def test_pdist_single(shape, device, p, dtype, trans):
+            x = torch.randn(shape, dtype=dtype, device=device)
+            if trans:
+                x.transpose_(-2, -1)
+            actual = torch.pdist(x, p=p)
+            expected = brute_pdist(x, p=p)
+            self.assertEqual(expected.shape, actual.shape)
+            self.assertTrue(torch.allclose(expected, actual))
+
         devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
         for device in devices:
-            for shape in [(4, 5), (3, 2), (2, 1), (2, 3, 4)]:
+            for shape in [(4, 5), (3, 2), (2, 1)]:
                 for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
                     for trans in [False, True]:
-                        x = torch.randn(shape, device=device)
-                        if trans:
-                            x.transpose_(-2, -1)
-                        actual = torch.pdist(x, p=p)
-                        expected = brute_pdist(x, p=p)
-                        self.assertEqual(expected.shape, actual.shape)
-                        self.assertTrue(torch.allclose(expected, actual))
+                        for dtype in [torch.float32, torch.float64]:
+                            test_pdist_single(shape, device, p, dtype, trans)
+
+            # do a simplified comparison with big inputs, see:
+            # https://github.com/pytorch/pytorch/issues/15511
+            for dtype in [torch.float32, torch.float64]:
+                test_pdist_single((1000, 2), device, 2, dtype, False)
 
     @unittest.skipIf(not TEST_SCIPY, "Scipy not found")
     def test_logsumexp(self):
index 7c94060..36d50fe 100644 (file)
@@ -2807,28 +2807,21 @@ pdist = _add_docstr(torch.pdist, r"""
 pdist(input, p=2) -> Tensor
 
 Computes the p-norm distance between every pair of row vectors in the input.
-If the input tensor has shape ... B x N x M, the result will be tensor with
-shape ... B x N * (N - 1) / 2. Every dimension prior to the last two are
-treated as independent batches of N vectors, each with M elements. If we use
-ordinal numbers to refer to the vectors in a batch, the last dimension of the
-output will be ordered as:
-
-```
-[dist(1, 2), dist(1, 2), ..., dist(1, N), dist(2, 3), ..., dist(N-1, N)]
-```
-
-The square verion of pdist that has redundant distances and the diagonal can be
-be computed with
-`torch.norm(inpup[..., None, :] - input[..., None, :, :], p=p, dim=-1)`.
-Appropriately selecting and flattening the upper triangular of the
-last two dimensions will produce identical results as pdist.
-
-This function is similar to
-`scipy.spatial.distance.pdist(input, 'minkowski', p=p)`
-if :math:`p \in (0, \infty)`.
+This is identical to the upper triangular portion, excluding the diagonal, of
+`torch.norm(input[:, None] - input, dim=2, p=p)`. This function will be faster
+if the rows are contiguous.
+
+If input has shape :math:`N \times M` then the output will have shape
+:math:`\frac{1}{2} N (N - 1)`.
+
+This function is equivalent to `scipy.spatial.distance.pdist(input,
+'minkowski', p=p)` if :math:`p \in (0, \infty)`. When :math:`p = 0` it is
+equivalent to `scipy.spatial.distance.pdist(input, 'hamming') * M`.
+When :math:`p = \infty`, the closest scipy function is
+`scipy.spatial.distance.pdist(xn, lambda x, y: np.abs(x - y).max())`.
 
 Args:
-    input: input tensor of shape :math:`... \times N \times M`.
+    input: input tensor of shape :math:`N \times M`.
     p: p value for the p-norm distance to calculate between each vector pair
         :math:`\in [0, \infty]`.
 """)