model(input)
+class TestDisabledTorchFunction(TestCase):
+ # Regression test for gh-64687
+ def test_parameter_does_not_prevent_dispatch(self):
+ class MyTensor():
+ def __torch_function__(self, func, types, args=(), kwargs=None):
+ return "called"
+
+ t1 = MyTensor()
+ t2 = torch.nn.Parameter(torch.rand(2, 2))
+ self.assertEqual(torch.add(t2, t1), "called")
+
+ inp = torch.rand(10, 10)
+ self.assertEqual(torch.nn.functional.linear(inp, t1, t2), "called")
+ self.assertEqual(torch.nn.functional.linear(inp, t2, t1), "called")
+
+
if __name__ == '__main__':
run_tests()
# We only collect arguments if they have a unique type, which ensures
# reasonable performance even with a long list of possibly overloaded
# arguments.
- if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__')):
+ #
+ # NB: Important to exclude _disabled_torch_function_impl, otherwise
+ # https://github.com/pytorch/pytorch/issues/64687
+ if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__') and
+ arg_type.__torch_function__ != torch._C._disabled_torch_function_impl):
# Create lists explicitly for the first type (usually the only one
# done) to avoid setting up the iterator for overloaded_args.
if overloaded_types: