From d4b1016850170deba98c026d1c8494de821b91c7 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 9 Sep 2021 07:17:26 -0700 Subject: [PATCH] Filter out _disabled_torch_function_impl from handle_torch_function (#64689) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64689 This brings it in line with the C++ implementation. Fixes https://github.com/pytorch/pytorch/issues/64687 Signed-off-by: Edward Z. Yang Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D30816215 Pulled By: ezyang fbshipit-source-id: ed36af6c35467ae678d9548197efd97c36d38dec --- test/test_overrides.py | 16 ++++++++++++++++ torch/overrides.py | 6 +++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/test/test_overrides.py b/test/test_overrides.py index a625237..4fc1477 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -1001,5 +1001,21 @@ class TestRNN(TestCase): model(input) +class TestDisabledTorchFunction(TestCase): + # Regression test for gh-64687 + def test_parameter_does_not_prevent_dispatch(self): + class MyTensor(): + def __torch_function__(self, func, types, args=(), kwargs=None): + return "called" + + t1 = MyTensor() + t2 = torch.nn.Parameter(torch.rand(2, 2)) + self.assertEqual(torch.add(t2, t1), "called") + + inp = torch.rand(10, 10) + self.assertEqual(torch.nn.functional.linear(inp, t1, t2), "called") + self.assertEqual(torch.nn.functional.linear(inp, t2, t1), "called") + + if __name__ == '__main__': run_tests() diff --git a/torch/overrides.py b/torch/overrides.py index 3fd87b8..c574109 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1270,7 +1270,11 @@ def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]: # We only collect arguments if they have a unique type, which ensures # reasonable performance even with a long list of possibly overloaded # arguments. - if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__')): + # + # NB: Important to exclude _disabled_torch_function_impl, otherwise + # https://github.com/pytorch/pytorch/issues/64687 + if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__') and + arg_type.__torch_function__ != torch._C._disabled_torch_function_impl): # Create lists explicitly for the first type (usually the only one # done) to avoid setting up the iterator for overloaded_args. if overloaded_types: -- 2.7.4