CPU implementation of torch.cdist (#16168)
authorIgor Fedan <ifedan@fb.com>
Mon, 28 Jan 2019 17:14:07 +0000 (09:14 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 28 Jan 2019 17:16:32 +0000 (09:16 -0800)
Summary:
cdist is used for calculating distances between collections of observations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16168

Differential Revision: D13739147

Pulled By: ifedan

fbshipit-source-id: 9419c2c166891ac7db40672c72f17848f0b446f9

aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/native/Distance.cpp
aten/src/ATen/native/Distance.h
aten/src/ATen/native/cpu/DistanceOpsKernel.cpp
aten/src/ATen/native/cuda/DistanceKernel.cu
aten/src/ATen/native/native_functions.yaml
test/common_utils.py
test/test_torch.py

index 514323c..1046121 100644 (file)
@@ -516,6 +516,7 @@ _(aten, orgqr) \
 _(aten, ormqr) \
 _(aten, pairwise_distance) \
 _(aten, pdist) \
+_(aten, cdist) \
 _(aten, permute) \
 _(aten, pin_memory) \
 _(aten, pinverse) \
@@ -905,6 +906,7 @@ _(attr, padding_mode) \
 _(attr, padding_value) \
 _(attr, params) \
 _(attr, pdist) \
+_(attr, cdist) \
 _(attr, periodic) \
 _(attr, pivot) \
 _(attr, pivots) \
index b5bf0a8..9146b4d 100644 (file)
@@ -8,6 +8,7 @@ namespace at { namespace native {
 
 DEFINE_DISPATCH(pdist_forward_stub);
 DEFINE_DISPATCH(pdist_backward_stub);
+DEFINE_DISPATCH(cdist_stub);
 
 Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps, bool keepdim) {
   return at::norm(x1 - x2 + eps, p, 1, keepdim);
@@ -22,6 +23,35 @@ Tensor pdist(const Tensor& self, const double p) {
   return at::_pdist_forward(self.contiguous(), p);
 }
 
+Tensor cdist(const Tensor& x1, const Tensor& x2, const double p) {
+  AT_CHECK(x1.dim() == 2, "cdist only supports 2D tensors, X1 got: ", x1.dim(), "D");
+  AT_CHECK(at::isFloatingType(x1.type().scalarType()), "cdist only supports floating-point dtypes, X1 got: ", x1.type().scalarType());
+  auto device1 = x1.type().device_type();
+  AT_CHECK(device1 == kCPU || device1 == kCUDA, "cdist only supports CPU and CUDA devices, X1 got: ", device1);
+  AT_CHECK(x2.dim() == 2, "cdist only supports 2D tensors, X2 got: ", x2.dim(), "D");
+  AT_CHECK(at::isFloatingType(x1.type().scalarType()), "cdist only supports floating-point dtypes, X2 got: ", x2.type().scalarType());
+  auto device2 = x2.type().device_type();
+  AT_CHECK(device2 == kCPU || device2 == kCUDA, "cdist only supports CPU and CUDA devices, X2 got: ", device2);
+  AT_CHECK(p >= 0, "cdist only supports non-negative p values");
+  AT_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2);
+  AT_CHECK(!x1.is_cuda() || x1.get_device() == x2.get_device(), "device of X1 (", x1.get_device(), ") must match device of X2 (", x2.get_device(), ")");
+  int64_t c1 = x1.size(-1);
+  int64_t c2 = x2.size(-1);
+  AT_CHECK(c1 == c2, "X1 and X2 must have the same number of columns. X1: ", c1, " X2: ", c2);
+
+  int64_t r1 = x1.size(-2);
+  int64_t r2 = x2.size(-2);
+  Tensor result = at::empty({r1, r2}, x1.options());
+  if (r1 > 0 && r2 > 0) {
+    if (c1 == 0) {
+      result.fill_(0);
+    } else {
+      cdist_stub(device1, result, x1.contiguous(), x2.contiguous(), p);
+    }
+  }
+  return result;
+}
+
 Tensor _pdist_forward(const Tensor& self, const double p) {
   AT_CHECK(self.is_contiguous(), "_pdist_forward requires contiguous input");
   auto device = self.type().device_type();
index 87cdc62..9c871ff 100644 (file)
@@ -7,8 +7,10 @@ namespace at { namespace native {
 
 using pdist_forward_fn = void(*)(Tensor&, const Tensor&, const double p);
 using pdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&);
+using cdist_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p);
 
 DECLARE_DISPATCH(pdist_forward_fn, pdist_forward_stub);
 DECLARE_DISPATCH(pdist_backward_fn, pdist_backward_stub);
+DECLARE_DISPATCH(cdist_fn, cdist_stub);
 
 }} // namespace at::native
index 8b06938..aac1109 100644 (file)
@@ -149,6 +149,54 @@ struct PDist {
   }
 
   template <typename F>
+  static void run_cdist_parallel(Tensor& result, const Tensor& t1, const Tensor& t2, const scalar_t p) {
+    const scalar_t * const t1_start = t1.data<scalar_t>();
+    const scalar_t * const t2_start = t2.data<scalar_t>();
+    int64_t r1 = t1.size(-2);
+    int64_t r2 = t2.size(-2);
+    int64_t m = t1.size(-1);
+
+    scalar_t * const res_start = result.data<scalar_t>();
+    int64_t total = r1 * r2;
+
+    parallel_for(0, total, internal::GRAIN_SIZE / (16 * m), [=](int64_t start, int64_t end) {
+      const Vec pvec(p);
+      scalar_t * res = res_start + start;
+      const scalar_t * const res_end = res_start + end;
+
+      int64_t k = start;
+      while (res != res_end) {
+        int64_t i = k / r2;
+        int64_t j = k % r2;
+        const scalar_t * self_i = t1_start + i * m;
+        const scalar_t * self_j = t2_start + j * m;
+
+        *res = F::finish(vec256::map2_reduce_all<scalar_t>(
+                [&pvec](Vec a, Vec b) { return F::map((a - b).abs(), pvec); },
+                F::red, self_i, self_j, m), p);
+
+        res += 1;
+        k++;
+      }
+    });
+  }
+
+  static void apply_cdist(Tensor& result, const Tensor& x1, const Tensor& x2, const scalar_t p) {
+    if (p == 0.0) {
+      run_cdist_parallel<zdist_calc>(result, x1, x2, p);
+    } else if (p == 1.0) {
+      run_cdist_parallel<odist_calc>(result, x1, x2, p);
+    } else if (p == 2.0) {
+      run_cdist_parallel<tdist_calc>(result, x1, x2, p);
+    } else if (std::isinf(p)) {
+      run_cdist_parallel<idist_calc>(result, x1, x2, p);
+    } else {
+      run_cdist_parallel<pdist_calc>(result, x1, x2, p);
+    }
+  }
+
+  // This does a backward pass down a Vec column of the input
+  template <typename F>
   inline static void backward_down_column(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int64_t gs, int64_t count = Vec::size()) {
     for (const scalar_t * const self_end = self_i + m * n; self_i != self_end - m; self_i += m, res_i += m) {
 
@@ -233,9 +281,18 @@ static void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const
   });
 }
 
+static void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, const double p) {
+  AT_DISPATCH_FLOATING_TYPES(result.type(), "cdist", [&] {
+    PDist<scalar_t>::apply_cdist(result, x1, x2, p);
+  });
+}
+
+
+
 }  // anonymous namespace
 
 REGISTER_DISPATCH(pdist_forward_stub, &pdist_forward_kernel_impl);
 REGISTER_DISPATCH(pdist_backward_stub, &pdist_backward_kernel_impl);
