Add OpInfo for `nn.functional.cosine_similarity` (#62959)
authorKushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com>
Tue, 31 Aug 2021 16:45:09 +0000 (09:45 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 17:31:36 +0000 (10:31 -0700)
Summary:
Please see https://github.com/facebookresearch/functorch/issues/78 and https://github.com/pytorch/pytorch/issues/54261.

Notes:

* Some redundant tests from `test_nn.py` have been removed. I'm unsure about precision checks if they can be removed as well.
* Broadcasting is also checked in the OpInfo for `cosine_similarity`.

cc: mruberry zou3519 Chillee

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62959

Reviewed By: heitorschueroff

Differential Revision: D30520176

Pulled By: zou3519

fbshipit-source-id: 14e902eb4bcce875edab28a1669a2ea021052b9b

test/test_nn.py
torch/testing/_internal/common_methods_invocations.py

index 96321ba..5008c72 100644 (file)
@@ -9617,25 +9617,6 @@ class TestNN(NNTestCase):
         test_huber_loss_zero_delta()
 
     def test_cosine_similarity(self):
-        input1 = torch.randn(4, 4, requires_grad=True)
-        input2 = torch.randn(4, 4, requires_grad=True)
-        self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y), (input1, input2)))
-
-        input1 = torch.randn(4, 5, 6, requires_grad=True)
-        input2 = torch.randn(4, 5, 6, requires_grad=True)
-        self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=0), (input1, input2)))
-        self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=-1), (input1, input2)))
-
-        input1 = torch.randn((), requires_grad=True)
-        input2 = torch.randn((), requires_grad=True)
-        self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=0), (input1, input2)))
-        self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=-1), (input1, input2)))
-
-        # Check broadcasting
-        input1 = torch.randn(2, 1, 3, requires_grad=True)
-        input2 = torch.randn(1, 2, 3, requires_grad=True)
-        self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=-1), (input1, input2)))
-
         # Check cosine_similarity input/output shapes
         input_size = (1, 3, 2, 1)
         expected_size = (1, 2, 1)
@@ -9662,7 +9643,6 @@ class TestNN(NNTestCase):
         with self.assertRaises(RuntimeError):
             F.cosine_similarity(input1, input2)
 
-
         # Check type promotion, issue #61454
         input = torch.tensor(12.)
         out = F.cosine_similarity(input.to(torch.int8), input, dim=-1)
index e7d9380..3579310 100644 (file)
@@ -1249,6 +1249,26 @@ def sample_inputs_linalg_norm(op_info, device, dtype, requires_grad):
                             dim=(0, 1))))
         return inputs
 
+def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as input_shape, dict of dim and eps
+    cases: Tuple[tuple, dict] = (  # type: ignore[assignment]
+        ((S, S), {'dim': 1}),
+        ((S, 2), {'dim': -1}),
+        ((S,), {'dim': 0, 'eps': 0.5}),
+        ((), {'dim': 0}),
+        ((S, S, M), {'dim': 2}),
+        ((S, S), {})
+    )
+
+    def generator():
+        for input_shape, kwargs in cases:
+            yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs)
+        # Test for Broadcasting
+        yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
+
+    return list(generator())
 
 def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs):
     make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -7175,6 +7195,13 @@ op_db: List[OpInfo] = [
                # FIXME: aminmax does not check for safe casting to output
                SkipInfo('TestCommon', 'test_out'),
            )),
+    OpInfo('nn.functional.cosine_similarity',
+           aten_name="cosine_similarity",
+           dtypes=floating_types_and(torch.bfloat16),
+           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           sample_inputs_func=sample_inputs_cosine_similarity),
     OpInfo('nn.functional.adaptive_avg_pool2d',
            dtypes=floating_types(),
            dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),