Replace use of ConstantLike with with ConstantOfShape (#16095)
authorSpandan Tiwari <sptiwari@microsoft.com>
Sun, 20 Jan 2019 03:50:20 +0000 (19:50 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sun, 20 Jan 2019 03:52:54 +0000 (19:52 -0800)
Summary:
Submitting this PR as an update to existing PR (https://github.com/pytorch/pytorch/pull/15938) on houseroad 's request.

This PR replaces the use of ONNX op `ConstantLike` with `ConstantOfShape` in the ONNX exporter. In addition to removing the call sites in `symbolic.py`, it also replace the call site in `peephole.cpp`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16095

Differential Revision: D13745723

Pulled By: houseroad

fbshipit-source-id: e2a5f534f01adf199df9e27544f7afcfa540e1f0

aten/src/ATen/core/interned_strings.h
caffe2/onnx/backend.cc
test/onnx/expect/TestOperators.test_full_like.expect [new file with mode: 0644]
test/onnx/expect/TestOperators.test_ones_like.expect [new file with mode: 0644]
test/onnx/expect/TestOperators.test_zeros_like.expect [new file with mode: 0644]
torch/csrc/jit/passes/onnx/peephole.cpp
torch/onnx/symbolic.py

index 049f5ad..393a1d3 100644 (file)
@@ -118,6 +118,7 @@ namespace c10 {
   _(onnx, Not)                     \
   _(onnx, ATen)                    \
   _(onnx, Split)                   \
+  _(onnx, ConstantOfShape)         \
   FORALL_ATTR_BASE_SYMBOLS(_)      \
   _(attr, Subgraph)                \
   _(attr, ReverseSubgraph)         \
index abc4b68..8bcef71 100644 (file)
@@ -470,7 +470,16 @@ Caffe2Ops Caffe2Backend::CreateConstantOfShape(
   Caffe2Ops ret;
   auto* c2_op = ret.ops.Add();
   const auto* value = onnx_node->attributes.get<const TensorProto*>("value");
-  BuildTensorFillingOp(c2_op, *value, onnx_node->node.output(0), onnx_node->node.input(0));
+  if (value) {
+    BuildTensorFillingOp(c2_op, *value, onnx_node->node.output(0), onnx_node->node.input(0));
+  } else {
+    c2_op->set_type("ConstantFill");
+    c2_op->add_input(onnx_node->node.input(0));
+    c2_op->add_output(onnx_node->node.output(0));
+    auto c2_input_as_shape = c2_op->add_arg();
+    c2_input_as_shape->set_name("input_as_shape");
+    c2_input_as_shape->set_i(1);
+  }
 
   return ret;
 }
diff --git a/test/onnx/expect/TestOperators.test_full_like.expect b/test/onnx/expect/TestOperators.test_full_like.expect
new file mode 100644 (file)
index 0000000..f92b5ee
--- /dev/null
@@ -0,0 +1,59 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+  node {
+    input: "0"
+    output: "1"
+    op_type: "Shape"
+  }
+  node {
+    input: "1"
+    output: "2"
+    op_type: "ConstantOfShape"
+    attribute {
+      name: "value"
+      t {
+        data_type: 1
+        raw_data: "\000\000\000@"
+      }
+      type: TENSOR
+    }
+  }
+  name: "torch-jit-export"
+  input {
+    name: "0"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 3
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "2"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 3
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 9
+}
diff --git a/test/onnx/expect/TestOperators.test_ones_like.expect b/test/onnx/expect/TestOperators.test_ones_like.expect
new file mode 100644 (file)
index 0000000..6d01e40
--- /dev/null
@@ -0,0 +1,59 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+  node {
+    input: "0"
+    output: "1"
+    op_type: "Shape"
+  }
+  node {
+    input: "1"
+    output: "2"
+    op_type: "ConstantOfShape"
+    attribute {
+      name: "value"
+      t {
+        data_type: 1
+        raw_data: "\000\000\200?"
+      }
+      type: TENSOR
+    }
+  }
+  name: "torch-jit-export"
+  input {
+    name: "0"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 6
+          }
+          dim {
+            dim_value: 10
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "2"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 6
+          }
+          dim {
+            dim_value: 10
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 9
+}
diff --git a/test/onnx/expect/TestOperators.test_zeros_like.expect b/test/onnx/expect/TestOperators.test_zeros_like.expect
new file mode 100644 (file)
index 0000000..5cf869d
--- /dev/null
@@ -0,0 +1,59 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+  node {
+    input: "0"
+    output: "1"
+    op_type: "Shape"
+  }
+  node {
+    input: "1"
+    output: "2"
+    op_type: "ConstantOfShape"
+    attribute {
+      name: "value"
+      t {
+        data_type: 1
+        raw_data: "\000\000\000\000"
+      }
+      type: TENSOR
+    }
+  }
+  name: "torch-jit-export"
+  input {
+    name: "0"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 5
+          }
+          dim {
+            dim_value: 8
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "2"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 5
+          }
+          dim {
+            dim_value: 8
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 9
+}
index 749626e..c0e9527 100644 (file)
@@ -411,12 +411,11 @@ void fixDefaultRNNState(Graph* graph, Node* n, int input_index) {
   concated_dims->addInput(unsqueezed_batch_size->outputs()[0]);
   concated_dims->addInput(hidden_size->outputs()[0]);
 
-  Node* constant_fill = graph->create(onnx::ConstantFill, 1);
-  constant_fill->insertBefore(n);
-  constant_fill->i_(attr::input_as_shape, 1);
-  constant_fill->addInput(concated_dims->outputs()[0]);
+  Node* constant_of_shape = graph->create(onnx::ConstantOfShape, 1);
+  constant_of_shape->insertBefore(n);
+  constant_of_shape->addInput(concated_dims->outputs()[0]);
+  n->replaceInput(input_index, constant_of_shape->outputs()[0]);
 
-  n->replaceInput(input_index, constant_fill->outputs()[0]);
   if (initial_state->uses().size() == 0) {
     initial_state->node()->destroy();
   }
index 3007882..e8d3d2c 100644 (file)
@@ -1127,7 +1127,9 @@ def zeros(g, sizes, dtype, layout, device):
 
 @parse_args('v', 'i', 'v', 'v')
 def zeros_like(g, input, dtype, layout, device):
-    return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=0.0)
+    shape = g.op("Shape", input)
+    return g.op("ConstantOfShape", shape,
+                value_t=torch.tensor(0, dtype=scalar_type_to_pytorch_type[dtype]))
 
 
 @parse_args('v', 'i', 'v', 'v')
@@ -1137,7 +1139,9 @@ def ones(g, sizes, dtype, layout, device):
 
 @parse_args('v', 'i', 'v', 'v')
 def ones_like(g, input, dtype, layout, device):
-    return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=1.0)
+    shape = g.op("Shape", input)
+    return g.op("ConstantOfShape", shape,
+                value_t=torch.tensor(1, dtype=scalar_type_to_pytorch_type[dtype]))
 
 
 def full(g, sizes, value, dtype, layout, device):
@@ -1153,7 +1157,9 @@ def full(g, sizes, value, dtype, layout, device):
 
 @parse_args('v', 'f', 'i', 'v', 'v')
 def full_like(g, input, fill_value, dtype, layout, device):
-    return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=fill_value)
+    shape = g.op("Shape", input)
+    return g.op("ConstantOfShape", shape,
+                value_t=torch.tensor(fill_value, dtype=scalar_type_to_pytorch_type[dtype]))
 
 
 @parse_args('v', 'v', 'v', 'v', 'i')