Add `__matmul__` to the magic methods for FX tracing (#64512)
authorHorace He <chilli@fb.com>
Wed, 8 Sep 2021 16:59:04 +0000 (09:59 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 17:03:48 +0000 (10:03 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/64483

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64512

Reviewed By: mrshenli

Differential Revision: D30797265

Pulled By: Chillee

fbshipit-source-id: 7630e048a960e0b27c4309d04d85301abe325189

test/test_fx.py
torch/fx/graph.py

index 2f8f82b..c509c3b 100644 (file)
@@ -260,6 +260,24 @@ class TestFX(JitTestCase):
 
         self.checkGraphModule(m, (input_dict,))
 
+    def test_matmul_tracing(self):
+        const = torch.randn(3)
+
+        def matmul_f(x):
+            return x @ const
+
+        mod = symbolic_trace(matmul_f)
+        inp = torch.randn(3)
+        self.assertEqual(mod(inp), matmul_f(inp))
+
+        def rmatmul_f(x):
+            return const @ x
+
+        mod = symbolic_trace(rmatmul_f)
+        inp = torch.randn(3)
+        self.assertEqual(mod(inp), rmatmul_f(inp))
+
+
     def test_disallow_override(self):
         # Custom delegate to disallow in-place tensor operations
         class NoMutableCallTracer(Tracer):
index 65e93d0..48441a1 100644 (file)
@@ -1187,7 +1187,8 @@ reflectable_magic_methods = {
     'and': '{} & {}',
     'or': '{} | {}',
     'xor': '{} ^ {}',
-    'getitem': '{}[{}]'
+    'getitem': '{}[{}]',
+    'matmul': '{} @ {}',
 }
 
 magic_methods = dict({