[mlir][sparse] Add rewriting rules for sparse_tensor.sort_coo.
authorbixia1 <bixia@google.com>
Wed, 9 Nov 2022 17:07:06 +0000 (09:07 -0800)
committerbixia1 <bixia@google.com>
Mon, 14 Nov 2022 16:48:53 +0000 (08:48 -0800)
Refactor the rewriting of sparse_tensor.sort to support the implementation of
sparse_tensor.sort_coo.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D137522

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir [new file with mode: 0644]

index 52a6aff..64facdc 100644 (file)
@@ -529,10 +529,10 @@ def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
     Sparse_tensor.sort_coo is similar to sparse_tensor.sort, except that all the
     `xs` values and some `ys` values are put in the linear buffer `xy`. The
     optional index attribute `nx` provides the number of `xs` values in `xy`.
-    When `ns` is not explicitly specified, its value is 1. The optional index
+    When `nx` is not explicitly specified, its value is 1. The optional index
     attribute `ny` provides the number of `ys` values in `xy`. When `ny` is not
-    explicitly specified, its value is 0. This instruction supports the TACO
-    COO style storage format for better sorting performance.
+    explicitly specified, its value is 0. This instruction supports a more
+    efficient way to store the COO definition in sparse tensor type.
 
     The buffer xy should have a dimension not less than n * (nx + ny) while the
     buffers in `ys` should have a dimension not less than `n`. The behavior of
index d0564ca..c556b0d 100644 (file)
@@ -43,32 +43,42 @@ static constexpr const char kSortNonstableFuncNamePrefix[] =
 static constexpr const char kSortStableFuncNamePrefix[] =
     "_sparse_sort_stable_";
 
-using FuncGeneratorType =
-    function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, size_t)>;
+using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
+                                            uint64_t, uint64_t, bool)>;
 
 /// Constructs a function name with this format to facilitate quick sort:
-///   <namePrefix><dim>_<x type>_<y0 type>..._<yn type>
+///   <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort
+///   <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
-                                         StringRef namePrefix, size_t dim,
+                                         StringRef namePrefix, uint64_t nx,
+                                         uint64_t ny, bool isCoo,
                                          ValueRange operands) {
   nameOstream
-      << namePrefix << dim << "_"
+      << namePrefix << nx << "_"
       << operands[xStartIdx].getType().cast<MemRefType>().getElementType();
 
-  for (Value v : operands.drop_front(xStartIdx + dim))
+  if (isCoo)
+    nameOstream << "_coo_" << ny;
+
+  uint64_t yBufferOffset = isCoo ? 1 : nx;
+  for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
     nameOstream << "_" << v.getType().cast<MemRefType>().getElementType();
 }
 
 /// Looks up a function that is appropriate for the given operands being
-/// sorted, and creates such a function if it doesn't exist yet.
+/// sorted, and creates such a function if it doesn't exist yet. The
+/// parameters `nx` and `ny` tell the number of x and y values provided
+/// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
+/// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
 static FlatSymbolRefAttr
 getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
                          TypeRange resultTypes, StringRef namePrefix,
-                         size_t dim, ValueRange operands,
-                         FuncGeneratorType createFunc) {
+                         uint64_t nx, uint64_t ny, bool isCoo,
+                         ValueRange operands, FuncGeneratorType createFunc) {
   SmallString<32> nameBuffer;
   llvm::raw_svector_ostream nameOstream(nameBuffer);
-  getMangledSortHelperFuncName(nameOstream, namePrefix, dim, operands);
+  getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo,
+                               operands);
 
   ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
   MLIRContext *context = module.getContext();
@@ -84,12 +94,61 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
         loc, nameOstream.str(),
         FunctionType::get(context, operands.getTypes(), resultTypes));
     func.setPrivate();
-    createFunc(builder, module, func, dim);
+    createFunc(builder, module, func, nx, ny, isCoo);
   }
 
   return result;
 }
 
