[mlir][sparse] Add rewriting rule for the convert operator.
authorbixia1 <bixia@google.com>
Mon, 31 Oct 2022 04:55:25 +0000 (21:55 -0700)
committerbixia1 <bixia@google.com>
Tue, 1 Nov 2022 22:57:34 +0000 (15:57 -0700)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D136301

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
mlir/test/Dialect/SparseTensor/sparse_reshape.mlir

index 7c615b4..52f9fef 100644 (file)
@@ -33,6 +33,10 @@ namespace sparse_tensor {
 /// Returns null-attribute for any type without an encoding.
 SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
 
+/// Returns true iff the given type is a type for a COO tensor with the last
+/// dimension level type being unique.
+bool isUniqueCOOType(RankedTensorType tp);
+
 //
 // Dimension level types.
 //
index f17080c..133879b 100644 (file)
@@ -262,6 +262,24 @@ mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
   return nullptr;
 }
 
+bool mlir::sparse_tensor::isUniqueCOOType(RankedTensorType tp) {
+  SparseTensorEncodingAttr enc = getSparseTensorEncoding(tp);
+
+  if (!enc)
+    return false;
+
+  if (!isCompressedDim(tp, 0))
+    return false;
+
+  for (uint64_t i = 1, e = tp.getRank(); i < e; ++i)
+    if (!isSingletonDim(tp, i))
+      return false;
+
+  // This works for rank == 1 (unique the only compressed) and rank > 1 (unique
+  // on the last singleton).
+  return isUniqueDim(tp, tp.getRank() - 1);
+}
+
 uint64_t mlir::sparse_tensor::toOrigDim(const SparseTensorEncodingAttr &enc,
                                         uint64_t d) {
   if (enc) {
index 4399cae..a051946 100644 (file)
@@ -155,6 +155,18 @@ static RankedTensorType getUnorderedCOOFromType(RankedTensorType src) {
   return RankedTensorType::get(src.getShape(), src.getElementType(), enc);
 }
 
+/// Collects the dynamic dimension sizes for `tp` with the assumption that
+/// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
+/// sizes to dynSizes.
+static void getDynamicSizes(RankedTensorType tp,
+                            const SmallVectorImpl<Value> &sizes,
+                            SmallVectorImpl<Value> &dynSizes) {
+  for (const auto &d : enumerate(tp.getShape())) {
+    if (d.value() == ShapedType::kDynamicSize)
+      dynSizes.push_back(sizes[d.index()]);
+  }
+}
+
 //===---------------------------------------------------------------------===//
 // The actual sparse tensor rewriting rules.
 //===---------------------------------------------------------------------===//
@@ -461,6 +473,204 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
   }
 };
 
