[mlir][sparse] Extend sparse_tensor.sort with a enum attribute to specify a sorting...
authorbixia1 <bixia@google.com>
Fri, 27 Jan 2023 23:16:31 +0000 (15:16 -0800)
committerbixia1 <bixia@google.com>
Mon, 30 Jan 2023 02:34:08 +0000 (18:34 -0800)
Currently, all the non-stable sorting algorithms are implemented via the
straightforward quick sort. This will be fixed in the following PR.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D142678

13 files changed:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir

index 43c493c..f6fc8fd 100644 (file)
@@ -325,4 +325,39 @@ class RankedSparseTensorOf<list<Type> allowedTypes>
 
 def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>;
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Sorting Algorithm Attribute.
+//===----------------------------------------------------------------------===//
+
+// TODO: Currently, we only provide four implementations, and expose the
+// implementations via attribute algorithm. In the future, if we will need
+// to support both stable and non-stable quick sort, we may add
+// quick_sort_nonstable enum to the attribute. Alternative, we may use two
+// attributes, (stable|nonstable, algorithm), to specify a sorting
+// implementation.
+//
+// --------------------------------------------------------------------------
+// |           | hybrid_qsort| insertion_sort | qsort       | heap_sort.    |
+// |non-stable | Impl        | X              |  Impl       | Impl          |
+// |stable     | X           | Impl           |  Not Impl   | X             |
+// --------------------------------------------------------------------------
+
+// The C++ enum for sparse tensor sort kind.
+def SparseTensorSortKindEnum
+    : I32EnumAttr<"SparseTensorSortKind", "sparse tensor sort algorithm", [
+        I32EnumAttrCase<"HybridQuickSort",    0, "hybrid_quick_sort">,
+        I32EnumAttrCase<"InsertionSortStable", 1, "insertion_sort_stable">,
+        I32EnumAttrCase<"QuickSort", 2, "quick_sort">,
+        I32EnumAttrCase<"HeapSort", 3, "heap_sort">,
+      ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = SparseTensor_Dialect.cppNamespace;
+}
+
+// Define the enum sparse tensor sort kind attribute.
+def SparseTensorSortKindAttr
+    : EnumAttr<SparseTensor_Dialect, SparseTensorSortKindEnum,
+               "SparseTensorSortAlgorithm"> {
+}
+
 #endif // SPARSETENSOR_ATTRDEFS
index 4690643..521df94 100644 (file)
@@ -581,10 +581,15 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
     // TODO: May want to extend tablegen with
     // class NonemptyVariadic<Type type> : Variadic<type> { let minSize = 1; }
     // and then use NonemptyVariadic<...>:$xs here.
+    //
+    // TODO: Currently tablegen doesn't support the assembly syntax when
+    // `algorithm` is an optional enum attribute. We may want to use an optional
+    // enum attribute when this is fixed in tablegen.
+    //
     Arguments<(ins Index:$n,
                Variadic<StridedMemRefRankOf<[AnyInteger, Index], [1]>>:$xs,
                Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
-               UnitAttr:$stable)>  {
+               SparseTensorSortKindAttr:$algorithm)>  {
   string summary = "Sorts the arrays in xs and ys lexicographically on the "
                    "integral values found in the xs list";
   string description = [{
@@ -606,8 +611,9 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
     is undefined if this condition is not met. The operator requires at least
     one buffer in `xs` while `ys` can be empty.
 
-    The `stable` attribute indicates whether a stable sorting algorithm should
-    be used to implement the operator.
+    The enum attribute `algorithm` indicates the sorting algorithm used to
+    implement the operator: hybrid_quick_sort, insertion_sort_stable,
+    quick_sort, or heap_sort.
 
     Note that this operation is "impure" in the sense that its behavior is
     solely defined by side-effects and not SSA values.
@@ -615,17 +621,17 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
     Example:
 
     ```mlir
-    sparse_tensor.sort %n, %x1, %x2 jointly y1, %y2
+    sparse_tensor.sort insertion_sort_stable %n, %x1, %x2 jointly y1, %y2
       : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
     ```
 
     ```mlir
-    sparse_tensor.sort stable %n, %x1, %x2 jointly y1, %y2
+    sparse_tensor.sort hybrid_quick_sort %n, %x1, %x2 jointly y1, %y2
+      { alg=1 : index}
       : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
     ```
   }];
-  let assemblyFormat = "(`stable` $stable^)? $n"
-                       "`,`$xs (`jointly` $ys^)? attr-dict"
+  let assemblyFormat = "$algorithm $n `,` $xs (`jointly` $ys^)? attr-dict"
                        "`:` type($xs) (`jointly` type($ys)^)?";
   let hasVerifier = 1;
 }
@@ -634,7 +640,7 @@ def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
     Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy,
                Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
                OptionalAttr<IndexAttr>:$nx, OptionalAttr<IndexAttr>:$ny,
-               UnitAttr:$stable)>  {
+               SparseTensorSortKindAttr:$algorithm)>  {
   let summary = "Sorts the arrays in xs and ys lexicographically on the "
                 "integral values found in the xs list";
   let description = [{
@@ -653,17 +659,18 @@ def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
     Example:
 
     ```mlir
-    sparse_tensor.sort_coo %n, %x { nx = 2 : index}
+    sparse_tensor.sort_coo insertion_sort_stable %n, %x { nx = 2 : index}
       : memref<?xindex>
     ```
 
     ```mlir
-    sparse_tensor.sort %n, %xy jointly %y1 { nx = 2 : index, ny = 2 : index}
+    sparse_tensor.sort hybrid_quick_sort %n, %xy jointly %y1
+      { nx = 2 : index, ny = 2 : index}
       : memref<?xi64> jointly memref<?xf32>
     ```
   }];
 
-  let assemblyFormat = "(`stable` $stable^)? $n"
+  let assemblyFormat = "$algorithm $n"
                        "`,`$xy (`jointly` $ys^)? attr-dict"
                        "`:` type($xy) (`jointly` type($ys)^)?";
   let hasVerifier = 1;
index 3fc760e..90dadf7 100644 (file)
@@ -814,11 +814,13 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
     }
     operands.push_back(v);
   }
+  bool isStable =
+      (op.getAlgorithm() == SparseTensorSortKind::InsertionSortStable);
   auto insertPoint = op->template getParentOfType<func::FuncOp>();
-  SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix
-                                          : kSortNonstableFuncNamePrefix);
+  SmallString<32> funcName(isStable ? kSortStableFuncNamePrefix
+                                    : kSortNonstableFuncNamePrefix);
   FuncGeneratorType funcGenerator =
-      op.getStable() ? createSortStableFunc : createSortNonstableFunc;
+      isStable ? createSortStableFunc : createSortNonstableFunc;
   FlatSymbolRefAttr func =
       getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
                                ny, isCoo, operands, funcGenerator);
index 2ce29e5..f96aeee 100644 (file)
@@ -819,7 +819,8 @@ public:
     // in the "added" array prior to applying the compression.
     unsigned rank = dstType.getShape().size();
     if (isOrderedDim(dstType, rank - 1))
-      rewriter.create<SortOp>(loc, count, ValueRange{added}, ValueRange{});
+      rewriter.create<SortOp>(loc, count, ValueRange{added}, ValueRange{},
+                              SparseTensorSortKind::HybridQuickSort);
     // While performing the insertions, we also need to reset the elements
     // of the values/filled-switch by only iterating over the set elements,
     // to ensure that the runtime complexity remains proportional to the
index 22ec479..1e2a7b0 100644 (file)
@@ -866,9 +866,9 @@ private:
             get1DMemRefType(getIndexOverheadType(rewriter, encSrc),
                             /*withLayout=*/false);
         Value xs = rewriter.create<ToIndicesBufferOp>(loc, indTp, src);
-        rewriter.create<SortCooOp>(loc, nnz, xs, ValueRange{y},
-                                   rewriter.getIndexAttr(rank),
-                                   rewriter.getIndexAttr(0));
+        rewriter.create<SortCooOp>(
+            loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(rank),
+            rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
       } else {
         // Gather the indices-arrays in the dst tensor storage order.
         SmallVector<Value> xs(rank, Value());
@@ -877,7 +877,8 @@ private:
           xs[toStoredDim(encDst, orgDim)] =
               genToIndices(rewriter, loc, src, i, /*cooStart=*/0);
         }
-        rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y});
+        rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y},
+                                SparseTensorSortKind::HybridQuickSort);
       }
     }
 
index 7e10ae1..b9d56f4 100644 (file)
@@ -55,7 +55,7 @@ func.func @sparse_push_back(%arg0: index, %arg1: memref<?xf64>, %arg2: f64) -> (
 //       CHECK: return %[[M]], %[[S2]] : memref<?xf64>, index
 func.func @sparse_push_back_n(%arg0: index, %arg1: memref<?xf64>, %arg2: f64, %arg3: index) -> (memref<?xf64>, index) {
   %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 : index, memref<?xf64>, f64, index
-  return %0#0, %0#1 : memref<?xf64>, index  
+  return %0#0, %0#1 : memref<?xf64>, index
 }
 
 // -----
@@ -155,7 +155,7 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
 // CHECK:         }
 func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?xf32>, %arg3: memref<10xindex>)
    -> (memref<10xi8>, memref<?xf32>, memref<10xindex>) {
-  sparse_tensor.sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref<?xf32>, memref<10xindex>
+  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref<?xf32>, memref<10xindex>
   return %arg1, %arg2, %arg3 : memref<10xi8>, memref<?xf32>, memref<10xindex>
 }
 
@@ -170,7 +170,7 @@ func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?x
 // CHECK-DAG:     func.func private @_sparse_sort_nonstable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
 // CHECK-LABEL:   func.func @sparse_sort_3d
 func.func @sparse_sort_3d(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
-  sparse_tensor.sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
   return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
 }
 
@@ -184,7 +184,7 @@ func.func @sparse_sort_3d(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?
 // CHECK-DAG:     func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
 // CHECK-LABEL:   func.func @sparse_sort_3d_stable
 func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
-  sparse_tensor.sort stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+  sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
   return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
 }
 
@@ -199,7 +199,7 @@ func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: m
 // CHECK-DAG:     func.func private @_sparse_sort_nonstable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
 // CHECK-LABEL:   func.func @sparse_sort_coo
 func.func @sparse_sort_coo(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
-  sparse_tensor.sort_coo %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+  sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
   return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
 }
 
@@ -213,7 +213,7 @@ func.func @sparse_sort_coo(%arg0: index, %arg1: memref<100xindex>, %arg2: memref
 // CHECK-DAG:     func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
 // CHECK-LABEL:   func.func @sparse_sort_coo_stable
 func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
-  sparse_tensor.sort_coo stable %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+  sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
   return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
 }
 
index 61c4324..1e06e65 100644 (file)
@@ -430,7 +430,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
 //   CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64
 //   CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index
-//       CHECK: sparse_tensor.sort %[[A7]], %[[A6]] : memref<?xindex>
+//       CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref<?xindex>
 //       CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]])
 //       CHECK:   %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref<?xindex>
 //       CHECK:   %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref<?xf64>
@@ -478,7 +478,7 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
 //       CHECK:     %[[A11:.*]] = arith.constant 0.000000e+00 : f64
 //       CHECK:     %[[A12:.*]] = arith.constant 1 : index
 //       CHECK:     %[[A13:.*]] = arith.constant 0 : index
-//       CHECK:     sparse_tensor.sort %[[A7]], %[[A6]] : memref<?xindex>
+//       CHECK:     sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref<?xindex>
 //       CHECK:     %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) -> (memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
 //       CHECK:       %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref<?xindex>
 //       CHECK:       %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref<?xf64>
index 6bdd5cd..2646b2d 100644 (file)
@@ -195,7 +195,7 @@ func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{
 //       CHECK-RWT:     %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]]
 //       CHECK-RWT:     %[[V:.*]] = sparse_tensor.values %[[COO]]
 //       CHECK-RWT:     %[[I:.*]] = sparse_tensor.indices_buffer %[[COO]]
-//       CHECK-RWT:     sparse_tensor.sort_coo %[[NNZ]], %[[I]] jointly %[[V]] {nx = 2 : index, ny = 0 : index}
+//       CHECK-RWT:     sparse_tensor.sort_coo hybrid_quick_sort %[[NNZ]], %[[I]] jointly %[[V]] {nx = 2 : index, ny = 0 : index}
 //       CHECK-RWT:     %[[T3:.*]] = bufferization.alloc_tensor()
 //       CHECK-RWT:     %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]])
 //       CHECK-RWT:     ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f32, %[[L1T:.*]]: tensor
index 531a987..feb45e1 100644 (file)
@@ -696,7 +696,7 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> (
 
 func.func @sparse_sort_x_type( %arg0: index, %arg1: memref<?xf32>) {
   // expected-error@+1 {{operand #1 must be 1D memref of integer or index values}}
-  sparse_tensor.sort %arg0, %arg1: memref<?xf32>
+  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1: memref<?xf32>
 }
 
 // -----
@@ -704,7 +704,7 @@ func.func @sparse_sort_x_type( %arg0: index, %arg1: memref<?xf32>) {
 func.func @sparse_sort_dim_too_small(%arg0: memref<10xindex>) {
   %i20 = arith.constant 20 : index
   // expected-error@+1 {{xs and ys need to have a dimension >= n: 10 < 20}}
-  sparse_tensor.sort %i20, %arg0 : memref<10xindex>
+  sparse_tensor.sort insertion_sort_stable %i20, %arg0 : memref<10xindex>
   return
 }
 
@@ -712,7 +712,7 @@ func.func @sparse_sort_dim_too_small(%arg0: memref<10xindex>) {
 
 func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<10xi8>) {
   // expected-error@+1 {{mismatch xs element types}}
-  sparse_tensor.sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8>
+  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8>
   return
 }
 
@@ -720,7 +720,7 @@ func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %a
 
 func.func @sparse_sort_coo_x_type( %arg0: index, %arg1: memref<?xf32>) {
   // expected-error@+1 {{operand #1 must be 1D memref of integer or index values}}
-  sparse_tensor.sort_coo %arg0, %arg1: memref<?xf32>
+  sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1: memref<?xf32>
   return
 }
 
@@ -729,7 +729,7 @@ func.func @sparse_sort_coo_x_type( %arg0: index, %arg1: memref<?xf32>) {
 func.func @sparse_sort_coo_x_too_small(%arg0: memref<50xindex>) {
   %i20 = arith.constant 20 : index
   // expected-error@+1 {{Expected dimension(xy) >= n * (nx + ny) got 50 < 60}}
-  sparse_tensor.sort_coo %i20, %arg0 {nx = 2 : index, ny = 1 : index} : memref<50xindex>
+  sparse_tensor.sort_coo hybrid_quick_sort %i20, %arg0 {nx = 2 : index, ny = 1 : index} : memref<50xindex>
   return
 }
 
@@ -738,7 +738,7 @@ func.func @sparse_sort_coo_x_too_small(%arg0: memref<50xindex>) {
 func.func @sparse_sort_coo_y_too_small(%arg0: memref<60xindex>, %arg1: memref<10xf32>) {
   %i20 = arith.constant 20 : index
   // expected-error@+1 {{Expected dimension(y) >= n got 10 < 20}}
-  sparse_tensor.sort_coo %i20, %arg0 jointly %arg1 {nx = 2 : index, ny = 1 : index} : memref<60xindex> jointly memref<10xf32>
+  sparse_tensor.sort_coo insertion_sort_stable %i20, %arg0 jointly %arg1 {nx = 2 : index, ny = 1 : index} : memref<60xindex> jointly memref<10xf32>
   return
 }
 
index 58375d6..1f48953 100644 (file)
@@ -504,10 +504,10 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> (
 // CHECK-LABEL: func @sparse_sort_1d0v(
 //  CHECK-SAME: %[[A:.*]]: index,
 //  CHECK-SAME: %[[B:.*]]: memref<?xindex>)
-//       CHECK: sparse_tensor.sort %[[A]], %[[B]] : memref<?xindex>
+//       CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] : memref<?xindex>
 //       CHECK: return %[[B]]
 func.func @sparse_sort_1d0v(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
-  sparse_tensor.sort %arg0, %arg1 : memref<?xindex>
+  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 : memref<?xindex>
   return %arg1 : memref<?xindex>
 }
 
@@ -518,10 +518,10 @@ func.func @sparse_sort_1d0v(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xi
 //  CHECK-SAME: %[[B:.*]]: memref<20xindex>,
 //  CHECK-SAME: %[[C:.*]]: memref<10xindex>,
 //  CHECK-SAME: %[[D:.*]]: memref<?xf32>)
-//       CHECK: sparse_tensor.sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
+//       CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
 //       CHECK: return %[[B]], %[[C]], %[[D]]
 func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<20xindex>, %arg2: memref<10xindex>, %arg3: memref<?xf32>) -> (memref<20xindex>, memref<10xindex>, memref<?xf32>) {
-  sparse_tensor.sort %arg0, %arg1 jointly %arg2, %arg3 : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
+  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
   return %arg1, %arg2, %arg3 : memref<20xindex>, memref<10xindex>, memref<?xf32>
 }
 
@@ -532,10 +532,10 @@ func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<20xindex>, %arg2: memref
 //  CHECK-SAME: %[[B:.*]]: memref<10xi8>,
 //  CHECK-SAME: %[[C:.*]]: memref<20xi8>,
 //  CHECK-SAME: %[[D:.*]]: memref<10xf64>)
-//       CHECK: sparse_tensor.sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
+//       CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
 //       CHECK: return %[[B]], %[[C]], %[[D]]
 func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) {
-  sparse_tensor.sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
+  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
   return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
 }
 
@@ -546,23 +546,34 @@ func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20
 //  CHECK-SAME: %[[B:.*]]: memref<10xi8>,
 //  CHECK-SAME: %[[C:.*]]: memref<20xi8>,
 //  CHECK-SAME: %[[D:.*]]: memref<10xf64>)
-//       CHECK: sparse_tensor.sort stable %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
+//       CHECK: sparse_tensor.sort insertion_sort_stable %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
 //       CHECK: return %[[B]], %[[C]], %[[D]]
 func.func @sparse_sort_stable(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) {
-  sparse_tensor.sort stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
+  sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
   return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
 }
 
 // -----
 
+// CHECK-LABEL: func @sparse_sort_coo(
+//  CHECK-SAME: %[[A:.*]]: index,
+//  CHECK-SAME: %[[B:.*]]: memref<?xindex>)
+//       CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A]], %[[B]] {nx = 2 : index, ny = 1 : index} : memref<?xindex>
+//       CHECK: return %[[B]]
 func.func @sparse_sort_coo(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
-  sparse_tensor.sort_coo %arg0, %arg1 { nx=2 : index, ny=1 : index}: memref<?xindex>
+  sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {nx = 2 : index, ny = 1 : index}: memref<?xindex>
   return %arg1 : memref<?xindex>
 }
 
 // -----
 
+// CHECK-LABEL: func @sparse_sort_coo_stable(
+//  CHECK-SAME: %[[A:.*]]: index,
+//  CHECK-SAME: %[[B:.*]]: memref<?xi64>,
+//  CHECK-SAME: %[[C:.*]]: memref<?xf32>)
+//       CHECK: sparse_tensor.sort_coo insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {nx = 2 : index, ny = 1 : index}
+//       CHECK: return %[[B]], %[[C]]
 func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<?xi64>, %arg2: memref<?xf32>) -> (memref<?xi64>, memref<?xf32>) {
-  sparse_tensor.sort_coo stable %arg0, %arg1 jointly %arg2 { nx=2 : index, ny=1 : index}: memref<?xi64> jointly memref<?xf32>
+  sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2 {nx = 2 : index, ny = 1 : index}: memref<?xi64> jointly memref<?xf32>
   return %arg1, %arg2 : memref<?xi64>, memref<?xf32>
 }
index 90247c3..70a5fa1 100644 (file)
 // CHECK:               } {"Emitted from" = "linalg.generic"}
 // CHECK:               scf.yield %[[VAL_70:.*]] : index
 // CHECK:             } {"Emitted from" = "linalg.generic"}
-// CHECK:             sparse_tensor.sort %[[VAL_71:.*]], %[[VAL_39]] : memref<?xindex>
+// CHECK:             sparse_tensor.sort hybrid_quick_sort %[[VAL_71:.*]], %[[VAL_39]] : memref<?xindex>
 // CHECK:             %[[VAL_72:.*]]:4 = scf.for %[[VAL_73:.*]] = %[[VAL_11]] to %[[VAL_71]] step %[[VAL_12]] iter_args(%[[VAL_74:.*]] = %[[VAL_42]], %[[VAL_75:.*]] = %[[VAL_43]], %[[VAL_76:.*]] = %[[VAL_44]], %[[VAL_77:.*]] = %[[VAL_45]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 // CHECK:               %[[VAL_78:.*]] = memref.load %[[VAL_38]]{{\[}}%[[VAL_73]]] : memref<4xindex>
 // CHECK:               %[[VAL_79:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64>
index 3dc9d53..1958548 100644 (file)
@@ -50,22 +50,22 @@ module {
 
     // Sort 0 elements.
     // CHECK: [10,  2,  0,  5,  1]
-    sparse_tensor.sort %i0, %x0 : memref<?xi32>
+    sparse_tensor.sort hybrid_quick_sort %i0, %x0 : memref<?xi32>
     call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
     // Stable sort.
     // CHECK: [10,  2,  0,  5,  1]
-    sparse_tensor.sort stable %i0, %x0 : memref<?xi32>
+    sparse_tensor.sort insertion_sort_stable %i0, %x0 : memref<?xi32>
     call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
 
     // Sort the first 4 elements, with the last valid value untouched.
     // CHECK: [0,  2,  5, 10,  1]
-    sparse_tensor.sort %i4, %x0 : memref<?xi32>
+    sparse_tensor.sort hybrid_quick_sort %i4, %x0 : memref<?xi32>
     call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
     // Stable sort.
     // CHECK: [0,  2,  5,  10,  1]
     call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort stable %i4, %x0 : memref<?xi32>
+    sparse_tensor.sort insertion_sort_stable %i4, %x0 : memref<?xi32>
     call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
 
     // Prepare more buffers of different dimensions.
@@ -89,7 +89,7 @@ module {
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort %i5, %x0, %x1, %x2 jointly %y0
+    sparse_tensor.sort hybrid_quick_sort %i5, %x0, %x1, %x2 jointly %y0
       : memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
     call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
     call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
@@ -108,7 +108,7 @@ module {
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort stable %i5, %x0, %x1, %x2 jointly %y0
+    sparse_tensor.sort insertion_sort_stable %i5, %x0, %x1, %x2 jointly %y0
       : memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
     call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
     call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
index 7c27a07..b0ff0cf 100644 (file)
@@ -92,7 +92,7 @@ module {
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort_coo %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+    sparse_tensor.sort_coo hybrid_quick_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
       : memref<?xi32> jointly memref<?xi32>
     %x0v = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %x0v : vector<5xi32>
@@ -120,7 +120,7 @@ module {
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort_coo stable %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+    sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
       : memref<?xi32> jointly memref<?xi32>
     %x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %x0v2 : vector<5xi32>