+/// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
+/// The code to process the value pairs is generated by `bodyBuilder`.
+static void forEachIJPairInXs(
+    OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
+    bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
+  Value iOffset, jOffset;
+  if (isCoo) {
+    Value cstep = constantIndex(builder, loc, nx + ny);
+    iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
+    jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
+  }
+  for (uint64_t k = 0; k < nx; k++) {
+    scf::IfOp ifOp;
+    Value i, j, buffer;
+    if (isCoo) {
+      Value ck = constantIndex(builder, loc, k);
+      i = builder.create<arith::AddIOp>(loc, ck, iOffset);
+      j = builder.create<arith::AddIOp>(loc, ck, jOffset);
+      buffer = args[xStartIdx];
+    } else {
+      i = args[0];
+      j = args[1];
+      buffer = args[xStartIdx + k];
+    }
+    bodyBuilder(k, i, j, buffer);
+  }
+}
+
+/// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
+/// The code to process the value pairs is generated by `bodyBuilder`.
+static void forEachIJPairInAllBuffers(
+    OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
+    bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
+
+  // Create code for the first (nx + ny) buffers. When isCoo==true, these
+  // logical buffers are all from the xy buffer of the sort_coo operator.
+  forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder);
+
+  uint64_t numHandledBuffers = isCoo ? 1 : nx + ny;
+
+  // Create code for the remaining buffers.
+  Value i = args[0];
+  Value j = args[1];
+  for (const auto &arg :
+       llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
+    bodyBuilder(arg.index() + nx + ny, i, j, arg.value());
+  }
+}
+
 /// Creates a code block for swapping the values in index i and j for all the
 /// buffers.
 //
@@ -101,21 +160,23 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
 //     swap(y0[i], y0[j]);
 //     ...
 //     swap(yn[i], yn[j]);
-static void createSwap(OpBuilder &builder, Location loc, ValueRange args) {
-  Value i = args[0];
-  Value j = args[1];
-  for (auto arg : args.drop_front(xStartIdx)) {
-    Value vi = builder.create<memref::LoadOp>(loc, arg, i);
-    Value vj = builder.create<memref::LoadOp>(loc, arg, j);
-    builder.create<memref::StoreOp>(loc, vj, arg, i);
-    builder.create<memref::StoreOp>(loc, vi, arg, j);
-  }
+static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
+                       uint64_t nx, uint64_t ny, bool isCoo) {
+  auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
+    Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
+    Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
+    builder.create<memref::StoreOp>(loc, vj, buffer, i);
+    builder.create<memref::StoreOp>(loc, vi, buffer, j);
+  };
+
+  forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair);
 }
 
 /// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to
 /// compare each pair is create via `compareBuilder`.
 static void createCompareFuncImplementation(
-    OpBuilder &builder, ModuleOp unused, func::FuncOp func, size_t dim,
+    OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx,
+    uint64_t ny, bool isCoo,
     function_ref<scf::IfOp(OpBuilder &, Location, Value, Value, Value, bool)>
         compareBuilder) {
   OpBuilder::InsertionGuard insertionGuard(builder);
@@ -126,17 +187,18 @@ static void createCompareFuncImplementation(
   ValueRange args = entryBlock->getArguments();
 
   scf::IfOp topIfOp;
-  for (const auto &item : llvm::enumerate(args.slice(xStartIdx, dim))) {
-    scf::IfOp ifOp = compareBuilder(builder, loc, args[0], args[1],
-                                    item.value(), (item.index() == dim - 1));
-    if (item.index() == 0) {
+  auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
+    scf::IfOp ifOp = compareBuilder(builder, loc, i, j, buffer, (k == nx - 1));
+    if (k == 0) {
       topIfOp = ifOp;
     } else {
       OpBuilder::InsertionGuard insertionGuard(builder);
       builder.setInsertionPointAfter(ifOp);
       builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
     }
-  }
+  };
+
+  forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder);
 
   builder.setInsertionPointAfter(topIfOp);
   builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
@@ -180,8 +242,10 @@ static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i,
 //     else if (x2[2] != x2[j]))
 //       and so on ...
 static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused,
-                                func::FuncOp func, size_t dim) {
-  createCompareFuncImplementation(builder, unused, func, dim, createEqCompare);
+                                func::FuncOp func, uint64_t nx, uint64_t ny,
+                                bool isCoo) {
+  createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
+                                  createEqCompare);
 }
 
 /// Generates an if-statement to compare whether x[i] is less than x[j].
