Add support for batched pdist (#12302)
authorErik Brinkman <erk@fb.com>
Thu, 20 Dec 2018 17:35:08 +0000 (09:35 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 17:41:08 +0000 (09:41 -0800)
Summary:
This updates pdist to work for batched inputs, and updates the
documentation to reflect issues raised.

closes #9406
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12302

Reviewed By: ezyang

Differential Revision: D13528485

Pulled By: erikbrinkman

fbshipit-source-id: 63d93a6e1cc95b483fb58e9ff021758b341cd4de

aten/src/ATen/native/Distance.cpp
aten/src/ATen/native/cpu/DistanceOpsKernel.cpp
aten/src/ATen/native/cuda/DistanceKernel.cu
test/common_utils.py
test/test_nn.py
test/test_torch.py
torch/nn/functional.py

index b5bf0a8..8f03d79 100644 (file)
@@ -14,9 +14,12 @@ Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double ep
 }
 
 // This is to guarantee that the contiguous memory is passed to the backward pass
+// TODO This currently enforces that the entire array is contiguous, but the
+// batches don't really have to be for efficiency sake, meaning there may be a
+// better way to only force the last two dimensions as contiguous.
 Tensor pdist(const Tensor& self, const double p) {
-  AT_CHECK(self.dim() == 2,
-      "pdist only supports 2D tensors, got: ", self.dim(), "D");
+  AT_CHECK(self.dim() >= 2,
+      "pdist only supports at least 2D tensors, got: ", self.dim(), "D");
   AT_CHECK(at::isFloatingType(self.type().scalarType()), "pdist only supports floating-point dtypes");
   AT_CHECK(p >= 0, "pdist only supports non-negative p values");
   return at::_pdist_forward(self.contiguous(), p);
@@ -26,17 +29,23 @@ Tensor _pdist_forward(const Tensor& self, const double p) {
   AT_CHECK(self.is_contiguous(), "_pdist_forward requires contiguous input");
   auto device = self.type().device_type();
   AT_CHECK(device == kCPU || device == kCUDA, "_pdist_forward only supports CPU and CUDA devices, got: ", device);
-  Tensor result = at::empty({0}, self.options());
-  if (self.size(0) <= 1) {
-    result.resize_({0});
-  } else {
-    int64_t n = self.size(0);
-    int64_t c = n * (n - 1) / 2;
-    result.resize_({c});
-    if (self.size(1) == 0) {
+
+  const auto batches = self.sizes().slice(0, self.dim() - 2);
+  int64_t b = at::tensor(batches).prod().item<int64_t>();
+  int64_t n = self.size(-2);
+  int64_t m = self.size(-1);
+  int64_t c = n * (n - 1) / 2;
+
+  std::vector<int64_t> result_sizes(batches.begin(), batches.end());
+  result_sizes.push_back(c);
+  Tensor result = at::empty(result_sizes, self.options());
+
+  if (n > 1) {
+    if (m == 0) {
       result.fill_(0);
     } else {
-      pdist_forward_stub(device, result, self, p);
+      Tensor result_view = result.view({b, c});
+      pdist_forward_stub(device, result_view, self.view({b, n, m}), p);
     }
   }
   return result;
@@ -48,7 +57,14 @@ Tensor _pdist_backward(const Tensor& grad, const Tensor& self, const double p, c
   auto device = self.type().device_type();
   AT_CHECK(device == kCPU || device == kCUDA, "_pdist_backward only supports CPU and CUDA devices, got: ", device);
   Tensor result = at::empty_like(self);
-  pdist_backward_stub(device, result, grad, self, p, pdist);
+
+  int64_t b = at::tensor(self.sizes().slice(0, self.dim() - 2)).prod().item<int64_t>();
+  int64_t n = self.size(-2);
+  int64_t m = self.size(-1);
+  int64_t c = pdist.size(-1);
+
+  Tensor result_view = result.view({b, n, m});
+  pdist_backward_stub(device, result_view, grad.contiguous().view({b, c}), self.view({b, n, m}), p, pdist.view({b, c}));
   return result;
 }
 
index 620537f..3808719 100644 (file)
@@ -95,27 +95,30 @@ struct PDist {
   template <typename F>
   static void run_parallel(Tensor& result, const Tensor& self, const scalar_t p) {
     const scalar_t * const self_start = self.data<scalar_t>();
-    const scalar_t * const self_end = self_start + self.numel();
-    int64_t n = self.size(0);
-    int64_t m = self.size(1);
+    int64_t b = self.size(0);
+    int64_t n = self.size(1);
+    int64_t m = self.size(2);
 
     scalar_t * const res_start = result.data<scalar_t>();
-    int64_t combs = result.numel(); // n * (n - 1) / 2
+    int64_t combs = n * (n - 1) / 2;
     const Vec pvec(p);
 
     // We conceptually iterate over tuples of (i, j, k) where i is the first
     // vector from the input, j is the second, and k is the result index. This
     // parallelizes over the range of k and infers what i and j are from the
     // value of k.
-    parallel_for(0, combs, internal::GRAIN_SIZE / (16 * m), [=, &pvec](int64_t k, int64_t end) {
+    parallel_for(0, combs * b, internal::GRAIN_SIZE / (16 * m), [=, &pvec](int64_t start, int64_t end) {
+      int64_t l = start / combs;
+      int64_t k = start % combs;
       float n2 = n - .5;
       // The -1 accounts for floating point truncation issues
       int64_t i = static_cast<int64_t>((n2 - std::sqrt(n2 * n2 - 2 * k - 1)));
       int64_t j = k - n * i + i * (i + 1) / 2 + i + 1;
 
-      const scalar_t * self_i = self_start + i * m;
-      const scalar_t * self_j = self_start + j * m;
-      scalar_t * res = res_start + k;
+      const scalar_t * self_i = self_start + (l * n + i) * m;
+      const scalar_t * self_j = self_start + (l * n + j) * m;
+      const scalar_t * self_end = self_start + (l + 1) * n * m;
+      scalar_t * res = res_start + start;
       const scalar_t * const res_end = res_start + end;
 
       while (res != res_end) {
@@ -127,6 +130,10 @@ struct PDist {
         self_j += m;
         if (self_j == self_end) {
           self_i += m;
+          if (self_i + m == self_end) {
+            self_i += m;
+            self_end += n * m;
+          }
           self_j = self_i + m;
         }
       }
@@ -148,8 +155,9 @@ struct PDist {
     }
   }
 
+  // This does a backward pass down a Vec column of the input
   template <typename F>
-  inline static void backward_down_column(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int64_t gs, int64_t count = Vec::size()) {
+  inline static void backward_down_column(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int count) {
     for (const scalar_t * const self_end = self_i + m * n; self_i != self_end - m; self_i += m, res_i += m) {
 
       const Vec self_vec_i = Vec::loadu(self_i, count);
@@ -157,7 +165,7 @@ struct PDist {
 
       const scalar_t * self_j = self_i + m;
       scalar_t * res_j = res_i + m;
-      for (; self_j != self_end; self_j += m, res_j += m, grad_k += gs, dist_k += 1) {
+      for (; self_j != self_end; self_j += m, res_j += m, grad_k += 1, dist_k += 1) {
         const Vec self_vec_j = Vec::loadu(self_j, count);
         Vec res_vec_j = Vec::loadu(res_j, count);
 
@@ -174,9 +182,11 @@ struct PDist {
 
   template <typename F>
   static void run_backward_parallel(Tensor& result, const Tensor & grad, const Tensor & self, const scalar_t p, const Tensor& dist) {
-    const int64_t n = self.size(0);
-    const int64_t m = self.size(1);
-    const int64_t gs = grad.stride(0);
+    const int64_t b = self.size(0);
+    const int64_t n = self.size(1);
+    const int64_t m = self.size(2);
+    const int64_t combs = dist.size(1);
+    const int64_t remainder = m % Vec::size();
     const Vec pvec(p);
 
     const scalar_t * const grad_start = grad.data<scalar_t>();
@@ -187,18 +197,39 @@ struct PDist {
     // The only way to parallelize and avoid locking requires parallelizing
     // over the columns of the input, i.e. we compute the gradient for the
     // first section of each vector independentaly of the second section, etc.
-    at::parallel_for(0, m / Vec::size(), internal::GRAIN_SIZE / (8 * n * n), [=, &pvec](int64_t l, int64_t end) {
-      const scalar_t * self_l = self_start + l * Vec::size();
-      scalar_t * res_l = res_start + l * Vec::size();
+    int64_t mv = (m + Vec::size() - 1) / Vec::size(); // number of Vecs in a row rounded up
+    at::parallel_for(0, b * mv, internal::GRAIN_SIZE / (8 * n * n), [=, &pvec](int64_t start, int64_t end) {
+      const int64_t l = start / mv;
+      const int64_t v = start % mv;
+
+      const scalar_t * self_l = self_start + l * n * m;
+      const scalar_t * self_v = self_l + v * Vec::size();
+
+      const scalar_t * dist_l = dist_start + l * combs;
+      const scalar_t * grad_l = grad_start + l * combs;
+
+      scalar_t * res_l = res_start + l * n * m;
+      scalar_t * res_v = res_l + v * Vec::size();
+
+      while (start != end) {
+        backward_down_column<F>(self_v, res_v, grad_l, dist_l, pvec, n, m, std::min(int(m - (self_v - self_l)), Vec::size()));
 
-      for (const scalar_t * const res_end = res_start + end * Vec::size(); res_l != res_end; self_l += Vec::size(), res_l += Vec::size()) {
-        backward_down_column<F>(self_l, res_l, grad_start, dist_start, pvec, n, m, gs);
+        start += 1;
+        self_v += Vec::size();
+        res_v += Vec::size();
+        if (self_v == self_l + mv * Vec::size()) {
+          // Reached the end of the row
+          self_l += n * m;
+          self_v = self_l;
+
+          res_l += n * m;
+          res_v = res_l;
+
+          dist_l += combs;
+          grad_l += combs;
+        }
       }
     });
-    const int64_t remainder = m % Vec::size();
-    if (remainder) {
-      backward_down_column<F>(self_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, pvec, n, m, gs, remainder);
-    }
   }
 
   // Assumes self is nonempty, contiguous, and 2D and dist is also contiguous
@@ -220,7 +251,7 @@ struct PDist {
 
 };
 
-void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, const double p) {
+static void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, const double p) {
   AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist", [&] {
     PDist<scalar_t>::apply(result, self, p);
   });
index a5fecf8..4b2555a 100644 (file)
@@ -79,18 +79,19 @@ struct dists {
 
 template <typename scalar_t, typename F>
 __global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t * self, const int64_t n, const int64_t m, const scalar_t p) {
-  const int k = blockIdx.x;
-  const int stride = blockDim.x;
+  const int l = blockIdx.x;
+  const int k = blockIdx.y;
+  const int stride = blockDim.y;
 
   float n2 = n - .5;
   // The -1 accounts for floating point truncation issues
   int64_t i = static_cast<int64_t>((n2 - device_sqrt<scalar_t>(n2 * n2 - 2 * k - 1)));
   int64_t j = k - n * i + i * (i + 1) / 2 + i + 1;
 
-  const scalar_t * const start = self + i * m;
+  const scalar_t * const start = self + (l * n + i) * m;
   const scalar_t * const end = start + m;
-  const scalar_t * a = start + threadIdx.x;
-  const scalar_t * b = self + j * m + threadIdx.x;
+  const scalar_t * a = start + threadIdx.y;
+  const scalar_t * b = self + (l * n + j) * m + threadIdx.y;
   scalar_t agg = 0.0;
   for (; a < end; a += stride, b += stride) {
     F::inc(agg, std::abs(*a - *b), p);
@@ -105,29 +106,32 @@ __global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t
   // This shared memory is significantly larger than necessary, but the
   // assumption is that it's not a bottleneck, and this is simple
   __shared__ scalar_t shared[forward_threads];
-  int lane = threadIdx.x % warpSize;
-  int warp_id = threadIdx.x / warpSize;
+  int lane = threadIdx.y % warpSize;
+  int warp_id = threadIdx.y / warpSize;
   if (lane == 0) {
     shared[warp_id] = agg;
   }
   __syncthreads();
-  agg = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0;
+  agg = (threadIdx.y < blockDim.y / warpSize) ? shared[lane] : 0.0;
   if (warp_id == 0) {
     // Only reduce theads with nonzero data
-    for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) {
+    for (int offset = blockDim.y / warpSize / 2; offset > 0; offset /= 2) {
       F::agg(agg, WARP_SHFL_DOWN(agg, offset));
     }
   }
-  if (threadIdx.x == 0) {
-    result[k] = F::finish(agg, p);
+  if (threadIdx.y == 0) {
+    const int64_t combs = n * (n - 1) / 2;
+    result[l * combs + k] = F::finish(agg, p);
   }
 }
 
 template <typename scalar_t, typename F>
-__global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * self, const scalar_t * dist, int64_t gs, const int64_t n, const int64_t m, const int64_t combs, const scalar_t p) {
-  const int k = blockIdx.y * blockDim.y + threadIdx.y;
-  const int init = blockIdx.x * blockDim.x + threadIdx.x;
-  const int stride = blockDim.x * gridDim.x;
+__global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * self, const scalar_t * dist, const int64_t n, const int64_t m, const scalar_t p) {
+  const int l = blockIdx.x;
+  const int k = blockIdx.z * blockDim.z + threadIdx.z;
+  const int init = blockIdx.y * blockDim.y + threadIdx.y;
+  const int stride = blockDim.y * gridDim.y;
+  const int combs = n * (n - 1) / 2;
 
   if (k >= combs) {
     return;
@@ -140,15 +144,15 @@ __global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const
   int64_t ib = j - i - 1;
   int64_t jb = n - 2 - i;
 
-  const scalar_t grad_k = grad[k * gs];
-  const scalar_t dist_k = dist[k];
+  const scalar_t grad_k = grad[l * combs + k];
+  const scalar_t dist_k = dist[l * combs + k];
 
-  const scalar_t * const start = self + i * m;
+  const scalar_t * const start = self + (l * n + i) * m;
   const scalar_t * const end = start + m;
   const scalar_t * self_i = start + init;
-  const scalar_t * self_j = self + j * m + init;
-  scalar_t * buff_i = buffer + (ib * n + i) * m + init;
-  scalar_t * buff_j = buffer + (jb * n + j) * m + init;
+  const scalar_t * self_j = self + (l * n + j) * m + init;
+  scalar_t * buff_i = buffer + ((l * (n - 1) + ib) * n + i) * m + init;
+  scalar_t * buff_j = buffer + ((l * (n - 1) + jb) * n + j) * m + init;
   for (; self_i < end; self_i += stride, self_j += stride, buff_i += stride, buff_j += stride) {
     const scalar_t res = F::backward(*self_i - *self_j, grad_k, dist_k, p);
     *buff_i = res;
@@ -157,10 +161,11 @@ __global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const
 }
 
 void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) {
-  const dim3 grid(result.numel());
-  const dim3 block(forward_threads);
-  int64_t n = self.size(0);
-  int64_t m = self.size(1);
+  int64_t b = self.size(0);
+  int64_t n = self.size(1);
+  int64_t m = self.size(2);
+  const dim3 grid(b, result.size(1));
+  const dim3 block(1, forward_threads);
 
   AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist_cuda", [&] {
     if (p == 0.0) {
@@ -183,31 +188,32 @@ void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor
     return;
   }
 
-  const int64_t n = result.size(0);
-  int64_t m = self.size(1);
-  const int block_x = 64;
-  const int block_y = 4;
-  const int grid_x = (m + block_x * 8 - 1) / (block_x * 8);
-  const int grid_y = (dist.numel() + block_y - 1) / block_y;
-  const dim3 grid(grid_x, grid_y);
-  const dim3 block(block_x, block_y);
-
-  Tensor buffer = at::empty({n - 1, result.size(0), result.size(1)}, result.options());
+  const int64_t b = self.size(0);
+  const int64_t n = self.size(1);
+  const int64_t m = self.size(2);
+  const int block_y = 64;
+  const int block_z = 4;
+  const int grid_y = (m + block_y * 8 - 1) / (block_y * 8);
+  const int grid_z = (dist.numel() + block_z - 1) / block_z;
+  const dim3 grid(b, grid_y, grid_z);
+  const dim3 block(1, block_y, block_z);
+
+  Tensor buffer = at::empty({b, n - 1, n, m}, result.options());
   AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist_cuda_backward", [&] {
     if (p == 1.0) {
-      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p);
+      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), n, m, p);
     } else if (p < 2.0) {
-      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::lt_two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p);
+      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::lt_two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), n, m, p);
     } else if (p == 2.0) {
-      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p);
+      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), n, m, p);
     } else if (std::isinf(p)) {
-      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p);
+      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), n, m, p);
     } else {
-      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p);
+      pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), n, m, p);
     }
   });
 
-  at::sum_out(result, buffer, 0);
+  at::sum_out(result, buffer, 1);
 }
 
 } // anonymous namespace
index ba4dbb5..871e870 100644 (file)
@@ -725,6 +725,20 @@ def random_fullrank_matrix_distinct_singular_value(l, *batches, **kwargs):
         return torch.stack(all_matrices).reshape(*(batches + (l, l)))
 
 
+def brute_pdist(inp, p=2):
+    """Computes the same as torch.pdist using primitives"""
+    n = inp.shape[-2]
+    k = n * (n - 1) // 2
+    if k == 0:
+        # torch complains about empty indices
+        return torch.empty(inp.shape[:-2] + (0,), device=inp.device)
+    square = torch.norm(inp[..., None, :] - inp[..., None, :, :], p=p, dim=-1)
+    unroll = square.view(square.shape[:-2] + (n * n,))
+    inds = torch.ones(k, dtype=torch.int)
+    inds[torch.arange(n - 1, 1, -1, dtype=torch.int).cumsum(0)] += torch.arange(2, n, dtype=torch.int)
+    return unroll[..., inds.cumsum(0)]
+
+
 def do_test_dtypes(self, dtypes, layout, device):
     for dtype in dtypes:
         if dtype != torch.float16:
index de31d73..d62742f 100644 (file)
@@ -5294,10 +5294,10 @@ class TestNN(NNTestCase):
 
     @skipIfRocm
     def test_pdist(self):
-        for device, trans in itertools.product(device_(), [False, True]):
-            inp = torch.randn(4, 5, dtype=torch.double, device=device, requires_grad=True)
+        for device, trans, shape in itertools.product(device_(), [False, True], [(4, 5), (2, 3, 4)]):
+            inp = torch.randn(shape, dtype=torch.double, device=device, requires_grad=True)
             if trans:
-                inp = inp.transpose(0, 1)
+                inp = inp.transpose(-2, -1)
             for p in [0, 1, 2, 0.5, 1.5, 2.5, float('inf')]:
                 self.assertTrue(gradcheck(lambda x: F.pdist(x, p), (inp,)))
 
index d2718b5..0386ba2 100644 (file)
@@ -25,7 +25,7 @@ from torch import multiprocessing as mp
 from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
     TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
     IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, skipIfRocm, do_test_dtypes, do_test_empty_full, \
-    IS_SANDCASTLE, load_tests
+    IS_SANDCASTLE, load_tests, brute_pdist
 from multiprocessing.reduction import ForkingPickler
 
 # load_tests from common_utils is used to automatically filter tests for
@@ -1124,27 +1124,19 @@ class _TestTorchMixin(object):
             x = torch.randn(shape, device=device)
             self.assertEqual(torch.zeros(3, device=device), torch.pdist(x))
 
-    @unittest.skipIf(not TEST_SCIPY, "Scipy not found")
-    def test_pdist_scipy(self):
-        from scipy.spatial.distance import pdist
+    def test_pdist_norm(self):
         devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
         for device in devices:
-            for shape in [(4, 5), (3, 2), (2, 1)]:
+            for shape in [(4, 5), (3, 2), (2, 1), (2, 3, 4)]:
                 for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
                     for trans in [False, True]:
                         x = torch.randn(shape, device=device)
                         if trans:
-                            x.transpose_(0, 1)
+                            x.transpose_(-2, -1)
                         actual = torch.pdist(x, p=p)
-                        # pdist doesn't handle 0 or inf norm properly
-                        if p == 0:
-                            expected = pdist(x.cpu(), 'hamming') * x.shape[1]
-                        elif p == float('inf'):
-                            expected = pdist(x.cpu(), lambda a, b: np.abs(a - b).max())
-                        else:
-                            expected = pdist(x.cpu(), 'minkowski', p=p)
+                        expected = brute_pdist(x, p=p)
                         self.assertEqual(expected.shape, actual.shape)
-                        self.assertTrue(np.allclose(expected, actual.cpu().numpy()))
+                        self.assertTrue(torch.allclose(expected, actual))
 
     @unittest.skipIf(not TEST_SCIPY, "Scipy not found")
     def test_logsumexp(self):
index f56ca3c..975839f 100644 (file)
@@ -2729,21 +2729,28 @@ pdist = _add_docstr(torch.pdist, r"""
 pdist(input, p=2) -> Tensor
 
 Computes the p-norm distance between every pair of row vectors in the input.
-This is identical to the upper triangular portion, excluding the diagonal, of
-`torch.norm(input[:, None] - input, dim=2, p=p)`. This function will be faster
-if the rows are contiguous.
-
-If input has shape :math:`N \times M` then the output will have shape
-:math:`\frac{1}{2} N (N - 1)`.
-
-This function is equivalent to `scipy.spatial.distance.pdist(input,
-'minkowski', p=p)` if :math:`p \in (0, \infty)`. When :math:`p = 0` it is
-equivalent to `scipy.spatial.distance.pdist(input, 'hamming') * M`.
-When :math:`p = \infty`, the closest scipy function is
-`scipy.spatial.distance.pdist(xn, lambda x, y: np.abs(x - y).max())`.
+If the input tensor has shape ... B x N x M, the result will be tensor with
+shape ... B x N * (N - 1) / 2. Every dimension prior to the last two are
+treated as independent batches of N vectors, each with M elements. If we use
+ordinal numbers to refer to the vectors in a batch, the last dimension of the
+output will be ordered as:
+
+```
+[dist(1, 2), dist(1, 2), ..., dist(1, N), dist(2, 3), ..., dist(N-1, N)]
+```
+
+The square verion of pdist that has redundant distances and the diagonal can be
+be computed with
+`torch.norm(inpup[..., None, :] - input[..., None, :, :], p=p, dim=-1)`.
+Appropriately selecting and flattening the upper triangular of the
+last two dimensions will produce identical results as pdist.
+
+This function is similar to
+`scipy.spatial.distance.pdist(input, 'minkowski', p=p)`
+if :math:`p \in (0, \infty)`.
 
 Args:
-    input: input tensor of shape :math:`N \times M`.
+    input: input tensor of shape :math:`... \times N \times M`.
     p: p value for the p-norm distance to calculate between each vector pair
         :math:`\in [0, \infty]`.
 """)