[mlir][sparse] Improve quick sort by using a loop to sort the bigger partition.
authorbixia1 <bixia@google.com>
Wed, 8 Mar 2023 14:56:19 +0000 (06:56 -0800)
committerbixia1 <bixia@google.com>
Sat, 11 Mar 2023 04:43:08 +0000 (20:43 -0800)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D145440

mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir

index 107d9ef..b8cf623 100644 (file)
@@ -918,9 +918,13 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
   builder.create<func::ReturnOp>(loc);
 }
 
-static void createQuickSort(OpBuilder &builder, ModuleOp module,
-                            func::FuncOp func, ValueRange args, uint64_t nx,
-                            uint64_t ny, bool isCoo, uint32_t nTrailingP) {
+/// A helper for generating code to perform quick sort. It partitions [lo, hi),
+/// recursively calls quick sort to process the smaller partition and returns
+/// the bigger partition to be processed by the enclosed while-loop.
+static std::pair<Value, Value>
+createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
+                ValueRange args, uint64_t nx, uint64_t ny, bool isCoo,
+                uint32_t nTrailingP) {
   MLIRContext *context = module.getContext();
   Location loc = func.getLoc();
   Value lo = args[loIdx];
@@ -928,20 +932,45 @@ static void createQuickSort(OpBuilder &builder, ModuleOp module,
   FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
       builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
       ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
-  auto p = builder.create<func::CallOp>(loc, partitionFunc,
-                                        TypeRange{IndexType::get(context)},
-                                        args.drop_back(nTrailingP));
-
-  SmallVector<Value> lowOperands{lo, p.getResult(0)};
-  lowOperands.append(args.begin() + xStartIdx, args.end());
-  builder.create<func::CallOp>(loc, func, lowOperands);
-
-  SmallVector<Value> highOperands{
-      builder.create<arith::AddIOp>(loc, p.getResult(0),
-                                    constantIndex(builder, loc, 1)),
-      hi};
-  highOperands.append(args.begin() + xStartIdx, args.end());
-  builder.create<func::CallOp>(loc, func, highOperands);
+  Value p = builder
+                .create<func::CallOp>(loc, partitionFunc,
+                                      TypeRange{IndexType::get(context)},
+                                      args.drop_back(nTrailingP))
+                .getResult(0);
+  Value pP1 =
+      builder.create<arith::AddIOp>(loc, p, constantIndex(builder, loc, 1));
+  Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
+  Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
+  Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
+                                             lenLow, lenHigh);
+
+  SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
+  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
+
+  Value c0 = constantIndex(builder, loc, 0);
+  auto mayRecursion = [&](Value low, Value high, Value len) {
+    Value cond =
+        builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0);
+    scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
+    builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+    SmallVector<Value> operands{low, high};
+    operands.append(args.begin() + xStartIdx, args.end());
+    builder.create<func::CallOp>(loc, func, operands);
+    builder.setInsertionPointAfter(ifOp);
+  };
+
+  // Recursively call quickSort to process the smaller partition and return
+  // the bigger partition to be processed by the enclosed while-loop.
+  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+  mayRecursion(lo, p, lenLow);
+  builder.create<scf::YieldOp>(loc, ValueRange{pP1, hi});
+
+  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+  mayRecursion(pP1, hi, lenHigh);
+  builder.create<scf::YieldOp>(loc, ValueRange{lo, p});
+
+  builder.setInsertionPointAfter(ifOp);
+  return std::make_pair(ifOp.getResult(0), ifOp.getResult(1));
 }
 
 /// Creates a function to perform insertion sort on the values in the range of
@@ -1036,16 +1065,21 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
 //
 // When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
 // void quickSort(lo, hi, data) {
-//   if (lo + 1 < hi) {
+//   while (lo + 1 < hi) {
 //        p = partition(low, high, data);
-//        quickSort(lo, p, data);
-//        quickSort(p + 1, hi, data);
+//        if (len(lo, p) < len(p+1, hi)) {
+//          quickSort(lo, p, data);
+//          lo = p+1;
+//        } else {
+//          quickSort(p + 1, hi, data);
+//          hi = p;
+//        }
 //   }
 // }
 //
 // When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
 // void hybridQuickSort(lo, hi, data, depthLimit) {
-//   if (lo + 1 < hi) {
+//   while (lo + 1 < hi) {
 //     len = hi - lo;
 //     if (len <= limit) {
 //       insertionSort(lo, hi, data);
