From ba6c49cb9c516bb8c18984113f8c6c5f4c4556f6 Mon Sep 17 00:00:00 2001 From: zrphercule Date: Tue, 27 Nov 2018 13:49:21 -0800 Subject: [PATCH] Add test of ONNX_ATEN (#14259) Summary: In #14239 we fixed ONNX_ATEN. In order to make sure its correctness in the future, we should add related test case. We use torch.fmod() to test ONNX_ATEN. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14259 Differential Revision: D13204610 Pulled By: zrphercule fbshipit-source-id: e4660c346e5edd201f1458b7d74d7dfac49b94c7 --- .../expect/TestPytorchExportModes.test_onnx_aten.expect | 16 ++++++++++++++++ test/test_jit.py | 17 +++++++++++++++++ 2 files changed, 33 insertions(+) create mode 100644 test/expect/TestPytorchExportModes.test_onnx_aten.expect diff --git a/test/expect/TestPytorchExportModes.test_onnx_aten.expect b/test/expect/TestPytorchExportModes.test_onnx_aten.expect new file mode 100644 index 0000000..222fa42 --- /dev/null +++ b/test/expect/TestPytorchExportModes.test_onnx_aten.expect @@ -0,0 +1,16 @@ +ModelProto { + producer_name: "pytorch" + domain: "" + doc_string: "" + graph: + GraphProto { + name: "torch-jit-export" + inputs: [{name: "0", type:Tensor dims: 3 4},{name: "1", type:Tensor dims: 3 4}] + outputs: [{name: "2", type:Tensor dims: 3 4}] + initializers: [] + nodes: [ + Node {type: "ATen", inputs: [0,1], outputs: [2], attributes: [{ name: 'operator', type: string, value: 'fmod'}]} + ] + } + opset_import: [OperatorSetIdProto { domain: }], +} diff --git a/test/test_jit.py b/test/test_jit.py index d5263de..d33fa11 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9280,6 +9280,23 @@ class TestPytorchExportModes(JitTestCase): operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK) self.assertExpected(exported) + # torch.fmod is using to test ONNX_ATEN. + # If you plan to remove fmod from aten, or found this test failed. + # please contact @Rui. + @skipIfRocm + def test_onnx_aten(self): + class ModelWithAtenFmod(nn.Module): + def forward(self, x, y): + return torch.fmod(x, y) + + f = io.BytesIO() + x = torch.randn(3, 4, dtype=torch.float32) + y = torch.randn(3, 4, dtype=torch.float32) + exported = torch.onnx.export_to_pretty_string( + ModelWithAtenFmod(), (x, y), f, + operator_export_type=OperatorExportTypes.ONNX_ATEN) + self.assertExpected(exported) + # known to be failing in tracer EXCLUDE_TRACED = { -- 2.7.4