From ca8dd296ee42fd68b8c9360d10916e02e009eeff Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Tue, 31 Aug 2021 09:45:09 -0700 Subject: [PATCH] Add OpInfo for `nn.functional.cosine_similarity` (#62959) Summary: Please see https://github.com/facebookresearch/functorch/issues/78 and https://github.com/pytorch/pytorch/issues/54261. Notes: * Some redundant tests from `test_nn.py` have been removed. I'm unsure about precision checks if they can be removed as well. * Broadcasting is also checked in the OpInfo for `cosine_similarity`. cc: mruberry zou3519 Chillee Pull Request resolved: https://github.com/pytorch/pytorch/pull/62959 Reviewed By: heitorschueroff Differential Revision: D30520176 Pulled By: zou3519 fbshipit-source-id: 14e902eb4bcce875edab28a1669a2ea021052b9b --- test/test_nn.py | 20 ---------------- .../_internal/common_methods_invocations.py | 27 ++++++++++++++++++++++ 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 96321ba..5008c72 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -9617,25 +9617,6 @@ class TestNN(NNTestCase): test_huber_loss_zero_delta() def test_cosine_similarity(self): - input1 = torch.randn(4, 4, requires_grad=True) - input2 = torch.randn(4, 4, requires_grad=True) - self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y), (input1, input2))) - - input1 = torch.randn(4, 5, 6, requires_grad=True) - input2 = torch.randn(4, 5, 6, requires_grad=True) - self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=0), (input1, input2))) - self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=-1), (input1, input2))) - - input1 = torch.randn((), requires_grad=True) - input2 = torch.randn((), requires_grad=True) - self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=0), (input1, input2))) - self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=-1), (input1, input2))) - - # Check broadcasting - input1 = torch.randn(2, 1, 3, requires_grad=True) - input2 = torch.randn(1, 2, 3, requires_grad=True) - self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=-1), (input1, input2))) - # Check cosine_similarity input/output shapes input_size = (1, 3, 2, 1) expected_size = (1, 2, 1) @@ -9662,7 +9643,6 @@ class TestNN(NNTestCase): with self.assertRaises(RuntimeError): F.cosine_similarity(input1, input2) - # Check type promotion, issue #61454 input = torch.tensor(12.) out = F.cosine_similarity(input.to(torch.int8), input, dim=-1) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index e7d9380..3579310 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1249,6 +1249,26 @@ def sample_inputs_linalg_norm(op_info, device, dtype, requires_grad): dim=(0, 1)))) return inputs +def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input_shape, dict of dim and eps + cases: Tuple[tuple, dict] = ( # type: ignore[assignment] + ((S, S), {'dim': 1}), + ((S, 2), {'dim': -1}), + ((S,), {'dim': 0, 'eps': 0.5}), + ((), {'dim': 0}), + ((S, S, M), {'dim': 2}), + ((S, S), {}) + ) + + def generator(): + for input_shape, kwargs in cases: + yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs) + # Test for Broadcasting + yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1}) + + return list(generator()) def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -7175,6 +7195,13 @@ op_db: List[OpInfo] = [ # FIXME: aminmax does not check for safe casting to output SkipInfo('TestCommon', 'test_out'), )), + OpInfo('nn.functional.cosine_similarity', + aten_name="cosine_similarity", + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + sample_inputs_func=sample_inputs_cosine_similarity), OpInfo('nn.functional.adaptive_avg_pool2d', dtypes=floating_types(), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), -- 2.7.4