Minor fix for onnx ConstantOfShape export (#18199)
authorBowenBao <semisqg@gmail.com>
Tue, 26 Mar 2019 20:00:29 +0000 (13:00 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 26 Mar 2019 20:23:16 +0000 (13:23 -0700)
Summary:
Set value as tensor of 1 element instead of scalar, according to ONNX spec.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18199

Reviewed By: dzhulgakov

Differential Revision: D14542588

Pulled By: houseroad

fbshipit-source-id: 70dc978d870ebe6ef37c519ba4a20061c3f07372

test/onnx/expect/TestOperators.test_full.expect
test/onnx/expect/TestOperators.test_full_like.expect
test/onnx/expect/TestOperators.test_ones_like.expect
test/onnx/expect/TestOperators.test_zeros_like.expect
torch/onnx/symbolic.py

index 5737a1c..c19f6c1 100644 (file)
@@ -96,6 +96,7 @@ graph {
     attribute {
       name: "value"
       t {
+        dims: 1
         data_type: 1
         raw_data: "\000\000\000@"
       }
index f6edcde..fda196e 100644 (file)
@@ -14,6 +14,7 @@ graph {
     attribute {
       name: "value"
       t {
+        dims: 1
         data_type: 1
         raw_data: "\000\000\000@"
       }
index 70394a8..715ca05 100644 (file)
@@ -14,6 +14,7 @@ graph {
     attribute {
       name: "value"
       t {
+        dims: 1
         data_type: 1
         raw_data: "\000\000\200?"
       }
index 9ab783a..5a0dbc3 100644 (file)
@@ -14,6 +14,7 @@ graph {
     attribute {
       name: "value"
       t {
+        dims: 1
         data_type: 1
         raw_data: "\000\000\000\000"
       }
index 878f8a3..288f362 100644 (file)
@@ -1281,27 +1281,27 @@ scalar_type_to_onnx = [
 def zeros(g, sizes, dtype, layout, device):
     # NOTE: no way to set device and layout in ONNX, so we ignore it
     return g.op("ConstantOfShape", sizes,
-                value_t=torch.tensor(0, dtype=scalar_type_to_pytorch_type[dtype]))
+                value_t=torch.tensor([0], dtype=scalar_type_to_pytorch_type[dtype]))
 
 
 @parse_args('v', 'i', 'v', 'v')
 def zeros_like(g, input, dtype, layout, device):
     shape = g.op("Shape", input)
     return g.op("ConstantOfShape", shape,
-                value_t=torch.tensor(0, dtype=scalar_type_to_pytorch_type[dtype]))
+                value_t=torch.tensor([0], dtype=scalar_type_to_pytorch_type[dtype]))
 
 
 @parse_args('v', 'i', 'v', 'v')
 def ones(g, sizes, dtype, layout, device):
     return g.op("ConstantOfShape", sizes,
-                value_t=torch.tensor(1, dtype=scalar_type_to_pytorch_type[dtype]))
+                value_t=torch.tensor([1], dtype=scalar_type_to_pytorch_type[dtype]))
 
 
 @parse_args('v', 'i', 'v', 'v')
 def ones_like(g, input, dtype, layout, device):
     shape = g.op("Shape", input)
     return g.op("ConstantOfShape", shape,
-                value_t=torch.tensor(1, dtype=scalar_type_to_pytorch_type[dtype]))
+                value_t=torch.tensor([1], dtype=scalar_type_to_pytorch_type[dtype]))
 
 
 def full(g, sizes, value, dtype, layout, device):
@@ -1312,14 +1312,14 @@ def full(g, sizes, value, dtype, layout, device):
     else:
         dtype = _get_const(dtype, 'i', 'dtype')
         return g.op("ConstantOfShape", sizes,
-                    value_t=torch.tensor(const_value, dtype=scalar_type_to_pytorch_type[dtype]))
+                    value_t=torch.tensor([const_value], dtype=scalar_type_to_pytorch_type[dtype]))
 
 
 @parse_args('v', 'f', 'i', 'v', 'v')
 def full_like(g, input, fill_value, dtype, layout, device):
     shape = g.op("Shape", input)
     return g.op("ConstantOfShape", shape,
-                value_t=torch.tensor(fill_value, dtype=scalar_type_to_pytorch_type[dtype]))
+                value_t=torch.tensor([fill_value], dtype=scalar_type_to_pytorch_type[dtype]))
 
 
 @parse_args('v', 'v', 'v', 'v', 'i')