From: Horace He Date: Wed, 8 Sep 2021 16:59:04 +0000 (-0700) Subject: Add `__matmul__` to the magic methods for FX tracing (#64512) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~375 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=35413a16f7433b345ba1dc9c460cbe9c2d975762;p=platform%2Fupstream%2Fpytorch.git Add `__matmul__` to the magic methods for FX tracing (#64512) Summary: Fixes https://github.com/pytorch/pytorch/issues/64483 Pull Request resolved: https://github.com/pytorch/pytorch/pull/64512 Reviewed By: mrshenli Differential Revision: D30797265 Pulled By: Chillee fbshipit-source-id: 7630e048a960e0b27c4309d04d85301abe325189 --- diff --git a/test/test_fx.py b/test/test_fx.py index 2f8f82b..c509c3b 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -260,6 +260,24 @@ class TestFX(JitTestCase): self.checkGraphModule(m, (input_dict,)) + def test_matmul_tracing(self): + const = torch.randn(3) + + def matmul_f(x): + return x @ const + + mod = symbolic_trace(matmul_f) + inp = torch.randn(3) + self.assertEqual(mod(inp), matmul_f(inp)) + + def rmatmul_f(x): + return const @ x + + mod = symbolic_trace(rmatmul_f) + inp = torch.randn(3) + self.assertEqual(mod(inp), rmatmul_f(inp)) + + def test_disallow_override(self): # Custom delegate to disallow in-place tensor operations class NoMutableCallTracer(Tracer): diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 65e93d0..48441a1 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -1187,7 +1187,8 @@ reflectable_magic_methods = { 'and': '{} & {}', 'or': '{} | {}', 'xor': '{} ^ {}', - 'getitem': '{}[{}]' + 'getitem': '{}[{}]', + 'matmul': '{} @ {}', } magic_methods = dict({