Export ones_like, zeros_like and full_like using ONNX ConstantLike op. (#14903)
authorSpandan Tiwari <sptiwari@microsoft.com>
Sun, 9 Dec 2018 06:46:03 +0000 (22:46 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sun, 9 Dec 2018 06:49:02 +0000 (22:49 -0800)
Summary:
This PR does the following:
1) Updates the ONNX export for `torch.zeros_like` and `torch.full_like` ops to use ONNX op `ConstantLike`. This reduces the export of experimental op `ConstantFill`, which may possibly be removed in future, see https://github.com/onnx/onnx/pull/1434).
2) It also adds export support for `torch.ones_like`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14903

Differential Revision: D13383700

Pulled By: houseroad

fbshipit-source-id: 566d00a943e9497172fcd5a034b638a650ab13a2

test/onnx/expect/TestOperators.test_full_like.expect [new file with mode: 0644]
test/onnx/expect/TestOperators.test_ones_like.expect [new file with mode: 0644]
test/onnx/expect/TestOperators.test_zeros_like.expect [new file with mode: 0644]
test/onnx/test_operators.py
torch/onnx/symbolic.py

diff --git a/test/onnx/expect/TestOperators.test_full_like.expect b/test/onnx/expect/TestOperators.test_full_like.expect
new file mode 100644 (file)
index 0000000..1932f54
--- /dev/null
@@ -0,0 +1,56 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+  node {
+    input: "0"
+    output: "1"
+    op_type: "ConstantLike"
+    attribute {
+      name: "dtype"
+      i: 1
+      type: INT
+    }
+    attribute {
+      name: "value"
+      f: 2
+      type: FLOAT
+    }
+  }
+  name: "torch-jit-export"
+  input {
+    name: "0"
+    type {
+      tensor_type {
+        elem_type: FLOAT
+        shape {
+          dim {
+            dim_value: 3
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "1"
+    type {
+      tensor_type {
+        elem_type: FLOAT
+        shape {
+          dim {
+            dim_value: 3
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 9
+}
diff --git a/test/onnx/expect/TestOperators.test_ones_like.expect b/test/onnx/expect/TestOperators.test_ones_like.expect
new file mode 100644 (file)
index 0000000..96016f3
--- /dev/null
@@ -0,0 +1,56 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+  node {
+    input: "0"
+    output: "1"
+    op_type: "ConstantLike"
+    attribute {
+      name: "dtype"
+      i: 1
+      type: INT
+    }
+    attribute {
+      name: "value"
+      f: 1
+      type: FLOAT
+    }
+  }
+  name: "torch-jit-export"
+  input {
+    name: "0"
+    type {
+      tensor_type {
+        elem_type: FLOAT
+        shape {
+          dim {
+            dim_value: 6
+          }
+          dim {
+            dim_value: 10
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "1"
+    type {
+      tensor_type {
+        elem_type: FLOAT
+        shape {
+          dim {
+            dim_value: 6
+          }
+          dim {
+            dim_value: 10
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 9
+}
diff --git a/test/onnx/expect/TestOperators.test_zeros_like.expect b/test/onnx/expect/TestOperators.test_zeros_like.expect
new file mode 100644 (file)
index 0000000..c21b4e9
--- /dev/null
@@ -0,0 +1,56 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+  node {
+    input: "0"
+    output: "1"
+    op_type: "ConstantLike"
+    attribute {
+      name: "dtype"
+      i: 1
+      type: INT
+    }
+    attribute {
+      name: "value"
+      f: 0
+      type: FLOAT
+    }
+  }
+  name: "torch-jit-export"
+  input {
+    name: "0"
+    type {
+      tensor_type {
+        elem_type: FLOAT
+        shape {
+          dim {
+            dim_value: 5
+          }
+          dim {
+            dim_value: 8
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "1"
+    type {
+      tensor_type {
+        elem_type: FLOAT
+        shape {
+          dim {
+            dim_value: 5
+          }
+          dim {
+            dim_value: 8
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 9
+}
index 9a27011..d472e6a 100644 (file)
@@ -292,6 +292,10 @@ class TestOperators(TestCase):
         x = torch.randn(3, 4, requires_grad=True)
         self.assertONNX(lambda x: torch.full(x.shape, 2), x)
 
+    def test_full_like(self):
+        x = torch.randn(3, 4, requires_grad=True)
+        self.assertONNX(lambda x: torch.full_like(x, 2), x)
+
     def test_max(self):
         x = torch.randn(3, 4, requires_grad=True)
         y = torch.randn(3, 4, requires_grad=True)
@@ -475,6 +479,13 @@ class TestOperators(TestCase):
         x = torch.randn(3, 4)
         self.assertONNX(torch.nn.Linear(4, 5, bias=True), x)
 
+    def test_zeros_like(self):
+        x = torch.randn(5, 8, requires_grad=True)
+        self.assertONNX(lambda x: torch.zeros_like(x), x)
+
+    def test_ones_like(self):
+        x = torch.randn(6, 10, requires_grad=True)
+        self.assertONNX(lambda x: torch.ones_like(x), x)
 
 if __name__ == '__main__':
     no_onnx_dep_flag = '--no-onnx'
index aace3b6..303a183 100644 (file)
@@ -1023,8 +1023,9 @@ def zeros(g, sizes, dtype, layout, device):
     return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=0)
 
 
-def zeros_like(g, input):
-    return g.op("Sub", input, input).setType(input.type().contiguous())
+@parse_args('v', 'i', 'v', 'v')
+def zeros_like(g, input, dtype, layout, device):
+    return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=0.0)
 
 
 @parse_args('v', 'i', 'v', 'v')
@@ -1032,6 +1033,11 @@ def ones(g, sizes, dtype, layout, device):
     return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=1)
 
 
+@parse_args('v', 'i', 'v', 'v')
+def ones_like(g, input, dtype, layout, device):
+    return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=1.0)
+
+
 def full(g, sizes, value, dtype, layout, device):
     const_value = _maybe_get_const(value, 't')
     if _is_value(const_value):
@@ -1043,9 +1049,9 @@ def full(g, sizes, value, dtype, layout, device):
                     input_as_shape_i=1, value_f=const_value)
 
 
-def full_like(g, input, fill_value):
-    # TODO: a more efficient implementation (ConstantFill?)
-    return add(g, zeros_like(g, input), fill_value, g.op("Constant", value_t=torch.tensor(1)))
+@parse_args('v', 'f', 'i', 'v', 'v')
+def full_like(g, input, fill_value, dtype, layout, device):
+    return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=fill_value)
 
 
 @parse_args('v', 'v', 'v', 'v', 'i')