From 1dd648f1c40c24a3d5a151581a8129652191fa86 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Fri, 20 Aug 2021 12:44:29 -0700 Subject: [PATCH] [ONNX] Suppport torch.dot and torch.nn.utils.spectral_norm (#62596) (#62765) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62765 Fixes #27723 Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D30375181 Pulled By: msaroufim fbshipit-source-id: 715f4745899757ec405877980cd20c826028eb2c Co-authored-by: BowenBao --- test/onnx/test_pytorch_onnx_onnxruntime.py | 21 +++++++++++++++++++++ torch/onnx/symbolic_opset9.py | 4 ++++ 2 files changed, 25 insertions(+) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index fd10629..865b365 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -5722,6 +5722,27 @@ class TestONNXRuntime(unittest.TestCase): y = torch.randint(10, (5, )) self.run_test(MatmulModel(), (x, y)) + @skipIfUnsupportedMinOpsetVersion(9) # MatMul long inputs is added in ONNX opset 9. + def test_dot(self): + class MatmulModel(torch.nn.Module): + def forward(self, input, other): + return torch.dot(input, other) + + x = torch.randn(5, requires_grad=True) + y = torch.randn(5, requires_grad=True) + self.run_test(MatmulModel(), (x, y)) + + x = torch.randint(10, (5, )) + y = torch.randint(10, (5, )) + self.run_test(MatmulModel(), (x, y)) + + @disableScriptTest() # SpectralNorm not TorchScript compatible. + def test_spectral_norm(self): + m = torch.nn.utils.spectral_norm(torch.nn.Linear(2, 4)) + + x = torch.randn(6, 2) + self.run_test(m, (x, )) + def test_prelu(self): class PReluModel(torch.nn.Module): def __init__(self): diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index ce59e15..70bb828 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -3138,6 +3138,10 @@ def mv(g, self, vec): return matmul(g, self, vec) +def dot(g, self, other): + return matmul(g, self, other) + + @parse_args('v', 'v') def fill(g, self, value): dtype = self.type().scalarType() -- 2.7.4