@@ -238,8 +302,9 @@ static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
 //     else if (x1[j] < x1[i]))
 //       and so on ...
 static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
-                               func::FuncOp func, size_t dim) {
-  createCompareFuncImplementation(builder, unused, func, dim,
+                               func::FuncOp func, uint64_t nx, uint64_t ny,
+                               bool isCoo) {
+  createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
                                   createLessThanCompare);
 }
 
@@ -257,7 +322,8 @@ static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
 //   return lo;
 //
 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
-                                   func::FuncOp func, size_t dim) {
+                                   func::FuncOp func, uint64_t nx, uint64_t ny,
+                                   bool isCoo) {
   OpBuilder::InsertionGuard insertionGuard(builder);
   Block *entryBlock = func.addEntryBlock();
   builder.setInsertionPointToStart(entryBlock);
@@ -292,12 +358,13 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
 
   // Compare xs[p] < xs[mid].
   SmallVector<Value, 6> compareOperands{p, mid};
+  uint64_t numXBuffers = isCoo ? 1 : nx;
   compareOperands.append(args.begin() + xStartIdx,
-                         args.begin() + xStartIdx + dim);
+                         args.begin() + xStartIdx + numXBuffers);
   Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
-  FlatSymbolRefAttr lessThanFunc =
-      getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix,
-                               dim, compareOperands, createLessThanFunc);
+  FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
+      builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
+      compareOperands, createLessThanFunc);
   Value cond2 = builder
                     .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
                                           compareOperands)
@@ -324,7 +391,8 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
 /// xs[i] == xs[p].
 static std::pair<Value, Value>
 createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
-               ValueRange xs, Value i, Value p, size_t dim, int step) {
+               ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny,
+               bool isCoo, int step) {
   Location loc = func.getLoc();
   scf::WhileOp whileOp =
       builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
@@ -344,9 +412,9 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
   compareOperands.append(xs.begin(), xs.end());
   MLIRContext *context = module.getContext();
   Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
-  FlatSymbolRefAttr lessThanFunc =
-      getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix,
-                               dim, compareOperands, createLessThanFunc);
+  FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
+      builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
+      compareOperands, createLessThanFunc);
   Value cond = builder
                    .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
                                          compareOperands)
@@ -365,8 +433,8 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
   compareOperands[0] = i;
   compareOperands[1] = p;
   FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc(
-      builder, func, {i1Type}, kCompareEqFuncNamePrefix, dim, compareOperands,
-      createEqCompareFunc);
+      builder, func, {i1Type}, kCompareEqFuncNamePrefix, nx, ny, isCoo,
+      compareOperands, createEqCompareFunc);
   Value compareEq =
       builder
           .create<func::CallOp>(loc, compareEqFunc, TypeRange{i1Type},
@@ -405,7 +473,8 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
 //   return p
 //   }
 static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
-                                func::FuncOp func, size_t dim) {
+                                func::FuncOp func, uint64_t nx, uint64_t ny,
+                                bool isCoo) {
   OpBuilder::InsertionGuard insertionGuard(builder);
 
   Block *entryBlock = func.addEntryBlock();
@@ -442,11 +511,14 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
   j = after->getArgument(1);
   p = after->getArgument(2);
 
-  auto [iresult, iCompareEq] = createScanLoop(
-      builder, module, func, args.slice(xStartIdx, dim), i, p, dim, 1);
+  uint64_t numXBuffers = isCoo ? 1 : nx;
+  auto [iresult, iCompareEq] =
+      createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
+                     i, p, nx, ny, isCoo, 1);
   i = iresult;
