[ONNX] Suppport torch.dot and torch.nn.utils.spectral_norm (#62596) (#62765)
authorBowenBao <bowbao@microsoft.com>
Fri, 20 Aug 2021 19:44:29 +0000 (12:44 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 20 Aug 2021 19:46:56 +0000 (12:46 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62765

Fixes #27723

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Differential Revision: D30375181

Pulled By: msaroufim

fbshipit-source-id: 715f4745899757ec405877980cd20c826028eb2c

Co-authored-by: BowenBao <bowbao@microsoft.com>
test/onnx/test_pytorch_onnx_onnxruntime.py
torch/onnx/symbolic_opset9.py

index fd10629..865b365 100644 (file)
@@ -5722,6 +5722,27 @@ class TestONNXRuntime(unittest.TestCase):
         y = torch.randint(10, (5, ))
         self.run_test(MatmulModel(), (x, y))
 
+    @skipIfUnsupportedMinOpsetVersion(9)  # MatMul long inputs is added in ONNX opset 9.
+    def test_dot(self):
+        class MatmulModel(torch.nn.Module):
+            def forward(self, input, other):
+                return torch.dot(input, other)
+
+        x = torch.randn(5, requires_grad=True)
+        y = torch.randn(5, requires_grad=True)
+        self.run_test(MatmulModel(), (x, y))
+
+        x = torch.randint(10, (5, ))
+        y = torch.randint(10, (5, ))
+        self.run_test(MatmulModel(), (x, y))
+
+    @disableScriptTest()  # SpectralNorm not TorchScript compatible.
+    def test_spectral_norm(self):
+        m = torch.nn.utils.spectral_norm(torch.nn.Linear(2, 4))
+
+        x = torch.randn(6, 2)
+        self.run_test(m, (x, ))
+
     def test_prelu(self):
         class PReluModel(torch.nn.Module):
             def __init__(self):
index ce59e15..70bb828 100644 (file)
@@ -3138,6 +3138,10 @@ def mv(g, self, vec):
     return matmul(g, self, vec)
 
 
+def dot(g, self, other):
+    return matmul(g, self, other)
+
+
 @parse_args('v', 'v')
 def fill(g, self, value):
     dtype = self.type().scalarType()