From 8db44eda013d71ebc5837e45d5ed16b00993a0d2 Mon Sep 17 00:00:00 2001 From: Erik Brinkman Date: Thu, 20 Dec 2018 09:35:08 -0800 Subject: [PATCH] Add support for batched pdist (#12302) Summary: This updates pdist to work for batched inputs, and updates the documentation to reflect issues raised. closes #9406 Pull Request resolved: https://github.com/pytorch/pytorch/pull/12302 Reviewed By: ezyang Differential Revision: D13528485 Pulled By: erikbrinkman fbshipit-source-id: 63d93a6e1cc95b483fb58e9ff021758b341cd4de --- aten/src/ATen/native/Distance.cpp | 40 ++++++++---- aten/src/ATen/native/cpu/DistanceOpsKernel.cpp | 77 +++++++++++++++------- aten/src/ATen/native/cuda/DistanceKernel.cu | 88 ++++++++++++++------------ test/common_utils.py | 14 ++++ test/test_nn.py | 6 +- test/test_torch.py | 20 ++---- torch/nn/functional.py | 33 ++++++---- 7 files changed, 172 insertions(+), 106 deletions(-) diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp index b5bf0a8..8f03d79 100644 --- a/aten/src/ATen/native/Distance.cpp +++ b/aten/src/ATen/native/Distance.cpp @@ -14,9 +14,12 @@ Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double ep } // This is to guarantee that the contiguous memory is passed to the backward pass +// TODO This currently enforces that the entire array is contiguous, but the +// batches don't really have to be for efficiency sake, meaning there may be a +// better way to only force the last two dimensions as contiguous. Tensor pdist(const Tensor& self, const double p) { - AT_CHECK(self.dim() == 2, - "pdist only supports 2D tensors, got: ", self.dim(), "D"); + AT_CHECK(self.dim() >= 2, + "pdist only supports at least 2D tensors, got: ", self.dim(), "D"); AT_CHECK(at::isFloatingType(self.type().scalarType()), "pdist only supports floating-point dtypes"); AT_CHECK(p >= 0, "pdist only supports non-negative p values"); return at::_pdist_forward(self.contiguous(), p); @@ -26,17 +29,23 @@ Tensor _pdist_forward(const Tensor& self, const double p) { AT_CHECK(self.is_contiguous(), "_pdist_forward requires contiguous input"); auto device = self.type().device_type(); AT_CHECK(device == kCPU || device == kCUDA, "_pdist_forward only supports CPU and CUDA devices, got: ", device); - Tensor result = at::empty({0}, self.options()); - if (self.size(0) <= 1) { - result.resize_({0}); - } else { - int64_t n = self.size(0); - int64_t c = n * (n - 1) / 2; - result.resize_({c}); - if (self.size(1) == 0) { + + const auto batches = self.sizes().slice(0, self.dim() - 2); + int64_t b = at::tensor(batches).prod().item(); + int64_t n = self.size(-2); + int64_t m = self.size(-1); + int64_t c = n * (n - 1) / 2; + + std::vector result_sizes(batches.begin(), batches.end()); + result_sizes.push_back(c); + Tensor result = at::empty(result_sizes, self.options()); + + if (n > 1) { + if (m == 0) { result.fill_(0); } else { - pdist_forward_stub(device, result, self, p); + Tensor result_view = result.view({b, c}); + pdist_forward_stub(device, result_view, self.view({b, n, m}), p); } } return result; @@ -48,7 +57,14 @@ Tensor _pdist_backward(const Tensor& grad, const Tensor& self, const double p, c auto device = self.type().device_type(); AT_CHECK(device == kCPU || device == kCUDA, "_pdist_backward only supports CPU and CUDA devices, got: ", device); Tensor result = at::empty_like(self); - pdist_backward_stub(device, result, grad, self, p, pdist); + + int64_t b = at::tensor(self.sizes().slice(0, self.dim() - 2)).prod().item(); + int64_t n = self.size(-2); + int64_t m = self.size(-1); + int64_t c = pdist.size(-1); + + Tensor result_view = result.view({b, n, m}); + pdist_backward_stub(device, result_view, grad.contiguous().view({b, c}), self.view({b, n, m}), p, pdist.view({b, c})); return result; } diff --git a/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp b/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp index 620537f..3808719 100644 --- a/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp @@ -95,27 +95,30 @@ struct PDist { template static void run_parallel(Tensor& result, const Tensor& self, const scalar_t p) { const scalar_t * const self_start = self.data(); - const scalar_t * const self_end = self_start + self.numel(); - int64_t n = self.size(0); - int64_t m = self.size(1); + int64_t b = self.size(0); + int64_t n = self.size(1); + int64_t m = self.size(2); scalar_t * const res_start = result.data(); - int64_t combs = result.numel(); // n * (n - 1) / 2 + int64_t combs = n * (n - 1) / 2; const Vec pvec(p); // We conceptually iterate over tuples of (i, j, k) where i is the first // vector from the input, j is the second, and k is the result index. This // parallelizes over the range of k and infers what i and j are from the // value of k. - parallel_for(0, combs, internal::GRAIN_SIZE / (16 * m), [=, &pvec](int64_t k, int64_t end) { + parallel_for(0, combs * b, internal::GRAIN_SIZE / (16 * m), [=, &pvec](int64_t start, int64_t end) { + int64_t l = start / combs; + int64_t k = start % combs; float n2 = n - .5; // The -1 accounts for floating point truncation issues int64_t i = static_cast((n2 - std::sqrt(n2 * n2 - 2 * k - 1))); int64_t j = k - n * i + i * (i + 1) / 2 + i + 1; - const scalar_t * self_i = self_start + i * m; - const scalar_t * self_j = self_start + j * m; - scalar_t * res = res_start + k; + const scalar_t * self_i = self_start + (l * n + i) * m; + const scalar_t * self_j = self_start + (l * n + j) * m; + const scalar_t * self_end = self_start + (l + 1) * n * m; + scalar_t * res = res_start + start; const scalar_t * const res_end = res_start + end; while (res != res_end) { @@ -127,6 +130,10 @@ struct PDist { self_j += m; if (self_j == self_end) { self_i += m; + if (self_i + m == self_end) { + self_i += m; + self_end += n * m; + } self_j = self_i + m; } } @@ -148,8 +155,9 @@ struct PDist { } } + // This does a backward pass down a Vec column of the input template - inline static void backward_down_column(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int64_t gs, int64_t count = Vec::size()) { + inline static void backward_down_column(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int count) { for (const scalar_t * const self_end = self_i + m * n; self_i != self_end - m; self_i += m, res_i += m) { const Vec self_vec_i = Vec::loadu(self_i, count); @@ -157,7 +165,7 @@ struct PDist { const scalar_t * self_j = self_i + m; scalar_t * res_j = res_i + m; - for (; self_j != self_end; self_j += m, res_j += m, grad_k += gs, dist_k += 1) { + for (; self_j != self_end; self_j += m, res_j += m, grad_k += 1, dist_k += 1) { const Vec self_vec_j = Vec::loadu(self_j, count); Vec res_vec_j = Vec::loadu(res_j, count); @@ -174,9 +182,11 @@ struct PDist { template static void run_backward_parallel(Tensor& result, const Tensor & grad, const Tensor & self, const scalar_t p, const Tensor& dist) { - const int64_t n = self.size(0); - const int64_t m = self.size(1); - const int64_t gs = grad.stride(0); + const int64_t b = self.size(0); + const int64_t n = self.size(1); + const int64_t m = self.size(2); + const int64_t combs = dist.size(1); + const int64_t remainder = m % Vec::size(); const Vec pvec(p); const scalar_t * const grad_start = grad.data(); @@ -187,18 +197,39 @@ struct PDist { // The only way to parallelize and avoid locking requires parallelizing // over the columns of the input, i.e. we compute the gradient for the // first section of each vector independentaly of the second section, etc. - at::parallel_for(0, m / Vec::size(), internal::GRAIN_SIZE / (8 * n * n), [=, &pvec](int64_t l, int64_t end) { - const scalar_t * self_l = self_start + l * Vec::size(); - scalar_t * res_l = res_start + l * Vec::size(); + int64_t mv = (m + Vec::size() - 1) / Vec::size(); // number of Vecs in a row rounded up + at::parallel_for(0, b * mv, internal::GRAIN_SIZE / (8 * n * n), [=, &pvec](int64_t start, int64_t end) { + const int64_t l = start / mv; + const int64_t v = start % mv; + + const scalar_t * self_l = self_start + l * n * m; + const scalar_t * self_v = self_l + v * Vec::size(); + + const scalar_t * dist_l = dist_start + l * combs; + const scalar_t * grad_l = grad_start + l * combs; + + scalar_t * res_l = res_start + l * n * m; + scalar_t * res_v = res_l + v * Vec::size(); + + while (start != end) { + backward_down_column(self_v, res_v, grad_l, dist_l, pvec, n, m, std::min(int(m - (self_v - self_l)), Vec::size())); - for (const scalar_t * const res_end = res_start + end * Vec::size(); res_l != res_end; self_l += Vec::size(), res_l += Vec::size()) { - backward_down_column(self_l, res_l, grad_start, dist_start, pvec, n, m, gs); + start += 1; + self_v += Vec::size(); + res_v += Vec::size(); + if (self_v == self_l + mv * Vec::size()) { + // Reached the end of the row + self_l += n * m; + self_v = self_l; + + res_l += n * m; + res_v = res_l; + + dist_l += combs; + grad_l += combs; + } } }); - const int64_t remainder = m % Vec::size(); - if (remainder) { - backward_down_column(self_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, pvec, n, m, gs, remainder); - } } // Assumes self is nonempty, contiguous, and 2D and dist is also contiguous @@ -220,7 +251,7 @@ struct PDist { }; -void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, const double p) { +static void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, const double p) { AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist", [&] { PDist::apply(result, self, p); }); diff --git a/aten/src/ATen/native/cuda/DistanceKernel.cu b/aten/src/ATen/native/cuda/DistanceKernel.cu index a5fecf8..4b2555a 100644 --- a/aten/src/ATen/native/cuda/DistanceKernel.cu +++ b/aten/src/ATen/native/cuda/DistanceKernel.cu @@ -79,18 +79,19 @@ struct dists { template __global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t * self, const int64_t n, const int64_t m, const scalar_t p) { - const int k = blockIdx.x; - const int stride = blockDim.x; + const int l = blockIdx.x; + const int k = blockIdx.y; + const int stride = blockDim.y; float n2 = n - .5; // The -1 accounts for floating point truncation issues int64_t i = static_cast((n2 - device_sqrt(n2 * n2 - 2 * k - 1))); int64_t j = k - n * i + i * (i + 1) / 2 + i + 1; - const scalar_t * const start = self + i * m; + const scalar_t * const start = self + (l * n + i) * m; const scalar_t * const end = start + m; - const scalar_t * a = start + threadIdx.x; - const scalar_t * b = self + j * m + threadIdx.x; + const scalar_t * a = start + threadIdx.y; + const scalar_t * b = self + (l * n + j) * m + threadIdx.y; scalar_t agg = 0.0; for (; a < end; a += stride, b += stride) { F::inc(agg, std::abs(*a - *b), p); @@ -105,29 +106,32 @@ __global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t // This shared memory is significantly larger than necessary, but the // assumption is that it's not a bottleneck, and this is simple __shared__ scalar_t shared[forward_threads]; - int lane = threadIdx.x % warpSize; - int warp_id = threadIdx.x / warpSize; + int lane = threadIdx.y % warpSize; + int warp_id = threadIdx.y / warpSize; if (lane == 0) { shared[warp_id] = agg; } __syncthreads(); - agg = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0; + agg = (threadIdx.y < blockDim.y / warpSize) ? shared[lane] : 0.0; if (warp_id == 0) { // Only reduce theads with nonzero data - for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) { + for (int offset = blockDim.y / warpSize / 2; offset > 0; offset /= 2) { F::agg(agg, WARP_SHFL_DOWN(agg, offset)); } } - if (threadIdx.x == 0) { - result[k] = F::finish(agg, p); + if (threadIdx.y == 0) { + const int64_t combs = n * (n - 1) / 2; + result[l * combs + k] = F::finish(agg, p); } } template -__global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * self, const scalar_t * dist, int64_t gs, const int64_t n, const int64_t m, const int64_t combs, const scalar_t p) { - const int k = blockIdx.y * blockDim.y + threadIdx.y; - const int init = blockIdx.x * blockDim.x + threadIdx.x; - const int stride = blockDim.x * gridDim.x; +__global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * self, const scalar_t * dist, const int64_t n, const int64_t m, const scalar_t p) { + const int l = blockIdx.x; + const int k = blockIdx.z * blockDim.z + threadIdx.z; + const int init = blockIdx.y * blockDim.y + threadIdx.y; + const int stride = blockDim.y * gridDim.y; + const int combs = n * (n - 1) / 2; if (k >= combs) { return; @@ -140,15 +144,15 @@ __global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const int64_t ib = j - i - 1; int64_t jb = n - 2 - i; - const scalar_t grad_k = grad[k * gs]; - const scalar_t dist_k = dist[k]; + const scalar_t grad_k = grad[l * combs + k]; + const scalar_t dist_k = dist[l * combs + k]; - const scalar_t * const start = self + i * m; + const scalar_t * const start = self + (l * n + i) * m; const scalar_t * const end = start + m; const scalar_t * self_i = start + init; - const scalar_t * self_j = self + j * m + init; - scalar_t * buff_i = buffer + (ib * n + i) * m + init; - scalar_t * buff_j = buffer + (jb * n + j) * m + init; + const scalar_t * self_j = self + (l * n + j) * m + init; + scalar_t * buff_i = buffer + ((l * (n - 1) + ib) * n + i) * m + init; + scalar_t * buff_j = buffer + ((l * (n - 1) + jb) * n + j) * m + init; for (; self_i < end; self_i += stride, self_j += stride, buff_i += stride, buff_j += stride) { const scalar_t res = F::backward(*self_i - *self_j, grad_k, dist_k, p); *buff_i = res; @@ -157,10 +161,11 @@ __global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const } void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) { - const dim3 grid(result.numel()); - const dim3 block(forward_threads); - int64_t n = self.size(0); - int64_t m = self.size(1); + int64_t b = self.size(0); + int64_t n = self.size(1); + int64_t m = self.size(2); + const dim3 grid(b, result.size(1)); + const dim3 block(1, forward_threads); AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist_cuda", [&] { if (p == 0.0) { @@ -183,31 +188,32 @@ void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor return; } - const int64_t n = result.size(0); - int64_t m = self.size(1); - const int block_x = 64; - const int block_y = 4; - const int grid_x = (m + block_x * 8 - 1) / (block_x * 8); - const int grid_y = (dist.numel() + block_y - 1) / block_y; - const dim3 grid(grid_x, grid_y); - const dim3 block(block_x, block_y); - - Tensor buffer = at::empty({n - 1, result.size(0), result.size(1)}, result.options()); + const int64_t b = self.size(0); + const int64_t n = self.size(1); + const int64_t m = self.size(2); + const int block_y = 64; + const int block_z = 4; + const int grid_y = (m + block_y * 8 - 1) / (block_y * 8); + const int grid_z = (dist.numel() + block_z - 1) / block_z; + const dim3 grid(b, grid_y, grid_z); + const dim3 block(1, block_y, block_z); + + Tensor buffer = at::empty({b, n - 1, n, m}, result.options()); AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist_cuda_backward", [&] { if (p == 1.0) { - pdist_backward_kernel_cuda_impl::one><<>>(buffer.data(), grad.data(), self.data(), dist.data(), grad.stride(0), n, m, dist.numel(), p); + pdist_backward_kernel_cuda_impl::one><<>>(buffer.data(), grad.data(), self.data(), dist.data(), n, m, p); } else if (p < 2.0) { - pdist_backward_kernel_cuda_impl::lt_two><<>>(buffer.data(), grad.data(), self.data(), dist.data(), grad.stride(0), n, m, dist.numel(), p); + pdist_backward_kernel_cuda_impl::lt_two><<>>(buffer.data(), grad.data(), self.data(), dist.data(), n, m, p); } else if (p == 2.0) { - pdist_backward_kernel_cuda_impl::two><<>>(buffer.data(), grad.data(), self.data(), dist.data(), grad.stride(0), n, m, dist.numel(), p); + pdist_backward_kernel_cuda_impl::two><<>>(buffer.data(), grad.data(), self.data(), dist.data(), n, m, p); } else if (std::isinf(p)) { - pdist_backward_kernel_cuda_impl::inf><<>>(buffer.data(), grad.data(), self.data(), dist.data(), grad.stride(0), n, m, dist.numel(), p); + pdist_backward_kernel_cuda_impl::inf><<>>(buffer.data(), grad.data(), self.data(), dist.data(), n, m, p); } else { - pdist_backward_kernel_cuda_impl::p><<>>(buffer.data(), grad.data(), self.data(), dist.data(), grad.stride(0), n, m, dist.numel(), p); + pdist_backward_kernel_cuda_impl::p><<>>(buffer.data(), grad.data(), self.data(), dist.data(), n, m, p); } }); - at::sum_out(result, buffer, 0); + at::sum_out(result, buffer, 1); } } // anonymous namespace diff --git a/test/common_utils.py b/test/common_utils.py index ba4dbb5..871e870 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -725,6 +725,20 @@ def random_fullrank_matrix_distinct_singular_value(l, *batches, **kwargs): return torch.stack(all_matrices).reshape(*(batches + (l, l))) +def brute_pdist(inp, p=2): + """Computes the same as torch.pdist using primitives""" + n = inp.shape[-2] + k = n * (n - 1) // 2 + if k == 0: + # torch complains about empty indices + return torch.empty(inp.shape[:-2] + (0,), device=inp.device) + square = torch.norm(inp[..., None, :] - inp[..., None, :, :], p=p, dim=-1) + unroll = square.view(square.shape[:-2] + (n * n,)) + inds = torch.ones(k, dtype=torch.int) + inds[torch.arange(n - 1, 1, -1, dtype=torch.int).cumsum(0)] += torch.arange(2, n, dtype=torch.int) + return unroll[..., inds.cumsum(0)] + + def do_test_dtypes(self, dtypes, layout, device): for dtype in dtypes: if dtype != torch.float16: diff --git a/test/test_nn.py b/test/test_nn.py index de31d73..d62742f 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -5294,10 +5294,10 @@ class TestNN(NNTestCase): @skipIfRocm def test_pdist(self): - for device, trans in itertools.product(device_(), [False, True]): - inp = torch.randn(4, 5, dtype=torch.double, device=device, requires_grad=True) + for device, trans, shape in itertools.product(device_(), [False, True], [(4, 5), (2, 3, 4)]): + inp = torch.randn(shape, dtype=torch.double, device=device, requires_grad=True) if trans: - inp = inp.transpose(0, 1) + inp = inp.transpose(-2, -1) for p in [0, 1, 2, 0.5, 1.5, 2.5, float('inf')]: self.assertTrue(gradcheck(lambda x: F.pdist(x, p), (inp,))) diff --git a/test/test_torch.py b/test/test_torch.py index d2718b5..0386ba2 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -25,7 +25,7 @@ from torch import multiprocessing as mp from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \ TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \ IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, skipIfRocm, do_test_dtypes, do_test_empty_full, \ - IS_SANDCASTLE, load_tests + IS_SANDCASTLE, load_tests, brute_pdist from multiprocessing.reduction import ForkingPickler # load_tests from common_utils is used to automatically filter tests for @@ -1124,27 +1124,19 @@ class _TestTorchMixin(object): x = torch.randn(shape, device=device) self.assertEqual(torch.zeros(3, device=device), torch.pdist(x)) - @unittest.skipIf(not TEST_SCIPY, "Scipy not found") - def test_pdist_scipy(self): - from scipy.spatial.distance import pdist + def test_pdist_norm(self): devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] for device in devices: - for shape in [(4, 5), (3, 2), (2, 1)]: + for shape in [(4, 5), (3, 2), (2, 1), (2, 3, 4)]: for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: for trans in [False, True]: x = torch.randn(shape, device=device) if trans: - x.transpose_(0, 1) + x.transpose_(-2, -1) actual = torch.pdist(x, p=p) - # pdist doesn't handle 0 or inf norm properly - if p == 0: - expected = pdist(x.cpu(), 'hamming') * x.shape[1] - elif p == float('inf'): - expected = pdist(x.cpu(), lambda a, b: np.abs(a - b).max()) - else: - expected = pdist(x.cpu(), 'minkowski', p=p) + expected = brute_pdist(x, p=p) self.assertEqual(expected.shape, actual.shape) - self.assertTrue(np.allclose(expected, actual.cpu().numpy())) + self.assertTrue(torch.allclose(expected, actual)) @unittest.skipIf(not TEST_SCIPY, "Scipy not found") def test_logsumexp(self): diff --git a/torch/nn/functional.py b/torch/nn/functional.py index f56ca3c..975839f 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2729,21 +2729,28 @@ pdist = _add_docstr(torch.pdist, r""" pdist(input, p=2) -> Tensor Computes the p-norm distance between every pair of row vectors in the input. -This is identical to the upper triangular portion, excluding the diagonal, of -`torch.norm(input[:, None] - input, dim=2, p=p)`. This function will be faster -if the rows are contiguous. - -If input has shape :math:`N \times M` then the output will have shape -:math:`\frac{1}{2} N (N - 1)`. - -This function is equivalent to `scipy.spatial.distance.pdist(input, -'minkowski', p=p)` if :math:`p \in (0, \infty)`. When :math:`p = 0` it is -equivalent to `scipy.spatial.distance.pdist(input, 'hamming') * M`. -When :math:`p = \infty`, the closest scipy function is -`scipy.spatial.distance.pdist(xn, lambda x, y: np.abs(x - y).max())`. +If the input tensor has shape ... B x N x M, the result will be tensor with +shape ... B x N * (N - 1) / 2. Every dimension prior to the last two are +treated as independent batches of N vectors, each with M elements. If we use +ordinal numbers to refer to the vectors in a batch, the last dimension of the +output will be ordered as: + +``` +[dist(1, 2), dist(1, 2), ..., dist(1, N), dist(2, 3), ..., dist(N-1, N)] +``` + +The square verion of pdist that has redundant distances and the diagonal can be +be computed with +`torch.norm(inpup[..., None, :] - input[..., None, :, :], p=p, dim=-1)`. +Appropriately selecting and flattening the upper triangular of the +last two dimensions will produce identical results as pdist. + +This function is similar to +`scipy.spatial.distance.pdist(input, 'minkowski', p=p)` +if :math:`p \in (0, \infty)`. Args: - input: input tensor of shape :math:`N \times M`. + input: input tensor of shape :math:`... \times N \times M`. p: p value for the p-norm distance to calculate between each vector pair :math:`\in [0, \infty]`. """) -- 2.7.4