+REGISTER_DISPATCH(cdist_stub, &cdist_kernel_impl);
 
 }}  // namespace at::native
index cd0a3ad..37d5c68 100644 (file)
@@ -79,6 +79,29 @@ struct dists {
 };
 
 template <typename scalar_t, typename F>
+__device__ static inline scalar_t reduce_agg(scalar_t agg) {
+  for (int offset = warpSize / 2; offset > 0; offset /= 2) {
+    F::agg(agg, WARP_SHFL_DOWN(agg, offset));
+  }
+
+  __shared__ scalar_t shared[forward_threads];
+  int lane = threadIdx.x % warpSize;
+  int warp_id = threadIdx.x / warpSize;
+  if (lane == 0) {
+    shared[warp_id] = agg;
+  }
+
+  __syncthreads();
+  agg = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0;
+  if (warp_id == 0) {
+    for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) {
+      F::agg(agg, WARP_SHFL_DOWN(agg, offset));
+    }
+  }
+  return agg;
+}
+
+template <typename scalar_t, typename F>
 __global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t * self, const int64_t n, const int64_t m, const scalar_t p,
                                               const double n2, const double n2_squared_minus_1) {
   const int k = blockIdx.x;
@@ -97,28 +120,7 @@ __global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t
     F::inc(agg, std::abs(*a - *b), p);
   }
 