-  auto [jresult, jCompareEq] = createScanLoop(
-      builder, module, func, args.slice(xStartIdx, dim), j, p, dim, -1);
+  auto [jresult, jCompareEq] =
+      createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
+                     j, p, nx, ny, isCoo, -1);
   j = jresult;
 
   // If i < j:
@@ -455,7 +527,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
   SmallVector<Value, 6> swapOperands{i, j};
   swapOperands.append(args.begin() + xStartIdx, args.end());
-  createSwap(builder, loc, swapOperands);
+  createSwap(builder, loc, swapOperands, nx, ny, isCoo);
   // If the pivot is moved, update p with the new pivot.
   Value icond =
       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
@@ -515,7 +587,8 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
 //   }
 // }
 static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
-                                    func::FuncOp func, size_t dim) {
+                                    func::FuncOp func, uint64_t nx, uint64_t ny,
+                                    bool isCoo) {
   OpBuilder::InsertionGuard insertionGuard(builder);
   Block *entryBlock = func.addEntryBlock();
   builder.setInsertionPointToStart(entryBlock);
@@ -532,8 +605,8 @@ static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
   // The if-stmt true branch.
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
   FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
-      builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, dim,
-      args, createPartitionFunc);
+      builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
+      ny, isCoo, args, createPartitionFunc);
   auto p = builder.create<func::CallOp>(
       loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args));
 
@@ -567,7 +640,8 @@ static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
 //   }
 // }
 static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
-                                 func::FuncOp func, size_t dim) {
+                                 func::FuncOp func, uint64_t nx, uint64_t ny,
+                                 bool isCoo) {
   OpBuilder::InsertionGuard insertionGuard(builder);
   Block *entryBlock = func.addEntryBlock();
   builder.setInsertionPointToStart(entryBlock);
@@ -587,20 +661,23 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
 
   // Binary search to find the insertion point p.
   SmallVector<Value, 6> operands{lo, i};
-  operands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + dim);
+  operands.append(args.begin() + xStartIdx, args.end());
   FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
-      builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
-      dim, operands, createBinarySearchFunc);
+      builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx,
+      ny, isCoo, operands, createBinarySearchFunc);
   Value p = builder
                 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
                                       operands)
                 .getResult(0);
 
   // Move the value at data[i] to a temporary location.
-  ValueRange data = args.drop_front(xStartIdx);
+  operands[0] = operands[1] = i;
   SmallVector<Value, 6> d;
-  for (Value v : data)
-    d.push_back(builder.create<memref::LoadOp>(loc, v, i));
+  forEachIJPairInAllBuffers(
+      builder, loc, operands, nx, ny, isCoo,
+      [&](uint64_t unused, Value i, Value unused2, Value buffer) {
+        d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
+      });
 
   // Start the inner for-stmt with induction variable j, for moving data[p..i)
   // to data[p+1..i+1).
@@ -610,21 +687,58 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
   builder.setInsertionPointToStart(forOpJ.getBody());
   Value j = forOpJ.getInductionVar();
   Value imj = builder.create<arith::SubIOp>(loc, i, j);
-  Value imjm1 = builder.create<arith::SubIOp>(loc, imj, c1);
-  for (Value v : data) {
-    Value t = builder.create<memref::LoadOp>(loc, v, imjm1);
-    builder.create<memref::StoreOp>(loc, t, v, imj);
-  }
+  operands[1] = imj;
+  operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
+  forEachIJPairInAllBuffers(
+      builder, loc, operands, nx, ny, isCoo,
+      [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
+        Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
+        builder.create<memref::StoreOp>(loc, t, buffer, imj);
+      });
 
   // Store the value at data[i] to data[p].
   builder.setInsertionPointAfter(forOpJ);
