_(onnx, Less) \
_(onnx, Not) \
_(onnx, ATen) \
+ _(onnx, Split) \
FORALL_ATTR_BASE_SYMBOLS(_) \
_(attr, Subgraph) \
_(attr, ReverseSubgraph) \
_(attr, a) \
_(attr, b) \
_(attr, beg) \
- _(attr, idx)
+ _(attr, idx) \
+ _(attr, split)
#else
#define FORALL_NS_SYMBOLS(_) \
_(namespaces, prim) \
--- /dev/null
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+ node {
+ input: "tensor"
+ output: "1"
+ output: "2"
+ output: "3"
+ op_type: "Split"
+ attribute {
+ name: "axis"
+ i: 1
+ type: INT
+ }
+ attribute {
+ name: "split"
+ ints: 2
+ ints: 2
+ ints: 2
+ type: INTS
+ }
+ }
+ name: "torch-jit-export"
+ input {
+ name: "tensor"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 6
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "1"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 2
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "2"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 2
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "3"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 2
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 9
+}
--- /dev/null
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+ node {
+ input: "tensor"
+ output: "1"
+ output: "2"
+ output: "3"
+ op_type: "Split"
+ attribute {
+ name: "axis"
+ i: 1
+ type: INT
+ }
+ attribute {
+ name: "split"
+ ints: 2
+ ints: 1
+ ints: 3
+ type: INTS
+ }
+ }
+ name: "torch-jit-export"
+ input {
+ name: "tensor"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 6
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "1"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 2
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "2"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 1
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "3"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 3
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 9
+}
x = torch.tensor([0.0, 1.0, 2.0], requires_grad=True)
self.assertONNX(lambda x: x.chunk(2), x)
+ def test_split(self):
+ x = torch.tensor([[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]])
+ self.assertONNX(lambda x: torch.split(x, 2, 1), x)
+
+ def test_split_with_sizes(self):
+ x = torch.tensor([[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]])
+ self.assertONNX(lambda x: torch.split(x, [2, 1, 3], 1), x)
+
def test_concat2(self):
x = torch.randn(2, 3)
y = torch.randn(2, 3)
}
}
+static void fuseSplitListUnpack(Block *b) {
+ for(auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
+ for (auto *child_block : it->blocks()) {
+ fuseSplitListUnpack(child_block);
+ }
+ if (it->kind() == prim::ListUnpack && it->input()->node()->kind() == onnx::Split) {
+ auto origSplitNode = it->input()->node();
+
+ Node * splitNode = b->owningGraph()->create(onnx::Split, it->outputs().size());
+ for (size_t i=0; i<splitNode->outputs().size(); ++i) {
+ splitNode->outputs()[i]->copyMetadata(it->outputs()[i]);
+ }
+ splitNode->copyAttributes(*origSplitNode);
+ splitNode->insertBefore(origSplitNode);
+ splitNode->addInput(origSplitNode->input());
+ it->replaceAllUsesWith(splitNode);
+ it->removeAllInputs();
+ origSplitNode->destroy();
+ it.destroyCurrent();
+ continue;
+ }
+ }
+}
+
// This optimization does ONNX-specific peephole optimizations.
//
// At the moment, here are the optimizations it does:
fuseTransposeIntoGemm(graph->block());
speculateOps(graph->block());
eraseListConstruct(graph->block());
+ fuseSplitListUnpack(graph->block());
}
} // namespace jit
return prim_ConstantSplit(g, self, split_size, dim)
+@parse_args('v', 'i', 'i')
+def split(g, self, split_size, dim):
+ size = self.type().sizes()[dim]
+ splits = [split_size] * (size // split_size)
+ leftover = size % split_size
+ if leftover:
+ splits.append(leftover)
+ return g.op("Split", self, split_i=splits, axis_i=dim, outputs=1)
+
+
+@parse_args('v', 'is', 'i')
+def split_with_sizes(g, self, split_sizes, dim):
+ return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=1)
+
+
@parse_args('v', 'i', 'v')
def select(g, self, dim, index):
if dim > 1:
else:
raise RuntimeError("Unsupported prim::Constant kind: `{}`. Send a bug report.".format(
n.kindOf("value")))
- elif op_name == "Undefined" or op_name == "None" or op_name == "ListConstruct":
+ elif op_name == "Undefined" or op_name == "None" or op_name == "ListConstruct" or op_name == "ListUnpack":
# Undefined/None is not an ONNX operator; keep it as prim::Undefined/
# prim::None and let the exporter handle finally eliminating these
- # For ListConstruct, it will be erased in the ONNX peephole pass
+ # For ListConstruct/ListUnpack, it will be erased in the ONNX peephole pass
return None
elif op_name == 'Loop' or op_name == 'If':
new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize())