[mlir][sparse] add util for ToCoordinatesBuffer for COO AoS
authorAart Bik <ajcbik@google.com>
Thu, 11 May 2023 17:05:16 +0000 (10:05 -0700)
committerAart Bik <ajcbik@google.com>
Thu, 11 May 2023 17:43:31 +0000 (10:43 -0700)
Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D150382

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

index 54a6786231e90d17657828c55c61a1e9795e2623..6fd55c7799306b3167b9eed82317d9573f8cba18 100644 (file)
@@ -679,6 +679,14 @@ Value sparse_tensor::genToCoordinates(OpBuilder &builder, Location loc,
                                          builder.getIndexAttr(lvl));
 }
 
+Value sparse_tensor::genToCoordinatesBuffer(OpBuilder &builder, Location loc,
+                                            Value tensor) {
+  const auto srcTp = getSparseTensorType(tensor);
+  const Type crdTp = srcTp.getEncoding().getCrdType();
+  const Type memTp = get1DMemRefType(crdTp, /*withLayout=*/false);
+  return builder.create<ToCoordinatesBufferOp>(loc, memTp, tensor);
+}
+
 Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
                                  Value tensor) {
   RankedTensorType srcTp = getRankedTensorType(tensor);
index 28acf1aed7de206326e25d4569fc45f8b5e09ae3..e04475ea2e8f1affeb6a6d81ad9256643285689f 100644 (file)
@@ -364,6 +364,9 @@ Value genToPositions(OpBuilder &builder, Location loc, Value tensor, Level lvl);
 Value genToCoordinates(OpBuilder &builder, Location loc, Value tensor,
                        Level lvl, Level cooStart);
 
+/// Infers the result type and generates `ToCoordinatesBufferOp`.
+Value genToCoordinatesBuffer(OpBuilder &builder, Location loc, Value tensor);
+
 /// Infers the result type and generates `ToValuesOp`.
 Value genToValues(OpBuilder &builder, Location loc, Value tensor);
 
index 4aba829d7dbdd2b5a413c5fcc8b96c1c69b7c20e..2a4bbb06eb507daed755cfeaf91ef00549f3987e 100644 (file)
@@ -895,9 +895,7 @@ private:
       // coordinates for the storage ordering of the dst tensor.  Use SortCoo
       // if the COO tensor has the same ordering as the dst tensor.
       if (dimRank > 1 && srcTp.hasSameDimToLvlMap(dstTp)) {
-        MemRefType coordsTp =
-            get1DMemRefType(encSrc.getCrdType(), /*withLayout=*/false);
-        Value xs = rewriter.create<ToCoordinatesBufferOp>(loc, coordsTp, src);
+        Value xs = genToCoordinatesBuffer(rewriter, loc, src);
         rewriter.create<SortCooOp>(
             loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(dimRank),
             rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);