+/// Sparse rewriting rule for the convert operator.
+struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(ConvertOp op,
+                                PatternRewriter &rewriter) const override {
+    auto encDst = getSparseTensorEncoding(op.getType());
+    auto encSrc = getSparseTensorEncoding(op.getSource().getType());
+    if (encDst && encSrc) {
+      // Trivial tensor conversion is handled in codegen.
+      if (encSrc == encDst)
+        return failure();
+      return sparse2SparseRewrite(op, rewriter);
+    }
+    if (encSrc && !encDst)
+      return sparse2DenseRewrite(op, rewriter);
+    if (!encSrc && encDst)
+      return dense2SparseRewrite(op, rewriter);
+
+    // Dense-to-dense convert is a nop and handled by canonicalization.
+    return failure();
+  }
+
+private:
+  // Handles sparse constant to sparse tensor or dense tensor to sparse tensor
+  // conversion as follows:
+  //   t = new sparse COO tensor
+  //   fill t using src
+  //   dst = convert t
+  //
+  // To fill the COO tensor from a dense tensor:
+  //   for i1 in dim1
+  //    ..
+  //     for ik in dimk
+  //       val = a[i1,..,ik]
+  //       if val != 0
+  //         t->add(val, [i1,..,ik], [p1,..,pk])
+  //
+  // To fill the COO tensor from a sparse constant in COO format:
+  //   for i in range(NNZ)
+  //     val = values[i]
+  //     [i1,..,ik] = indices[i]
+  //     t->add(val, [i1,..,ik], [p1,..,pk])
+  LogicalResult dense2SparseRewrite(ConvertOp op,
+                                    PatternRewriter &rewriter) const {
+    Location loc = op.getLoc();
+    Value src = op.getSource();
+    RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
+    SmallVector<Value, 4> sizes;
+    sizesFromSrc(rewriter, sizes, loc, src);
+    SmallVector<Value, 4> dynSizes;
+    getDynamicSizes(dstTp, sizes, dynSizes);
+
+    RankedTensorType cooTp = getUnorderedCOOFromType(dstTp);
+    auto cooBuffer =
+        rewriter.create<AllocTensorOp>(loc, cooTp, dynSizes).getResult();
+    unsigned rank = dstTp.cast<ShapedType>().getRank();
+
+    genDenseTensorOrSparseConstantIterLoop(
+        rewriter, loc, src, rank,
+        [&](OpBuilder &builder, Location loc, Value val, ValueRange indices) {
+          builder.create<InsertOp>(loc, val, cooBuffer, indices);
+        });
+
+    rewriter.setInsertionPointAfter(op);
+    rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, cooBuffer);
+    rewriter.create<DeallocTensorOp>(loc, cooBuffer);
+
+    return success();
+  }
+
+  // Handles sparse tensor to dense tensor conversion as follows:
+  //   dst = new dense tensor;
+  //   foreach elemment in src
+  //     dst[elemment.indices] = element.value
+  LogicalResult sparse2DenseRewrite(ConvertOp op,
+                                    PatternRewriter &rewriter) const {
+    Location loc = op->getLoc();
+    RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
+    Value src = op.getSource();
+    RankedTensorType srcTp = src.getType().cast<RankedTensorType>();
+
+    SmallVector<Value, 4> sizes;
+    sizesForTensor(rewriter, sizes, loc, srcTp, src);
+    Value dst = allocDenseTensor(rewriter, loc, dstTp, sizes);
+
+    rewriter.create<ForeachOp>(
+        loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) {
+          builder.create<memref::StoreOp>(loc, args.back(), dst,
+                                          args.drop_back());
+          builder.create<sparse_tensor::YieldOp>(loc);
+        });
+
+    rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstTp, dst);
+    return success();
+  }
+
+  // Handles sparse tensor to sparse tensor conversion as follows:
+  //   if src is not COO
+  //       construct a COO to represent the src
+  //   sort the src COO
+  //   foreach elemment in the sorted src COO
+  //     insert element to dst
+  LogicalResult sparse2SparseRewrite(ConvertOp op,
+                                     PatternRewriter &rewriter) const {
+    Location loc = op->getLoc();
+    Value src = op.getSource();
+    RankedTensorType srcTp = src.getType().cast<RankedTensorType>();
+    RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
+    SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
+    SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
+
+    SmallVector<Value, 4> srcSizes;
+    sizesForTensor(rewriter, srcSizes, loc, srcTp, src);
+    Value tmpCoo = Value();
+    if (!isUniqueCOOType(srcTp)) {
+      // Construct a COO tensor from the src tensor.
+      // TODO: there may be cases for which more efficiently without
+      // going through an intermediate COO, such as cases that only change
+      // the overhead types.
+      SmallVector<Value, 4> dynSrcSizes;
+      getDynamicSizes(srcTp, srcSizes, dynSrcSizes);
+      srcTp = getUnorderedCOOFromType(srcTp);
+      tmpCoo =
+          rewriter.create<AllocTensorOp>(loc, srcTp, dynSrcSizes).getResult();
+      rewriter.create<ForeachOp>(
+          loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) {
+            SmallVector<Value, 4> indices;
+            for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) {
+              uint64_t dim = toStoredDim(encSrc, i);
+              indices.push_back(args[dim]);
+            }
+            builder.create<InsertOp>(loc, args.back(), tmpCoo, indices);
+            builder.create<sparse_tensor::YieldOp>(loc);
+          });
+      src = tmpCoo;
+    }
+
+    // Sort the COO tensor so that its elements are ordered via increasing
+    // indices for the storage ordering of the dst tensor.
+    auto dynShape = {ShapedType::kDynamicSize};
+    auto indTp =
+        MemRefType::get(dynShape, getIndexOverheadType(rewriter, encSrc));
+    uint64_t rank = dstTp.getRank();
+    // Gather the indices-arrays in the dst tensor storage order.
+    SmallVector<Value, 4> xs(rank, Value());
+    for (int64_t i = 0; i < rank; i++) {
+      uint64_t orgDim = toOrigDim(encSrc, i);
+      xs[toStoredDim(encDst, orgDim)] = rewriter.create<ToIndicesOp>(
+          loc, indTp, src, rewriter.getIndexAttr(orgDim));
+    }
+
+    // Retrieve NNZ.
+    auto ptrTp =
+        MemRefType::get(dynShape, getPointerOverheadType(rewriter, encSrc));
+    Value p0 =
+        rewriter.create<ToIndicesOp>(loc, ptrTp, src, rewriter.getIndexAttr(0));
+    Value c1 = constantIndex(rewriter, loc, 1);
+    Value nnz = rewriter.create<memref::LoadOp>(loc, p0, c1);
+    nnz =
+        rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), nnz);
+
+    // Retrieve the values-array.
+    auto valTp = MemRefType::get(dynShape, srcTp.getElementType());
+    Value y = rewriter.create<ToValuesOp>(loc, valTp, src);
+
+    // Sort the COO tensor.
+    rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y});
+
+    // For each element in the COO tensor, insert the element to the dst tensor.
+    SmallVector<Value, 4> dynDstSizes;
+    getDynamicSizes(dstTp, srcSizes, dynDstSizes);
+    Value dst =
+        rewriter.create<AllocTensorOp>(loc, dstTp, dynDstSizes).getResult();
+    rewriter.create<ForeachOp>(
+        loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) {
+          SmallVector<Value, 4> indices;
+          for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) {
+            uint64_t dim = toStoredDim(encDst, i);
+            indices.push_back(args[dim]);
+          }
+          builder.create<InsertOp>(loc, args.back(), dst, indices);
+          builder.create<sparse_tensor::YieldOp>(loc);
+        });
+
+    // Release the temporary COO if it is created.
+    if (tmpCoo)
+      rewriter.create<DeallocTensorOp>(loc, tmpCoo);
+
+    // Directly replace op with dst results in bufferization error message
+    // "sparse tensor allocation should not escape function".
+    // As such, we insert a trivial tensor convert which will be removed by
+    // codegen.
+    rewriter.setInsertionPointAfter(op);
+    rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, dst);
+    return success();
+  }
+};
+
 /// Sparse rewriting rule for the foreach operator.
 struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
 public:
@@ -685,17 +895,19 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
 //===---------------------------------------------------------------------===//
 void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
                                          bool enableRT, bool enableForeach,
-                                         bool /*enableConvert*/) {
+                                         bool enableConvert) {
   patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd,
                ReshapeRewriter<tensor::ExpandShapeOp>,
                ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
   if (enableForeach)
     patterns.add<ForeachRewriter>(patterns.getContext());
-
   // TODO: If RT not enabled, rewrite concatenate ops, etc here.
-  if (!enableRT)
+  if (!enableRT) {
     patterns.add<ConcatenateRewriter, NewRewriter, OutRewriter,
                  Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
                  Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>>(
         patterns.getContext());
+    if (enableConvert)
+      patterns.add<ConvertRewriter>(patterns.getContext());
+  }
 }
index 13772e8..d67e11b 100644 (file)
@@ -1,4 +1,6 @@
 // RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \
+// RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
 
 #SparseVector = #sparse_tensor.encoding<{
   dimLevelType = ["compressed"]
@@ -100,6 +102,37 @@ func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
 //       CHECK: call @delSparseTensorCOOF64(%[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
+
+// CHECK-RWT-LABEL:   func.func @sparse_convert_2d(
+//  CHECK-RWT-SAME:   %[[A:.*]]: tensor<2x4xf64>) -> tensor<2x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> {
+//  CHECK-RWT-DAG:    %[[C0:.*]] = arith.constant 0 : index
+//  CHECK-RWT-DAG:    %[[C1:.*]] = arith.constant 1 : index
+//  CHECK-RWT-DAG:    %[[C2:.*]] = arith.constant 2 : index
+//  CHECK-RWT-DAG:    %[[C4:.*]] = arith.constant 4 : index
+//  CHECK-RWT-DAG:    %[[F0:.*]] = arith.constant 0.000000e+00 : f64
+//      CHECK-RWT:    %[[COO:.*]] = bufferization.alloc_tensor()
+//      CHECK-RWT:    scf.for %[[FI:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+//      CHECK-RWT:      scf.for %[[FJ:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+//      CHECK-RWT:        %[[V:.*]] = tensor.extract %[[A]]{{\[}}%[[FI]], %[[FJ]]] : tensor<2x4xf64>
+//      CHECK-RWT:        %[[NZ:.*]] = arith.cmpf une, %[[V]], %[[F0]] : f64
+//      CHECK-RWT:        scf.if %[[NZ]] {
+//      CHECK-RWT:          %{{.*}} = sparse_tensor.insert %[[V]] into %[[COO]]{{\[}}%[[FI]], %[[FJ]]]
+//      CHECK-RWT:        }
+//      CHECK-RWT:      }
+//      CHECK-RWT:    }
+//      CHECK-RWT:    %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
+//      CHECK-RWT:    %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index}
+//      CHECK-RWT:    %[[NNZ:.*]] = memref.load %[[I0]]{{\[}}%[[C1]]] : memref<?xindex>
+//      CHECK-RWT:    %[[V2:.*]] = sparse_tensor.values %[[COO]]
+//      CHECK-RWT:    sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V2]]
+//      CHECK-RWT:    %[[DST:.*]] = bufferization.alloc_tensor()
+//      CHECK-RWT:    sparse_tensor.foreach in %[[COO]]
+//      CHECK-RWT:    ^bb0(%[[FI0:.*]]: index, %[[FI1:.*]]: index, %[[FV:.*]]: f64):
+//      CHECK-RWT:      sparse_tensor.insert %[[FV]] into %[[DST]]{{\[}}%[[FI0]], %[[FI1]]]
+//      CHECK-RWT:    }
+//      CHECK-RWT:    %[[R:.*]] = sparse_tensor.convert %[[DST]]
+//      CHECK-RWT:    bufferization.dealloc_tensor %[[COO]]
+//      CHECK-RWT:    return %[[R]] : tensor<2x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>
 func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
   %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #CSR>
   return %0 : tensor<2x4xf64, #CSR>
