--- /dev/null
+//===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewriting rules that are specific to sparse tensor
+// primitives with memref operands.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodegenUtils.h"
+
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+//===---------------------------------------------------------------------===//
+// Helper methods for the actual rewriting rules.
+//===---------------------------------------------------------------------===//
+
+constexpr uint64_t loIdx = 0;
+constexpr uint64_t hiIdx = 1;
+constexpr uint64_t xStartIdx = 2;
+
+typedef function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, size_t)>
+ FuncGeneratorType;
+
+/// Constructs a function name with this format to facilitate quick sort:
+/// <namePrefix><dim>_<x type>_<y0 type>..._<yn type>
+static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
+ StringRef namePrefix, size_t dim,
+ ValueRange operands) {
+ nameOstream
+ << namePrefix << dim << "_"
+ << operands[xStartIdx].getType().cast<MemRefType>().getElementType();
+
+ for (Value v : operands.drop_front(xStartIdx + dim))
+ nameOstream << "_" << v.getType().cast<MemRefType>().getElementType();
+}
+
+/// Looks up a function that is appropriate for the given operands being
+/// sorted, and creates such a function if it doesn't exist yet.
+static FlatSymbolRefAttr
+getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
+ TypeRange resultTypes, StringRef namePrefix,
+ size_t dim, ValueRange operands,
+ FuncGeneratorType createFunc) {
+ SmallString<32> nameBuffer;
+ llvm::raw_svector_ostream nameOstream(nameBuffer);
+ getMangledSortHelperFuncName(nameOstream, namePrefix, dim, operands);
+
+ ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
+ MLIRContext *context = module.getContext();
+ auto result = SymbolRefAttr::get(context, nameOstream.str());
+ auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
+
+ if (!func) {
+ // Create the function.
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPoint(insertPoint);
+ Location loc = insertPoint.getLoc();
+ func = builder.create<func::FuncOp>(
+ loc, nameOstream.str(),
+ FunctionType::get(context, operands.getTypes(), resultTypes));
+ func.setPrivate();
+ createFunc(builder, module, func, dim);
+ }
+
+ return result;
+}
+
+/// Creates a function for swapping the values in index i and j for all the
+/// buffers.
+//
+// The generate IR corresponds to this C like algorithm:
+// if (i != j) {
+// swap(x0[i], x0[j]);
+// swap(x1[i], x1[j]);
+// ...
+// swap(xn[i], xn[j]);
+// swap(y0[i], y0[j]);
+// ...
+// swap(yn[i], yn[j]);
+// }
+static void createMaySwapFunc(OpBuilder &builder, ModuleOp unused,
+ func::FuncOp func, size_t dim) {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+
+ Block *entryBlock = func.addEntryBlock();
+ builder.setInsertionPointToStart(entryBlock);
+
+ Location loc = func.getLoc();
+ ValueRange args = entryBlock->getArguments();
+ Value i = args[0];
+ Value j = args[1];
+ Value cond =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, i, j);
+ scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
+
+ // If i!=j swap values in the buffers.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ for (auto arg : args.drop_front(xStartIdx)) {
+ Value vi = builder.create<memref::LoadOp>(loc, arg, i);
+ Value vj = builder.create<memref::LoadOp>(loc, arg, j);
+ builder.create<memref::StoreOp>(loc, vj, arg, i);
+ builder.create<memref::StoreOp>(loc, vi, arg, j);
+ }
+
+ builder.setInsertionPointAfter(ifOp);
+ builder.create<func::ReturnOp>(loc);
+}
+
+/// Generates an if-statement to compare x[i] and x[j].
+static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
+ Value i, Value j, Value x,
+ bool isLastDim) {
+ Value f = constantI1(builder, loc, false);
+ Value t = constantI1(builder, loc, true);
+ Value vi = builder.create<memref::LoadOp>(loc, x, i);
+ Value vj = builder.create<memref::LoadOp>(loc, x, j);
+
+ Value cond =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
+ scf::IfOp ifOp =
+ builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
+ // If (x[i] < x[j]).
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ builder.create<scf::YieldOp>(loc, t);
+
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ if (isLastDim == 1) {
+ // Finish checking all dimensions.
+ builder.create<scf::YieldOp>(loc, f);
+ } else {
+ cond =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vj, vi);
+ scf::IfOp ifOp2 =
+ builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
+ // Otherwise if (x[j] < x[i]).
+ builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
+ builder.create<scf::YieldOp>(loc, f);
+
+ // Otherwise check the remaining dimensions.
+ builder.setInsertionPointAfter(ifOp2);
+ builder.create<scf::YieldOp>(loc, ifOp2.getResult(0));
+ // Set up the insertion point for the nested if-stmt that checks the
+ // remaining dimensions.
+ builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
+ }
+
+ return ifOp;
+}
+
+/// Creates a function to compare the xs values in index i and j for all the
+/// dimensions. The function returns true iff xs[i] < xs[j].
+//
+// The generate IR corresponds to this C like algorithm:
+// if (x0[i] < x0[j])
+// return true;
+// else if (x0[j] < x0[i])
+// return false;
+// else
+// if (x1[i] < x1[j])
+// return true;
+// else if (x1[j] < x1[i]))
+// and so on ...
+static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
+ func::FuncOp func, size_t dim) {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+
+ Block *entryBlock = func.addEntryBlock();
+ builder.setInsertionPointToStart(entryBlock);
+ Location loc = func.getLoc();
+ ValueRange args = entryBlock->getArguments();
+
+ scf::IfOp topIfOp;
+ for (const auto &item : llvm::enumerate(args.slice(xStartIdx, dim))) {
+ scf::IfOp ifOp =
+ createLessThanCompare(builder, loc, args[0], args[1], item.value(),
+ (item.index() == dim - 1));
+ if (item.index() == 0) {
+ topIfOp = ifOp;
+ } else {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointAfter(ifOp);
+ builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
+ }
+ }
+
+ builder.setInsertionPointAfter(topIfOp);
+ builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
+}
+
+/// Creates a function to perform quick sort partition on the values in the
+/// range of index [lo, hi), assuming lo < hi.
+//
+// The generated IR corresponds to this C like algorithm:
+// int partition(lo, hi, data) {
+// pivot = data[hi - 1];
+// i = (lo – 1) // RHS of the pivot found so far.
+// for (j = lo; j < hi - 1; j++){
+// if (data[j] < pivot){
+// i++;
+// swap data[i] and data[j]
+// }
+// }
+// i++
+// swap data[i] and data[hi-1])
+// return i
+// }
+static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
+ func::FuncOp func, size_t dim) {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+
+ Block *entryBlock = func.addEntryBlock();
+ builder.setInsertionPointToStart(entryBlock);
+
+ MLIRContext *context = module.getContext();
+ Location loc = func.getLoc();
+ ValueRange args = entryBlock->getArguments();
+ Value lo = args[loIdx];
+ Value c1 = constantIndex(builder, loc, 1);
+ Value i = builder.create<arith::SubIOp>(loc, lo, c1);
+ Value him1 = builder.create<arith::SubIOp>(loc, args[hiIdx], c1);
+ scf::ForOp forOp =
+ builder.create<scf::ForOp>(loc, lo, him1, c1, ValueRange{i});
+
+ // Start the for-stmt body.
+ builder.setInsertionPointToStart(forOp.getBody());
+ Value j = forOp.getInductionVar();
+ SmallVector<Value, 6> compareOperands{j, him1};
+ ValueRange xs = args.slice(xStartIdx, dim);
+ compareOperands.append(xs.begin(), xs.end());
+ Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
+ FlatSymbolRefAttr lessThanFunc =
+ getMangledSortHelperFunc(builder, func, {i1Type}, "_sparse_less_than_",
+ dim, compareOperands, createLessThanFunc);
+ Value cond = builder
+ .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
+ compareOperands)
+ .getResult(0);
+ scf::IfOp ifOp =
+ builder.create<scf::IfOp>(loc, i.getType(), cond, /*else=*/true);
+
+ // The if-stmt true branch: i++; swap(data[i], data[j]); yield i.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ Value i1 =
+ builder.create<arith::AddIOp>(loc, forOp.getRegionIterArgs().front(), c1);
+ SmallVector<Value, 6> swapOperands{i1, j};
+ swapOperands.append(args.begin() + xStartIdx, args.end());
+ FlatSymbolRefAttr swapFunc =
+ getMangledSortHelperFunc(builder, func, TypeRange(), "_sparse_may_swap_",
+ dim, swapOperands, createMaySwapFunc);
+ builder.create<func::CallOp>(loc, swapFunc, TypeRange(), swapOperands);
+ builder.create<scf::YieldOp>(loc, i1);
+
+ // The if-stmt false branch: yield i.
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ builder.create<scf::YieldOp>(loc, forOp.getRegionIterArgs().front());
+
+ // After the if-stmt, yield the updated i value to end the for-stmt body.
+ builder.setInsertionPointAfter(ifOp);
+ builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
+
+ // After the for-stmt: i++; swap(data[i], data[him1]); return i.
+ builder.setInsertionPointAfter(forOp);
+ i1 = builder.create<arith::AddIOp>(loc, forOp.getResult(0), c1);
+ swapOperands[0] = i1;
+ swapOperands[1] = him1;
+ builder.create<func::CallOp>(loc, swapFunc, TypeRange(), swapOperands);
+ builder.create<func::ReturnOp>(loc, i1);
+}
+
+/// Creates a function to perform quick sort on the value in the range of
+/// index [lo, hi).
+//
+// The generate IR corresponds to this C like algorithm:
+// void quickSort(lo, hi, data) {
+// if (lo < hi) {
+// p = partition(low, high, data);
+// quickSort(lo, p, data);
+// quickSort(p + 1, hi, data);
+// }
+// }
+static void createSortFunc(OpBuilder &builder, ModuleOp module,
+ func::FuncOp func, size_t dim) {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ Block *entryBlock = func.addEntryBlock();
+ builder.setInsertionPointToStart(entryBlock);
+
+ MLIRContext *context = module.getContext();
+ Location loc = func.getLoc();
+ ValueRange args = entryBlock->getArguments();
+ Value lo = args[loIdx];
+ Value hi = args[hiIdx];
+ Value cond =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, lo, hi);
+ scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
+
+ // The if-stmt true branch.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
+ builder, func, {IndexType::get(context)}, "_sparse_partition_", dim, args,
+ createPartitionFunc);
+ auto p = builder.create<func::CallOp>(
+ loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args));
+
+ SmallVector<Value, 6> lowOperands{lo, p.getResult(0)};
+ lowOperands.append(args.begin() + xStartIdx, args.end());
+ builder.create<func::CallOp>(loc, func, lowOperands);
+
+ SmallVector<Value, 6> highOperands{
+ builder.create<arith::AddIOp>(loc, p.getResult(0),
+ constantIndex(builder, loc, 1)),
+ hi};
+ highOperands.append(args.begin() + xStartIdx, args.end());
+ builder.create<func::CallOp>(loc, func, highOperands);
+
+ // After the if-stmt.
+ builder.setInsertionPointAfter(ifOp);
+ builder.create<func::ReturnOp>(loc);
+}
+
+//===---------------------------------------------------------------------===//
+// The actual sparse buffer rewriting rules.
+//===---------------------------------------------------------------------===//
+
+namespace {
+
+/// Sparse rewriting rule for the sort operator.
+struct SortRewriter : public OpRewritePattern<SortOp> {
+public:
+ using OpRewritePattern<SortOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SortOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ SmallVector<Value, 6> operands{constantIndex(rewriter, loc, 0), op.getN()};
+
+ // Convert `values` to have dynamic shape and append them to `operands`.
+ auto addValues = [&](ValueRange values) {
+ for (Value v : values) {
+ auto mtp = v.getType().cast<MemRefType>();
+ if (!mtp.isDynamicDim(0)) {
+ auto new_mtp =
+ MemRefType::get({ShapedType::kDynamicSize}, mtp.getElementType());
+ v = rewriter.create<memref::CastOp>(loc, new_mtp, v);
+ }
+ operands.push_back(v);
+ }
+ };
+ ValueRange xs = op.getXs();
+ addValues(xs);
+ addValues(op.getYs());
+ auto insertPoint = op->getParentOfType<func::FuncOp>();
+ FlatSymbolRefAttr func = getMangledSortHelperFunc(
+ rewriter, insertPoint, TypeRange(), "_sparse_sort_", xs.size(),
+ operands, createSortFunc);
+ rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
+ return success();
+ }
+};
+
+} // namespace
+
+//===---------------------------------------------------------------------===//
+// Methods that add patterns described in this file to a pattern list.
+//===---------------------------------------------------------------------===//
+
+void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns) {
+ patterns.add<SortRewriter>(patterns.getContext());
+}
--- /dev/null
+// RUN: mlir-opt %s --sparse-buffer-rewrite --canonicalize --cse | FileCheck %s
+
+// CHECK-LABEL: func.func private @_sparse_less_than_1_i8(
+// CHECK-SAME: %[[I:arg0]]: index,
+// CHECK-SAME: %[[J:.*]]: index,
+// CHECK-SAME: %[[X0:.*]]: memref<?xi8>) -> i1 {
+// CHECK: %[[VI:.*]] = memref.load %[[X0]]{{\[}}%[[I]]]
+// CHECK: %[[VJ:.*]] = memref.load %[[X0]]{{\[}}%[[J]]]
+// CHECK: %[[C:.*]] = arith.cmpi ult, %[[VI]], %[[VJ]]
+// CHECK: return %[[C]]
+// CHECK: }
+
+// CHECK-LABEL: func.func private @_sparse_may_swap_1_i8_f32_index(
+// CHECK-SAME: %[[I:arg0]]: index,
+// CHECK-SAME: %[[J:.*]]: index,
+// CHECK-SAME: %[[X0:.*]]: memref<?xi8>,
+// CHECK-SAME: %[[Y0:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[Y1:.*]]: memref<?xindex>) {
+// CHECK: %[[C:.*]] = arith.cmpi ne, %[[I]], %[[J]]
+// CHECK: scf.if %[[C]] {
+// CHECK: %[[Vx0i:.*]] = memref.load %[[X0]]{{\[}}%[[I]]]
+// CHECK: %[[Vx0j:.*]] = memref.load %[[X0]]{{\[}}%[[J]]]
+// CHECK: memref.store %[[Vx0j]], %[[X0]]{{\[}}%[[I]]]
+// CHECK: memref.store %[[Vx0i]], %[[X0]]{{\[}}%[[J]]]
+// CHECK: %[[Vy0i:.*]] = memref.load %[[Y0]]{{\[}}%[[I]]]
+// CHECK: %[[Vy0j:.*]] = memref.load %[[Y0]]{{\[}}%[[J]]]
+// CHECK: memref.store %[[Vy0j]], %[[Y0]]{{\[}}%[[I]]]
+// CHECK: memref.store %[[Vy0i]], %[[Y0]]{{\[}}%[[J]]]
+// CHECK: %[[Vy1i:.*]] = memref.load %[[Y1]]{{\[}}%[[I]]]
+// CHECK: %[[Vy1j:.*]] = memref.load %[[Y1]]{{\[}}%[[J]]]
+// CHECK: memref.store %[[Vy1j]], %[[Y1]]{{\[}}%[[I]]]
+// CHECK: memref.store %[[Vy1i]], %[[Y1]]{{\[}}%[[J]]]
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+// CHECK-LABEL: func.func private @_sparse_partition_1_i8_f32_index(
+// CHECK-SAME: %[[L:arg0]]: index,
+// CHECK-SAME: %[[H:.*]]: index,
+// CHECK-SAME: %[[X0:.*]]: memref<?xi8>,
+// CHECK-SAME: %[[Y0:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[Y1:.*]]: memref<?xindex>) -> index {
+// CHECK: %[[C1:.*]] = arith.constant 1
+// CHECK: %[[I:.*]] = arith.subi %[[L]], %[[C1]]
+// CHECK: %[[Hm1:.*]] = arith.subi %[[H]], %[[C1]]
+// CHECK: %[[I3:.*]] = scf.for %[[J:.*]] = %[[L]] to %[[Hm1]] step %[[C1]] iter_args(%[[I2:.*]] = %[[I]]) -> (index) {
+// CHECK: %[[COND:.*]] = func.call @_sparse_less_than_1_i8(%[[J]], %[[Hm1]], %[[X0]])
+// CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (index) {
+// CHECK: %[[Ip1:.*]] = arith.addi %[[I2]], %[[C1]]
+// CHECK: func.call @_sparse_may_swap_1_i8_f32_index(%[[Ip1]], %[[J]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK: scf.yield %[[Ip1]]
+// CHECK: } else {
+// CHECK: scf.yield %[[I2]]
+// CHECK: }
+// CHECK: scf.yield %[[IF:.*]]
+// CHECK: }
+// CHECK: %[[I3p1:.*]] = arith.addi %[[I3:.*]], %[[C1]] : index
+// CHECK: call @_sparse_may_swap_1_i8_f32_index(%[[I3p1]], %[[Hm1]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK: return %[[I3p1]]
+// CHECK: }
+
+// CHECK-LABEL: func.func private @_sparse_sort_1_i8_f32_index(
+// CHECK-SAME: %[[L:arg0]]: index,
+// CHECK-SAME: %[[H:.*]]: index,
+// CHECK-SAME: %[[X0:.*]]: memref<?xi8>,
+// CHECK-SAME: %[[Y0:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[Y1:.*]]: memref<?xindex>) {
+// CHECK: %[[C1:.*]] = arith.constant 1
+// CHECK: %[[COND:.*]] = arith.cmpi ult, %[[L]], %[[H]]
+// CHECK: scf.if %[[COND]] {
+// CHECK: %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK: func.call @_sparse_sort_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]] : index
+// CHECK: func.call @_sparse_sort_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+// CHECK-LABEL: func.func @sparse_sort_1d2v(
+// CHECK-SAME: %[[N:.*]]: index,
+// CHECK-SAME: %[[X0:.*]]: memref<10xi8>,
+// CHECK-SAME: %[[Y0:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[Y1:.*]]: memref<10xindex>) -> (memref<10xi8>, memref<?xf32>, memref<10xindex>) {
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[DX0:.*]] = memref.cast %[[X0]] : memref<10xi8> to memref<?xi8>
+// CHECK: %[[DY1:.*]] = memref.cast %[[Y1]] : memref<10xindex> to memref<?xindex>
+// CHECK: call @_sparse_sort_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]])
+// CHECK: return %[[X0]], %[[Y0]], %[[Y1]]
+// CHECK: }
+func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?xf32>, %arg3: memref<10xindex>)
+ -> (memref<10xi8>, memref<?xf32>, memref<10xindex>) {
+ sparse_tensor.sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref<?xf32>, memref<10xindex>
+ return %arg1, %arg2, %arg3 : memref<10xi8>, memref<?xf32>, memref<10xindex>
+}
+
+// Only check the generated supporting function now. We have integration test
+// to verify correctness of the generated code.
+//
+// CHECK-DAG: func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
+// CHECK-DAG: func.func private @_sparse_may_swap_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
+// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
+// CHECK-DAG: func.func private @_sparse_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
+// CHECK-LABEL: func.func @sparse_sort_3d
+func.func @sparse_sort_3d(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
+ sparse_tensor.sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+ return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+}
--- /dev/null
+// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \
+// RUN: mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+module {
+ // Stores 5 values to the memref buffer.
+ func.func @storeValuesTo(%b: memref<?xi32>, %v0: i32, %v1: i32, %v2: i32,
+ %v3: i32, %v4: i32) -> () {
+ %i0 = arith.constant 0 : index
+ %i1 = arith.constant 1 : index
+ %i2 = arith.constant 2 : index
+ %i3 = arith.constant 3 : index
+ %i4 = arith.constant 4 : index
+ memref.store %v0, %b[%i0] : memref<?xi32>
+ memref.store %v1, %b[%i1] : memref<?xi32>
+ memref.store %v2, %b[%i2] : memref<?xi32>
+ memref.store %v3, %b[%i3] : memref<?xi32>
+ memref.store %v4, %b[%i4] : memref<?xi32>
+ return
+ }
+
+ // The main driver.
+ func.func @entry() {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ %c2 = arith.constant 2 : i32
+ %c3 = arith.constant 3 : i32
+ %c4 = arith.constant 4 : i32
+ %c5 = arith.constant 5 : i32
+ %c6 = arith.constant 6 : i32
+ %c7 = arith.constant 7 : i32
+ %c8 = arith.constant 8 : i32
+ %c9 = arith.constant 9 : i32
+ %c10 = arith.constant 10 : i32
+ %c100 = arith.constant 100 : i32
+
+ %i0 = arith.constant 0 : index
+ %i4 = arith.constant 4 : index
+ %i5 = arith.constant 5 : index
+
+ // Prepare a buffer.
+ %x0s = memref.alloc() : memref<5xi32>
+ %x0 = memref.cast %x0s : memref<5xi32> to memref<?xi32>
+ call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
+ : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+
+ // Sort 0 elements.
+ // CHECK: ( 10, 2, 0, 5, 1 )
+ sparse_tensor.sort %i0, %x0 : memref<?xi32>
+ %x0v0 = vector.transfer_read %x0[%i0], %c100: memref<?xi32>, vector<5xi32>
+ vector.print %x0v0 : vector<5xi32>
+
+ // Sort the first 4 elements, with the last valid value untouched.
+ // CHECK: ( 0, 2, 5, 10, 1 )
+ sparse_tensor.sort %i4, %x0 : memref<?xi32>
+ %x0v1 = vector.transfer_read %x0[%i0], %c100: memref<?xi32>, vector<5xi32>
+ vector.print %x0v1 : vector<5xi32>
+
+ // Prepare more buffers of different dimensions.
+ %x1s = memref.alloc() : memref<10xi32>
+ %x1 = memref.cast %x1s : memref<10xi32> to memref<?xi32>
+ %x2s = memref.alloc() : memref<6xi32>
+ %x2 = memref.cast %x2s : memref<6xi32> to memref<?xi32>
+ %y0s = memref.alloc() : memref<7xi32>
+ %y0 = memref.cast %y0s : memref<7xi32> to memref<?xi32>
+ call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
+ : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
+ : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesTo(%x2, %c2, %c4, %c4, %c7, %c9)
+ : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
+ : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+
+ // Sort "parallel arrays".
+ // CHECK: ( 0, 1, 2, 5, 10 )
+ // CHECK: ( 3, 3, 1, 10, 1 )
+ // CHECK: ( 4, 9, 4, 7, 2 )
+ // CHECK: ( 8, 7, 10, 9, 6 )
+ sparse_tensor.sort %i5, %x0, %x1, %x2 jointly %y0
+ : memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
+ %x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32>, vector<5xi32>
+ vector.print %x0v2 : vector<5xi32>
+ %x1v = vector.transfer_read %x1[%i0], %c100: memref<?xi32>, vector<5xi32>
+ vector.print %x1v : vector<5xi32>
+ %x2v = vector.transfer_read %x2[%i0], %c100: memref<?xi32>, vector<5xi32>
+ vector.print %x2v : vector<5xi32>
+ %y0v = vector.transfer_read %y0[%i0], %c100: memref<?xi32>, vector<5xi32>
+ vector.print %y0v : vector<5xi32>
+
+ // Release the buffers.
+ memref.dealloc %x0 : memref<?xi32>
+ memref.dealloc %x1 : memref<?xi32>
+ memref.dealloc %x2 : memref<?xi32>
+ memref.dealloc %y0 : memref<?xi32>
+ return
+ }
+}