[mlir][sparse] Improve sort operation by generating inlined code to compare values.
authorbixia1 <bixia@google.com>
Thu, 9 Mar 2023 17:38:09 +0000 (09:38 -0800)
committerbixia1 <bixia@google.com>
Tue, 14 Mar 2023 22:14:49 +0000 (15:14 -0700)
Previously, we generate function calls to compare values for sorting. It turns
out that the compiler doesn't inline those function calls. We now directly
generate inlined code. Also, modify the code for comparing values to use less
number of branches.

This improves all sort implementation in general. For arabic-2005.mtx CSR, the
improvement is around 25%.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D145442

mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir

index b8cf623..321ef84 100644 (file)
@@ -34,8 +34,6 @@ static constexpr uint64_t loIdx = 0;
 static constexpr uint64_t hiIdx = 1;
 static constexpr uint64_t xStartIdx = 2;
 
-static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_";
-static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_";
 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
 static constexpr const char kBinarySearchFuncNamePrefix[] =
     "_sparse_binary_search_";
@@ -181,27 +179,24 @@ static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
   forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair);
 }
 
-/// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to
-/// compare each pair is create via `compareBuilder`.
-static void createCompareFuncImplementation(
-    OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx,
-    uint64_t ny, bool isCoo,
-    function_ref<scf::IfOp(OpBuilder &, Location, Value, Value, Value, bool)>
+/// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
+/// each pair is create via `compareBuilder`.
+static Value createInlinedCompareImplementation(
+    OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
+    bool isCoo,
+    function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
         compareBuilder) {
-  OpBuilder::InsertionGuard insertionGuard(builder);
-
-  Block *entryBlock = func.addEntryBlock();
-  builder.setInsertionPointToStart(entryBlock);
-  Location loc = func.getLoc();
-  ValueRange args = entryBlock->getArguments();
-
-  scf::IfOp topIfOp;
+  Value result;
   auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
-    scf::IfOp ifOp = compareBuilder(builder, loc, i, j, buffer, (k == nx - 1));
-    if (k == 0) {
-      topIfOp = ifOp;
-    } else {
+    bool isFirstDim = (k == 0);
+    bool isLastDim = (k == nx - 1);
+    Value val =
+        compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
+    if (isFirstDim) {
+      result = val;
+    } else if (!isLastDim) {
       OpBuilder::InsertionGuard insertionGuard(builder);
+      auto ifOp = cast<scf::IfOp>(val.getDefiningOp());
       builder.setInsertionPointAfter(ifOp);
       builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
     }
@@ -209,38 +204,44 @@ static void createCompareFuncImplementation(
 
   forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder);
 
-  builder.setInsertionPointAfter(topIfOp);
-  builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
+  builder.setInsertionPointAfterValue(result);
+  return result;
 }
 
-/// Generates an if-statement to compare whether x[i] is equal to x[j].
-static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i,
-                                 Value j, Value x, bool isLastDim) {
-  Value f = constantI1(builder, loc, false);
-  Value t = constantI1(builder, loc, true);
+/// Generates code to compare whether x[i] is equal to x[j] and returns the
+/// result of the comparison.
+static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
+                             Value x, bool isFirstDim, bool isLastDim) {
   Value vi = builder.create<memref::LoadOp>(loc, x, i);
   Value vj = builder.create<memref::LoadOp>(loc, x, j);
 
-  Value cond =
-      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
-  scf::IfOp ifOp =
-      builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
-
-  // x[1] != x[j]:
-  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-  builder.create<scf::YieldOp>(loc, f);
+  Value res;
+  if (isLastDim) {
+    res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
+    // For 1D, we create a compare without any control flow. Otherwise, we
+    // create YieldOp to return the result in the nested if-stmt.
+    if (!isFirstDim)
+      builder.create<scf::YieldOp>(loc, res);
+  } else {
+    Value ne =
+        builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
+    scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
+                                               ne, /*else=*/true);
+    // If (x[i] != x[j]).
+    builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+    Value f = constantI1(builder, loc, false);
+    builder.create<scf::YieldOp>(loc, f);
 
-  // x[i] == x[j]:
-  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-  if (isLastDim == 1) {
-    // Finish checking all dimensions.
-    builder.create<scf::YieldOp>(loc, t);
+    // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
+    // checks the remaining dimensions.
+    builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+    res = ifOp.getResult(0);
   }
 
-  return ifOp;
+  return res;
 }
 
-/// Creates a function to compare whether xs[i] is equal to xs[j].
+/// Creates code to compare whether xs[i] is equal to xs[j].
 //
 // The generate IR corresponds to this C like algorithm:
 //   if (x0[i] != x0[j])
@@ -250,77 +251,68 @@ static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i,
 //       return false;
 //     else if (x2[2] != x2[j]))
 //       and so on ...
-static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused,
-                                func::FuncOp func, uint64_t nx, uint64_t ny,
-                                bool isCoo, uint32_t nTrailingP = 0) {
+static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
+                                    ValueRange args, uint64_t nx, uint64_t ny,
+                                    bool isCoo, uint32_t nTrailingP = 0) {
   // Compare functions don't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
-  createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
-                                  createEqCompare);
+  return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
+                                            createEqCompare);
 }
 