@@ -132,6 +165,35 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
 //       CHECK: call @delSparseTensorCOOF32(%[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
+
+// CHECK-RWT-LABEL:  func.func @sparse_constant()
+//   CHECK-RWT-DAG:  %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-RWT-DAG:  %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-RWT-DAG:  %[[SI:.*]] = arith.constant dense<{{\[\[}}0, 0], [1, 6]]> : tensor<2x2xi64>
+//   CHECK-RWT-DAG:  %[[SV:.*]] = arith.constant dense<[1.000000e+00, 5.000000e+00]> : tensor<2xf32>
+//   CHECK-RWT-DAG:  %[[C2:.*]] = arith.constant 2 : index
+//       CHECK-RWT:  %[[COO:.*]] = bufferization.alloc_tensor()
+//       CHECK-RWT:  scf.for %[[FI:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+//       CHECK-RWT:    %[[I0r:.*]] = tensor.extract %[[SI]]{{\[}}%[[FI]], %[[C0]]] : tensor<2x2xi64>
+//       CHECK-RWT:    %[[I0:.*]] = arith.index_cast %[[I0r]] : i64 to index
+//       CHECK-RWT:    %[[I1r:.*]] = tensor.extract %[[SI]]{{\[}}%[[FI]], %[[C1]]] : tensor<2x2xi64>
+//       CHECK-RWT:    %[[I1:.*]] = arith.index_cast %[[I1r]] : i64 to index
+//       CHECK-RWT:    %[[V:.*]] = tensor.extract %[[SV]]{{\[}}%[[FI]]] : tensor<2xf32>
+//       CHECK-RWT:    sparse_tensor.insert %[[V]] into %[[COO]]{{\[}}%[[I0]], %[[I1]]]
+//       CHECK-RWT:  }
+//       CHECK-RWT:  %[[TI0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
+//       CHECK-RWT:  %[[TI1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index}
+//       CHECK-RWT:  %[[NNZ:.*]] = memref.load %[[TI0]]{{\[}}%[[C1]]] : memref<?xindex>
+//       CHECK-RWT:  %[[TV:.*]] = sparse_tensor.values %[[COO]]
+//       CHECK-RWT:  sparse_tensor.sort %[[NNZ]], %[[TI0]], %[[TI1]] jointly %[[TV]]
+//       CHECK-RWT:  %[[DST:.*]] = bufferization.alloc_tensor()
+//       CHECK-RWT:  sparse_tensor.foreach in %[[COO]]
+//       CHECK-RWT:  ^bb0(%[[F2I0:.*]]: index, %[[F2I1:.*]]: index, %[[F2V:.*]]: f32):
+//       CHECK-RWT:    sparse_tensor.insert %[[F2V]] into %[[DST]]{{\[}}%[[F2I0]], %[[F2I1]]]
+//       CHECK-RWT:  }
+//       CHECK-RWT:  %[[R:.*]] = sparse_tensor.convert %[[DST]]
+//       CHECK-RWT:  bufferization.dealloc_tensor %[[COO]]
+//       CHECK-RWT:  return %[[R]] : tensor<8x7xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>
 func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{
   // Initialize a tensor.
   %0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32>
index ee7499a..8980c42 100644 (file)
@@ -1,5 +1,8 @@
 // RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
 
+// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \
+// RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
+
 #SparseVector = #sparse_tensor.encoding<{
   dimLevelType = ["compressed"]
 }>
