[mlir][sparse] Add sparse_tensor.sort_coo operator.
authorbixia1 <bixia@google.com>
Mon, 7 Nov 2022 16:18:53 +0000 (08:18 -0800)
committerbixia1 <bixia@google.com>
Mon, 7 Nov 2022 16:23:51 +0000 (08:23 -0800)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D137442

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir

index a22dcce..52a6aff 100644 (file)
@@ -518,6 +518,45 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
   let hasVerifier = 1;
 }
 
+def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
+    Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy,
+               Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
+               OptionalAttr<IndexAttr>:$nx, OptionalAttr<IndexAttr>:$ny,
+               UnitAttr:$stable)>  {
+  let summary = "Sorts the arrays in xs and ys lexicographically on the "
+                "integral values found in the xs list";
+  let description = [{
+    Sparse_tensor.sort_coo is similar to sparse_tensor.sort, except that all the
+    `xs` values and some `ys` values are put in the linear buffer `xy`. The
+    optional index attribute `nx` provides the number of `xs` values in `xy`.
+    When `ns` is not explicitly specified, its value is 1. The optional index
+    attribute `ny` provides the number of `ys` values in `xy`. When `ny` is not
+    explicitly specified, its value is 0. This instruction supports the TACO
+    COO style storage format for better sorting performance.
+
+    The buffer xy should have a dimension not less than n * (nx + ny) while the
+    buffers in `ys` should have a dimension not less than `n`. The behavior of
+    the operator is undefined if this condition is not met.
+
+    Example:
+
+    ```mlir
+    sparse_tensor.sort_coo %n, %x { nx = 2 : index}
+      : memref<?xindex>
+    ```
+
+    ```mlir
+    sparse_tensor.sort %n, %xy jointly %y1 { nx = 2 : index, ny = 2 : index}
+      : memref<?xi64> jointly memref<?xf32>
+    ```
+  }];
+
+  let assemblyFormat = "(`stable` $stable^)? $n"
+                       "`,`$xy (`jointly` $ys^)? attr-dict"
+                       "`:` type($xy) (`jointly` type($ys)^)?";
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Syntax Operations.
 //===----------------------------------------------------------------------===//
index 9d8cf37..693af03 100644 (file)
@@ -719,6 +719,42 @@ LogicalResult SortOp::verify() {
   return success();
 }
 
+LogicalResult SortCooOp::verify() {
+  auto cn = getN().getDefiningOp<arith::ConstantIndexOp>();
+  // We can't check the size of the buffers when n or buffer dimensions aren't
+  // compile-time constants.
+  if (!cn)
+    return success();
+
+  uint64_t n = cn.value();
+  uint64_t nx = 1;
+  if (auto nxAttr = getNxAttr()) {
+    nx = nxAttr.getInt();
+    if (nx < 1)
+      emitError(llvm::formatv("Expected nx > 1, got {0}", nx));
+  }
+  uint64_t ny = 0;
+  if (auto nyAttr = getNyAttr()) {
+    ny = nyAttr.getInt();
+  }
+
+  auto checkDim = [&](Value v, uint64_t min, const char *message) {
+    MemRefType tp = v.getType().cast<MemRefType>();
+    int64_t dim = tp.getShape()[0];
+    if (dim != ShapedType::kDynamicSize && dim < min) {
+      emitError(llvm::formatv("{0} got {1} < {2}", message, dim, min));
+    }
+  };
+
+  checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)");
+
+  for (Value opnd : getYs()) {
+    checkDim(opnd, n, "Expected dimension(y) >= n");
+  }
+
+  return success();
+}
+
 LogicalResult YieldOp::verify() {
   // Check for compatible parent.
   auto *parentOp = (*this)->getParentOp();
index 407f194..02fb97b 100644 (file)
@@ -622,6 +622,32 @@ func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %a
 
 // -----
 
+func.func @sparse_sort_coo_x_type( %arg0: index, %arg1: memref<?xf32>) {
+  // expected-error@+1 {{operand #1 must be 1D memref of integer or index values}}
+  sparse_tensor.sort_coo %arg0, %arg1: memref<?xf32>
+  return
+}
+
+// -----
+
+func.func @sparse_sort_coo_x_too_small(%arg0: memref<50xindex>) {
+  %i20 = arith.constant 20 : index
+  // expected-error@+1 {{Expected dimension(xy) >= n * (nx + ny) got 50 < 60}}
+  sparse_tensor.sort_coo %i20, %arg0 {nx = 2 : index, ny = 1 : index} : memref<50xindex>
+  return
+}
+
+// -----
+
+func.func @sparse_sort_coo_y_too_small(%arg0: memref<60xindex>, %arg1: memref<10xf32>) {
+  %i20 = arith.constant 20 : index
+  // expected-error@+1 {{Expected dimension(y) >= n got 10 < 20}}
+  sparse_tensor.sort_coo %i20, %arg0 jointly %arg1 {nx = 2 : index, ny = 1 : index} : memref<60xindex> jointly memref<10xf32>
+  return
+}
+
+// -----
+
 #CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
 
 func.func @sparse_alloc_escapes(%arg0: index) -> tensor<10x?xf64, #CSR> {
index 7f850cc..bc664ae 100644 (file)
@@ -484,3 +484,18 @@ func.func @sparse_sort_stable(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<
   sparse_tensor.sort stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
   return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
 }
+
+// -----
+
+func.func @sparse_sort_coo(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
+  sparse_tensor.sort_coo %arg0, %arg1 { nx=2 : index, ny=1 : index}: memref<?xindex>
+  return %arg1 : memref<?xindex>
+}
+
+// -----
+
+func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<?xi64>, %arg2: memref<?xf32>) -> (memref<?xi64>, memref<?xf32>) {
+  sparse_tensor.sort_coo stable %arg0, %arg1 jointly %arg2 { nx=2 : index, ny=1 : index}: memref<?xi64> jointly memref<?xf32>
+  return %arg1, %arg2 : memref<?xi64>, memref<?xf32>
+}
+