}
// This is to guarantee that the contiguous memory is passed to the backward pass
+// TODO This currently enforces that the entire array is contiguous, but the
+// batches don't really have to be for efficiency sake, meaning there may be a
+// better way to only force the last two dimensions as contiguous.
Tensor pdist(const Tensor& self, const double p) {
- AT_CHECK(self.dim() == 2,
- "pdist only supports 2D tensors, got: ", self.dim(), "D");
+ AT_CHECK(self.dim() >= 2,
+ "pdist only supports at least 2D tensors, got: ", self.dim(), "D");
AT_CHECK(at::isFloatingType(self.type().scalarType()), "pdist only supports floating-point dtypes");
AT_CHECK(p >= 0, "pdist only supports non-negative p values");
return at::_pdist_forward(self.contiguous(), p);
AT_CHECK(self.is_contiguous(), "_pdist_forward requires contiguous input");
auto device = self.type().device_type();
AT_CHECK(device == kCPU || device == kCUDA, "_pdist_forward only supports CPU and CUDA devices, got: ", device);
- Tensor result = at::empty({0}, self.options());
- if (self.size(0) <= 1) {
- result.resize_({0});
- } else {
- int64_t n = self.size(0);
- int64_t c = n * (n - 1) / 2;
- result.resize_({c});
- if (self.size(1) == 0) {
+
+ const auto batches = self.sizes().slice(0, self.dim() - 2);
+ int64_t b = at::tensor(batches).prod().item<int64_t>();
+ int64_t n = self.size(-2);
+ int64_t m = self.size(-1);
+ int64_t c = n * (n - 1) / 2;
+
+ std::vector<int64_t> result_sizes(batches.begin(), batches.end());
+ result_sizes.push_back(c);
+ Tensor result = at::empty(result_sizes, self.options());
+
+ if (n > 1) {
+ if (m == 0) {
result.fill_(0);
} else {
- pdist_forward_stub(device, result, self, p);
+ Tensor result_view = result.view({b, c});
+ pdist_forward_stub(device, result_view, self.view({b, n, m}), p);
}
}
return result;
auto device = self.type().device_type();
AT_CHECK(device == kCPU || device == kCUDA, "_pdist_backward only supports CPU and CUDA devices, got: ", device);
Tensor result = at::empty_like(self);
- pdist_backward_stub(device, result, grad, self, p, pdist);
+
+ int64_t b = at::tensor(self.sizes().slice(0, self.dim() - 2)).prod().item<int64_t>();
+ int64_t n = self.size(-2);
+ int64_t m = self.size(-1);
+ int64_t c = pdist.size(-1);
+
+ Tensor result_view = result.view({b, n, m});
+ pdist_backward_stub(device, result_view, grad.contiguous().view({b, c}), self.view({b, n, m}), p, pdist.view({b, c}));
return result;
}
template <typename F>
static void run_parallel(Tensor& result, const Tensor& self, const scalar_t p) {
const scalar_t * const self_start = self.data<scalar_t>();
- const scalar_t * const self_end = self_start + self.numel();
- int64_t n = self.size(0);
- int64_t m = self.size(1);
+ int64_t b = self.size(0);
+ int64_t n = self.size(1);
+ int64_t m = self.size(2);
scalar_t * const res_start = result.data<scalar_t>();
- int64_t combs = result.numel(); // n * (n - 1) / 2
+ int64_t combs = n * (n - 1) / 2;
const Vec pvec(p);
// We conceptually iterate over tuples of (i, j, k) where i is the first
// vector from the input, j is the second, and k is the result index. This
// parallelizes over the range of k and infers what i and j are from the
// value of k.
- parallel_for(0, combs, internal::GRAIN_SIZE / (16 * m), [=, &pvec](int64_t k, int64_t end) {
+ parallel_for(0, combs * b, internal::GRAIN_SIZE / (16 * m), [=, &pvec](int64_t start, int64_t end) {
+ int64_t l = start / combs;
+ int64_t k = start % combs;
float n2 = n - .5;
// The -1 accounts for floating point truncation issues
int64_t i = static_cast<int64_t>((n2 - std::sqrt(n2 * n2 - 2 * k - 1)));
int64_t j = k - n * i + i * (i + 1) / 2 + i + 1;
- const scalar_t * self_i = self_start + i * m;
- const scalar_t * self_j = self_start + j * m;
- scalar_t * res = res_start + k;
+ const scalar_t * self_i = self_start + (l * n + i) * m;
+ const scalar_t * self_j = self_start + (l * n + j) * m;
+ const scalar_t * self_end = self_start + (l + 1) * n * m;
+ scalar_t * res = res_start + start;
const scalar_t * const res_end = res_start + end;
while (res != res_end) {
self_j += m;
if (self_j == self_end) {
self_i += m;
+ if (self_i + m == self_end) {
+ self_i += m;
+ self_end += n * m;
+ }
self_j = self_i + m;
}
}
}
}
+ // This does a backward pass down a Vec column of the input
template <typename F>
- inline static void backward_down_column(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int64_t gs, int64_t count = Vec::size()) {
+ inline static void backward_down_column(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int count) {
for (const scalar_t * const self_end = self_i + m * n; self_i != self_end - m; self_i += m, res_i += m) {
const Vec self_vec_i = Vec::loadu(self_i, count);
const scalar_t * self_j = self_i + m;
scalar_t * res_j = res_i + m;
- for (; self_j != self_end; self_j += m, res_j += m, grad_k += gs, dist_k += 1) {
+ for (; self_j != self_end; self_j += m, res_j += m, grad_k += 1, dist_k += 1) {
const Vec self_vec_j = Vec::loadu(self_j, count);
Vec res_vec_j = Vec::loadu(res_j, count);
template <typename F>
static void run_backward_parallel(Tensor& result, const Tensor & grad, const Tensor & self, const scalar_t p, const Tensor& dist) {
- const int64_t n = self.size(0);
- const int64_t m = self.size(1);
- const int64_t gs = grad.stride(0);
+ const int64_t b = self.size(0);
+ const int64_t n = self.size(1);
+ const int64_t m = self.size(2);
+ const int64_t combs = dist.size(1);
+ const int64_t remainder = m % Vec::size();
const Vec pvec(p);
const scalar_t * const grad_start = grad.data<scalar_t>();
// The only way to parallelize and avoid locking requires parallelizing
// over the columns of the input, i.e. we compute the gradient for the
// first section of each vector independentaly of the second section, etc.
- at::parallel_for(0, m / Vec::size(), internal::GRAIN_SIZE / (8 * n * n), [=, &pvec](int64_t l, int64_t end) {
- const scalar_t * self_l = self_start + l * Vec::size();
- scalar_t * res_l = res_start + l * Vec::size();
+ int64_t mv = (m + Vec::size() - 1) / Vec::size(); // number of Vecs in a row rounded up
+ at::parallel_for(0, b * mv, internal::GRAIN_SIZE / (8 * n * n), [=, &pvec](int64_t start, int64_t end) {
+ const int64_t l = start / mv;
+ const int64_t v = start % mv;
+
+ const scalar_t * self_l = self_start + l * n * m;
+ const scalar_t * self_v = self_l + v * Vec::size();
+
+ const scalar_t * dist_l = dist_start + l * combs;
+ const scalar_t * grad_l = grad_start + l * combs;
+
+ scalar_t * res_l = res_start + l * n * m;
+ scalar_t * res_v = res_l + v * Vec::size();
+
+ while (start != end) {
+ backward_down_column<F>(self_v, res_v, grad_l, dist_l, pvec, n, m, std::min(int(m - (self_v - self_l)), Vec::size()));
- for (const scalar_t * const res_end = res_start + end * Vec::size(); res_l != res_end; self_l += Vec::size(), res_l += Vec::size()) {
- backward_down_column<F>(self_l, res_l, grad_start, dist_start, pvec, n, m, gs);
+ start += 1;
+ self_v += Vec::size();
+ res_v += Vec::size();
+ if (self_v == self_l + mv * Vec::size()) {
+ // Reached the end of the row
+ self_l += n * m;
+ self_v = self_l;
+
+ res_l += n * m;
+ res_v = res_l;
+
+ dist_l += combs;
+ grad_l += combs;
+ }
}
});
- const int64_t remainder = m % Vec::size();
- if (remainder) {
- backward_down_column<F>(self_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, pvec, n, m, gs, remainder);
- }
}
// Assumes self is nonempty, contiguous, and 2D and dist is also contiguous
};
-void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, const double p) {
+static void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, const double p) {
AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist", [&] {
PDist<scalar_t>::apply(result, self, p);
});
template <typename scalar_t, typename F>
__global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t * self, const int64_t n, const int64_t m, const scalar_t p) {
- const int k = blockIdx.x;
- const int stride = blockDim.x;
+ const int l = blockIdx.x;
+ const int k = blockIdx.y;
+ const int stride = blockDim.y;
float n2 = n - .5;
// The -1 accounts for floating point truncation issues
int64_t i = static_cast<int64_t>((n2 - device_sqrt<scalar_t>(n2 * n2 - 2 * k - 1)));
int64_t j = k - n * i + i * (i + 1) / 2 + i + 1;
- const scalar_t * const start = self + i * m;
+ const scalar_t * const start = self + (l * n + i) * m;
const scalar_t * const end = start + m;
- const scalar_t * a = start + threadIdx.x;
- const scalar_t * b = self + j * m + threadIdx.x;
+ const scalar_t * a = start + threadIdx.y;
+ const scalar_t * b = self + (l * n + j) * m + threadIdx.y;
scalar_t agg = 0.0;
for (; a < end; a += stride, b += stride) {
F::inc(agg, std::abs(*a - *b), p);
// This shared memory is significantly larger than necessary, but the
// assumption is that it's not a bottleneck, and this is simple
__shared__ scalar_t shared[forward_threads];
- int lane = threadIdx.x % warpSize;
- int warp_id = threadIdx.x / warpSize;
+ int lane = threadIdx.y % warpSize;
+ int warp_id = threadIdx.y / warpSize;
if (lane == 0) {
shared[warp_id] = agg;
}
__syncthreads();
- agg = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0;
+ agg = (threadIdx.y < blockDim.y / warpSize) ? shared[lane] : 0.0;
if (warp_id == 0) {
// Only reduce theads with nonzero data
- for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) {
+ for (int offset = blockDim.y / warpSize / 2; offset > 0; offset /= 2) {
F::agg(agg, WARP_SHFL_DOWN(agg, offset));
}
}
- if (threadIdx.x == 0) {
- result[k] = F::finish(agg, p);
+ if (threadIdx.y == 0) {
+ const int64_t combs = n * (n - 1) / 2;
+ result[l * combs + k] = F::finish(agg, p);
}
}
template <typename scalar_t, typename F>
-__global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * self, const scalar_t * dist, int64_t gs, const int64_t n, const int64_t m, const int64_t combs, const scalar_t p) {
- const int k = blockIdx.y * blockDim.y + threadIdx.y;
- const int init = blockIdx.x * blockDim.x + threadIdx.x;
- const int stride = blockDim.x * gridDim.x;
+__global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * self, const scalar_t * dist, const int64_t n, const int64_t m, const scalar_t p) {
+ const int l = blockIdx.x;
+ const int k = blockIdx.z * blockDim.z + threadIdx.z;
+ const int init = blockIdx.y * blockDim.y + threadIdx.y;
+ const int stride = blockDim.y * gridDim.y;
+ const int combs = n * (n - 1) / 2;
if (k >= combs) {
return;
int64_t ib = j - i - 1;
int64_t jb = n - 2 - i;
- const scalar_t grad_k = grad[k * gs];
- const scalar_t dist_k = dist[k];
+ const scalar_t grad_k = grad[l * combs + k];
+ const scalar_t dist_k = dist[l * combs + k];
- const scalar_t * const start = self + i * m;
+ const scalar_t * const start = self + (l * n + i) * m;
const scalar_t * const end = start + m;
const scalar_t * self_i = start + init;
- const scalar_t * self_j = self + j * m + init;
- scalar_t * buff_i = buffer + (ib * n + i) * m + init;
- scalar_t * buff_j = buffer + (jb * n + j) * m + init;
+ const scalar_t * self_j = self + (l * n + j) * m + init;
+ scalar_t * buff_i = buffer + ((l * (n - 1) + ib) * n + i) * m + init;
+ scalar_t * buff_j = buffer + ((l * (n - 1) + jb) * n + j) * m + init;
for (; self_i < end; self_i += stride, self_j += stride, buff_i += stride, buff_j += stride) {
const scalar_t res = F::backward(*self_i - *self_j, grad_k, dist_k, p);
*buff_i = res;
}
void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) {
- const dim3 grid(result.numel());
- const dim3 block(forward_threads);
- int64_t n = self.size(0);
- int64_t m = self.size(1);
+ int64_t b = self.size(0);
+ int64_t n = self.size(1);
+ int64_t m = self.size(2);
+ const dim3 grid(b, result.size(1));
+ const dim3 block(1, forward_threads);
AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist_cuda", [&] {
if (p == 0.0) {
return;
}
- const int64_t n = result.size(0);
- int64_t m = self.size(1);
- const int block_x = 64;
- const int block_y = 4;
- const int grid_x = (m + block_x * 8 - 1) / (block_x * 8);
- const int grid_y = (dist.numel() + block_y - 1) / block_y;
- const dim3 grid(grid_x, grid_y);
- const dim3 block(block_x, block_y);
-
- Tensor buffer = at::empty({n - 1, result.size(0), result.size(1)}, result.options());
+ const int64_t b = self.size(0);
+ const int64_t n = self.size(1);
+ const int64_t m = self.size(2);
+ const int block_y = 64;
+ const int block_z = 4;
+ const int grid_y = (m + block_y * 8 - 1) / (block_y * 8);
+ const int grid_z = (dist.numel() + block_z - 1) / block_z;
+ const dim3 grid(b, grid_y, grid_z);
+ const dim3 block(1, block_y, block_z);
+
+ Tensor buffer = at::empty({b, n - 1, n, m}, result.options());
AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist_cuda_backward", [&] {
if (p == 1.0) {
- pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p);
+ pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), n, m, p);
} else if (p < 2.0) {
- pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::lt_two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p);
+ pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::lt_two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), n, m, p);
} else if (p == 2.0) {
- pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p);
+ pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), n, m, p);
} else if (std::isinf(p)) {
- pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p);
+ pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), n, m, p);
} else {
- pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p);
+ pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), n, m, p);
}
});
- at::sum_out(result, buffer, 0);
+ at::sum_out(result, buffer, 1);
}
} // anonymous namespace
return torch.stack(all_matrices).reshape(*(batches + (l, l)))
+def brute_pdist(inp, p=2):
+ """Computes the same as torch.pdist using primitives"""
+ n = inp.shape[-2]
+ k = n * (n - 1) // 2
+ if k == 0:
+ # torch complains about empty indices
+ return torch.empty(inp.shape[:-2] + (0,), device=inp.device)
+ square = torch.norm(inp[..., None, :] - inp[..., None, :, :], p=p, dim=-1)
+ unroll = square.view(square.shape[:-2] + (n * n,))
+ inds = torch.ones(k, dtype=torch.int)
+ inds[torch.arange(n - 1, 1, -1, dtype=torch.int).cumsum(0)] += torch.arange(2, n, dtype=torch.int)
+ return unroll[..., inds.cumsum(0)]
+
+
def do_test_dtypes(self, dtypes, layout, device):
for dtype in dtypes:
if dtype != torch.float16:
@skipIfRocm
def test_pdist(self):
- for device, trans in itertools.product(device_(), [False, True]):
- inp = torch.randn(4, 5, dtype=torch.double, device=device, requires_grad=True)
+ for device, trans, shape in itertools.product(device_(), [False, True], [(4, 5), (2, 3, 4)]):
+ inp = torch.randn(shape, dtype=torch.double, device=device, requires_grad=True)
if trans:
- inp = inp.transpose(0, 1)
+ inp = inp.transpose(-2, -1)
for p in [0, 1, 2, 0.5, 1.5, 2.5, float('inf')]:
self.assertTrue(gradcheck(lambda x: F.pdist(x, p), (inp,)))
from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, skipIfRocm, do_test_dtypes, do_test_empty_full, \
- IS_SANDCASTLE, load_tests
+ IS_SANDCASTLE, load_tests, brute_pdist
from multiprocessing.reduction import ForkingPickler
# load_tests from common_utils is used to automatically filter tests for
x = torch.randn(shape, device=device)
self.assertEqual(torch.zeros(3, device=device), torch.pdist(x))
- @unittest.skipIf(not TEST_SCIPY, "Scipy not found")
- def test_pdist_scipy(self):
- from scipy.spatial.distance import pdist
+ def test_pdist_norm(self):
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
for device in devices:
- for shape in [(4, 5), (3, 2), (2, 1)]:
+ for shape in [(4, 5), (3, 2), (2, 1), (2, 3, 4)]:
for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
for trans in [False, True]:
x = torch.randn(shape, device=device)
if trans:
- x.transpose_(0, 1)
+ x.transpose_(-2, -1)
actual = torch.pdist(x, p=p)
- # pdist doesn't handle 0 or inf norm properly
- if p == 0:
- expected = pdist(x.cpu(), 'hamming') * x.shape[1]
- elif p == float('inf'):
- expected = pdist(x.cpu(), lambda a, b: np.abs(a - b).max())
- else:
- expected = pdist(x.cpu(), 'minkowski', p=p)
+ expected = brute_pdist(x, p=p)
self.assertEqual(expected.shape, actual.shape)
- self.assertTrue(np.allclose(expected, actual.cpu().numpy()))
+ self.assertTrue(torch.allclose(expected, actual))
@unittest.skipIf(not TEST_SCIPY, "Scipy not found")
def test_logsumexp(self):
pdist(input, p=2) -> Tensor
Computes the p-norm distance between every pair of row vectors in the input.
-This is identical to the upper triangular portion, excluding the diagonal, of
-`torch.norm(input[:, None] - input, dim=2, p=p)`. This function will be faster
-if the rows are contiguous.
-
-If input has shape :math:`N \times M` then the output will have shape
-:math:`\frac{1}{2} N (N - 1)`.
-
-This function is equivalent to `scipy.spatial.distance.pdist(input,
-'minkowski', p=p)` if :math:`p \in (0, \infty)`. When :math:`p = 0` it is
-equivalent to `scipy.spatial.distance.pdist(input, 'hamming') * M`.
-When :math:`p = \infty`, the closest scipy function is
-`scipy.spatial.distance.pdist(xn, lambda x, y: np.abs(x - y).max())`.
+If the input tensor has shape ... B x N x M, the result will be tensor with
+shape ... B x N * (N - 1) / 2. Every dimension prior to the last two are
+treated as independent batches of N vectors, each with M elements. If we use
+ordinal numbers to refer to the vectors in a batch, the last dimension of the
+output will be ordered as:
+
+```
+[dist(1, 2), dist(1, 2), ..., dist(1, N), dist(2, 3), ..., dist(N-1, N)]
+```
+
+The square verion of pdist that has redundant distances and the diagonal can be
+be computed with
+`torch.norm(inpup[..., None, :] - input[..., None, :, :], p=p, dim=-1)`.
+Appropriately selecting and flattening the upper triangular of the
+last two dimensions will produce identical results as pdist.
+
+This function is similar to
+`scipy.spatial.distance.pdist(input, 'minkowski', p=p)`
+if :math:`p \in (0, \infty)`.
Args:
- input: input tensor of shape :math:`N \times M`.
+ input: input tensor of shape :math:`... \times N \times M`.
p: p value for the p-norm distance to calculate between each vector pair
:math:`\in [0, \infty]`.
""")