[mlir][sparse] Add rewrite rule for the sort operator.
authorbixia1 <bixia@google.com>
Wed, 28 Sep 2022 19:59:00 +0000 (12:59 -0700)
committerbixia1 <bixia@google.com>
Thu, 29 Sep 2022 18:38:19 +0000 (11:38 -0700)
Add sparse-buffer-rewrite pass to rewrite sparse primitives on buffers to MLIR
implementation.

Add sparse rewrite rule for the sort operator.

Add FileCheck test and integration test.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D134627

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp [new file with mode: 0644]
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir [new file with mode: 0644]
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir [new file with mode: 0644]

index 8565313..e6e65b7 100644 (file)
@@ -166,6 +166,9 @@ void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT);
 std::unique_ptr<Pass> createDenseBufferizationPass(
     const bufferization::OneShotBufferizationOptions &options);
 
+void populateSparseBufferRewriting(RewritePatternSet &patterns);
+std::unique_ptr<Pass> createSparseBufferRewritePass();
+
 //===----------------------------------------------------------------------===//
 // Registration.
 //===----------------------------------------------------------------------===//
index 04b6641..086d1cb 100644 (file)
@@ -178,4 +178,19 @@ def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> {
   ];
 }
 
+def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> {
+  let summary = "Rewrite sparse primitives on buffers to actual code";
+  let description = [{
+    A pass that rewrites sparse primitives on buffers to the MLIR implementation
+    of the primitives. For example, sparse_tensor.sort operator is implemented
+    in this pass.
+  }];
+  let constructor = "mlir::createSparseBufferRewritePass()";
+  let dependentDialects = [
+    "arith::ArithmeticDialect",
+    "memref::MemRefDialect",
+    "scf::SCFDialect",
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
index 51ea50d..abecf46 100644 (file)
@@ -64,6 +64,7 @@ void mlir::sparse_tensor::buildSparseCompiler(
         options.sparseTensorConversionOptions()));
   else
     pm.addPass(createSparseTensorCodegenPass());
+  pm.addPass(createSparseBufferRewritePass());
   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
   pm.addPass(createDenseBufferizationPass(
       getBufferizationOptions(/*analysisOnly=*/false)));
index a1d93f3..8d8c84a 100644 (file)
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   CodegenUtils.cpp
   DenseBufferizationPass.cpp
   Sparsification.cpp
+  SparseBufferRewriting.cpp
   SparseTensorCodegen.cpp
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
new file mode 100644 (file)
index 0000000..7b02d46
--- /dev/null
@@ -0,0 +1,382 @@
+//===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewriting rules that are specific to sparse tensor
+// primitives with memref operands.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodegenUtils.h"
+
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+//===---------------------------------------------------------------------===//
+// Helper methods for the actual rewriting rules.
+//===---------------------------------------------------------------------===//
+
+constexpr uint64_t loIdx = 0;
+constexpr uint64_t hiIdx = 1;
+constexpr uint64_t xStartIdx = 2;
+
+typedef function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, size_t)>
+    FuncGeneratorType;
+
+/// Constructs a function name with this format to facilitate quick sort:
+///   <namePrefix><dim>_<x type>_<y0 type>..._<yn type>
+static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
+                                         StringRef namePrefix, size_t dim,
+                                         ValueRange operands) {
+  nameOstream
+      << namePrefix << dim << "_"
+      << operands[xStartIdx].getType().cast<MemRefType>().getElementType();
+
+  for (Value v : operands.drop_front(xStartIdx + dim))
+    nameOstream << "_" << v.getType().cast<MemRefType>().getElementType();
+}
+
+/// Looks up a function that is appropriate for the given operands being
+/// sorted, and creates such a function if it doesn't exist yet.
+static FlatSymbolRefAttr
+getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
+                         TypeRange resultTypes, StringRef namePrefix,
+                         size_t dim, ValueRange operands,
+                         FuncGeneratorType createFunc) {
+  SmallString<32> nameBuffer;
+  llvm::raw_svector_ostream nameOstream(nameBuffer);
+  getMangledSortHelperFuncName(nameOstream, namePrefix, dim, operands);
+
+  ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
+  MLIRContext *context = module.getContext();
+  auto result = SymbolRefAttr::get(context, nameOstream.str());
+  auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
+
+  if (!func) {
+    // Create the function.
+    OpBuilder::InsertionGuard insertionGuard(builder);
+    builder.setInsertionPoint(insertPoint);
+    Location loc = insertPoint.getLoc();
+    func = builder.create<func::FuncOp>(
+        loc, nameOstream.str(),
+        FunctionType::get(context, operands.getTypes(), resultTypes));
+    func.setPrivate();
+    createFunc(builder, module, func, dim);
+  }
+
+  return result;
+}
+
+/// Creates a function for swapping the values in index i and j for all the
+/// buffers.
+//
+// The generate IR corresponds to this C like algorithm:
+//   if (i != j) {
+//     swap(x0[i], x0[j]);
+//     swap(x1[i], x1[j]);
+//     ...
+//     swap(xn[i], xn[j]);
+//     swap(y0[i], y0[j]);
+//     ...
+//     swap(yn[i], yn[j]);
+//   }
+static void createMaySwapFunc(OpBuilder &builder, ModuleOp unused,
+                              func::FuncOp func, size_t dim) {
+  OpBuilder::InsertionGuard insertionGuard(builder);
+
+  Block *entryBlock = func.addEntryBlock();
+  builder.setInsertionPointToStart(entryBlock);
+
+  Location loc = func.getLoc();
+  ValueRange args = entryBlock->getArguments();
+  Value i = args[0];
+  Value j = args[1];
+  Value cond =
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, i, j);
+  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
+
+  // If i!=j swap values in the buffers.
+  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+  for (auto arg : args.drop_front(xStartIdx)) {
+    Value vi = builder.create<memref::LoadOp>(loc, arg, i);
+    Value vj = builder.create<memref::LoadOp>(loc, arg, j);
+    builder.create<memref::StoreOp>(loc, vj, arg, i);
+    builder.create<memref::StoreOp>(loc, vi, arg, j);
+  }
+
+  builder.setInsertionPointAfter(ifOp);
+  builder.create<func::ReturnOp>(loc);
+}
+
+/// Generates an if-statement to compare x[i] and x[j].
+static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
+                                       Value i, Value j, Value x,
+                                       bool isLastDim) {
+  Value f = constantI1(builder, loc, false);
+  Value t = constantI1(builder, loc, true);
+  Value vi = builder.create<memref::LoadOp>(loc, x, i);
+  Value vj = builder.create<memref::LoadOp>(loc, x, j);
+
+  Value cond =
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
+  scf::IfOp ifOp =
+      builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
+  // If (x[i] < x[j]).
+  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+  builder.create<scf::YieldOp>(loc, t);
+
+  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+  if (isLastDim == 1) {
+    // Finish checking all dimensions.
+    builder.create<scf::YieldOp>(loc, f);
+  } else {
+    cond =
+        builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vj, vi);
+    scf::IfOp ifOp2 =
+        builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
+    // Otherwise if (x[j] < x[i]).
+    builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
+    builder.create<scf::YieldOp>(loc, f);
+
+    // Otherwise check the remaining dimensions.
+    builder.setInsertionPointAfter(ifOp2);
+    builder.create<scf::YieldOp>(loc, ifOp2.getResult(0));
+    // Set up the insertion point for the nested if-stmt that checks the
+    // remaining dimensions.
+    builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
+  }
+
+  return ifOp;
+}
+
+/// Creates a function to compare the xs values in index i and j for all the
+/// dimensions. The function returns true iff xs[i] < xs[j].
+//
+// The generate IR corresponds to this C like algorithm:
+//   if (x0[i] < x0[j])
+//     return true;
+//   else if (x0[j] < x0[i])
+//     return false;
+//   else
+//     if (x1[i] < x1[j])
+//       return true;
+//     else if (x1[j] < x1[i]))
+//       and so on ...
+static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
+                               func::FuncOp func, size_t dim) {
+  OpBuilder::InsertionGuard insertionGuard(builder);
+
+  Block *entryBlock = func.addEntryBlock();
+  builder.setInsertionPointToStart(entryBlock);
+  Location loc = func.getLoc();
+  ValueRange args = entryBlock->getArguments();
+
+  scf::IfOp topIfOp;
+  for (const auto &item : llvm::enumerate(args.slice(xStartIdx, dim))) {
+    scf::IfOp ifOp =
+        createLessThanCompare(builder, loc, args[0], args[1], item.value(),
+                              (item.index() == dim - 1));
+    if (item.index() == 0) {
+      topIfOp = ifOp;
+    } else {
+      OpBuilder::InsertionGuard insertionGuard(builder);
+      builder.setInsertionPointAfter(ifOp);
+      builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
+    }
+  }
+
+  builder.setInsertionPointAfter(topIfOp);
+  builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
+}
+
+/// Creates a function to perform quick sort partition on the values in the
+/// range of index [lo, hi), assuming lo < hi.
+//
+// The generated IR corresponds to this C like algorithm:
+// int partition(lo, hi, data) {
+//   pivot = data[hi - 1];
+//   i = (lo – 1)  // RHS of the pivot found so far.
+//   for (j = lo; j < hi - 1; j++){
+//     if (data[j] < pivot){
+//       i++;
+//       swap data[i] and data[j]
+//     }
+//   }
+//   i++
+//   swap data[i] and data[hi-1])
+//   return i
+// }
+static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
+                                func::FuncOp func, size_t dim) {
+  OpBuilder::InsertionGuard insertionGuard(builder);
+
+  Block *entryBlock = func.addEntryBlock();
+  builder.setInsertionPointToStart(entryBlock);
+
+  MLIRContext *context = module.getContext();
+  Location loc = func.getLoc();
+  ValueRange args = entryBlock->getArguments();
+  Value lo = args[loIdx];
+  Value c1 = constantIndex(builder, loc, 1);
+  Value i = builder.create<arith::SubIOp>(loc, lo, c1);
+  Value him1 = builder.create<arith::SubIOp>(loc, args[hiIdx], c1);
+  scf::ForOp forOp =
+      builder.create<scf::ForOp>(loc, lo, him1, c1, ValueRange{i});
+
+  // Start the for-stmt body.
+  builder.setInsertionPointToStart(forOp.getBody());
+  Value j = forOp.getInductionVar();
+  SmallVector<Value, 6> compareOperands{j, him1};
+  ValueRange xs = args.slice(xStartIdx, dim);
+  compareOperands.append(xs.begin(), xs.end());
+  Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
+  FlatSymbolRefAttr lessThanFunc =
+      getMangledSortHelperFunc(builder, func, {i1Type}, "_sparse_less_than_",
+                               dim, compareOperands, createLessThanFunc);
+  Value cond = builder
+                   .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
+                                         compareOperands)
+                   .getResult(0);
+  scf::IfOp ifOp =
+      builder.create<scf::IfOp>(loc, i.getType(), cond, /*else=*/true);
+
+  // The if-stmt true branch: i++; swap(data[i], data[j]); yield i.
+  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+  Value i1 =
+      builder.create<arith::AddIOp>(loc, forOp.getRegionIterArgs().front(), c1);
+  SmallVector<Value, 6> swapOperands{i1, j};
+  swapOperands.append(args.begin() + xStartIdx, args.end());
+  FlatSymbolRefAttr swapFunc =
+      getMangledSortHelperFunc(builder, func, TypeRange(), "_sparse_may_swap_",
+                               dim, swapOperands, createMaySwapFunc);
+  builder.create<func::CallOp>(loc, swapFunc, TypeRange(), swapOperands);
+  builder.create<scf::YieldOp>(loc, i1);
+
+  // The if-stmt false branch: yield i.
+  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+  builder.create<scf::YieldOp>(loc, forOp.getRegionIterArgs().front());
+
+  // After the if-stmt, yield the updated i value to end the for-stmt body.
+  builder.setInsertionPointAfter(ifOp);
+  builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
+
+  // After the for-stmt: i++; swap(data[i], data[him1]); return i.
+  builder.setInsertionPointAfter(forOp);
+  i1 = builder.create<arith::AddIOp>(loc, forOp.getResult(0), c1);
+  swapOperands[0] = i1;
+  swapOperands[1] = him1;
+  builder.create<func::CallOp>(loc, swapFunc, TypeRange(), swapOperands);
+  builder.create<func::ReturnOp>(loc, i1);
+}
+
+/// Creates a function to perform quick sort on the value in the range of
+/// index [lo, hi).
+//
+// The generate IR corresponds to this C like algorithm:
+// void quickSort(lo, hi, data) {
+//   if (lo < hi) {
+//        p = partition(low, high, data);
+//        quickSort(lo, p, data);
+//        quickSort(p + 1, hi, data);
+//   }
+// }
+static void createSortFunc(OpBuilder &builder, ModuleOp module,
+                           func::FuncOp func, size_t dim) {
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  Block *entryBlock = func.addEntryBlock();
+  builder.setInsertionPointToStart(entryBlock);
+
+  MLIRContext *context = module.getContext();
+  Location loc = func.getLoc();
+  ValueRange args = entryBlock->getArguments();
+  Value lo = args[loIdx];
+  Value hi = args[hiIdx];
+  Value cond =
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, lo, hi);
+  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
+
+  // The if-stmt true branch.
+  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+  FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
+      builder, func, {IndexType::get(context)}, "_sparse_partition_", dim, args,
+      createPartitionFunc);
+  auto p = builder.create<func::CallOp>(
+      loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args));
+
+  SmallVector<Value, 6> lowOperands{lo, p.getResult(0)};
+  lowOperands.append(args.begin() + xStartIdx, args.end());
+  builder.create<func::CallOp>(loc, func, lowOperands);
+
+  SmallVector<Value, 6> highOperands{
+      builder.create<arith::AddIOp>(loc, p.getResult(0),
+                                    constantIndex(builder, loc, 1)),
+      hi};
+  highOperands.append(args.begin() + xStartIdx, args.end());
+  builder.create<func::CallOp>(loc, func, highOperands);
+
+  // After the if-stmt.
+  builder.setInsertionPointAfter(ifOp);
+  builder.create<func::ReturnOp>(loc);
+}
+
+//===---------------------------------------------------------------------===//
+// The actual sparse buffer rewriting rules.
+//===---------------------------------------------------------------------===//
+
+namespace {
+
+/// Sparse rewriting rule for the sort operator.
+struct SortRewriter : public OpRewritePattern<SortOp> {
+public:
+  using OpRewritePattern<SortOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SortOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    SmallVector<Value, 6> operands{constantIndex(rewriter, loc, 0), op.getN()};
+
+    // Convert `values` to have dynamic shape and append them to `operands`.
+    auto addValues = [&](ValueRange values) {
+      for (Value v : values) {
+        auto mtp = v.getType().cast<MemRefType>();
+        if (!mtp.isDynamicDim(0)) {
+          auto new_mtp =
+              MemRefType::get({ShapedType::kDynamicSize}, mtp.getElementType());
+          v = rewriter.create<memref::CastOp>(loc, new_mtp, v);
+        }
+        operands.push_back(v);
+      }
+    };
+    ValueRange xs = op.getXs();
+    addValues(xs);
+    addValues(op.getYs());
+    auto insertPoint = op->getParentOfType<func::FuncOp>();
+    FlatSymbolRefAttr func = getMangledSortHelperFunc(
+        rewriter, insertPoint, TypeRange(), "_sparse_sort_", xs.size(),
+        operands, createSortFunc);
+    rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
+    return success();
+  }
+};
+
+} // namespace
+
+//===---------------------------------------------------------------------===//
+// Methods that add patterns described in this file to a pattern list.
+//===---------------------------------------------------------------------===//
+
+void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns) {
+  patterns.add<SortRewriter>(patterns.getContext());
+}
index 5af4e11..51bcc6a 100644 (file)
@@ -24,6 +24,7 @@ namespace mlir {
 #define GEN_PASS_DEF_SPARSIFICATIONPASS
 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
+#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 } // namespace mlir
 
@@ -198,6 +199,20 @@ struct SparseTensorCodegenPass
   }
 };
 