-/// Generates an if-statement to compare whether x[i] is less than x[j].
-static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
-                                       Value i, Value j, Value x,
-                                       bool isLastDim) {
-  Value f = constantI1(builder, loc, false);
-  Value t = constantI1(builder, loc, true);
+/// Generates code to compare whether x[i] is less than x[j] and returns the
+/// result of the comparison.
+static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
+                                   Value j, Value x, bool isFirstDim,
+                                   bool isLastDim) {
   Value vi = builder.create<memref::LoadOp>(loc, x, i);
   Value vj = builder.create<memref::LoadOp>(loc, x, j);
 
-  Value cond =
-      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
-  scf::IfOp ifOp =
-      builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
-  // If (x[i] < x[j]).
-  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-  builder.create<scf::YieldOp>(loc, t);
-
-  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-  if (isLastDim == 1) {
-    // Finish checking all dimensions.
-    builder.create<scf::YieldOp>(loc, f);
+  Value res;
+  if (isLastDim) {
+    res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
+    // For 1D, we create a compare without any control flow. Otherwise, we
+    // create YieldOp to return the result in the nested if-stmt.
+    if (!isFirstDim)
+      builder.create<scf::YieldOp>(loc, res);
   } else {
-    cond =
-        builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vj, vi);
-    scf::IfOp ifOp2 =
-        builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
-    // Otherwise if (x[j] < x[i]).
-    builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
-    builder.create<scf::YieldOp>(loc, f);
-
-    // Otherwise check the remaining dimensions.
-    builder.setInsertionPointAfter(ifOp2);
-    builder.create<scf::YieldOp>(loc, ifOp2.getResult(0));
-    // Set up the insertion point for the nested if-stmt that checks the
-    // remaining dimensions.
-    builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
+    Value ne =
+        builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
+    scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
+                                               ne, /*else=*/true);
+    // If (x[i] != x[j]).
+    builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+    Value lt =
+        builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
+    builder.create<scf::YieldOp>(loc, lt);
+
+    // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
+    // checks the remaining dimensions.
+    builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+    res = ifOp.getResult(0);
   }
 
-  return ifOp;
+  return res;
 }
 
-/// Creates a function to compare whether xs[i] is less than xs[j].
+/// Creates code to compare whether xs[i] is less than xs[j].
 //
 // The generate IR corresponds to this C like algorithm:
-//   if (x0[i] < x0[j])
-//     return true;
-//   else if (x0[j] < x0[i])
-//     return false;
+//   if (x0[i] != x0[j])
+//     return x0[i] < x0[j];
+//   else if (x1[j] != x1[i])
+//     return x1[i] < x1[j];
 //   else
-//     if (x1[i] < x1[j])
-//       return true;
-//     else if (x1[j] < x1[i]))
 //       and so on ...
-static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
-                               func::FuncOp func, uint64_t nx, uint64_t ny,
-                               bool isCoo, uint32_t nTrailingP = 0) {
+static Value createInlinedLessThan(OpBuilder &builder, Location loc,
+                                   ValueRange args, uint64_t nx, uint64_t ny,
+                                   bool isCoo, uint32_t nTrailingP = 0) {
   // Compare functions don't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
-  createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
-                                  createLessThanCompare);
+  return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
+                                            createLessThanCompare);
 }
 
 /// Creates a function to use a binary search to find the insertion point for
@@ -379,15 +371,8 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
   uint64_t numXBuffers = isCoo ? 1 : nx;
   compareOperands.append(args.begin() + xStartIdx,
                          args.begin() + xStartIdx + numXBuffers);
-  Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
-  FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
-      builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
-      compareOperands, createLessThanFunc, nTrailingP);
-  Value cond2 = builder
-                    .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
-                                          compareOperands)
-                    .getResult(0);
-
+  Value cond2 =
+      createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
   // Update lo and hi for the WhileOp as follows:
   //   if (xs[p] < xs[mid]))
   //     hi = mid;
@@ -428,15 +413,8 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
     compareOperands.push_back(before->getArgument(0));
   }
   compareOperands.append(xs.begin(), xs.end());
-  MLIRContext *context = module.getContext();
-  Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
-  FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
-      builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
-      compareOperands, createLessThanFunc);
-  Value cond = builder
-                   .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
-                                         compareOperands)
-                   .getResult(0);
+  Value cond =
+      createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
 
   Block *after =
@@ -450,14 +428,8 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
   builder.setInsertionPointAfter(whileOp);
   compareOperands[0] = i;
   compareOperands[1] = p;
-  FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc(
-      builder, func, {i1Type}, kCompareEqFuncNamePrefix, nx, ny, isCoo,
-      compareOperands, createEqCompareFunc);
   Value compareEq =
-      builder
-          .create<func::CallOp>(loc, compareEqFunc, TypeRange{i1Type},
-                                compareOperands)
-          .getResult(0);
+      createInlinedEqCompare(builder, loc, compareOperands, nx, ny, isCoo);
 
   return std::make_pair(whileOp.getResult(0), compareEq);
 }
@@ -485,14 +457,10 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
                          args.begin() + xStartIdx + numXBuffers);
   Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
   SmallVector<Type, 1> cmpTypes{i1Type};
-  FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
-      builder, func, cmpTypes, kLessThanFuncNamePrefix, nx, ny, isCoo,
-      compareOperands, createLessThanFunc);
   Location loc = func.getLoc();
   // Compare data[mi] < data[lo].
   Value cond1 =
-      builder.create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
-          .getResult(0);
+      createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
   SmallVector<Type, 1> ifTypes{lo.getType()};
   scf::IfOp ifOp1 =
       builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
@@ -502,11 +470,9 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
   auto createFindMedian = [&](Value a, Value b, Value c) -> scf::IfOp {
     compareOperands[0] = c;
     compareOperands[1] = a;
-    // Compare data[c]] < data[a].
+    // Compare data[c] < data[b].
     Value cond2 =
-        builder
-            .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
-            .getResult(0);
+        createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
     scf::IfOp ifOp2 =
         builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
     builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
@@ -514,9 +480,7 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
     compareOperands[1] = b;
     // Compare data[c] < data[b].
     Value cond3 =
-        builder
-            .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
-            .getResult(0);
+        createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
     builder.create<scf::YieldOp>(
         loc, ValueRange{builder.create<arith::SelectOp>(loc, cond3, b, c)});
     builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
@@ -758,10 +722,6 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
   uint64_t numXBuffers = isCoo ? 1 : nx;
   compareOperands.append(args.begin() + xStartIdx,
                          args.begin() + xStartIdx + numXBuffers);
-  Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
-  FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
-      builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
-      compareOperands, createLessThanFunc);
 
   // Generate code to inspect the children of 'r' and return the larger child
   // as follows:
@@ -784,10 +744,8 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
     // Compare data[left] < data[right].
     compareOperands[0] = lChildIdx;
     compareOperands[1] = rChildIdx;
-    Value cond2 = builder
-                      .create<func::CallOp>(loc, lessThanFunc,
-                                            TypeRange{i1Type}, compareOperands)
-                      .getResult(0);
+    Value cond2 =
+        createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
     scf::IfOp if2 =
         builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
     builder.setInsertionPointToStart(&if2.getThenRegion().front());
@@ -818,10 +776,8 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
   childIdx = before->getArgument(2);
   compareOperands[0] = start;
   compareOperands[1] = childIdx;
-  Value cond = builder
-                   .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
-                                         compareOperands)
-                   .getResult(0);
+  Value cond =
+      createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
 
   // The after-region of the WhileOp.
index 68e5c9b..84721f7 100644 (file)
@@ -75,54 +75,132 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
 
 // -----
 
-// CHECK-LABEL:   func.func private @_sparse_less_than_1_i8(
-// CHECK-SAME:                                              %[[I:arg0]]: index,
-// CHECK-SAME:                                              %[[J:.*]]: index,
-// CHECK-SAME:                                              %[[X0:.*]]: memref<?xi8>) -> i1 {
-// CHECK:           %[[VI:.*]] = memref.load %[[X0]]{{\[}}%[[I]]]
-// CHECK:           %[[VJ:.*]] = memref.load %[[X0]]{{\[}}%[[J]]]
-// CHECK:           %[[C:.*]] = arith.cmpi ult, %[[VI]], %[[VJ]]
-// CHECK:           return %[[C]]
-// CHECK:         }
-
-// CHECK-LABEL:   func.func private @_sparse_compare_eq_1_i8(
-// CHECK-SAME:                                               %[[I:arg0]]: index,
-// CHECK-SAME:                                               %[[J:.*]]: index,
-// CHECK-SAME:                                               %[[X0:.*]]: memref<?xi8>) -> i1 {
-// CHECK:           %[[VI:.*]] = memref.load %[[X0]]{{\[}}%[[I]]]
-// CHECK:           %[[VJ:.*]] = memref.load %[[X0]]{{\[}}%[[J]]]
-// CHECK:           %[[C:.*]] = arith.cmpi eq, %[[VI]], %[[VJ]]
-// CHECK:           return %[[C]]
-// CHECK:         }
-
 // CHECK-LABEL:   func.func private @_sparse_partition_1_i8_f32_index(