@@ -1055,10 +1089,14 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
 //         heapSort(lo, hi, data);
 //       } else {
 //          p = partition(low, high, data);
-//          quickSort(lo, p, data);
-//          quickSort(p + 1, hi, data);
+//          if (len(lo, p) < len(p+1, hi)) {
+//            quickSort(lo, p, data, depthLimit);
+//            lo = p+1;
+//          } else {
+//            quickSort(p + 1, hi, data, depthLimit);
+//            hi = p;
+//          }
 //       }
-//       depthLimit ++;
 //     }
 //   }
 // }
@@ -1073,70 +1111,98 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
   builder.setInsertionPointToStart(entryBlock);
 
   Location loc = func.getLoc();
-  ValueRange args = entryBlock->getArguments();
+  SmallVector<Value> args;
+  args.append(entryBlock->getArguments().begin(),
+              entryBlock->getArguments().end());
   Value lo = args[loIdx];
   Value hi = args[hiIdx];
-  Value loCmp =
+  SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
+  scf::WhileOp whileOp =
+      builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi});
+
+  // The before-region of the WhileOp.
+  Block *before =
+      builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
+  builder.setInsertionPointToEnd(before);
+  lo = before->getArgument(0);
+  hi = before->getArgument(1);
+  Value loP1 =
       builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
-  Value cond =
-      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loCmp, hi);
-  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
+  Value needSort =
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi);
+  builder.create<scf::ConditionOp>(loc, needSort, before->getArguments());
 
-  // The if-stmt true branch.
-  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-  Value pDepthLimit;
-  Value savedDepthLimit;
-  scf::IfOp depthIf;
+  // The after-region of the WhileOp.
+  Block *after =
+      builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
+  builder.setInsertionPointToEnd(after);
+  lo = after->getArgument(0);
+  hi = after->getArgument(1);
+  args[0] = lo;
+  args[1] = hi;
 
   if (isHybrid) {
     Value len = builder.create<arith::SubIOp>(loc, hi, lo);
     Value lenLimit = constantIndex(builder, loc, 30);
     Value lenCond = builder.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::ule, len, lenLimit);
-    scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true);
+    scf::IfOp lenIf =
+        builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true);
 
     // When len <= limit.
     builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
     FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
         builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo,
-        args.drop_back(nTrailingP), createSortStableFunc);
+        ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
     builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
-                                 ValueRange(args.drop_back(nTrailingP)));
+                                 ValueRange(args).drop_back(nTrailingP));
+    builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
 
     // When len > limit.
     builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
-    pDepthLimit = args.back();
-    savedDepthLimit = builder.create<memref::LoadOp>(loc, pDepthLimit);
-    Value depthLimit = builder.create<arith::SubIOp>(
-        loc, savedDepthLimit, constantI64(builder, loc, 1));
-    builder.create<memref::StoreOp>(loc, depthLimit, pDepthLimit);
+    Value depthLimit = args.back();
+    depthLimit = builder.create<arith::SubIOp>(loc, depthLimit,
+                                               constantI64(builder, loc, 1));
     Value depthCond =
         builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
                                       depthLimit, constantI64(builder, loc, 0));
-    depthIf = builder.create<scf::IfOp>(loc, depthCond, /*else=*/true);
+    scf::IfOp depthIf =
+        builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true);
 
     // When depth exceeds limit.
     builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
     FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
         builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo,
-        args.drop_back(nTrailingP), createHeapSortFunc);
+        ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
     builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
-                                 ValueRange(args.drop_back(nTrailingP)));
+                                 ValueRange(args).drop_back(nTrailingP));
+    builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
 
     // When depth doesn't exceed limit.
     builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
-  }
-
-  createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
+    args.back() = depthLimit;
+    std::tie(lo, hi) =
+        createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
+    builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
 
-  if (isHybrid) {
-    // Restore depthLimit.
     builder.setInsertionPointAfter(depthIf);
-    builder.create<memref::StoreOp>(loc, savedDepthLimit, pDepthLimit);
+    lo = depthIf.getResult(0);
+    hi = depthIf.getResult(1);
+    builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
+
+    builder.setInsertionPointAfter(lenIf);
+    lo = lenIf.getResult(0);
+    hi = lenIf.getResult(1);
+  } else {
+    std::tie(lo, hi) =
+        createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
   }
 
-  // After the if-stmt.
-  builder.setInsertionPointAfter(ifOp);
+  // New [lo, hi) for the next while-loop iteration.
+  builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
+
+  // After the while-loop.
+  builder.setInsertionPointAfter(whileOp);
   builder.create<func::ReturnOp>(loc);
 }
 