+struct SparseBufferRewritePass
+    : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
+
+  SparseBufferRewritePass() = default;
+  SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    populateSparseBufferRewriting(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -241,3 +256,7 @@ std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
 std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
   return std::make_unique<SparseTensorCodegenPass>();
 }
+
+std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
+  return std::make_unique<SparseBufferRewritePass>();
+}
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
new file mode 100644 (file)
index 0000000..e40064b
--- /dev/null
@@ -0,0 +1,107 @@
+// RUN: mlir-opt %s --sparse-buffer-rewrite  --canonicalize --cse | FileCheck %s
+
+// CHECK-LABEL:   func.func private @_sparse_less_than_1_i8(
+// CHECK-SAME:                                              %[[I:arg0]]: index,
+// CHECK-SAME:                                              %[[J:.*]]: index,
+// CHECK-SAME:                                              %[[X0:.*]]: memref<?xi8>) -> i1 {
+// CHECK:           %[[VI:.*]] = memref.load %[[X0]]{{\[}}%[[I]]]
+// CHECK:           %[[VJ:.*]] = memref.load %[[X0]]{{\[}}%[[J]]]
+// CHECK:           %[[C:.*]] = arith.cmpi ult, %[[VI]], %[[VJ]]
+// CHECK:           return %[[C]]
+// CHECK:         }
+
+// CHECK-LABEL:   func.func private @_sparse_may_swap_1_i8_f32_index(
+// CHECK-SAME:                                                       %[[I:arg0]]: index,
+// CHECK-SAME:                                                       %[[J:.*]]: index,
+// CHECK-SAME:                                                       %[[X0:.*]]: memref<?xi8>,
+// CHECK-SAME:                                                       %[[Y0:.*]]: memref<?xf32>,
+// CHECK-SAME:                                                       %[[Y1:.*]]: memref<?xindex>) {
+// CHECK:           %[[C:.*]] = arith.cmpi ne, %[[I]], %[[J]]
+// CHECK:           scf.if %[[C]] {
+// CHECK:             %[[Vx0i:.*]] = memref.load %[[X0]]{{\[}}%[[I]]]
+// CHECK:             %[[Vx0j:.*]] = memref.load %[[X0]]{{\[}}%[[J]]]
+// CHECK:             memref.store %[[Vx0j]], %[[X0]]{{\[}}%[[I]]]
+// CHECK:             memref.store %[[Vx0i]], %[[X0]]{{\[}}%[[J]]]
+// CHECK:             %[[Vy0i:.*]] = memref.load %[[Y0]]{{\[}}%[[I]]]
+// CHECK:             %[[Vy0j:.*]] = memref.load %[[Y0]]{{\[}}%[[J]]]
+// CHECK:             memref.store %[[Vy0j]], %[[Y0]]{{\[}}%[[I]]]
+// CHECK:             memref.store %[[Vy0i]], %[[Y0]]{{\[}}%[[J]]]
+// CHECK:             %[[Vy1i:.*]] = memref.load %[[Y1]]{{\[}}%[[I]]]
+// CHECK:             %[[Vy1j:.*]] = memref.load %[[Y1]]{{\[}}%[[J]]]
+// CHECK:             memref.store %[[Vy1j]], %[[Y1]]{{\[}}%[[I]]]
+// CHECK:             memref.store %[[Vy1i]], %[[Y1]]{{\[}}%[[J]]]
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func private @_sparse_partition_1_i8_f32_index(
+// CHECK-SAME:                                                        %[[L:arg0]]: index,
+// CHECK-SAME:                                                        %[[H:.*]]: index,
+// CHECK-SAME:                                                        %[[X0:.*]]: memref<?xi8>,
+// CHECK-SAME:                                                        %[[Y0:.*]]: memref<?xf32>,
+// CHECK-SAME:                                                        %[[Y1:.*]]: memref<?xindex>) -> index {
+// CHECK:           %[[C1:.*]] = arith.constant 1
+// CHECK:           %[[I:.*]] = arith.subi %[[L]], %[[C1]]
+// CHECK:           %[[Hm1:.*]] = arith.subi %[[H]], %[[C1]]
+// CHECK:           %[[I3:.*]] = scf.for %[[J:.*]] = %[[L]] to %[[Hm1]] step %[[C1]] iter_args(%[[I2:.*]] = %[[I]]) -> (index) {
+// CHECK:             %[[COND:.*]] = func.call @_sparse_less_than_1_i8(%[[J]], %[[Hm1]], %[[X0]])
+// CHECK:             %[[IF:.*]] = scf.if %[[COND]] -> (index) {
+// CHECK:               %[[Ip1:.*]] = arith.addi %[[I2]], %[[C1]]
+// CHECK:               func.call @_sparse_may_swap_1_i8_f32_index(%[[Ip1]], %[[J]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK:               scf.yield %[[Ip1]]
+// CHECK:             } else {
+// CHECK:               scf.yield %[[I2]]
+// CHECK:             }
+// CHECK:             scf.yield %[[IF:.*]]
+// CHECK:           }
+// CHECK:           %[[I3p1:.*]] = arith.addi %[[I3:.*]], %[[C1]] : index
+// CHECK:           call @_sparse_may_swap_1_i8_f32_index(%[[I3p1]], %[[Hm1]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK:           return %[[I3p1]]
+// CHECK:         }
+
+// CHECK-LABEL:   func.func private @_sparse_sort_1_i8_f32_index(
+// CHECK-SAME:                                                   %[[L:arg0]]: index,
+// CHECK-SAME:                                                   %[[H:.*]]: index,
+// CHECK-SAME:                                                   %[[X0:.*]]: memref<?xi8>,
+// CHECK-SAME:                                                   %[[Y0:.*]]: memref<?xf32>,
+// CHECK-SAME:                                                   %[[Y1:.*]]: memref<?xindex>) {
+// CHECK:           %[[C1:.*]] = arith.constant 1
+// CHECK:           %[[COND:.*]] = arith.cmpi ult, %[[L]], %[[H]]
+// CHECK:           scf.if %[[COND]] {
+// CHECK:             %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK:             func.call @_sparse_sort_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK:             %[[P2:.*]] = arith.addi %[[P]], %[[C1]] : index
+// CHECK:             func.call @_sparse_sort_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func @sparse_sort_1d2v(
+// CHECK-SAME:                                %[[N:.*]]: index,
+// CHECK-SAME:                                %[[X0:.*]]: memref<10xi8>,
+// CHECK-SAME:                                %[[Y0:.*]]: memref<?xf32>,
+// CHECK-SAME:                                %[[Y1:.*]]: memref<10xindex>) -> (memref<10xi8>, memref<?xf32>, memref<10xindex>) {
+// CHECK:           %[[C0:.*]] = arith.constant 0
+// CHECK:           %[[DX0:.*]] = memref.cast %[[X0]] : memref<10xi8> to memref<?xi8>
+// CHECK:           %[[DY1:.*]] = memref.cast %[[Y1]] : memref<10xindex> to memref<?xindex>
+// CHECK:           call @_sparse_sort_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]])
+// CHECK:           return %[[X0]], %[[Y0]], %[[Y1]]
+// CHECK:         }
+func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?xf32>, %arg3: memref<10xindex>)
+   -> (memref<10xi8>, memref<?xf32>, memref<10xindex>) {
+  sparse_tensor.sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref<?xf32>, memref<10xindex>
+  return %arg1, %arg2, %arg3 : memref<10xi8>, memref<?xf32>, memref<10xindex>
+}
+
+// Only check the generated supporting function now. We have integration test
+// to verify correctness of the generated code.
+//
+// CHECK-DAG:     func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
+// CHECK-DAG:     func.func private @_sparse_may_swap_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
+// CHECK-DAG:     func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
+// CHECK-DAG:     func.func private @_sparse_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
+// CHECK-LABEL:   func.func @sparse_sort_3d
+func.func @sparse_sort_3d(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
+  sparse_tensor.sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+  return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
new file mode 100644 (file)
index 0000000..4db44ad
--- /dev/null
@@ -0,0 +1,100 @@
+// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+module {
+  // Stores 5 values to the memref buffer.
+  func.func @storeValuesTo(%b: memref<?xi32>, %v0: i32, %v1: i32, %v2: i32,
+    %v3: i32, %v4: i32) -> () {
+    %i0 = arith.constant 0 : index
+    %i1 = arith.constant 1 : index
+    %i2 = arith.constant 2 : index
+    %i3 = arith.constant 3 : index
+    %i4 = arith.constant 4 : index
+    memref.store %v0, %b[%i0] : memref<?xi32>
+    memref.store %v1, %b[%i1] : memref<?xi32>
+    memref.store %v2, %b[%i2] : memref<?xi32>
+    memref.store %v3, %b[%i3] : memref<?xi32>
+    memref.store %v4, %b[%i4] : memref<?xi32>
+    return
+  }
+
+  // The main driver.
+  func.func @entry() {
+    %c0 = arith.constant 0 : i32
+    %c1 = arith.constant 1 : i32
+    %c2 = arith.constant 2 : i32
+    %c3 = arith.constant 3 : i32
+    %c4 = arith.constant 4 : i32
+    %c5 = arith.constant 5 : i32
+    %c6 = arith.constant 6 : i32
+    %c7 = arith.constant 7 : i32
+    %c8 = arith.constant 8 : i32
+    %c9 = arith.constant 9 : i32
+    %c10 = arith.constant 10 : i32
+    %c100 = arith.constant 100 : i32
+
+    %i0 = arith.constant 0 : index
+    %i4 = arith.constant 4 : index
+    %i5 = arith.constant 5 : index
+
+    // Prepare a buffer.
+    %x0s = memref.alloc() : memref<5xi32>
+    %x0 = memref.cast %x0s : memref<5xi32> to memref<?xi32>
+    call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+
+    // Sort 0 elements.
+    // CHECK: ( 10, 2, 0, 5, 1 )
+    sparse_tensor.sort %i0, %x0 : memref<?xi32>
+    %x0v0 = vector.transfer_read %x0[%i0], %c100: memref<?xi32>, vector<5xi32>
+    vector.print %x0v0 : vector<5xi32>
+
+    // Sort the first 4 elements, with the last valid value untouched.
+    // CHECK: ( 0, 2, 5, 10, 1 )
+    sparse_tensor.sort %i4, %x0 : memref<?xi32>
+    %x0v1 = vector.transfer_read %x0[%i0], %c100: memref<?xi32>, vector<5xi32>
+    vector.print %x0v1 : vector<5xi32>
+
+    // Prepare more buffers of different dimensions.
+    %x1s = memref.alloc() : memref<10xi32>
+    %x1 = memref.cast %x1s : memref<10xi32> to memref<?xi32>
+    %x2s = memref.alloc() : memref<6xi32>
+    %x2 = memref.cast %x2s : memref<6xi32> to memref<?xi32>
+    %y0s = memref.alloc() : memref<7xi32>
+    %y0 = memref.cast %y0s : memref<7xi32> to memref<?xi32>
+    call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesTo(%x2, %c2, %c4, %c4, %c7, %c9)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+
+    // Sort "parallel arrays".
+    // CHECK: ( 0, 1, 2, 5, 10 )
+    // CHECK: ( 3, 3, 1, 10, 1 )
+    // CHECK: ( 4, 9, 4, 7, 2 )
+    // CHECK: ( 8, 7, 10, 9, 6 )
+    sparse_tensor.sort %i5, %x0, %x1, %x2 jointly %y0
+      : memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
+    %x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32>, vector<5xi32>
+    vector.print %x0v2 : vector<5xi32>
+    %x1v = vector.transfer_read %x1[%i0], %c100: memref<?xi32>, vector<5xi32>
+    vector.print %x1v : vector<5xi32>
+    %x2v = vector.transfer_read %x2[%i0], %c100: memref<?xi32>, vector<5xi32>
+    vector.print %x2v : vector<5xi32>
+    %y0v = vector.transfer_read %y0[%i0], %c100: memref<?xi32>, vector<5xi32>
+    vector.print %y0v : vector<5xi32>
+
+    // Release the buffers.
+    memref.dealloc %x0 : memref<?xi32>
+    memref.dealloc %x1 : memref<?xi32>
+    memref.dealloc %x2 : memref<?xi32>
+    memref.dealloc %y0 : memref<?xi32>
+    return
+  }
+}