-// CHECK-SAME:                                                        %[[L:arg0]]: index,
-// CHECK-SAME:                                                        %[[H:.*]]: index,
-// CHECK-SAME:                                                        %[[X0:.*]]: memref<?xi8>,
-// CHECK-SAME:                                                        %[[Y0:.*]]: memref<?xf32>,
-// CHECK-SAME:                                                        %[[Y1:.*]]: memref<?xindex>) -> index {
-// CHECK:           %[[C1:.*]] = arith.constant 1
-// CHECK:           %[[VAL_6:.*]] = arith.constant -
-// CHECK:           %[[SUM:.*]] = arith.addi %[[L]], %[[H]]
-// CHECK:           %[[P:.*]] = arith.shrui %[[SUM]], %[[C1]]
-// CHECK:           %[[J:.*]] = arith.subi %[[H]], %[[C1]]
-// CHECK:           %[[W:.*]]:3 = scf.while (%[[Ib:.*]] = %[[L]], %[[Jb:.*]] = %[[J]], %[[pb:.*]] = %[[P]]) : (index, index, index) -> (index, index, index) {
-// CHECK:             %[[Cn:.*]] = arith.cmpi ult, %[[Ib]], %[[Jb]]
-// CHECK:             scf.condition(%[[Cn]]) %[[Ib]], %[[Jb]], %[[pb]]
+// CHECK-SAME:    %[[VAL_0:.*0]]: index,
+// CHECK-SAME:    %[[VAL_1:.*1]]: index,
+// CHECK-SAME:    %[[VAL_2:.*2]]: memref<?xi8>,
+// CHECK-SAME:    %[[VAL_3:.*3]]: memref<?xf32>,
+// CHECK-SAME:    %[[VAL_4:.*4]]: memref<?xindex>) -> index {
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant -1
+// CHECK:           %[[VAL_7:.*]] = arith.addi %[[VAL_0]], %[[VAL_1]]
+// CHECK:           %[[VAL_8:.*]] = arith.shrui %[[VAL_7]], %[[VAL_5]]
+// CHECK:           %[[VAL_9:.*]] = arith.subi %[[VAL_1]], %[[VAL_5]]
+// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_8]]]
+// CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]]
+// CHECK:           %[[VAL_12:.*]] = arith.cmpi ult, %[[VAL_10]], %[[VAL_11]]
+// CHECK:           %[[VAL_13:.*]] = scf.if %[[VAL_12]] -> (index) {
+// CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
+// CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]]
+// CHECK:             %[[VAL_16:.*]] = arith.cmpi ult, %[[VAL_14]], %[[VAL_15]]
+// CHECK:             %[[VAL_17:.*]] = scf.if %[[VAL_16]] -> (index) {
+// CHECK:               %[[VAL_18:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
+// CHECK:               %[[VAL_19:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_8]]]
+// CHECK:               %[[VAL_20:.*]] = arith.cmpi ult, %[[VAL_18]], %[[VAL_19]]
+// CHECK:               %[[VAL_21:.*]] = arith.select %[[VAL_20]], %[[VAL_8]], %[[VAL_9]]
+// CHECK:               scf.yield %[[VAL_21]]
+// CHECK:             } else {
+// CHECK:               scf.yield %[[VAL_0]]
+// CHECK:             }
+// CHECK:             scf.yield %[[VAL_22:.*]]
+// CHECK:           } else {
+// CHECK:             %[[VAL_23:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
+// CHECK:             %[[VAL_24:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_8]]]
+// CHECK:             %[[VAL_25:.*]] = arith.cmpi ult, %[[VAL_23]], %[[VAL_24]]
+// CHECK:             %[[VAL_26:.*]] = scf.if %[[VAL_25]] -> (index) {
+// CHECK:               %[[VAL_27:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]]
+// CHECK:               %[[VAL_28:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_0]]]
+// CHECK:               %[[VAL_29:.*]] = arith.cmpi ult, %[[VAL_27]], %[[VAL_28]]
+// CHECK:               %[[VAL_30:.*]] = arith.select %[[VAL_29]], %[[VAL_0]], %[[VAL_9]]
+// CHECK:               scf.yield %[[VAL_30]]
+// CHECK:             } else {
+// CHECK:               scf.yield %[[VAL_8]]
+// CHECK:             }
+// CHECK:             scf.yield %[[VAL_31:.*]]
+// CHECK:           }
+// CHECK:           %[[VAL_32:.*]] = arith.cmpi ne, %[[VAL_8]], %[[VAL_13:.*]]
+// CHECK:           scf.if %[[VAL_32]] {
+// CHECK:             %[[VAL_34:.*]] = memref.load %[[VAL_2]]{{\[}}
+// CHECK:             %[[VAL_35:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_8]]]
+// CHECK:             memref.store %[[VAL_35]], %[[VAL_2]]
+// CHECK:             memref.store %[[VAL_34]], %[[VAL_2]]{{\[}}%[[VAL_8]]]
+// CHECK:             %[[VAL_36:.*]] = memref.load %[[VAL_3]]
+// CHECK:             %[[VAL_37:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_8]]]
+// CHECK:             memref.store %[[VAL_37]], %[[VAL_3]]
+// CHECK:             memref.store %[[VAL_36]], %[[VAL_3]]{{\[}}%[[VAL_8]]]
+// CHECK:             %[[VAL_38:.*]] = memref.load %[[VAL_4]]
+// CHECK:             %[[VAL_39:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_8]]]
+// CHECK:             memref.store %[[VAL_39]], %[[VAL_4]]
+// CHECK:             memref.store %[[VAL_38]], %[[VAL_4]]{{\[}}%[[VAL_8]]]
+// CHECK:           }
+// CHECK:           %[[VAL_40:.*]]:3 = scf.while (%[[VAL_41:.*]] = %[[VAL_0]], %[[VAL_42:.*]] = %[[VAL_9]], %[[VAL_43:.*]] = %[[VAL_8]])
+// CHECK:             %[[VAL_44:.*]] = arith.cmpi ult, %[[VAL_41]], %[[VAL_42]]
+// CHECK:             scf.condition(%[[VAL_44]]) %[[VAL_41]], %[[VAL_42]], %[[VAL_43]]
 // CHECK:           } do {
