From 3d44eeec0a31334e7642b4251da151733ec9fb69 Mon Sep 17 00:00:00 2001 From: zrphercule Date: Fri, 4 Jan 2019 16:11:23 -0800 Subject: [PATCH] Fix different types in rsub caused bug (#15707) Summary: Before this pr, rsub did not convert two elements into the same dtype, therefore "1 - x" may export to an onnx model that two elements of rsub having different dtype. By adding this symbolic patch this bug should be fixed. Related test cases also created. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15707 Differential Revision: D13583042 Pulled By: zrphercule fbshipit-source-id: 3a2de47a1a8d1ded1a0adfb911adbe6ac729cdef --- test/onnx/expect/TestOperators.test_rsub.expect | 4 ++-- test/onnx/test_pytorch_onnx_caffe2.py | 8 ++++++++ torch/onnx/symbolic.py | 2 ++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/test/onnx/expect/TestOperators.test_rsub.expect b/test/onnx/expect/TestOperators.test_rsub.expect index dca90da..504f5d7 100644 --- a/test/onnx/expect/TestOperators.test_rsub.expect +++ b/test/onnx/expect/TestOperators.test_rsub.expect @@ -8,8 +8,8 @@ graph { attribute { name: "value" t { - data_type: 7 - raw_data: "\001\000\000\000\000\000\000\000" + data_type: 11 + raw_data: "\000\000\000\000\000\000\360?" } type: TENSOR } diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index cc1950e..2ea4352 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -1001,6 +1001,14 @@ class TestCaffe2Backend(unittest.TestCase): model = nn.GroupNorm(3, 6) self.run_model_test(model, train=True, input=c, batch_size=BATCH_SIZE) + def test_rsub(self): + class RsubModel(torch.nn.Module): + def forward(self, x): + return 1 - x + + x = torch.randn(1, 2) + self.run_model_test(RsubModel(), train=False, input=(x,), + batch_size=BATCH_SIZE, use_gpu=False) # a bit of metaprogramming to set up all the rnn tests diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index dca3b20..6bae829 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -229,6 +229,8 @@ def sub(g, self, other, alpha=None): def rsub(g, self, other, alpha=None): + other = _maybe_get_scalar(other) + other = _if_scalar_type_as(g, other, self) return sub(g, other, self, alpha=alpha) -- 2.7.4