Using sqrt for better precision in cosine_similarity (#18250)
authorAiling Zhang <ailzhang@fb.com>
Fri, 22 Mar 2019 20:22:52 +0000 (13:22 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Mar 2019 20:33:30 +0000 (13:33 -0700)
Summary:
address comment in #18168 .
Testing in CI...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18250

Differential Revision: D14568601

Pulled By: ailzhang

fbshipit-source-id: 39fbbdb08743b53fa665c7e88e4750cbe0976ec7

aten/src/ATen/native/Distance.cpp
test/test_nn.py

index ad08939..d8420fc 100644 (file)
@@ -104,8 +104,8 @@ Tensor cosine_similarity(const Tensor& x1, const Tensor& x2, int64_t dim, double
   Tensor w12 = at::sum(x1 * x2, dim);
   Tensor w1 = at::sum(x1 * x1, dim);
   Tensor w2 = at::sum(x2 * x2, dim);
-  Tensor n12 = (w1 * w2).rsqrt_().clamp_max(1.0 / eps);
-  return w12.mul_(n12);
+  Tensor n12 = (w1 * w2).clamp_min_(eps * eps).sqrt_();
+  return w12.div_(n12);
 }
 
 }}  // namespace at::native
index 5fcbf3d..d8651ab 100644 (file)
@@ -5971,6 +5971,13 @@ class TestNN(NNTestCase):
         out = F.cosine_similarity(vv1, vv2)
         self.assertLessEqual(out, 1.0)
 
+        # Check dividing by 0.
+        input1 = torch.randn(10).requires_grad_()
+        input2 = torch.zeros_like(input1).requires_grad_()
+        torch.cosine_similarity(input1, input2, 0).sum().backward()
+        self.assertEqual(input1.grad, torch.zeros_like(input1))
+        self.assertEqual(input2.grad, input1 * 1e8)
+
     def test_grid_sample_error_checking(self):
         input = torch.empty(1, 1, 2, 2)
         grid = torch.empty(1, 1, 1, 2)