-  // Reduce warps
-  for (int offset = warpSize / 2; offset > 0; offset /= 2) {
-    F::agg(agg, WARP_SHFL_DOWN(agg, offset));
-  }
-
-  // Reduce block
-  // This shared memory is significantly larger than necessary, but the
-  // assumption is that it's not a bottleneck, and this is simple
-  __shared__ scalar_t shared[forward_threads];
-  int lane = threadIdx.x % warpSize;
-  int warp_id = threadIdx.x / warpSize;
-  if (lane == 0) {
-    shared[warp_id] = agg;
-  }
-  __syncthreads();
-  agg = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0.0;
-  if (warp_id == 0) {
-    // Only reduce theads with nonzero data
-    for (int offset = blockDim.x / warpSize / 2; offset > 0; offset /= 2) {
-      F::agg(agg, WARP_SHFL_DOWN(agg, offset));
-    }
-  }
+  agg = reduce_agg<scalar_t, F>(agg);
   if (threadIdx.x == 0) {
     result[k] = F::finish(agg, p);
   }
@@ -157,6 +159,50 @@ __global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const
   }
 }
 
+template <typename scalar_t, typename F>
+__global__ static void cdist_kernel_cuda_impl(scalar_t * result, const scalar_t * x1, const scalar_t * x2, const scalar_t p, const int64_t r1, const int64_t r2, const int64_t m) {
+  const int k = blockIdx.x;
+  const int64_t i = k / r2;
+  const int64_t j = k % r2;
+  const int stride = blockDim.x;
+
+  const scalar_t * const start = x1 + i * m;
+  const scalar_t * const end = start + m;
+  const scalar_t * a = start + threadIdx.x;
+  const scalar_t * b = x2 + j * m + threadIdx.x;
+
+  scalar_t agg = 0.0;
+  for (; a < end; a += stride, b += stride) {
+    F::inc(agg, std::abs(*a - *b), p);
+  }
+  agg = reduce_agg<scalar_t, F>(agg);
+  if (threadIdx.x == 0) {
+    result[k] = F::finish(agg, p);
+  }
+}
+
+void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, double p) {
+  int64_t r1 = x1.size(-2);
+  int64_t r2 = x2.size(-2);
+  int64_t m = x1.size(-1);
+  const dim3 grid(r1*r2);
+  const dim3 block(forward_threads);
+
+  AT_DISPATCH_FLOATING_TYPES(x1.type(), "cdist_cuda", [&] {
+    if (p == 0.0) {
+      cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::zero><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m);
+    } else if (p == 1.0) {
+      cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m);
+    } else if (p == 2.0) {
+      cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m);
+    } else if (std::isinf(p)) {
+      cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m);
+    } else {
+      cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m);
+    }
+  });
+}
+
 void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) {
   const dim3 grid(result.numel());
   const dim3 block(forward_threads);
@@ -228,5 +274,6 @@ void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor
 
 REGISTER_DISPATCH(pdist_forward_stub, &pdist_forward_kernel_impl);
 REGISTER_DISPATCH(pdist_backward_stub, &pdist_backward_kernel_impl);
+REGISTER_DISPATCH(cdist_stub, &cdist_kernel_impl);
 
 }} // at::native
index cdc8ed9..08c3539 100644 (file)
 - func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor
   matches_jit_signature: True
 
+- func: cdist(Tensor x1, Tensor x2, double p=2) -> Tensor
+    
 - func: pdist(Tensor self, float p=2) -> Tensor
   matches_jit_signature: True
 