-// CHECK:           ^bb0(%[[Ia:.*]]: index, %[[Ja:.*]]: index, %[[Pa:.*]]: index):
-// CHECK:             %[[I2:.*]] = scf.while
-// CHECK:             %[[Ieq:.*]] = func.call @_sparse_compare_eq_1_i8(%[[I2:.*]], %[[Pa]], %[[X0]])
-// CHECK:             %[[J2:.*]] = scf.while
-// CHECK:             %[[Jeq:.*]] = func.call @_sparse_compare_eq_1_i8(%[[J2:.*]], %[[Pa]], %[[X0]])
-// CHECK:             %[[Cn2:.*]] = arith.cmpi ult, %[[I2]], %[[J2]]
-// CHECK:             %[[If:.*]]:3 = scf.if %[[Cn2]] -> (index, index, index) {
+// CHECK:           ^bb0(%[[VAL_45:.*]]: index, %[[VAL_46:.*]]: index, %[[VAL_47:.*]]: index):
+// CHECK:             %[[VAL_48:.*]] = scf.while (%[[VAL_49:.*]] = %[[VAL_45]]) : (index) -> index {
+// CHECK:               %[[VAL_50:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_49]]]
+// CHECK:               %[[VAL_51:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_47]]]
+// CHECK:               %[[VAL_52:.*]] = arith.cmpi ult, %[[VAL_50]], %[[VAL_51]]
+// CHECK:               scf.condition(%[[VAL_52]]) %[[VAL_49]]
+// CHECK:             } do {
+// CHECK:             ^bb0(%[[VAL_53:.*]]: index):
+// CHECK:               %[[VAL_54:.*]] = arith.addi %[[VAL_53]], %[[VAL_5]]
+// CHECK:               scf.yield %[[VAL_54]]
+// CHECK:             }
+// CHECK:             %[[VAL_55:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_56:.*]]]
+// CHECK:             %[[VAL_57:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_47]]]
+// CHECK:             %[[VAL_58:.*]] = arith.cmpi eq, %[[VAL_55]], %[[VAL_57]]
+// CHECK:             %[[VAL_59:.*]] = scf.while (%[[VAL_60:.*]] = %[[VAL_46]]) : (index) -> index {
+// CHECK:               %[[VAL_61:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_47]]]
+// CHECK:               %[[VAL_62:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_60]]]
+// CHECK:               %[[VAL_63:.*]] = arith.cmpi ult, %[[VAL_61]], %[[VAL_62]]
+// CHECK:               scf.condition(%[[VAL_63]]) %[[VAL_60]]
+// CHECK:             } do {
+// CHECK:             ^bb0(%[[VAL_64:.*]]: index):
+// CHECK:               %[[VAL_65:.*]] = arith.addi %[[VAL_64]], %[[VAL_6]]
+// CHECK:               scf.yield %[[VAL_65]]
+// CHECK:             }
+// CHECK:             %[[VAL_66:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_67:.*]]]
+// CHECK:             %[[VAL_68:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_47]]]
+// CHECK:             %[[VAL_69:.*]] = arith.cmpi eq, %[[VAL_66]], %[[VAL_68]]
+// CHECK:             %[[VAL_70:.*]] = arith.cmpi ult, %[[VAL_56]], %[[VAL_67]]
+// CHECK:             %[[VAL_71:.*]]:3 = scf.if %[[VAL_70]] -> (index, index, index) {
+// CHECK:               %[[VAL_72:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_56]]]
+// CHECK:               %[[VAL_73:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_67]]]
+// CHECK:               memref.store %[[VAL_73]], %[[VAL_2]]{{\[}}%[[VAL_56]]]
+// CHECK:               memref.store %[[VAL_72]], %[[VAL_2]]{{\[}}%[[VAL_67]]]
+// CHECK:               %[[VAL_74:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_56]]]
+// CHECK:               %[[VAL_75:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_67]]]
+// CHECK:               memref.store %[[VAL_75]], %[[VAL_3]]{{\[}}%[[VAL_56]]]
+// CHECK:               memref.store %[[VAL_74]], %[[VAL_3]]{{\[}}%[[VAL_67]]]
+// CHECK:               %[[VAL_76:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_56]]]
+// CHECK:               %[[VAL_77:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_67]]]
+// CHECK:               memref.store %[[VAL_77]], %[[VAL_4]]{{\[}}%[[VAL_56]]]
+// CHECK:               memref.store %[[VAL_76]], %[[VAL_4]]{{\[}}%[[VAL_67]]]
+// CHECK:               %[[VAL_78:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_47]]
+// CHECK:               %[[VAL_79:.*]] = scf.if %[[VAL_78]] -> (index) {
+// CHECK:                 scf.yield %[[VAL_67]]
+// CHECK:               } else {
+// CHECK:                 %[[VAL_80:.*]] = arith.cmpi eq, %[[VAL_67]], %[[VAL_47]]
+// CHECK:                 %[[VAL_81:.*]] = arith.select %[[VAL_80]], %[[VAL_56]], %[[VAL_47]]
+// CHECK:                 scf.yield %[[VAL_81]]
+// CHECK:               }
+// CHECK:               %[[VAL_82:.*]] = arith.andi %[[VAL_58]], %[[VAL_69]]
+// CHECK:               %[[VAL_83:.*]]:2 = scf.if %[[VAL_82]] -> (index, index) {
+// CHECK:                 %[[VAL_84:.*]] = arith.addi %[[VAL_56]], %[[VAL_5]]
+// CHECK:                 %[[VAL_85:.*]] = arith.subi %[[VAL_67]], %[[VAL_5]]
+// CHECK:                 scf.yield %[[VAL_84]], %[[VAL_85]]
+// CHECK:               } else {
+// CHECK:                 scf.yield %[[VAL_56]], %[[VAL_67]]
+// CHECK:               }
+// CHECK:               scf.yield %[[VAL_86:.*]]#0, %[[VAL_86]]#1, %[[VAL_87:.*]]
 // CHECK:             } else {
