From ceece5dd0f112f202901924d08419cdcb3d9c15a Mon Sep 17 00:00:00 2001 From: Igor Fedan Date: Mon, 28 Jan 2019 09:14:07 -0800 Subject: [PATCH] CPU implementation of torch.cdist (#16168) Summary: cdist is used for calculating distances between collections of observations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16168 Differential Revision: D13739147 Pulled By: ifedan fbshipit-source-id: 9419c2c166891ac7db40672c72f17848f0b446f9 --- aten/src/ATen/core/aten_interned_strings.h | 2 + aten/src/ATen/native/Distance.cpp | 30 +++++++++ aten/src/ATen/native/Distance.h | 2 + aten/src/ATen/native/cpu/DistanceOpsKernel.cpp | 57 ++++++++++++++++ aten/src/ATen/native/cuda/DistanceKernel.cu | 91 +++++++++++++++++++------- aten/src/ATen/native/native_functions.yaml | 2 + test/common_utils.py | 7 ++ test/test_torch.py | 71 +++++++++++++++++++- 8 files changed, 239 insertions(+), 23 deletions(-) diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 514323c..1046121 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -516,6 +516,7 @@ _(aten, orgqr) \ _(aten, ormqr) \ _(aten, pairwise_distance) \ _(aten, pdist) \ +_(aten, cdist) \ _(aten, permute) \ _(aten, pin_memory) \ _(aten, pinverse) \ @@ -905,6 +906,7 @@ _(attr, padding_mode) \ _(attr, padding_value) \ _(attr, params) \ _(attr, pdist) \ +_(attr, cdist) \ _(attr, periodic) \ _(attr, pivot) \ _(attr, pivots) \ diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp index b5bf0a8..9146b4d 100644 --- a/aten/src/ATen/native/Distance.cpp +++ b/aten/src/ATen/native/Distance.cpp @@ -8,6 +8,7 @@ namespace at { namespace native { DEFINE_DISPATCH(pdist_forward_stub); DEFINE_DISPATCH(pdist_backward_stub); +DEFINE_DISPATCH(cdist_stub); Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps, bool keepdim) { return at::norm(x1 - x2 + eps, p, 1, keepdim); @@ -22,6 +23,35 @@ Tensor pdist(const Tensor& self, const double p) { return at::_pdist_forward(self.contiguous(), p); } +Tensor cdist(const Tensor& x1, const Tensor& x2, const double p) { + AT_CHECK(x1.dim() == 2, "cdist only supports 2D tensors, X1 got: ", x1.dim(), "D"); + AT_CHECK(at::isFloatingType(x1.type().scalarType()), "cdist only supports floating-point dtypes, X1 got: ", x1.type().scalarType()); + auto device1 = x1.type().device_type(); + AT_CHECK(device1 == kCPU || device1 == kCUDA, "cdist only supports CPU and CUDA devices, X1 got: ", device1); + AT_CHECK(x2.dim() == 2, "cdist only supports 2D tensors, X2 got: ", x2.dim(), "D"); + AT_CHECK(at::isFloatingType(x1.type().scalarType()), "cdist only supports floating-point dtypes, X2 got: ", x2.type().scalarType()); + auto device2 = x2.type().device_type(); + AT_CHECK(device2 == kCPU || device2 == kCUDA, "cdist only supports CPU and CUDA devices, X2 got: ", device2); + AT_CHECK(p >= 0, "cdist only supports non-negative p values"); + AT_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2); + AT_CHECK(!x1.is_cuda() || x1.get_device() == x2.get_device(), "device of X1 (", x1.get_device(), ") must match device of X2 (", x2.get_device(), ")"); + int64_t c1 = x1.size(-1); + int64_t c2 = x2.size(-1); + AT_CHECK(c1 == c2, "X1 and X2 must have the same number of columns. X1: ", c1, " X2: ", c2); + + int64_t r1 = x1.size(-2); + int64_t r2 = x2.size(-2); + Tensor result = at::empty({r1, r2}, x1.options()); + if (r1 > 0 && r2 > 0) { + if (c1 == 0) { + result.fill_(0); + } else { + cdist_stub(device1, result, x1.contiguous(), x2.contiguous(), p); + } + } + return result; +} + Tensor _pdist_forward(const Tensor& self, const double p) { AT_CHECK(self.is_contiguous(), "_pdist_forward requires contiguous input"); auto device = self.type().device_type(); diff --git a/aten/src/ATen/native/Distance.h b/aten/src/ATen/native/Distance.h index 87cdc62..9c871ff 100644 --- a/aten/src/ATen/native/Distance.h +++ b/aten/src/ATen/native/Distance.h @@ -7,8 +7,10 @@ namespace at { namespace native { using pdist_forward_fn = void(*)(Tensor&, const Tensor&, const double p); using pdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&); +using cdist_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p); DECLARE_DISPATCH(pdist_forward_fn, pdist_forward_stub); DECLARE_DISPATCH(pdist_backward_fn, pdist_backward_stub); +DECLARE_DISPATCH(cdist_fn, cdist_stub); }} // namespace at::native diff --git a/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp b/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp index 8b06938..aac1109 100644 --- a/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp @@ -149,6 +149,54 @@ struct PDist { } template + static void run_cdist_parallel(Tensor& result, const Tensor& t1, const Tensor& t2, const scalar_t p) { + const scalar_t * const t1_start = t1.data(); + const scalar_t * const t2_start = t2.data(); + int64_t r1 = t1.size(-2); + int64_t r2 = t2.size(-2); + int64_t m = t1.size(-1); + + scalar_t * const res_start = result.data(); + int64_t total = r1 * r2; + + parallel_for(0, total, internal::GRAIN_SIZE / (16 * m), [=](int64_t start, int64_t end) { + const Vec pvec(p); + scalar_t * res = res_start + start; + const scalar_t * const res_end = res_start + end; + + int64_t k = start; + while (res != res_end) { + int64_t i = k / r2; + int64_t j = k % r2; + const scalar_t * self_i = t1_start + i * m; + const scalar_t * self_j = t2_start + j * m; + + *res = F::finish(vec256::map2_reduce_all( + [&pvec](Vec a, Vec b) { return F::map((a - b).abs(), pvec); }, + F::red, self_i, self_j, m), p); + + res += 1; + k++; + } + }); + } + + static void apply_cdist(Tensor& result, const Tensor& x1, const Tensor& x2, const scalar_t p) { + if (p == 0.0) { + run_cdist_parallel(result, x1, x2, p); + } else if (p == 1.0) { + run_cdist_parallel(result, x1, x2, p); + } else if (p == 2.0) { + run_cdist_parallel(result, x1, x2, p); + } else if (std::isinf(p)) { + run_cdist_parallel(result, x1, x2, p); + } else { + run_cdist_parallel(result, x1, x2, p); + } + } + + // This does a backward pass down a Vec column of the input + template inline static void backward_down_column(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int64_t gs, int64_t count = Vec::size()) { for (const scalar_t * const self_end = self_i + m * n; self_i != self_end - m; self_i += m, res_i += m) { @@ -233,9 +281,18 @@ static void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const }); } +static void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, const double p) { + AT_DISPATCH_FLOATING_TYPES(result.type(), "cdist", [&] { + PDist::apply_cdist(result, x1, x2, p); + }); +} + + + } // anonymous namespace REGISTER_DISPATCH(pdist_forward_stub, &pdist_forward_kernel_impl); REGISTER_DISPATCH(pdist_backward_stub, &pdist_backward_kernel_impl); +REGISTER_DISPATCH(cdist_stub, &cdist_kernel_impl); }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/DistanceKernel.cu b/aten/src/ATen/native/cuda/DistanceKernel.cu index cd0a3ad..37d5c68 100644 --- a/aten/src/ATen/native/cuda/DistanceKernel.cu +++ b/aten/src/ATen/native/cuda/DistanceKernel.cu @@ -79,6 +79,29 @@ struct dists { }; template +__device__ static inline scalar_t reduce_agg(scalar_t agg) { + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + F::agg(agg, WARP_SHFL_DOWN(agg, offset)); + } + + __shared__ scalar_t shared[forward_threads]; + int lane = threadIdx.x % warpSize; + int warp_id = threadIdx.x / warpSize; + if (lane == 0) { + shared[warp_id] = agg; + } + + __syncthreads(); + agg = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0; + if (warp_id == 0) { + for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) { + F::agg(agg, WARP_SHFL_DOWN(agg, offset)); + } + } + return agg; +} + +template __global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t * self, const int64_t n, const int64_t m, const scalar_t p, const double n2, const double n2_squared_minus_1) { const int k = blockIdx.x; @@ -97,28 +120,7 @@ __global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t F::inc(agg, std::abs(*a - *b), p); } - // Reduce warps - for (int offset = warpSize / 2; offset > 0; offset /= 2) { - F::agg(agg, WARP_SHFL_DOWN(agg, offset)); - } - - // Reduce block - // This shared memory is significantly larger than necessary, but the - // assumption is that it's not a bottleneck, and this is simple - __shared__ scalar_t shared[forward_threads]; - int lane = threadIdx.x % warpSize; - int warp_id = threadIdx.x / warpSize; - if (lane == 0) { - shared[warp_id] = agg; - } - __syncthreads(); - agg = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0; - if (warp_id == 0) { - // Only reduce theads with nonzero data - for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) { - F::agg(agg, WARP_SHFL_DOWN(agg, offset)); - } - } + agg = reduce_agg(agg); if (threadIdx.x == 0) { result[k] = F::finish(agg, p); } @@ -157,6 +159,50 @@ __global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const } } +template +__global__ static void cdist_kernel_cuda_impl(scalar_t * result, const scalar_t * x1, const scalar_t * x2, const scalar_t p, const int64_t r1, const int64_t r2, const int64_t m) { + const int k = blockIdx.x; + const int64_t i = k / r2; + const int64_t j = k % r2; + const int stride = blockDim.x; + + const scalar_t * const start = x1 + i * m; + const scalar_t * const end = start + m; + const scalar_t * a = start + threadIdx.x; + const scalar_t * b = x2 + j * m + threadIdx.x; + + scalar_t agg = 0.0; + for (; a < end; a += stride, b += stride) { + F::inc(agg, std::abs(*a - *b), p); + } + agg = reduce_agg(agg); + if (threadIdx.x == 0) { + result[k] = F::finish(agg, p); + } +} + +void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, double p) { + int64_t r1 = x1.size(-2); + int64_t r2 = x2.size(-2); + int64_t m = x1.size(-1); + const dim3 grid(r1*r2); + const dim3 block(forward_threads); + + AT_DISPATCH_FLOATING_TYPES(x1.type(), "cdist_cuda", [&] { + if (p == 0.0) { + cdist_kernel_cuda_impl::zero><<>>(result.data(), x1.data(), x2.data(), p, r1, r2, m); + } else if (p == 1.0) { + cdist_kernel_cuda_impl::one><<>>(result.data(), x1.data(), x2.data(), p, r1, r2, m); + } else if (p == 2.0) { + cdist_kernel_cuda_impl::two><<>>(result.data(), x1.data(), x2.data(), p, r1, r2, m); + } else if (std::isinf(p)) { + cdist_kernel_cuda_impl::inf><<>>(result.data(), x1.data(), x2.data(), p, r1, r2, m); + } else { + cdist_kernel_cuda_impl::p><<>>(result.data(), x1.data(), x2.data(), p, r1, r2, m); + } + }); +} + void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) { const dim3 grid(result.numel()); const dim3 block(forward_threads); @@ -228,5 +274,6 @@ void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor REGISTER_DISPATCH(pdist_forward_stub, &pdist_forward_kernel_impl); REGISTER_DISPATCH(pdist_backward_stub, &pdist_backward_kernel_impl); +REGISTER_DISPATCH(cdist_stub, &cdist_kernel_impl); }} // at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index cdc8ed9..08c3539 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1484,6 +1484,8 @@ - func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor matches_jit_signature: True +- func: cdist(Tensor x1, Tensor x2, double p=2) -> Tensor + - func: pdist(Tensor self, float p=2) -> Tensor matches_jit_signature: True diff --git a/test/common_utils.py b/test/common_utils.py index b0071a4..f83f9e4 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -745,6 +745,13 @@ def brute_pdist(inp, p=2): inds[torch.arange(n - 1, 1, -1, dtype=torch.int).cumsum(0)] += torch.arange(2, n, dtype=torch.int) return unroll[..., inds.cumsum(0)] +def brute_cdist(x, y, p=2): + r1 = x.shape[-2] + r2 = y.shape[-2] + if r1 == 0 or r2 == 0: + return torch.empty(r1, r2, device=x.device) + return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1) + def do_test_dtypes(self, dtypes, layout, device): for dtype in dtypes: diff --git a/test/test_torch.py b/test/test_torch.py index 4522317..70ed959 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -27,7 +27,7 @@ from common_methods_invocations import tri_tests_args, run_additional_tri_tests, from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \ TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \ IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, skipIfRocm, do_test_dtypes, do_test_empty_full, \ - IS_SANDCASTLE, load_tests, brute_pdist + IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist from multiprocessing.reduction import ForkingPickler # load_tests from common_utils is used to automatically filter tests for @@ -1231,6 +1231,75 @@ class _TestTorchMixin(object): for dtype in [torch.float32, torch.float64]: test_pdist_single((1000, 2), device, 2, dtype, False) + + def test_cdist_empty(self): + devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] + for device in devices: + x = torch.randn((0, 5), device=device) + y = torch.randn((4, 5), device=device) + self.assertEqual(torch.empty(0, 4, device=device), torch.cdist(x, y)) + + x = torch.randn((2, 5), device=device) + y = torch.randn((0, 5), device=device) + self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y)) + + x = torch.randn((2, 0), device=device) + y = torch.randn((3, 0), device=device) + self.assertEqual(torch.zeros(2, 3, device=device), torch.cdist(x, y)) + + x = torch.randn((2, 0), device=device) + y = torch.randn((0, 0), device=device) + self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y)) + + def test_cdist_norm(self): + devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] + for device in devices: + for r1 in [3, 4, 5, 6]: + for m in [2, 3, 4, 10]: + for r2 in [4, 6, 7, 8]: + for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: + x = torch.randn(r1, m, device=device) + y = torch.randn(r2, m, device=device) + actual = torch.cdist(x, y, p=p) + expected = brute_cdist(x, y, p=p) + self.assertTrue(torch.allclose(expected, actual)) + + def test_cdist_large(self): + devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] + for device in devices: + x = torch.randn(1000, 10, device=device) + y = torch.randn(1000, 10, device=device) + actual = torch.cdist(x, y, p=2) + expected = brute_cdist(x, y, p=2) + self.assertTrue(torch.allclose(expected, actual)) + + def test_cdist_non_contiguous(self): + devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] + for device in devices: + x = torch.randn(5, 7, device=device).t() + y = torch.randn(5, 3, device=device).t() + actual = torch.cdist(x, y, p=2) + expected = brute_cdist(x, y, p=2) + self.assertFalse(x.is_contiguous()) + self.assertFalse(y.is_contiguous()) + self.assertTrue(torch.allclose(expected, actual)) + + x = torch.randn(7, 5, device=device) + y = torch.randn(5, 3, device=device).t() + actual = torch.cdist(x, y, p=2) + expected = brute_cdist(x, y, p=2) + self.assertTrue(x.is_contiguous()) + self.assertFalse(y.is_contiguous()) + self.assertTrue(torch.allclose(expected, actual)) + + x = torch.randn(5, 7, device=device).t() + y = torch.randn(3, 5, device=device) + actual = torch.cdist(x, y, p=2) + expected = brute_cdist(x, y, p=2) + self.assertFalse(x.is_contiguous()) + self.assertTrue(y.is_contiguous()) + self.assertTrue(torch.allclose(expected, actual)) + @unittest.skipIf(not TEST_SCIPY, "Scipy not found") def test_logsumexp(self): from scipy.special import logsumexp -- 2.7.4