Add support for exporting onnx split (#15092)
authorBowenBao <semisqg@gmail.com>
Tue, 8 Jan 2019 00:06:34 +0000 (16:06 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 8 Jan 2019 00:09:24 +0000 (16:09 -0800)
Summary:
* With the update of split output to dynamic list it breaks the export to onnx.
 Now split ir becomes two ops: 1. Dynamic[] <= Split(), and 2. out1, out2, out3
 <= Prim::ListUnpack. In this fix these two consecutive ops get fused when being
 exported to onnx.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15092

Reviewed By: dzhulgakov

Differential Revision: D13583832

Pulled By: houseroad

fbshipit-source-id: 3eb18c871e750921ad6d5cc179254bee9bcf4c99

aten/src/ATen/core/interned_strings.h
test/onnx/expect/TestOperators.test_split.expect [new file with mode: 0644]
test/onnx/expect/TestOperators.test_split_with_sizes.expect [new file with mode: 0644]
test/onnx/test_operators.py
torch/csrc/jit/passes/onnx/peephole.cpp
torch/onnx/symbolic.py
torch/onnx/utils.py

index 286653a..40f9b39 100644 (file)
@@ -115,6 +115,7 @@ namespace c10 {
   _(onnx, Less)                    \
   _(onnx, Not)                     \
   _(onnx, ATen)                    \
+  _(onnx, Split)                   \
   FORALL_ATTR_BASE_SYMBOLS(_)      \
   _(attr, Subgraph)                \
   _(attr, ReverseSubgraph)         \
@@ -140,7 +141,8 @@ namespace c10 {
   _(attr, a)                       \
   _(attr, b)                       \
   _(attr, beg)                     \
-  _(attr, idx)
+  _(attr, idx)                     \
+  _(attr, split)
 #else
 #define FORALL_NS_SYMBOLS(_) \
   _(namespaces, prim)              \
diff --git a/test/onnx/expect/TestOperators.test_split.expect b/test/onnx/expect/TestOperators.test_split.expect
new file mode 100644 (file)
index 0000000..e9dbdd4
--- /dev/null
@@ -0,0 +1,92 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+  node {
+    input: "tensor"
+    output: "1"
+    output: "2"
+    output: "3"
+    op_type: "Split"
+    attribute {
+      name: "axis"
+      i: 1
+      type: INT
+    }
+    attribute {
+      name: "split"
+      ints: 2
+      ints: 2
+      ints: 2
+      type: INTS
+    }
+  }
+  name: "torch-jit-export"
+  input {
+    name: "tensor"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 6
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "1"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 2
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "2"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 2
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "3"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 2
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 9
+}
diff --git a/test/onnx/expect/TestOperators.test_split_with_sizes.expect b/test/onnx/expect/TestOperators.test_split_with_sizes.expect
new file mode 100644 (file)
index 0000000..fd72d1a
--- /dev/null
@@ -0,0 +1,92 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+  node {
+    input: "tensor"
+    output: "1"
+    output: "2"
+    output: "3"
+    op_type: "Split"
+    attribute {
+      name: "axis"
+      i: 1
+      type: INT
+    }
+    attribute {
+      name: "split"
+      ints: 2
+      ints: 1
+      ints: 3
+      type: INTS
+    }
+  }
+  name: "torch-jit-export"
+  input {
+    name: "tensor"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 6
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "1"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 2
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "2"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 1
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "3"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 3
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 9
+}
index 5d5db46..8484f9e 100644 (file)
@@ -179,6 +179,14 @@ class TestOperators(TestCase):
         x = torch.tensor([0.0, 1.0, 2.0], requires_grad=True)
         self.assertONNX(lambda x: x.chunk(2), x)
 
+    def test_split(self):
+        x = torch.tensor([[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]])
+        self.assertONNX(lambda x: torch.split(x, 2, 1), x)
+
+    def test_split_with_sizes(self):
+        x = torch.tensor([[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]])
+        self.assertONNX(lambda x: torch.split(x, [2, 1, 3], 1), x)
+
     def test_concat2(self):
         x = torch.randn(2, 3)
         y = torch.randn(2, 3)
index 6987b93..749626e 100644 (file)
@@ -560,6 +560,30 @@ static void eraseListConstruct(Block* block) {
   }
 }
 
+static void fuseSplitListUnpack(Block *b) {
+  for(auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
+    for (auto *child_block : it->blocks()) {
+      fuseSplitListUnpack(child_block);
+    }
+    if (it->kind() == prim::ListUnpack && it->input()->node()->kind() == onnx::Split) {
+      auto origSplitNode = it->input()->node();
+
+      Node * splitNode = b->owningGraph()->create(onnx::Split, it->outputs().size());
+      for (size_t i=0; i<splitNode->outputs().size(); ++i) {
+        splitNode->outputs()[i]->copyMetadata(it->outputs()[i]);
+      }
+      splitNode->copyAttributes(*origSplitNode);
+      splitNode->insertBefore(origSplitNode);
+      splitNode->addInput(origSplitNode->input());
+      it->replaceAllUsesWith(splitNode);
+      it->removeAllInputs();
+      origSplitNode->destroy();
+      it.destroyCurrent();
+      continue;
+    }
+  }
+}
+
 // This optimization does ONNX-specific peephole optimizations.
 //
 // At the moment, here are the optimizations it does:
@@ -592,6 +616,7 @@ void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph) {
   fuseTransposeIntoGemm(graph->block());
   speculateOps(graph->block());
   eraseListConstruct(graph->block());
+  fuseSplitListUnpack(graph->block());
 }
 
 } // namespace jit
index 6bae829..bbdb761 100644 (file)
@@ -437,6 +437,21 @@ def prim_ConstantChunk(g, self, chunks, dim):
     return prim_ConstantSplit(g, self, split_size, dim)
 
 
+@parse_args('v', 'i', 'i')
+def split(g, self, split_size, dim):
+    size = self.type().sizes()[dim]
+    splits = [split_size] * (size // split_size)
+    leftover = size % split_size
+    if leftover:
+        splits.append(leftover)
+    return g.op("Split", self, split_i=splits, axis_i=dim, outputs=1)
+
+
+@parse_args('v', 'is', 'i')
+def split_with_sizes(g, self, split_sizes, dim):
+    return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=1)
+
+
 @parse_args('v', 'i', 'v')
 def select(g, self, dim, index):
     if dim > 1:
index 0242ffa..a2e0218 100644 (file)
@@ -515,11 +515,11 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor
                 else:
                     raise RuntimeError("Unsupported prim::Constant kind: `{}`. Send a bug report.".format(
                         n.kindOf("value")))
-            elif op_name == "Undefined" or op_name == "None" or op_name == "ListConstruct":
+            elif op_name == "Undefined" or op_name == "None" or op_name == "ListConstruct" or op_name == "ListUnpack":
                 # Undefined/None is not an ONNX operator; keep it as prim::Undefined/
                 # prim::None and let the exporter handle finally eliminating these
 
-                # For ListConstruct, it will be erased in the ONNX peephole pass
+                # For ListConstruct/ListUnpack, it will be erased in the ONNX peephole pass
                 return None
             elif op_name == 'Loop' or op_name == 'If':
                 new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize())