_(aten, ormqr) \
_(aten, pairwise_distance) \
_(aten, pdist) \
+_(aten, cdist) \
_(aten, permute) \
_(aten, pin_memory) \
_(aten, pinverse) \
_(attr, padding_value) \
_(attr, params) \
_(attr, pdist) \
+_(attr, cdist) \
_(attr, periodic) \
_(attr, pivot) \
_(attr, pivots) \
DEFINE_DISPATCH(pdist_forward_stub);
DEFINE_DISPATCH(pdist_backward_stub);
+DEFINE_DISPATCH(cdist_stub);
Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps, bool keepdim) {
return at::norm(x1 - x2 + eps, p, 1, keepdim);
return at::_pdist_forward(self.contiguous(), p);
}
+Tensor cdist(const Tensor& x1, const Tensor& x2, const double p) {
+ AT_CHECK(x1.dim() == 2, "cdist only supports 2D tensors, X1 got: ", x1.dim(), "D");
+ AT_CHECK(at::isFloatingType(x1.type().scalarType()), "cdist only supports floating-point dtypes, X1 got: ", x1.type().scalarType());
+ auto device1 = x1.type().device_type();
+ AT_CHECK(device1 == kCPU || device1 == kCUDA, "cdist only supports CPU and CUDA devices, X1 got: ", device1);
+ AT_CHECK(x2.dim() == 2, "cdist only supports 2D tensors, X2 got: ", x2.dim(), "D");
+ AT_CHECK(at::isFloatingType(x1.type().scalarType()), "cdist only supports floating-point dtypes, X2 got: ", x2.type().scalarType());
+ auto device2 = x2.type().device_type();
+ AT_CHECK(device2 == kCPU || device2 == kCUDA, "cdist only supports CPU and CUDA devices, X2 got: ", device2);
+ AT_CHECK(p >= 0, "cdist only supports non-negative p values");
+ AT_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2);
+ AT_CHECK(!x1.is_cuda() || x1.get_device() == x2.get_device(), "device of X1 (", x1.get_device(), ") must match device of X2 (", x2.get_device(), ")");
+ int64_t c1 = x1.size(-1);
+ int64_t c2 = x2.size(-1);
+ AT_CHECK(c1 == c2, "X1 and X2 must have the same number of columns. X1: ", c1, " X2: ", c2);
+
+ int64_t r1 = x1.size(-2);
+ int64_t r2 = x2.size(-2);
+ Tensor result = at::empty({r1, r2}, x1.options());
+ if (r1 > 0 && r2 > 0) {
+ if (c1 == 0) {
+ result.fill_(0);
+ } else {
+ cdist_stub(device1, result, x1.contiguous(), x2.contiguous(), p);
+ }
+ }
+ return result;
+}
+
Tensor _pdist_forward(const Tensor& self, const double p) {
AT_CHECK(self.is_contiguous(), "_pdist_forward requires contiguous input");
auto device = self.type().device_type();
using pdist_forward_fn = void(*)(Tensor&, const Tensor&, const double p);
using pdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&);
+using cdist_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p);
DECLARE_DISPATCH(pdist_forward_fn, pdist_forward_stub);
DECLARE_DISPATCH(pdist_backward_fn, pdist_backward_stub);
+DECLARE_DISPATCH(cdist_fn, cdist_stub);
}} // namespace at::native
}
template <typename F>
+ static void run_cdist_parallel(Tensor& result, const Tensor& t1, const Tensor& t2, const scalar_t p) {
+ const scalar_t * const t1_start = t1.data<scalar_t>();
+ const scalar_t * const t2_start = t2.data<scalar_t>();
+ int64_t r1 = t1.size(-2);
+ int64_t r2 = t2.size(-2);
+ int64_t m = t1.size(-1);
+
+ scalar_t * const res_start = result.data<scalar_t>();
+ int64_t total = r1 * r2;
+
+ parallel_for(0, total, internal::GRAIN_SIZE / (16 * m), [=](int64_t start, int64_t end) {
+ const Vec pvec(p);
+ scalar_t * res = res_start + start;
+ const scalar_t * const res_end = res_start + end;
+
+ int64_t k = start;
+ while (res != res_end) {
+ int64_t i = k / r2;
+ int64_t j = k % r2;
+ const scalar_t * self_i = t1_start + i * m;
+ const scalar_t * self_j = t2_start + j * m;
+
+ *res = F::finish(vec256::map2_reduce_all<scalar_t>(
+ [&pvec](Vec a, Vec b) { return F::map((a - b).abs(), pvec); },
+ F::red, self_i, self_j, m), p);
+
+ res += 1;
+ k++;
+ }
+ });
+ }
+
+ static void apply_cdist(Tensor& result, const Tensor& x1, const Tensor& x2, const scalar_t p) {
+ if (p == 0.0) {
+ run_cdist_parallel<zdist_calc>(result, x1, x2, p);
+ } else if (p == 1.0) {
+ run_cdist_parallel<odist_calc>(result, x1, x2, p);
+ } else if (p == 2.0) {
+ run_cdist_parallel<tdist_calc>(result, x1, x2, p);
+ } else if (std::isinf(p)) {
+ run_cdist_parallel<idist_calc>(result, x1, x2, p);
+ } else {
+ run_cdist_parallel<pdist_calc>(result, x1, x2, p);
+ }
+ }
+
+ // This does a backward pass down a Vec column of the input
+ template <typename F>
inline static void backward_down_column(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int64_t gs, int64_t count = Vec::size()) {
for (const scalar_t * const self_end = self_i + m * n; self_i != self_end - m; self_i += m, res_i += m) {
});
}
+static void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, const double p) {
+ AT_DISPATCH_FLOATING_TYPES(result.type(), "cdist", [&] {
+ PDist<scalar_t>::apply_cdist(result, x1, x2, p);
+ });
+}
+
+
+
} // anonymous namespace
REGISTER_DISPATCH(pdist_forward_stub, &pdist_forward_kernel_impl);
REGISTER_DISPATCH(pdist_backward_stub, &pdist_backward_kernel_impl);
+REGISTER_DISPATCH(cdist_stub, &cdist_kernel_impl);
}} // namespace at::native
};
template <typename scalar_t, typename F>
+__device__ static inline scalar_t reduce_agg(scalar_t agg) {
+ for (int offset = warpSize / 2; offset > 0; offset /= 2) {
+ F::agg(agg, WARP_SHFL_DOWN(agg, offset));
+ }
+
+ __shared__ scalar_t shared[forward_threads];
+ int lane = threadIdx.x % warpSize;
+ int warp_id = threadIdx.x / warpSize;
+ if (lane == 0) {
+ shared[warp_id] = agg;
+ }
+
+ __syncthreads();
+ agg = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0;
+ if (warp_id == 0) {
+ for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) {
+ F::agg(agg, WARP_SHFL_DOWN(agg, offset));
+ }
+ }
+ return agg;
+}
+
+template <typename scalar_t, typename F>
__global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t * self, const int64_t n, const int64_t m, const scalar_t p,
const double n2, const double n2_squared_minus_1) {
const int k = blockIdx.x;
F::inc(agg, std::abs(*a - *b), p);
}
- // Reduce warps
- for (int offset = warpSize / 2; offset > 0; offset /= 2) {
- F::agg(agg, WARP_SHFL_DOWN(agg, offset));
- }
-
- // Reduce block
- // This shared memory is significantly larger than necessary, but the
- // assumption is that it's not a bottleneck, and this is simple
- __shared__ scalar_t shared[forward_threads];
- int lane = threadIdx.x % warpSize;
- int warp_id = threadIdx.x / warpSize;
- if (lane == 0) {
- shared[warp_id] = agg;
- }
- __syncthreads();
- agg = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0;
- if (warp_id == 0) {
- // Only reduce theads with nonzero data
- for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) {
- F::agg(agg, WARP_SHFL_DOWN(agg, offset));
- }
- }
+ agg = reduce_agg<scalar_t, F>(agg);
if (threadIdx.x == 0) {
result[k] = F::finish(agg, p);
}
}
}
+template <typename scalar_t, typename F>
+__global__ static void cdist_kernel_cuda_impl(scalar_t * result, const scalar_t * x1, const scalar_t * x2, const scalar_t p, const int64_t r1, const int64_t r2, const int64_t m) {
+ const int k = blockIdx.x;
+ const int64_t i = k / r2;
+ const int64_t j = k % r2;
+ const int stride = blockDim.x;
+
+ const scalar_t * const start = x1 + i * m;
+ const scalar_t * const end = start + m;
+ const scalar_t * a = start + threadIdx.x;
+ const scalar_t * b = x2 + j * m + threadIdx.x;
+
+ scalar_t agg = 0.0;
+ for (; a < end; a += stride, b += stride) {
+ F::inc(agg, std::abs(*a - *b), p);
+ }
+ agg = reduce_agg<scalar_t, F>(agg);
+ if (threadIdx.x == 0) {
+ result[k] = F::finish(agg, p);
+ }
+}
+
+void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, double p) {
+ int64_t r1 = x1.size(-2);
+ int64_t r2 = x2.size(-2);
+ int64_t m = x1.size(-1);
+ const dim3 grid(r1*r2);
+ const dim3 block(forward_threads);
+
+ AT_DISPATCH_FLOATING_TYPES(x1.type(), "cdist_cuda", [&] {
+ if (p == 0.0) {
+ cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::zero><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m);
+ } else if (p == 1.0) {
+ cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m);
+ } else if (p == 2.0) {
+ cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m);
+ } else if (std::isinf(p)) {
+ cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m);
+ } else {
+ cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m);
+ }
+ });
+}
+
void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) {
const dim3 grid(result.numel());
const dim3 block(forward_threads);
REGISTER_DISPATCH(pdist_forward_stub, &pdist_forward_kernel_impl);
REGISTER_DISPATCH(pdist_backward_stub, &pdist_backward_kernel_impl);
+REGISTER_DISPATCH(cdist_stub, &cdist_kernel_impl);
}} // at::native
- func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor
matches_jit_signature: True
+- func: cdist(Tensor x1, Tensor x2, double p=2) -> Tensor
+
- func: pdist(Tensor self, float p=2) -> Tensor
matches_jit_signature: True
inds[torch.arange(n - 1, 1, -1, dtype=torch.int).cumsum(0)] += torch.arange(2, n, dtype=torch.int)
return unroll[..., inds.cumsum(0)]
+def brute_cdist(x, y, p=2):
+ r1 = x.shape[-2]
+ r2 = y.shape[-2]
+ if r1 == 0 or r2 == 0:
+ return torch.empty(r1, r2, device=x.device)
+ return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
+
def do_test_dtypes(self, dtypes, layout, device):
for dtype in dtypes:
from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, skipIfRocm, do_test_dtypes, do_test_empty_full, \
- IS_SANDCASTLE, load_tests, brute_pdist
+ IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist
from multiprocessing.reduction import ForkingPickler
# load_tests from common_utils is used to automatically filter tests for
for dtype in [torch.float32, torch.float64]:
test_pdist_single((1000, 2), device, 2, dtype, False)
+
+ def test_cdist_empty(self):
+ devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
+ for device in devices:
+ x = torch.randn((0, 5), device=device)
+ y = torch.randn((4, 5), device=device)
+ self.assertEqual(torch.empty(0, 4, device=device), torch.cdist(x, y))
+
+ x = torch.randn((2, 5), device=device)
+ y = torch.randn((0, 5), device=device)
+ self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
+
+ x = torch.randn((2, 0), device=device)
+ y = torch.randn((3, 0), device=device)
+ self.assertEqual(torch.zeros(2, 3, device=device), torch.cdist(x, y))
+
+ x = torch.randn((2, 0), device=device)
+ y = torch.randn((0, 0), device=device)
+ self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
+
+ def test_cdist_norm(self):
+ devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
+ for device in devices:
+ for r1 in [3, 4, 5, 6]:
+ for m in [2, 3, 4, 10]:
+ for r2 in [4, 6, 7, 8]:
+ for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
+ x = torch.randn(r1, m, device=device)
+ y = torch.randn(r2, m, device=device)
+ actual = torch.cdist(x, y, p=p)
+ expected = brute_cdist(x, y, p=p)
+ self.assertTrue(torch.allclose(expected, actual))
+
+ def test_cdist_large(self):
+ devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
+ for device in devices:
+ x = torch.randn(1000, 10, device=device)
+ y = torch.randn(1000, 10, device=device)
+ actual = torch.cdist(x, y, p=2)
+ expected = brute_cdist(x, y, p=2)
+ self.assertTrue(torch.allclose(expected, actual))
+
+ def test_cdist_non_contiguous(self):
+ devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
+ for device in devices:
+ x = torch.randn(5, 7, device=device).t()
+ y = torch.randn(5, 3, device=device).t()
+ actual = torch.cdist(x, y, p=2)
+ expected = brute_cdist(x, y, p=2)
+ self.assertFalse(x.is_contiguous())
+ self.assertFalse(y.is_contiguous())
+ self.assertTrue(torch.allclose(expected, actual))
+
+ x = torch.randn(7, 5, device=device)
+ y = torch.randn(5, 3, device=device).t()
+ actual = torch.cdist(x, y, p=2)
+ expected = brute_cdist(x, y, p=2)
+ self.assertTrue(x.is_contiguous())
+ self.assertFalse(y.is_contiguous())
+ self.assertTrue(torch.allclose(expected, actual))
+
+ x = torch.randn(5, 7, device=device).t()
+ y = torch.randn(3, 5, device=device)
+ actual = torch.cdist(x, y, p=2)
+ expected = brute_cdist(x, y, p=2)
+ self.assertFalse(x.is_contiguous())
+ self.assertTrue(y.is_contiguous())
+ self.assertTrue(torch.allclose(expected, actual))
+
@unittest.skipIf(not TEST_SCIPY, "Scipy not found")
def test_logsumexp(self):
from scipy.special import logsumexp