-// CHECK:               scf.yield %[[I2]], %[[J2]], %[[Pa]]
+// CHECK:               scf.yield %[[VAL_56]], %[[VAL_67]], %[[VAL_47]]
 // CHECK:             }
-// CHECK:             scf.yield %[[If:.*]]#0, %[[If]]#1, %[[If]]#2
+// CHECK:             scf.yield %[[VAL_88:.*]]#0, %[[VAL_88]]#1, %[[VAL_88]]#2
 // CHECK:           }
-// CHECK:           return %[[W:.*]]#2
+// CHECK:           return %[[VAL_89:.*]]#2
 // CHECK:         }
 
 // CHECK-LABEL:   func.func private @_sparse_qsort_1_i8_f32_index(
@@ -176,8 +254,6 @@ func.func @sparse_sort_1d2v_quick(%arg0: index, %arg1: memref<10xi8>, %arg2: mem
 // Only check the generated supporting function now. We have integration test
 // to verify correctness of the generated code.
 //
-// CHECK-DAG:     func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
-// CHECK-DAG:     func.func private @_sparse_compare_eq_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
 // CHECK-DAG:     func.func private @_sparse_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
 // CHECK-LABEL:   func.func @sparse_sort_3d_quick
@@ -191,12 +267,10 @@ func.func @sparse_sort_3d_quick(%arg0: index, %arg1: memref<10xindex>, %arg2: me
 // Only check the generated supporting function now. We have integration test
 // to verify correctness of the generated code.
 //
-// CHECK-DAG:     func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
 // CHECK-DAG:     func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
 // CHECK-DAG:     func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: index) {
 // CHECK-DAG:     func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
-// CHECK-DAG:     func.func private @_sparse_compare_eq_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
 // CHECK-DAG:     func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: i64) {
 // CHECK-LABEL:   func.func @sparse_sort_3d_hybrid
@@ -210,7 +284,6 @@ func.func @sparse_sort_3d_hybrid(%arg0: index, %arg1: memref<10xindex>, %arg2: m
 // Only check the generated supporting functions. We have integration test to
 // verify correctness of the generated code.
 //
-// CHECK-DAG:     func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
 // CHECK-DAG:     func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
 // CHECK-LABEL:   func.func @sparse_sort_3d_stable
@@ -224,7 +297,6 @@ func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: m
 // Only check the generated supporting functions. We have integration test to
 // verify correctness of the generated code.
 //
-// CHECK-DAG:     func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: index) {
 // CHECK-DAG:     func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
 // CHECK-LABEL:   func.func @sparse_sort_3d_heap
@@ -238,8 +310,6 @@ func.func @sparse_sort_3d_heap(%arg0: index, %arg1: memref<10xindex>, %arg2: mem
 // Only check the generated supporting functions. We have integration test to
 // verify correctness of the generated code.
 //
-// CHECK-DAG:     func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
-// CHECK-DAG:     func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
 // CHECK-DAG:     func.func private @_sparse_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
 // CHECK-LABEL:   func.func @sparse_sort_coo_quick
@@ -253,12 +323,10 @@ func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2:
 // Only check the generated supporting functions. We have integration test to
 // verify correctness of the generated code.
 //
-// CHECK-DAG:     func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
 // CHECK-DAG:     func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
 // CHECK-DAG:     func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
 // CHECK-DAG:     func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
-// CHECK-DAG:     func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
 // CHECK-DAG:     func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: i64) {
 // CHECK-LABEL:   func.func @sparse_sort_coo_hybrid
@@ -272,7 +340,6 @@ func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2:
 // Only check the generated supporting functions. We have integration test to
 // verify correctness of the generated code.
 //
-// CHECK-DAG:     func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
 // CHECK-DAG:     func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
 // CHECK-LABEL:   func.func @sparse_sort_coo_stable
@@ -286,7 +353,6 @@ func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2:
 // Only check the generated supporting functions. We have integration test to
 // verify correctness of the generated code.
 //
-// CHECK-DAG:     func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
 // CHECK-DAG:     func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
 // CHECK-LABEL:   func.func @sparse_sort_coo_heap