Fix cosine similarity dim checks (#66214)
authorNatalia Gimelshein <ngimel@fb.com>
Fri, 8 Oct 2021 14:22:40 +0000 (07:22 -0700)
committerGitHub <noreply@github.com>
Fri, 8 Oct 2021 14:22:40 +0000 (07:22 -0700)
* fix cosine similarity dimensionality check

* fix shapes in the doc

aten/src/ATen/native/Distance.cpp
test/test_nn.py
torch/nn/functional.py
torch/testing/_internal/common_methods_invocations.py

index 7974840..9105c83 100644 (file)
@@ -240,14 +240,11 @@ Tensor _pdist_backward(const Tensor& grad, const Tensor& self, const double p, c
 }
 
 Tensor cosine_similarity(const Tensor& x1, const Tensor& x2, int64_t dim, double eps) {
-  TORCH_CHECK(x1.ndimension() == x2.ndimension(), "cosine_similarity requires both inputs to have the same number of dimensions, but x1 has ",
-              x1.ndimension(), " and x2 has ", x2.ndimension());
-  TORCH_CHECK(x1.ndimension() == 0 || x1.size(dim) == x2.size(dim), "cosine_similarity requires both inputs to have the same size at dimension ", dim, "but x1 has ",
-  x1.size(dim), " and x2 has ", x2.size(dim));
+  auto common_size = at::infer_size_dimvector(x1.sizes(), x2.sizes());
   auto commonDtype = at::result_type(x1, x2);
   TORCH_CHECK(at::isFloatingType(commonDtype), "expected common dtype to be floating point, yet common dtype is ", commonDtype);
-  Tensor x1_ = x1.to(commonDtype);
-  Tensor x2_ = x2.to(commonDtype);
+  Tensor x1_ = x1.to(commonDtype).expand(common_size);
+  Tensor x2_ = x2.to(commonDtype).expand(common_size);
   // Follow scipy impl to improve numerical precision
   // Use x / sqrt(x * x) instead of x / (sqrt(x) * sqrt(x))
   Tensor w12 = at::sum(x1_ * x2_, dim);
index 92357d9..b6dd466 100644 (file)
@@ -9704,12 +9704,6 @@ class TestNN(NNTestCase):
         self.assertEqual(input1.grad, torch.zeros_like(input1))
         self.assertEqual(input2.grad, input1 * 1e8)
 
-        # Check error when inputs are not the same shape
-        input1 = torch.randn(2, 2, 1)
-        input2 = torch.randn(2, 1, 3)
-        with self.assertRaises(RuntimeError):
-            F.cosine_similarity(input1, input2)
-
         # Check type promotion, issue #61454
         input = torch.tensor(12.)
         out = F.cosine_similarity(input.to(torch.int8), input, dim=-1)
index 4b0449c..5f14f31 100644 (file)
@@ -4256,7 +4256,10 @@ cosine_similarity = _add_docstr(
     r"""
 cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor
 
-Returns cosine similarity between x1 and x2, computed along dim.
+Returns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable
+to a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is
+squeezed (see :func:`torch.squeeze`), resulting in the
+output tensor having 1 fewer dimension.
 
 .. math ::
     \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}
@@ -4265,16 +4268,11 @@ Supports :ref:`type promotion <type-promotion-doc>`.
 
 Args:
     x1 (Tensor): First input.
-    x2 (Tensor): Second input (with the same number of dimensions as x1, matching x1 size at dimension `dim`,
-        and broadcastable with x1 at other dimensions).
-    dim (int, optional): Dimension of vectors. Default: 1
+    x2 (Tensor): Second input.
+    dim (int, optional): Dimension along which cosine similarity is computed. Default: 1
     eps (float, optional): Small value to avoid division by zero.
         Default: 1e-8
 
-Shape:
-    - Input: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`.
-    - Output: :math:`(\ast_1, \ast_2)`
-
 Example::
 
     >>> input1 = torch.randn(100, 128)
index d817c9e..41abeb7 100644 (file)
@@ -1256,6 +1256,8 @@ def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwa
             yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs)
         # Test for Broadcasting
         yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
+        yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2})
+        yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
 
     return list(generator())