@@ -128,6 +131,18 @@ func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<
 //       CHECK: }
 //       CHECK: %[[T:.*]] = bufferization.to_tensor %[[M]] : memref<2x4xf64>
 //       CHECK: return %[[T]] : tensor<2x4xf64>
+
+// CHECK-RWT-LABEL: func.func @sparse_convert_2d(
+//  CHECK-RWT-SAME: %[[A:.*]]: tensor<2x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>) -> tensor<2x4xf64> {
+//       CHECK-RWT: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
+//       CHECK-RWT: %[[B:.*]] = memref.alloc() : memref<2x4xf64>
+//       CHECK-RWT: linalg.fill ins(%[[F0]] : f64) outs(%[[B]]
+//       CHECK-RWT: sparse_tensor.foreach in %[[A]]
+//       CHECK-RWT: ^bb0(%[[FI0:.*]]: index, %[[FI1:.*]]: index, %[[FV:.*]]: f64):
+//       CHECK-RWT:   memref.store %[[FV]], %[[B]]{{\[}}%[[FI0]], %[[FI1]]]
+//       CHECK-RWT: }
+//       CHECK-RWT: %[[T:.*]] = bufferization.to_tensor %[[B]]
+//       CHECK-RWT: return %[[T]] : tensor<2x4xf64>
 func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x4xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64, #SparseMatrix> to tensor<2x4xf64>
   return %0 : tensor<2x4xf64>
@@ -260,6 +275,22 @@ func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tens
 //       CHECK: }
 //       CHECK: %[[T:.*]] = bufferization.to_tensor %[[M]] : memref<?x?xf64>
 //       CHECK: return %[[T]] : tensor<?x?xf64>
