From: BowenBao Date: Tue, 8 Jan 2019 00:06:34 +0000 (-0800) Subject: Add support for exporting onnx split (#15092) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1988 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=24867a58aae1c532da8c50d1c18cbfe5d3894f93;p=platform%2Fupstream%2Fpytorch.git Add support for exporting onnx split (#15092) Summary: * With the update of split output to dynamic list it breaks the export to onnx. Now split ir becomes two ops: 1. Dynamic[] <= Split(), and 2. out1, out2, out3 <= Prim::ListUnpack. In this fix these two consecutive ops get fused when being exported to onnx. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15092 Reviewed By: dzhulgakov Differential Revision: D13583832 Pulled By: houseroad fbshipit-source-id: 3eb18c871e750921ad6d5cc179254bee9bcf4c99 --- diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 286653a..40f9b39 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -115,6 +115,7 @@ namespace c10 { _(onnx, Less) \ _(onnx, Not) \ _(onnx, ATen) \ + _(onnx, Split) \ FORALL_ATTR_BASE_SYMBOLS(_) \ _(attr, Subgraph) \ _(attr, ReverseSubgraph) \ @@ -140,7 +141,8 @@ namespace c10 { _(attr, a) \ _(attr, b) \ _(attr, beg) \ - _(attr, idx) + _(attr, idx) \ + _(attr, split) #else #define FORALL_NS_SYMBOLS(_) \ _(namespaces, prim) \ diff --git a/test/onnx/expect/TestOperators.test_split.expect b/test/onnx/expect/TestOperators.test_split.expect new file mode 100644 index 0000000..e9dbdd4 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_split.expect @@ -0,0 +1,92 @@ +ir_version: 3 +producer_name: "pytorch" +producer_version: "0.4" +graph { + node { + input: "tensor" + output: "1" + output: "2" + output: "3" + op_type: "Split" + attribute { + name: "axis" + i: 1 + type: INT + } + attribute { + name: "split" + ints: 2 + ints: 2 + ints: 2 + type: INTS + } + } + name: "torch-jit-export" + input { + name: "tensor" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 6 + } + } + } + } + } + output { + name: "1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "2" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "3" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/expect/TestOperators.test_split_with_sizes.expect b/test/onnx/expect/TestOperators.test_split_with_sizes.expect new file mode 100644 index 0000000..fd72d1a --- /dev/null +++ b/test/onnx/expect/TestOperators.test_split_with_sizes.expect @@ -0,0 +1,92 @@ +ir_version: 3 +producer_name: "pytorch" +producer_version: "0.4" +graph { + node { + input: "tensor" + output: "1" + output: "2" + output: "3" + op_type: "Split" + attribute { + name: "axis" + i: 1 + type: INT + } + attribute { + name: "split" + ints: 2 + ints: 1 + ints: 3 + type: INTS + } + } + name: "torch-jit-export" + input { + name: "tensor" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 6 + } + } + } + } + } + output { + name: "1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "2" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 1 + } + } + } + } + } + output { + name: "3" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 5d5db46..8484f9e 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -179,6 +179,14 @@ class TestOperators(TestCase): x = torch.tensor([0.0, 1.0, 2.0], requires_grad=True) self.assertONNX(lambda x: x.chunk(2), x) + def test_split(self): + x = torch.tensor([[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]]) + self.assertONNX(lambda x: torch.split(x, 2, 1), x) + + def test_split_with_sizes(self): + x = torch.tensor([[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]]) + self.assertONNX(lambda x: torch.split(x, [2, 1, 3], 1), x) + def test_concat2(self): x = torch.randn(2, 3) y = torch.randn(2, 3) diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 6987b93..749626e 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -560,6 +560,30 @@ static void eraseListConstruct(Block* block) { } } +static void fuseSplitListUnpack(Block *b) { + for(auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { + for (auto *child_block : it->blocks()) { + fuseSplitListUnpack(child_block); + } + if (it->kind() == prim::ListUnpack && it->input()->node()->kind() == onnx::Split) { + auto origSplitNode = it->input()->node(); + + Node * splitNode = b->owningGraph()->create(onnx::Split, it->outputs().size()); + for (size_t i=0; ioutputs().size(); ++i) { + splitNode->outputs()[i]->copyMetadata(it->outputs()[i]); + } + splitNode->copyAttributes(*origSplitNode); + splitNode->insertBefore(origSplitNode); + splitNode->addInput(origSplitNode->input()); + it->replaceAllUsesWith(splitNode); + it->removeAllInputs(); + origSplitNode->destroy(); + it.destroyCurrent(); + continue; + } + } +} + // This optimization does ONNX-specific peephole optimizations. // // At the moment, here are the optimizations it does: @@ -592,6 +616,7 @@ void PeepholeOptimizeONNX(std::shared_ptr& graph) { fuseTransposeIntoGemm(graph->block()); speculateOps(graph->block()); eraseListConstruct(graph->block()); + fuseSplitListUnpack(graph->block()); } } // namespace jit diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 6bae829..bbdb761 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -437,6 +437,21 @@ def prim_ConstantChunk(g, self, chunks, dim): return prim_ConstantSplit(g, self, split_size, dim) +@parse_args('v', 'i', 'i') +def split(g, self, split_size, dim): + size = self.type().sizes()[dim] + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + return g.op("Split", self, split_i=splits, axis_i=dim, outputs=1) + + +@parse_args('v', 'is', 'i') +def split_with_sizes(g, self, split_sizes, dim): + return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=1) + + @parse_args('v', 'i', 'v') def select(g, self, dim, index): if dim > 1: diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 0242ffa..a2e0218 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -515,11 +515,11 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor else: raise RuntimeError("Unsupported prim::Constant kind: `{}`. Send a bug report.".format( n.kindOf("value"))) - elif op_name == "Undefined" or op_name == "None" or op_name == "ListConstruct": + elif op_name == "Undefined" or op_name == "None" or op_name == "ListConstruct" or op_name == "ListUnpack": # Undefined/None is not an ONNX operator; keep it as prim::Undefined/ # prim::None and let the exporter handle finally eliminating these - # For ListConstruct, it will be erased in the ONNX peephole pass + # For ListConstruct/ListUnpack, it will be erased in the ONNX peephole pass return None elif op_name == 'Loop' or op_name == 'If': new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize())