From 654e59fcac4a9d4bf0b48306e1d7f7be5b7e40b1 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 26 Mar 2019 13:00:29 -0700 Subject: [PATCH] Minor fix for onnx ConstantOfShape export (#18199) Summary: Set value as tensor of 1 element instead of scalar, according to ONNX spec. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18199 Reviewed By: dzhulgakov Differential Revision: D14542588 Pulled By: houseroad fbshipit-source-id: 70dc978d870ebe6ef37c519ba4a20061c3f07372 --- test/onnx/expect/TestOperators.test_full.expect | 1 + test/onnx/expect/TestOperators.test_full_like.expect | 1 + test/onnx/expect/TestOperators.test_ones_like.expect | 1 + test/onnx/expect/TestOperators.test_zeros_like.expect | 1 + torch/onnx/symbolic.py | 12 ++++++------ 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/test/onnx/expect/TestOperators.test_full.expect b/test/onnx/expect/TestOperators.test_full.expect index 5737a1c..c19f6c1 100644 --- a/test/onnx/expect/TestOperators.test_full.expect +++ b/test/onnx/expect/TestOperators.test_full.expect @@ -96,6 +96,7 @@ graph { attribute { name: "value" t { + dims: 1 data_type: 1 raw_data: "\000\000\000@" } diff --git a/test/onnx/expect/TestOperators.test_full_like.expect b/test/onnx/expect/TestOperators.test_full_like.expect index f6edcde..fda196e 100644 --- a/test/onnx/expect/TestOperators.test_full_like.expect +++ b/test/onnx/expect/TestOperators.test_full_like.expect @@ -14,6 +14,7 @@ graph { attribute { name: "value" t { + dims: 1 data_type: 1 raw_data: "\000\000\000@" } diff --git a/test/onnx/expect/TestOperators.test_ones_like.expect b/test/onnx/expect/TestOperators.test_ones_like.expect index 70394a8..715ca05 100644 --- a/test/onnx/expect/TestOperators.test_ones_like.expect +++ b/test/onnx/expect/TestOperators.test_ones_like.expect @@ -14,6 +14,7 @@ graph { attribute { name: "value" t { + dims: 1 data_type: 1 raw_data: "\000\000\200?" } diff --git a/test/onnx/expect/TestOperators.test_zeros_like.expect b/test/onnx/expect/TestOperators.test_zeros_like.expect index 9ab783a..5a0dbc3 100644 --- a/test/onnx/expect/TestOperators.test_zeros_like.expect +++ b/test/onnx/expect/TestOperators.test_zeros_like.expect @@ -14,6 +14,7 @@ graph { attribute { name: "value" t { + dims: 1 data_type: 1 raw_data: "\000\000\000\000" } diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 878f8a3..288f362 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -1281,27 +1281,27 @@ scalar_type_to_onnx = [ def zeros(g, sizes, dtype, layout, device): # NOTE: no way to set device and layout in ONNX, so we ignore it return g.op("ConstantOfShape", sizes, - value_t=torch.tensor(0, dtype=scalar_type_to_pytorch_type[dtype])) + value_t=torch.tensor([0], dtype=scalar_type_to_pytorch_type[dtype])) @parse_args('v', 'i', 'v', 'v') def zeros_like(g, input, dtype, layout, device): shape = g.op("Shape", input) return g.op("ConstantOfShape", shape, - value_t=torch.tensor(0, dtype=scalar_type_to_pytorch_type[dtype])) + value_t=torch.tensor([0], dtype=scalar_type_to_pytorch_type[dtype])) @parse_args('v', 'i', 'v', 'v') def ones(g, sizes, dtype, layout, device): return g.op("ConstantOfShape", sizes, - value_t=torch.tensor(1, dtype=scalar_type_to_pytorch_type[dtype])) + value_t=torch.tensor([1], dtype=scalar_type_to_pytorch_type[dtype])) @parse_args('v', 'i', 'v', 'v') def ones_like(g, input, dtype, layout, device): shape = g.op("Shape", input) return g.op("ConstantOfShape", shape, - value_t=torch.tensor(1, dtype=scalar_type_to_pytorch_type[dtype])) + value_t=torch.tensor([1], dtype=scalar_type_to_pytorch_type[dtype])) def full(g, sizes, value, dtype, layout, device): @@ -1312,14 +1312,14 @@ def full(g, sizes, value, dtype, layout, device): else: dtype = _get_const(dtype, 'i', 'dtype') return g.op("ConstantOfShape", sizes, - value_t=torch.tensor(const_value, dtype=scalar_type_to_pytorch_type[dtype])) + value_t=torch.tensor([const_value], dtype=scalar_type_to_pytorch_type[dtype])) @parse_args('v', 'f', 'i', 'v', 'v') def full_like(g, input, fill_value, dtype, layout, device): shape = g.op("Shape", input) return g.op("ConstantOfShape", shape, - value_t=torch.tensor(fill_value, dtype=scalar_type_to_pytorch_type[dtype])) + value_t=torch.tensor([fill_value], dtype=scalar_type_to_pytorch_type[dtype])) @parse_args('v', 'v', 'v', 'v', 'i') -- 2.7.4