-  for (auto it : llvm::zip(d, data))
-    builder.create<memref::StoreOp>(loc, std::get<0>(it), std::get<1>(it), p);
+  operands[0] = operands[1] = p;
+  forEachIJPairInAllBuffers(
+      builder, loc, operands, nx, ny, isCoo,
+      [&](uint64_t k, Value p, Value usused, Value buffer) {
+        builder.create<memref::StoreOp>(loc, d[k], buffer, p);
+      });
 
   builder.setInsertionPointAfter(forOpI);
   builder.create<func::ReturnOp>(loc);
 }
 
+/// Implements the rewriting for operator sort and sort_coo.
+template <typename OpTy>
+LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
+                                    uint64_t ny, bool isCoo,
+                                    PatternRewriter &rewriter) {
+  Location loc = op.getLoc();
+  SmallVector<Value, 6> operands{constantIndex(rewriter, loc, 0), op.getN()};
+
+  // Convert `values` to have dynamic shape and append them to `operands`.
+  for (Value v : xys) {
+    auto mtp = v.getType().cast<MemRefType>();
+    if (!mtp.isDynamicDim(0)) {
+      auto newMtp =
+          MemRefType::get({ShapedType::kDynamicSize}, mtp.getElementType());
+      v = rewriter.create<memref::CastOp>(loc, newMtp, v);
+    }
+    operands.push_back(v);
+  }
+  auto insertPoint = op->template getParentOfType<func::FuncOp>();
+  SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix
+                                          : kSortNonstableFuncNamePrefix);
+  FuncGeneratorType funcGenerator =
+      op.getStable() ? createSortStableFunc : createSortNonstableFunc;
+  FlatSymbolRefAttr func =
+      getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
+                               ny, isCoo, operands, funcGenerator);
+  rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
+  return success();
+}
+
 //===---------------------------------------------------------------------===//
 // The actual sparse buffer rewriting rules.
 //===---------------------------------------------------------------------===//
@@ -755,34 +869,33 @@ public:
 
   LogicalResult matchAndRewrite(SortOp op,
                                 PatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    SmallVector<Value, 6> operands{constantIndex(rewriter, loc, 0), op.getN()};
-
-    // Convert `values` to have dynamic shape and append them to `operands`.
-    auto addValues = [&](ValueRange values) {
-      for (Value v : values) {
-        auto mtp = v.getType().cast<MemRefType>();
-        if (!mtp.isDynamicDim(0)) {
-          auto newMtp =
-              MemRefType::get({ShapedType::kDynamicSize}, mtp.getElementType());
-          v = rewriter.create<memref::CastOp>(loc, newMtp, v);
-        }
-        operands.push_back(v);
-      }
-    };
-    ValueRange xs = op.getXs();
-    addValues(xs);
-    addValues(op.getYs());
-    auto insertPoint = op->getParentOfType<func::FuncOp>();
-    SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix
-                                            : kSortNonstableFuncNamePrefix);
-    FuncGeneratorType funcGenerator =
-        op.getStable() ? createSortStableFunc : createSortNonstableFunc;
-    FlatSymbolRefAttr func =
-        getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
-                                 xs.size(), operands, funcGenerator);
-    rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
-    return success();
+    SmallVector<Value, 6> xys(op.getXs());
+    xys.append(op.getYs().begin(), op.getYs().end());
+    return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0,
+                                 /*isCoo=*/false, rewriter);
+  }
+};
+
+/// Sparse rewriting rule for the sort_coo operator.
+struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
+public:
+  using OpRewritePattern<SortCooOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SortCooOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value, 6> xys;
+    xys.push_back(op.getXy());
+    xys.append(op.getYs().begin(), op.getYs().end());
+    uint64_t nx = 1;
+    if (auto nxAttr = op.getNxAttr())
+      nx = nxAttr.getInt();
+
+    uint64_t ny = 0;
+    if (auto nyAttr = op.getNyAttr())
+      ny = nyAttr.getInt();
+
+    return matchAndRewriteSortOp(op, xys, nx, ny,
+                                 /*isCoo=*/true, rewriter);
   }
 };
 