@@ -1171,9 +1237,6 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
     funcName = kHybridQuickSortFuncNamePrefix;
     funcGenerator = createQuickSortFunc;
     nTrailingP = 1;
-    Value pDepthLimit = rewriter.create<memref::AllocaOp>(
-        loc, MemRefType::get({}, rewriter.getI64Type()));
-    operands.push_back(pDepthLimit);
     // As a heuristics, set depthLimit = 2 * log2(n).
     Value lo = operands[loIdx];
     Value hi = operands[hiIdx];
@@ -1183,9 +1246,7 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
     Value depthLimit = rewriter.create<arith::SubIOp>(
         loc, constantI64(rewriter, loc, 64),
         rewriter.create<math::CountLeadingZerosOp>(loc, len));
-    depthLimit = rewriter.create<arith::ShLIOp>(loc, depthLimit,
-                                                constantI64(rewriter, loc, 1));
-    rewriter.create<memref::StoreOp>(loc, depthLimit, pDepthLimit);
+    operands.push_back(depthLimit);
     break;
   }
   case SparseTensorSortKind::QuickSort:
index dbe0c97..68e5c9b 100644 (file)
@@ -132,13 +132,24 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
 // CHECK-SAME:                                                   %[[Y0:.*]]: memref<?xf32>,
 // CHECK-SAME:                                                   %[[Y1:.*]]: memref<?xindex>) {
 // CHECK:           %[[C1:.*]] = arith.constant 1
-// CHECK:           %[[Lb:.*]] = arith.addi %[[L]], %[[C1]]
-// CHECK:           %[[COND:.*]] = arith.cmpi ult, %[[Lb]], %[[H]]
-// CHECK:           scf.if %[[COND]] {
-// CHECK:             %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
-// CHECK:             func.call @_sparse_qsort_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
-// CHECK:             %[[P2:.*]] = arith.addi %[[P]], %[[C1]] : index
-// CHECK:             func.call @_sparse_qsort_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK:           scf.while (%[[L2:.*]] = %[[L]], %[[H2:.*]] = %[[H]])
+// CHECK:             %[[Lb:.*]] = arith.addi %[[L2]], %[[C1]]
+// CHECK:             %[[COND:.*]] = arith.cmpi ult, %[[Lb]], %[[H2]]
+// CHECK:             scf.condition(%[[COND]]) %[[L2]], %[[H2]]
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[L3:.*]]: index, %[[H3:.*]]: index)
+// CHECK:             %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L3]], %[[H3]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK:             %[[PP1:.*]] = arith.addi %[[P]], %[[C1]] : index
+// CHECK:             %[[LenL:.*]] = arith.subi %[[P]], %[[L3]]
+// CHECK:             %[[LenH:.*]] = arith.subi %[[H3]], %[[P]]
+// CHECK:             %[[Cmp:.*]] = arith.cmpi ule, %[[LenL]], %[[LenH]]
+// CHECK:             %[[L4:.*]] = arith.select %[[Cmp]], %[[PP1]], %[[L3]]
+// CHECK:             %[[H4:.*]] = arith.select %[[Cmp]], %[[H3]], %[[P]]
+// CHECK:             scf.if %[[Cmp]]
+// CHECK:               func.call @_sparse_qsort_1_i8_f32_index(%[[L3]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK:             else
+// CHECK:               func.call @_sparse_qsort_1_i8_f32_index(%[[PP1]], %[[H3]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK:             scf.yield %[[L4]], %[[H4]]
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -187,7 +198,7 @@ func.func @sparse_sort_3d_quick(%arg0: index, %arg1: memref<10xindex>, %arg2: me
 // CHECK-DAG:     func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
 // CHECK-DAG:     func.func private @_sparse_compare_eq_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
-// CHECK-DAG:     func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: memref<i64>) {
+// CHECK-DAG:     func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: i64) {
 // CHECK-LABEL:   func.func @sparse_sort_3d_hybrid
 func.func @sparse_sort_3d_hybrid(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
   sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
@@ -249,7 +260,7 @@ func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2:
 // CHECK-DAG:     func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
 // CHECK-DAG:     func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
-// CHECK-DAG:     func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: memref<i64>) {
+// CHECK-DAG:     func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: i64) {
 // CHECK-LABEL:   func.func @sparse_sort_coo_hybrid
 func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
   sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>