self.checkGraphModule(m, (input_dict,))
+ def test_matmul_tracing(self):
+ const = torch.randn(3)
+
+ def matmul_f(x):
+ return x @ const
+
+ mod = symbolic_trace(matmul_f)
+ inp = torch.randn(3)
+ self.assertEqual(mod(inp), matmul_f(inp))
+
+ def rmatmul_f(x):
+ return const @ x
+
+ mod = symbolic_trace(rmatmul_f)
+ inp = torch.randn(3)
+ self.assertEqual(mod(inp), rmatmul_f(inp))
+
+
def test_disallow_override(self):
# Custom delegate to disallow in-place tensor operations
class NoMutableCallTracer(Tracer):