@@ -796,5 +909,5 @@ void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
                                          bool enableBufferInitialization) {
   patterns.add<PushBackRewriter>(patterns.getContext(),
                                  enableBufferInitialization);
-  patterns.add<SortRewriter>(patterns.getContext());
+  patterns.add<SortRewriter, SortCooRewriter>(patterns.getContext());
 }
index f74eb5f..c153dcd 100644 (file)
@@ -173,6 +173,7 @@ struct SparseTensorCodegenPass
     // Most ops in the sparse dialect must go!
     target.addIllegalDialect<SparseTensorDialect>();
     target.addLegalOp<SortOp>();
+    target.addLegalOp<SortCooOp>();
     target.addLegalOp<PushBackOp>();
     // All dynamic rules below accept new function, call, return, and various
     // tensor and bufferization operations as legal output of the rewriting
index f563452..18140de 100644 (file)
@@ -194,3 +194,33 @@ func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: m
   sparse_tensor.sort stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
   return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
 }
+
+// -----
+
+// Only check the generated supporting functions. We have integration test to
+// verify correctness of the generated code.
+//
+// CHECK-DAG:     func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
+// CHECK-DAG:     func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
+// CHECK-DAG:     func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
+// CHECK-DAG:     func.func private @_sparse_sort_nonstable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-LABEL:   func.func @sparse_sort_coo
+func.func @sparse_sort_coo(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
+  sparse_tensor.sort_coo %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+  return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
+}
+
+// -----
+
+// Only check the generated supporting functions. We have integration test to
+// verify correctness of the generated code.
+//
+// CHECK-DAG:     func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
+// CHECK-DAG:     func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
+// CHECK-DAG:     func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-LABEL:   func.func @sparse_sort_coo_stable
+func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
+  sparse_tensor.sort_coo stable %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+  return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
+}
+
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
new file mode 100644 (file)
index 0000000..2efd2e4
--- /dev/null
@@ -0,0 +1,134 @@
+// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+module {
+  // Stores 5 values to the memref buffer.
+  func.func @storeValuesTo(%b: memref<?xi32>, %v0: i32, %v1: i32, %v2: i32,
+    %v3: i32, %v4: i32) -> () {
+    %i0 = arith.constant 0 : index
+    %i1 = arith.constant 1 : index
+    %i2 = arith.constant 2 : index
+    %i3 = arith.constant 3 : index
+    %i4 = arith.constant 4 : index
+    memref.store %v0, %b[%i0] : memref<?xi32>
+    memref.store %v1, %b[%i1] : memref<?xi32>
+    memref.store %v2, %b[%i2] : memref<?xi32>
+    memref.store %v3, %b[%i3] : memref<?xi32>
+    memref.store %v4, %b[%i4] : memref<?xi32>
+    return
+  }
+
+  // Stores 5 values to the memref buffer.
+  func.func @storeValuesToStrided(%b: memref<?xi32, strided<[4], offset: ?>>, %v0: i32, %v1: i32, %v2: i32,
+    %v3: i32, %v4: i32) -> () {
+    %i0 = arith.constant 0 : index
+    %i1 = arith.constant 1 : index
+    %i2 = arith.constant 2 : index
+    %i3 = arith.constant 3 : index
+    %i4 = arith.constant 4 : index
+    memref.store %v0, %b[%i0] : memref<?xi32, strided<[4], offset: ?>>
+    memref.store %v1, %b[%i1] : memref<?xi32, strided<[4], offset: ?>>
+    memref.store %v2, %b[%i2] : memref<?xi32, strided<[4], offset: ?>>
+    memref.store %v3, %b[%i3] : memref<?xi32, strided<[4], offset: ?>>
+    memref.store %v4, %b[%i4] : memref<?xi32, strided<[4], offset: ?>>
+    return
+  }
+
+  // The main driver.
+  func.func @entry() {
+    %c0 = arith.constant 0 : i32
+    %c1 = arith.constant 1 : i32
+    %c2 = arith.constant 2 : i32
+    %c3 = arith.constant 3 : i32
+    %c4 = arith.constant 4 : i32
+    %c5 = arith.constant 5 : i32
+    %c6 = arith.constant 6 : i32
+    %c7 = arith.constant 7 : i32
+    %c8 = arith.constant 8 : i32
+    %c9 = arith.constant 9 : i32
+    %c10 = arith.constant 10 : i32
+    %c100 = arith.constant 100 : i32
+
+    %i0 = arith.constant 0 : index
+    %i1 = arith.constant 1 : index
+    %i2 = arith.constant 2 : index
+    %i3 = arith.constant 3 : index
+    %i4 = arith.constant 4 : index
+    %i5 = arith.constant 5 : index
+
+    // Prepare a buffer for x0, x1, x2, y0 and a buffer for y1.
+    %xys = memref.alloc() : memref<20xi32>
+    %xy = memref.cast %xys : memref<20xi32> to memref<?xi32>
+    %x0 = memref.subview %xy[%i0][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
+    %x1 = memref.subview %xy[%i1][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
+    %x2 = memref.subview %xy[%i2][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
+    %y0 = memref.subview %xy[%i3][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
+    %y1s = memref.alloc() : memref<7xi32>
+    %y1 = memref.cast %y1s : memref<7xi32> to memref<?xi32>
+
+    // Sort "parallel arrays".
+    // CHECK: ( 1, 1, 2, 5, 10 )
+    // CHECK: ( 3, 3, 1, 10, 1 )
+    // CHECK: ( 9, 9, 4, 7, 2 )
+    // CHECK: ( 8, 7, 10, 9, 6 )
+    // CHECK: ( 4, 7, 7, 9, 5 )
+    call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1)
+      : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3)
+      : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesToStrided(%x2, %c2, %c4, %c9, %c7, %c9)
+      : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesToStrided(%y0, %c6, %c10, %c8, %c9, %c7)
+      : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+    sparse_tensor.sort_coo %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+      : memref<?xi32> jointly memref<?xi32>
+    %x0v = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+    vector.print %x0v : vector<5xi32>
+    %x1v = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+    vector.print %x1v : vector<5xi32>
+    %x2v = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+    vector.print %x2v : vector<5xi32>
+    %y0v = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+    vector.print %y0v : vector<5xi32>
+    %y1v = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
+    vector.print %y1v : vector<5xi32>
+    // Stable sort.
+    // CHECK: ( 1, 1, 2, 5, 10 )
+    // CHECK: ( 3, 3, 1, 10, 1 )
+    // CHECK: ( 9, 9, 4, 7, 2 )
+    // CHECK: ( 8, 7, 10, 9, 6 )
+    // CHECK: ( 4, 7, 7, 9, 5 )
+    call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1)
+      : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3)
+      : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesToStrided(%x2, %c2, %c4, %c9, %c7, %c9)
+      : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesToStrided(%y0, %c6, %c10, %c8, %c9, %c7)
+      : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+    sparse_tensor.sort_coo stable %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+      : memref<?xi32> jointly memref<?xi32>
+    %x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+    vector.print %x0v2 : vector<5xi32>
+    %x1v2 = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+    vector.print %x1v2 : vector<5xi32>
+    %x2v2 = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+    vector.print %x2v2 : vector<5xi32>
+    %y0v2 = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+    vector.print %y0v2 : vector<5xi32>
+    %y1v2 = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
+    vector.print %y1v2 : vector<5xi32>
+
+    // Release the buffers.
+    memref.dealloc %xy : memref<?xi32>
+    memref.dealloc %y1 : memref<?xi32>
+    return
+  }
+}