index b0071a4..f83f9e4 100644 (file)
@@ -745,6 +745,13 @@ def brute_pdist(inp, p=2):
     inds[torch.arange(n - 1, 1, -1, dtype=torch.int).cumsum(0)] += torch.arange(2, n, dtype=torch.int)
     return unroll[..., inds.cumsum(0)]
 
+def brute_cdist(x, y, p=2):
+    r1 = x.shape[-2]
+    r2 = y.shape[-2]
+    if r1 == 0 or r2 == 0:
+        return torch.empty(r1, r2, device=x.device)
+    return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
+
 
 def do_test_dtypes(self, dtypes, layout, device):
     for dtype in dtypes:
index 4522317..70ed959 100644 (file)
@@ -27,7 +27,7 @@ from common_methods_invocations import tri_tests_args, run_additional_tri_tests,
 from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
     TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
     IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, skipIfRocm, do_test_dtypes, do_test_empty_full, \
-    IS_SANDCASTLE, load_tests, brute_pdist
+    IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist
 from multiprocessing.reduction import ForkingPickler
 
 # load_tests from common_utils is used to automatically filter tests for
@@ -1231,6 +1231,75 @@ class _TestTorchMixin(object):
             for dtype in [torch.float32, torch.float64]:
                 test_pdist_single((1000, 2), device, 2, dtype, False)
 
+
+    def test_cdist_empty(self):
+        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
+        for device in devices:
+            x = torch.randn((0, 5), device=device)
+            y = torch.randn((4, 5), device=device)
+            self.assertEqual(torch.empty(0, 4, device=device), torch.cdist(x, y))
+
+            x = torch.randn((2, 5), device=device)
+            y = torch.randn((0, 5), device=device)
+            self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
+
+            x = torch.randn((2, 0), device=device)
+            y = torch.randn((3, 0), device=device)
+            self.assertEqual(torch.zeros(2, 3, device=device), torch.cdist(x, y))
+
+            x = torch.randn((2, 0), device=device)
+            y = torch.randn((0, 0), device=device)
+            self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
+
+    def test_cdist_norm(self):
+        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
+        for device in devices:
+            for r1 in [3, 4, 5, 6]:
+                for m in [2, 3, 4, 10]:
+                    for r2 in [4, 6, 7, 8]:
+                        for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
+                            x = torch.randn(r1, m, device=device)
+                            y = torch.randn(r2, m, device=device)
+                            actual = torch.cdist(x, y, p=p)
+                            expected = brute_cdist(x, y, p=p)
+                            self.assertTrue(torch.allclose(expected, actual))
+
+    def test_cdist_large(self):
+        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
+        for device in devices:
+            x = torch.randn(1000, 10, device=device)
+            y = torch.randn(1000, 10, device=device)
+            actual = torch.cdist(x, y, p=2)
+            expected = brute_cdist(x, y, p=2)
+            self.assertTrue(torch.allclose(expected, actual))
+
+    def test_cdist_non_contiguous(self):
+        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
+        for device in devices:
+            x = torch.randn(5, 7, device=device).t()
+            y = torch.randn(5, 3, device=device).t()
+            actual = torch.cdist(x, y, p=2)
+            expected = brute_cdist(x, y, p=2)
+            self.assertFalse(x.is_contiguous())
+            self.assertFalse(y.is_contiguous())
+            self.assertTrue(torch.allclose(expected, actual))
+
+            x = torch.randn(7, 5, device=device)
+            y = torch.randn(5, 3, device=device).t()
+            actual = torch.cdist(x, y, p=2)
+            expected = brute_cdist(x, y, p=2)
+            self.assertTrue(x.is_contiguous())
+            self.assertFalse(y.is_contiguous())
+            self.assertTrue(torch.allclose(expected, actual))
+
+            x = torch.randn(5, 7, device=device).t()
+            y = torch.randn(3, 5, device=device)
+            actual = torch.cdist(x, y, p=2)
+            expected = brute_cdist(x, y, p=2)
+            self.assertFalse(x.is_contiguous())
+            self.assertTrue(y.is_contiguous())
+            self.assertTrue(torch.allclose(expected, actual))
+
     @unittest.skipIf(not TEST_SCIPY, "Scipy not found")
     def test_logsumexp(self):
         from scipy.special import logsumexp