From cf24d49dc81b06e8efff15bd77f332840180867c Mon Sep 17 00:00:00 2001 From: bixia1 Date: Mon, 7 Nov 2022 08:18:53 -0800 Subject: [PATCH] [mlir][sparse] Add sparse_tensor.sort_coo operator. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D137442 --- .../Dialect/SparseTensor/IR/SparseTensorOps.td | 39 ++++++++++++++++++++++ .../SparseTensor/IR/SparseTensorDialect.cpp | 36 ++++++++++++++++++++ mlir/test/Dialect/SparseTensor/invalid.mlir | 26 +++++++++++++++ mlir/test/Dialect/SparseTensor/roundtrip.mlir | 15 +++++++++ 4 files changed, 116 insertions(+) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index a22dcce..52a6aff 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -518,6 +518,45 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>, let hasVerifier = 1; } +def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">, + Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy, + Variadic>:$ys, + OptionalAttr:$nx, OptionalAttr:$ny, + UnitAttr:$stable)> { + let summary = "Sorts the arrays in xs and ys lexicographically on the " + "integral values found in the xs list"; + let description = [{ + Sparse_tensor.sort_coo is similar to sparse_tensor.sort, except that all the + `xs` values and some `ys` values are put in the linear buffer `xy`. The + optional index attribute `nx` provides the number of `xs` values in `xy`. + When `ns` is not explicitly specified, its value is 1. The optional index + attribute `ny` provides the number of `ys` values in `xy`. When `ny` is not + explicitly specified, its value is 0. This instruction supports the TACO + COO style storage format for better sorting performance. + + The buffer xy should have a dimension not less than n * (nx + ny) while the + buffers in `ys` should have a dimension not less than `n`. The behavior of + the operator is undefined if this condition is not met. + + Example: + + ```mlir + sparse_tensor.sort_coo %n, %x { nx = 2 : index} + : memref + ``` + + ```mlir + sparse_tensor.sort %n, %xy jointly %y1 { nx = 2 : index, ny = 2 : index} + : memref jointly memref + ``` + }]; + + let assemblyFormat = "(`stable` $stable^)? $n" + "`,`$xy (`jointly` $ys^)? attr-dict" + "`:` type($xy) (`jointly` type($ys)^)?"; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // Sparse Tensor Syntax Operations. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 9d8cf37..693af03 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -719,6 +719,42 @@ LogicalResult SortOp::verify() { return success(); } +LogicalResult SortCooOp::verify() { + auto cn = getN().getDefiningOp(); + // We can't check the size of the buffers when n or buffer dimensions aren't + // compile-time constants. + if (!cn) + return success(); + + uint64_t n = cn.value(); + uint64_t nx = 1; + if (auto nxAttr = getNxAttr()) { + nx = nxAttr.getInt(); + if (nx < 1) + emitError(llvm::formatv("Expected nx > 1, got {0}", nx)); + } + uint64_t ny = 0; + if (auto nyAttr = getNyAttr()) { + ny = nyAttr.getInt(); + } + + auto checkDim = [&](Value v, uint64_t min, const char *message) { + MemRefType tp = v.getType().cast(); + int64_t dim = tp.getShape()[0]; + if (dim != ShapedType::kDynamicSize && dim < min) { + emitError(llvm::formatv("{0} got {1} < {2}", message, dim, min)); + } + }; + + checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)"); + + for (Value opnd : getYs()) { + checkDim(opnd, n, "Expected dimension(y) >= n"); + } + + return success(); +} + LogicalResult YieldOp::verify() { // Check for compatible parent. auto *parentOp = (*this)->getParentOp(); diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir index 407f194..02fb97bc 100644 --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -622,6 +622,32 @@ func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %a // ----- +func.func @sparse_sort_coo_x_type( %arg0: index, %arg1: memref) { + // expected-error@+1 {{operand #1 must be 1D memref of integer or index values}} + sparse_tensor.sort_coo %arg0, %arg1: memref + return +} + +// ----- + +func.func @sparse_sort_coo_x_too_small(%arg0: memref<50xindex>) { + %i20 = arith.constant 20 : index + // expected-error@+1 {{Expected dimension(xy) >= n * (nx + ny) got 50 < 60}} + sparse_tensor.sort_coo %i20, %arg0 {nx = 2 : index, ny = 1 : index} : memref<50xindex> + return +} + +// ----- + +func.func @sparse_sort_coo_y_too_small(%arg0: memref<60xindex>, %arg1: memref<10xf32>) { + %i20 = arith.constant 20 : index + // expected-error@+1 {{Expected dimension(y) >= n got 10 < 20}} + sparse_tensor.sort_coo %i20, %arg0 jointly %arg1 {nx = 2 : index, ny = 1 : index} : memref<60xindex> jointly memref<10xf32> + return +} + +// ----- + #CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}> func.func @sparse_alloc_escapes(%arg0: index) -> tensor<10x?xf64, #CSR> { diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir index 7f850cc..bc664ae 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -484,3 +484,18 @@ func.func @sparse_sort_stable(%arg0: index, %arg1: memref<10xi8>, %arg2: memref< sparse_tensor.sort stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64> return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64> } + +// ----- + +func.func @sparse_sort_coo(%arg0: index, %arg1: memref) -> (memref) { + sparse_tensor.sort_coo %arg0, %arg1 { nx=2 : index, ny=1 : index}: memref + return %arg1 : memref +} + +// ----- + +func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref, %arg2: memref) -> (memref, memref) { + sparse_tensor.sort_coo stable %arg0, %arg1 jointly %arg2 { nx=2 : index, ny=1 : index}: memref jointly memref + return %arg1, %arg2 : memref, memref +} + -- 2.7.4