Filter out _disabled_torch_function_impl from handle_torch_function (#64689)
authorEdward Yang <ezyang@fb.com>
Thu, 9 Sep 2021 14:17:26 +0000 (07:17 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 14:29:09 +0000 (07:29 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64689

This brings it in line with the C++ implementation.

Fixes https://github.com/pytorch/pytorch/issues/64687

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D30816215

Pulled By: ezyang

fbshipit-source-id: ed36af6c35467ae678d9548197efd97c36d38dec

test/test_overrides.py
torch/overrides.py

index a625237..4fc1477 100644 (file)
@@ -1001,5 +1001,21 @@ class TestRNN(TestCase):
         model(input)
 
 
+class TestDisabledTorchFunction(TestCase):
+    # Regression test for gh-64687
+    def test_parameter_does_not_prevent_dispatch(self):
+        class MyTensor():
+            def __torch_function__(self, func, types, args=(), kwargs=None):
+                return "called"
+
+        t1 = MyTensor()
+        t2 = torch.nn.Parameter(torch.rand(2, 2))
+        self.assertEqual(torch.add(t2, t1), "called")
+
+        inp = torch.rand(10, 10)
+        self.assertEqual(torch.nn.functional.linear(inp, t1, t2), "called")
+        self.assertEqual(torch.nn.functional.linear(inp, t2, t1), "called")
+
+
 if __name__ == '__main__':
     run_tests()
index 3fd87b8..c574109 100644 (file)
@@ -1270,7 +1270,11 @@ def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]:
         # We only collect arguments if they have a unique type, which ensures
         # reasonable performance even with a long list of possibly overloaded
         # arguments.
-        if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__')):
+        #
+        # NB: Important to exclude _disabled_torch_function_impl, otherwise
+        # https://github.com/pytorch/pytorch/issues/64687
+        if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__') and
+                arg_type.__torch_function__ != torch._C._disabled_torch_function_impl):
             # Create lists explicitly for the first type (usually the only one
             # done) to avoid setting up the iterator for overloaded_args.
             if overloaded_types: