From a1507668807e6108c12ffecf3740cb339b15018d Mon Sep 17 00:00:00 2001 From: bixia1 Date: Wed, 8 Feb 2023 09:02:13 -0800 Subject: [PATCH] [mlir][sparse] Implement hybrid quick sort for sparse_tensor.sort. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D143227 --- .../Dialect/SparseTensor/Transforms/CodegenUtils.h | 5 + .../Transforms/SparseBufferRewriting.cpp | 182 ++++++++++++++++----- .../Dialect/SparseTensor/buffer_rewriting.mlir | 67 ++++++-- .../SparseTensor/CPU/sparse_rewrite_sort.mlir | 16 +- .../SparseTensor/CPU/sparse_rewrite_sort_coo.mlir | 2 +- 5 files changed, 216 insertions(+), 56 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h index b07991e..12cfd3b 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -275,6 +275,11 @@ inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) { return builder.create(loc, i); } +/// Generates a constant of `i64` type. +inline Value constantI64(OpBuilder &builder, Location loc, int64_t i) { + return builder.create(loc, i, 64); +} + /// Generates a constant of `i32` type. inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) { return builder.create(loc, i, 32); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index 90ca39f..3e61570 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" @@ -38,12 +39,13 @@ static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_"; static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_"; static constexpr const char kBinarySearchFuncNamePrefix[] = "_sparse_binary_search_"; -static constexpr const char kSortNonstableFuncNamePrefix[] = - "_sparse_sort_nonstable_"; +static constexpr const char kHybridQuickSortFuncNamePrefix[] = + "_sparse_hybrid_qsort_"; static constexpr const char kSortStableFuncNamePrefix[] = "_sparse_sort_stable_"; static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_"; static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_"; +static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_"; using FuncGeneratorType = function_ref; @@ -916,41 +918,19 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, builder.create(loc); } -/// Creates a function to perform quick sort on the value in the range of -/// index [lo, hi). -// -// The generate IR corresponds to this C like algorithm: -// void quickSort(lo, hi, data) { -// if (lo < hi) { -// p = partition(low, high, data); -// quickSort(lo, p, data); -// quickSort(p + 1, hi, data); -// } -// } -static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module, - func::FuncOp func, uint64_t nx, uint64_t ny, - bool isCoo, uint32_t nTrailingP) { - (void)nTrailingP; - OpBuilder::InsertionGuard insertionGuard(builder); - Block *entryBlock = func.addEntryBlock(); - builder.setInsertionPointToStart(entryBlock); - +static void createQuickSort(OpBuilder &builder, ModuleOp module, + func::FuncOp func, ValueRange args, uint64_t nx, + uint64_t ny, bool isCoo, uint32_t nTrailingP) { MLIRContext *context = module.getContext(); Location loc = func.getLoc(); - ValueRange args = entryBlock->getArguments(); Value lo = args[loIdx]; Value hi = args[hiIdx]; - Value cond = - builder.create(loc, arith::CmpIPredicate::ult, lo, hi); - scf::IfOp ifOp = builder.create(loc, cond, /*else=*/false); - - // The if-stmt true branch. - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx, - ny, isCoo, args, createPartitionFunc); - auto p = builder.create( - loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args)); + ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc); + auto p = builder.create(loc, partitionFunc, + TypeRange{IndexType::get(context)}, + args.drop_back(nTrailingP)); SmallVector lowOperands{lo, p.getResult(0)}; lowOperands.append(args.begin() + xStartIdx, args.end()); @@ -962,10 +942,6 @@ static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module, hi}; highOperands.append(args.begin() + xStartIdx, args.end()); builder.create(loc, func, highOperands); - - // After the if-stmt. - builder.setInsertionPointAfter(ifOp); - builder.create(loc); } /// Creates a function to perform insertion sort on the values in the range of @@ -1054,6 +1030,116 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module, builder.create(loc); } +/// Creates a function to perform quick sort or a hybrid quick sort on the +/// values in the range of index [lo, hi). +// +// +// When nTrailingP == 0, the generated IR corresponds to this C like algorithm: +// void quickSort(lo, hi, data) { +// if (lo + 1 < hi) { +// p = partition(low, high, data); +// quickSort(lo, p, data); +// quickSort(p + 1, hi, data); +// } +// } +// +// When nTrailingP == 1, the generated IR corresponds to this C like algorithm: +// void hybridQuickSort(lo, hi, data, depthLimit) { +// if (lo + 1 < hi) { +// len = hi - lo; +// if (len <= limit) { +// insertionSort(lo, hi, data); +// } else { +// depthLimit --; +// if (depthLimit <= 0) { +// heapSort(lo, hi, data); +// } else { +// p = partition(low, high, data); +// quickSort(lo, p, data); +// quickSort(p + 1, hi, data); +// } +// depthLimit ++; +// } +// } +// } +// +static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, + func::FuncOp func, uint64_t nx, uint64_t ny, + bool isCoo, uint32_t nTrailingP) { + assert(nTrailingP == 1 || nTrailingP == 0); + bool isHybrid = (nTrailingP == 1); + OpBuilder::InsertionGuard insertionGuard(builder); + Block *entryBlock = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + + Location loc = func.getLoc(); + ValueRange args = entryBlock->getArguments(); + Value lo = args[loIdx]; + Value hi = args[hiIdx]; + Value loCmp = + builder.create(loc, lo, constantIndex(builder, loc, 1)); + Value cond = + builder.create(loc, arith::CmpIPredicate::ult, loCmp, hi); + scf::IfOp ifOp = builder.create(loc, cond, /*else=*/false); + + // The if-stmt true branch. + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value pDepthLimit; + Value savedDepthLimit; + scf::IfOp depthIf; + + if (isHybrid) { + Value len = builder.create(loc, hi, lo); + Value lenLimit = constantIndex(builder, loc, 30); + Value lenCond = builder.create( + loc, arith::CmpIPredicate::ule, len, lenLimit); + scf::IfOp lenIf = builder.create(loc, lenCond, /*else=*/true); + + // When len <= limit. + builder.setInsertionPointToStart(&lenIf.getThenRegion().front()); + FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc( + builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo, + args.drop_back(nTrailingP), createSortStableFunc); + builder.create(loc, insertionSortFunc, TypeRange(), + ValueRange(args.drop_back(nTrailingP))); + + // When len > limit. + builder.setInsertionPointToStart(&lenIf.getElseRegion().front()); + pDepthLimit = args.back(); + savedDepthLimit = builder.create(loc, pDepthLimit); + Value depthLimit = builder.create( + loc, savedDepthLimit, constantI64(builder, loc, 1)); + builder.create(loc, depthLimit, pDepthLimit); + Value depthCond = + builder.create(loc, arith::CmpIPredicate::ule, + depthLimit, constantI64(builder, loc, 0)); + depthIf = builder.create(loc, depthCond, /*else=*/true); + + // When depth exceeds limit. + builder.setInsertionPointToStart(&depthIf.getThenRegion().front()); + FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc( + builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo, + args.drop_back(nTrailingP), createHeapSortFunc); + builder.create(loc, heapSortFunc, TypeRange(), + ValueRange(args.drop_back(nTrailingP))); + + // When depth doesn't exceed limit. + builder.setInsertionPointToStart(&depthIf.getElseRegion().front()); + } + + createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP); + + if (isHybrid) { + // Restore depthLimit. + builder.setInsertionPointAfter(depthIf); + builder.create(loc, savedDepthLimit, pDepthLimit); + } + + // After the if-stmt. + builder.setInsertionPointAfter(ifOp); + builder.create(loc); +} + /// Implements the rewriting for operator sort and sort_coo. template LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx, @@ -1078,10 +1164,30 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx, FuncGeneratorType funcGenerator; uint32_t nTrailingP = 0; switch (op.getAlgorithm()) { - case SparseTensorSortKind::HybridQuickSort: + case SparseTensorSortKind::HybridQuickSort: { + funcName = kHybridQuickSortFuncNamePrefix; + funcGenerator = createQuickSortFunc; + nTrailingP = 1; + Value pDepthLimit = rewriter.create( + loc, MemRefType::get({}, rewriter.getI64Type())); + operands.push_back(pDepthLimit); + // As a heuristics, set depthLimit = 2 * log2(n). + Value lo = operands[loIdx]; + Value hi = operands[hiIdx]; + Value len = rewriter.create( + loc, rewriter.getI64Type(), + rewriter.create(loc, hi, lo)); + Value depthLimit = rewriter.create( + loc, constantI64(rewriter, loc, 64), + rewriter.create(loc, len)); + depthLimit = rewriter.create(loc, depthLimit, + constantI64(rewriter, loc, 1)); + rewriter.create(loc, depthLimit, pDepthLimit); + break; + } case SparseTensorSortKind::QuickSort: - funcName = kSortNonstableFuncNamePrefix; - funcGenerator = createSortNonstableFunc; + funcName = kQuickSortFuncNamePrefix; + funcGenerator = createQuickSortFunc; break; case SparseTensorSortKind::InsertionSortStable: funcName = kSortStableFuncNamePrefix; diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir index 68c8366..dbe0c97 100644 --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -125,24 +125,25 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref, %arg2: f // CHECK: return %[[W:.*]]#2 // CHECK: } -// CHECK-LABEL: func.func private @_sparse_sort_nonstable_1_i8_f32_index( +// CHECK-LABEL: func.func private @_sparse_qsort_1_i8_f32_index( // CHECK-SAME: %[[L:arg0]]: index, // CHECK-SAME: %[[H:.*]]: index, // CHECK-SAME: %[[X0:.*]]: memref, // CHECK-SAME: %[[Y0:.*]]: memref, // CHECK-SAME: %[[Y1:.*]]: memref) { // CHECK: %[[C1:.*]] = arith.constant 1 -// CHECK: %[[COND:.*]] = arith.cmpi ult, %[[L]], %[[H]] +// CHECK: %[[Lb:.*]] = arith.addi %[[L]], %[[C1]] +// CHECK: %[[COND:.*]] = arith.cmpi ult, %[[Lb]], %[[H]] // CHECK: scf.if %[[COND]] { // CHECK: %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]]) -// CHECK: func.call @_sparse_sort_nonstable_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]]) +// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]]) // CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]] : index -// CHECK: func.call @_sparse_sort_nonstable_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]]) +// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]]) // CHECK: } // CHECK: return // CHECK: } -// CHECK-LABEL: func.func @sparse_sort_1d2v( +// CHECK-LABEL: func.func @sparse_sort_1d2v_quick( // CHECK-SAME: %[[N:.*]]: index, // CHECK-SAME: %[[X0:.*]]: memref<10xi8>, // CHECK-SAME: %[[Y0:.*]]: memref, @@ -150,12 +151,12 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref, %arg2: f // CHECK: %[[C0:.*]] = arith.constant 0 // CHECK: %[[DX0:.*]] = memref.cast %[[X0]] : memref<10xi8> to memref // CHECK: %[[DY1:.*]] = memref.cast %[[Y1]] : memref<10xindex> to memref -// CHECK: call @_sparse_sort_nonstable_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]]) +// CHECK: call @_sparse_qsort_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]]) // CHECK: return %[[X0]], %[[Y0]], %[[Y1]] // CHECK: } -func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref, %arg3: memref<10xindex>) +func.func @sparse_sort_1d2v_quick(%arg0: index, %arg1: memref<10xi8>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xi8>, memref, memref<10xindex>) { - sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref, memref<10xindex> + sparse_tensor.sort quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref, memref<10xindex> return %arg1, %arg2, %arg3 : memref<10xi8>, memref, memref<10xindex> } @@ -167,9 +168,28 @@ func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref, %arg3: memref, %arg4: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_compare_eq_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_sort_nonstable_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-LABEL: func.func @sparse_sort_3d -func.func @sparse_sort_3d(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { +// CHECK-DAG: func.func private @_sparse_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-LABEL: func.func @sparse_sort_3d_quick +func.func @sparse_sort_3d_quick(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { + sparse_tensor.sort quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> + return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> +} + +// ----- + +// Only check the generated supporting function now. We have integration test +// to verify correctness of the generated code. +// +// CHECK-DAG: func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> i1 { +// CHECK-DAG: func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { +// CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_compare_eq_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> i1 { +// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref) { +// CHECK-LABEL: func.func @sparse_sort_3d_hybrid +func.func @sparse_sort_3d_hybrid(%arg0: index, %arg1: memref<10xindex>, %arg2: memref, %arg3: memref<10xindex>) -> (memref<10xindex>, memref, memref<10xindex>) { sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> return %arg1, %arg2, %arg3 : memref<10xindex>, memref, memref<10xindex> } @@ -210,9 +230,28 @@ func.func @sparse_sort_3d_heap(%arg0: index, %arg1: memref<10xindex>, %arg2: mem // CHECK-DAG: func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref) -> i1 { // CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { -// CHECK-DAG: func.func private @_sparse_sort_nonstable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { -// CHECK-LABEL: func.func @sparse_sort_coo -func.func @sparse_sort_coo(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { +// CHECK-DAG: func.func private @_sparse_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-LABEL: func.func @sparse_sort_coo_quick +func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { + sparse_tensor.sort_coo quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> + return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> +} + +// ----- + +// Only check the generated supporting functions. We have integration test to +// verify correctness of the generated code. +// +// CHECK-DAG: func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref) -> i1 { +// CHECK-DAG: func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { +// CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) { +// CHECK-DAG: func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref) -> i1 { +// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref) -> index { +// CHECK-DAG: func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref) { +// CHECK-LABEL: func.func @sparse_sort_coo_hybrid +func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir index 3c2d9cf..d3ef2fa 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir @@ -49,8 +49,9 @@ module { : (memref, i32, i32, i32, i32, i32) -> () // Sort 0 elements. + // Quick sort. // CHECK: [10, 2, 0, 5, 1] - sparse_tensor.sort hybrid_quick_sort %i0, %x0 : memref + sparse_tensor.sort quick_sort %i0, %x0 : memref call @printMemref1dI32(%x0) : (memref) -> () // Stable sort. // CHECK: [10, 2, 0, 5, 1] @@ -60,10 +61,15 @@ module { // CHECK: [10, 2, 0, 5, 1] sparse_tensor.sort heap_sort %i0, %x0 : memref call @printMemref1dI32(%x0) : (memref) -> () + // Hybrid sort. + // CHECK: [10, 2, 0, 5, 1] + sparse_tensor.sort hybrid_quick_sort %i0, %x0 : memref + call @printMemref1dI32(%x0) : (memref) -> () // Sort the first 4 elements, with the last valid value untouched. + // Quick sort. // CHECK: [0, 2, 5, 10, 1] - sparse_tensor.sort hybrid_quick_sort %i4, %x0 : memref + sparse_tensor.sort quick_sort %i4, %x0 : memref call @printMemref1dI32(%x0) : (memref) -> () // Stable sort. // CHECK: [0, 2, 5, 10, 1] @@ -77,6 +83,10 @@ module { : (memref, i32, i32, i32, i32, i32) -> () sparse_tensor.sort heap_sort %i4, %x0 : memref call @printMemref1dI32(%x0) : (memref) -> () + // Hybrid sort. + // CHECK: [0, 2, 5, 10, 1] + sparse_tensor.sort hybrid_quick_sort %i4, %x0 : memref + call @printMemref1dI32(%x0) : (memref) -> () // Prepare more buffers of different dimensions. %x1s = memref.alloc() : memref<10xi32> @@ -99,7 +109,7 @@ module { : (memref, i32, i32, i32, i32, i32) -> () call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7) : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort hybrid_quick_sort %i5, %x0, %x1, %x2 jointly %y0 + sparse_tensor.sort quick_sort %i5, %x0, %x1, %x2 jointly %y0 : memref, memref, memref jointly memref call @printMemref1dI32(%x0) : (memref) -> () call @printMemref1dI32(%x1) : (memref) -> () diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir index 46e1020..70119f8 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir @@ -92,7 +92,7 @@ module { : (memref>, i32, i32, i32, i32, i32) -> () call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7) : (memref, i32, i32, i32, i32, i32) -> () - sparse_tensor.sort_coo hybrid_quick_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index} + sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index} : memref jointly memref %x0v = vector.transfer_read %x0[%i0], %c100: memref>, vector<5xi32> vector.print %x0v : vector<5xi32> -- 2.7.4