+
+// CHECK-RWT-LABEL: func.func @sparse_convert_2d_dyn2(
+//  CHECK-RWT-SAME: %[[A:.*]]: tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>) -> tensor<?x?xf64> {
+//   CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-RWT-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
+//       CHECK-RWT: %[[D0:.*]] = tensor.dim %[[A]], %[[C0]]
+//       CHECK-RWT: %[[D1:.*]] = tensor.dim %[[A]], %[[C1]]
+//       CHECK-RWT: %[[B:.*]] = memref.alloc(%[[D0]], %[[D1]])
+//       CHECK-RWT: linalg.fill ins(%[[F0]] : f64) outs(%[[B]]
+//       CHECK-RWT: sparse_tensor.foreach in %[[A]]
+//       CHECK-RWT: ^bb0(%[[FI0:.*]]: index, %[[FI1:.*]]: index, %[[FV:.*]]: f64):
+//       CHECK-RWT:   memref.store %[[FV]], %[[B]]{{\[}}%[[FI0]], %[[FI1]]]
+//       CHECK-RWT: }
+//       CHECK-RWT: %[[T:.*]] = bufferization.to_tensor %[[B]]
+//       CHECK-RWT: return %[[T]] : tensor<?x?xf64>
 func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tensor<?x?xf64> {
   %0 = sparse_tensor.convert %arg0 : tensor<?x?xf64, #SparseMatrix> to tensor<?x?xf64>
   return %0 : tensor<?x?xf64>
index cd5575b..92f9e46 100644 (file)
@@ -6,6 +6,9 @@
 // RUN: mlir-opt %s --sparse-tensor-conversion="s2s-strategy=0" \
 // RUN:    --canonicalize --cse | FileCheck %s -check-prefixes=CHECK-AUTO,CHECK
 
+// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \
+// RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
+
 #SparseVector64 = #sparse_tensor.encoding<{
   dimLevelType = ["compressed"],
   pointerBitWidth = 64,
@@ -79,6 +82,24 @@ func.func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor
 //   CHECK-AUTO-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<1xindex> to memref<?xindex>
 //       CHECK-AUTO: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[SparseToSparse]], %[[A]])
 //       CHECK-AUTO: return %[[T]] : !llvm.ptr<i8>
+
+// CHECK-RWT-LABEL: func.func @sparse_convert(
+//  CHECK-RWT-SAME: %[[A:.*]]: tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 64, indexBitWidth = 64 }>>)
+//  CHECK-RWT-DAG:  %[[C0:.*]] = arith.constant 0 : index
+//  CHECK-RWT-DAG:  %[[C1:.*]] = arith.constant 1 : index
+//      CHECK-RWT:  %[[D:.*]] = tensor.dim %[[A]], %[[C0]]
+//      CHECK-RWT:  %[[I0:.*]] = sparse_tensor.indices %[[A]] {dimension = 0 : index}
+//      CHECK-RWT:  %[[NNZr:.*]] = memref.load %[[I0]]{{\[}}%[[C1]]] : memref<?xi64>
+//      CHECK-RWT:  %[[NNZ:.*]] = arith.index_cast %[[NNZr]] : i64 to index
+//      CHECK-RWT:  %[[V:.*]] = sparse_tensor.values %[[A]]
+//      CHECK-RWT:  sparse_tensor.sort %[[NNZ]], %[[I0]] jointly %[[V]]
+//      CHECK-RWT:  %[[DST:.*]] = bufferization.alloc_tensor(%[[D]])
+//      CHECK-RWT:  sparse_tensor.foreach in %[[A]]
+//      CHECK-RWT:  ^bb0(%[[FI2:.*]]: index, %[[FV2:.*]]: f32):
+//      CHECK-RWT:    sparse_tensor.insert %[[FV2]] into %[[DST]]{{\[}}%[[FI2]]]
+//      CHECK-RWT:  }
+//      CHECK-RWT:  %[[R:.*]] = sparse_tensor.convert %[[DST]]
+//      CHECK-RWT:  return %[[R]] : tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 32, indexBitWidth = 32 }>>
 func.func @sparse_convert(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32, #SparseVector32> {
   %0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector64> to tensor<?xf32, #SparseVector32>
   return %0 : tensor<?xf32, #SparseVector32>
index 3d2a5e2..79b616d 100644 (file)
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -sparse-tensor-rewrite=enable-runtime-library=false | FileCheck %s
+// RUN: mlir-opt %s -sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" |\
+// RUN: FileCheck %s
 
 #CSR = #sparse_tensor.encoding<{
   dimLevelType = ["dense", "compressed"]
index 8ab23e1..7280c6f 100644 (file)
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s --sparse-tensor-rewrite=enable-runtime-library=false --sparsification | FileCheck %s
+// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" \
+// RUN: --sparsification | FileCheck %s
 
 #DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
 
index adc1f6d..c162bac 100644 (file)
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND
 // RUN: mlir-opt %s --sparse-tensor-conversion --cse --canonicalize | FileCheck %s --check-prefix=CHECK-CONV
-// RUN: mlir-opt %s --sparse-tensor-rewrite=enable-runtime-library=false --cse --canonicalize  | FileCheck %s --check-prefix=CHECK-RWT
+// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" \
+// RUN: --cse --canonicalize  | FileCheck %s --check-prefix=CHECK-RWT
 
 #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
 #SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>