[mlir][sparse] Add sparse rewriting rules for tensor::ReshapeOp
authorAnlun Xu <anlunx@google.com>
Mon, 1 May 2023 01:16:11 +0000 (18:16 -0700)
committerAnlun Xu <anlunx@google.com>
Tue, 16 May 2023 21:56:33 +0000 (14:56 -0700)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D149564

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir [new file with mode: 0644]
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir [new file with mode: 0644]

index ca27794..a16ab66 100644 (file)
@@ -386,6 +386,106 @@ public:
 };
 
 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
+struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
+public:
+  using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ReshapeOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value srcTensor = op.getSource();
+    const auto srcTp = getSparseTensorType(srcTensor);
+    const auto dstTp = getSparseTensorType(op.getResult());
+
+    if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
+        !dstTp.hasStaticDimShape())
+      return failure();
+
+    SmallVector<Value> srcSizes;
+    sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
+    SmallVector<Value> dstSizes;
+    for (Dimension d : dstTp.getDimShape())
+      dstSizes.push_back(constantIndex(rewriter, loc, d));
+
+    Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
+    // Only need an unordered COO buffer if input and output are not sorted
+    // in the same way.
+    Type bufferTp =
+        srcTp.isAllOrdered() && srcTp.isIdentity() && dstTp.isIdentity()
+            ? dstTp.getRankedTensorType()
+            : getUnorderedCOOFromType(dstTp);
+    SmallVector<Value> dynSizes;
+    Value buffer = rewriter
+                       .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
+                                              nnz, Attribute())
+                       .getResult();
+
+    // Convert src coordinates to dst coordinates by first collapsing it to 1D
+    // and then expand it to the match the rank of the destination tensor.
+    // Implemented as follows:
+    //   foreach srcCoords %srcTensor
+    //     collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
+    //     expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
+    //     insert expandedCoords, %buffer
+    //
+    // followed by an optional
+    //   %t = sparse_tensor.cast %tmp
+    // depending on whether the input/output are sorted in the same way.
+    const auto encSrc = srcTp.getEncoding();
+    ForeachOp foreachOp = rewriter.create<ForeachOp>(
+        loc, srcTensor, buffer,
+        [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
+            ValueRange reduc) {
+          const Dimension srcRank = srcTp.getDimRank();
+          SmallVector<Value> srcDcvs;
+          srcDcvs.reserve(srcRank);
+          for (Dimension d = 0; d < srcRank; d++) {
+            // FIXME: `toStoredDim` is deprecated
+            Level lvl = toStoredDim(encSrc, d);
+            srcDcvs.push_back(srcLcvs[lvl]);
+          }
+
+          Value collapsed_size = constantIndex(builder, loc, 1);
+          for (Dimension d = 0; d < srcRank; d++)
+            collapsed_size =
+                builder.create<arith::MulIOp>(loc, collapsed_size, srcSizes[d]);
+          SmallVector<Value, 1> collapsedSizes = {collapsed_size};
+
+          ReassociationIndices collapse_indices;
+          for (Dimension i = 0; i < srcRank; i++)
+            collapse_indices.push_back(i);
+          SmallVector<ReassociationIndices, 1> collapse_reassociation = {
+              collapse_indices};
+          SmallVector<Value, 1> collapsedDcvs;
+          reshapeCvs(builder, loc, collapse_reassociation, srcSizes, srcDcvs,
+                     collapsedSizes, collapsedDcvs);
+
+          ReassociationIndices expand_indices;
+          for (Dimension i = 0; i < dstTp.getDimRank(); i++)
+            expand_indices.push_back(i);
+          SmallVector<ReassociationIndices, 1> expand_reassociation = {
+              expand_indices};
+          SmallVector<Value> dstDcvs;
+          reshapeCvs(builder, loc, expand_reassociation, collapsedSizes,
+                     collapsedDcvs, dstSizes, dstDcvs);
+
+          auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
+          builder.create<sparse_tensor::YieldOp>(loc, t);
+        });
+
+    Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
+    if (bufferTp != dstTp) {
+      auto dstRTT = dstTp.getRankedTensorType();
+      Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
+      rewriter.create<DeallocTensorOp>(loc, t);
+      t = converted;
+    }
+    rewriter.replaceOp(op, t);
+    return success();
+  }
+};
+
+/// Sparse rewriting rule for sparse-to-sparse reshape operator.
 template <typename ReshapeOp>
 struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
 public:
@@ -1169,7 +1269,8 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
                                                bool enableForeach,
                                                bool enableConvert) {
   patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
-               ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
+               ReshapeRewriter<tensor::CollapseShapeOp>, TensorReshapeRewriter>(
+      patterns.getContext());
   if (enableForeach)
     patterns.add<ForeachRewriter>(patterns.getContext());
   // TODO: If RT not enabled, rewrite concatenate ops, etc here.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir
new file mode 100644 (file)
index 0000000..369044c
--- /dev/null
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
+// RUN: --cse --canonicalize  | FileCheck %s
+
+#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
+
+// CHECK:         func.func @sparse_reshape(
+// CHECK-SAME:    %[[S:.*]]:
+// CHECK-DAG:     %[[C25:.*]] = arith.constant 25 : index
+// CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK:         %[[B:.*]] = bufferization.alloc_tensor()
+// CHECK:         %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
+// CHECK:         %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
+// CHECK:         %[[P1:.*]] = sparse_tensor.positions %[[S]] {level = 1 : index}
+// CHECK:         %[[I1:.*]] = sparse_tensor.coordinates %[[S]] {level = 1 : index}
+// CHECK:         %[[V:.*]] = sparse_tensor.values %[[S]]
+// CHECK:         %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK:         %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK:         %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[A0:.*]] = %[[B]])
+// CHECK:           %[[SI0:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
+// CHECK-DAG:       %[[S1:.*]] = memref.load %[[P1]]{{\[}}%[[I]]] : memref<?xindex>
+// CHECK-DAG:       %[[PE1:.*]] = arith.addi %[[I]], %[[C1]] : index
+// CHECK:           %[[E1:.*]] = memref.load %[[P1]]{{\[}}%[[PE1]]] : memref<?xindex>
+// CHECK:           %[[RET_1:.*]] = scf.for %[[J:.*]] = %[[S1]] to %[[E1]] step %[[C1]] iter_args(%[[A1:.*]] = %[[A0]])
+// CHECK:             %[[SI1:.*]] = memref.load %[[I1]]{{\[}}%[[J]]] : memref<?xindex>
+// CHECK:             %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[J]]] : memref<?xf64>
+// CHECK:             %[[T:.*]] = arith.muli %[[SI0]], %[[C25]] : index
+// CHECK:             %[[DI:.*]] = arith.addi %[[T]], %[[SI1]] : index
+// CHECK:             %[[D:.*]] = arith.divui %[[DI]], %[[C10]] : index
+// CHECK:             %[[R:.*]] = arith.remui %[[DI]], %[[C10]] : index
+// CHECK:             %[[R1:.*]] = sparse_tensor.insert %[[SV]] into %[[A1]]{{\[}}%[[D]], %[[R]]]
+// CHECK:              scf.yield %[[R1]]
+// CHECK:            }
+// CHECK:            scf.yield %[[RET_1]]
+// CHECK:         }
+// CHECK:        %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
+// CHECK:        return %[[NT1]] : tensor<10x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+//
+func.func @sparse_reshape(%arg0: tensor<4x25xf64, #SparseMatrix>) -> tensor<10x10xf64, #SparseMatrix> {
+  %shape = arith.constant dense <[ 10, 10 ]> : tensor<2xi32>
+  %0 = tensor.reshape %arg0(%shape) :
+    (tensor<4x25xf64, #SparseMatrix>, tensor<2xi32>) -> tensor<10x10xf64, #SparseMatrix>
+  return %0 : tensor<10x10xf64, #SparseMatrix>
+}
\ No newline at end of file
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir
new file mode 100644 (file)
index 0000000..4945294
--- /dev/null
@@ -0,0 +1,87 @@
+// DEFINE: %{option} = enable-runtime-library=true
+// DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
+// DEFINE: %{run} = mlir-cpu-runner \
+// DEFINE:  -e entry -entry-point-result=void  \
+// DEFINE:  -shared-libs=%mlir_c_runner_utils | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{compile} | %{run}
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{option} = enable-runtime-library=false
+// RUN: %{compile} | %{run}
+//
+// Do the same run, but now with direct IR generation and vectorization.
+// REDEFINE: %{option} = "enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
+// RUN: %{compile} | %{run}
+
+#SparseVector = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed"]
+}>
+
+#SparseMatrix = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed", "compressed"]
+}>
+
+#Sparse3dTensor = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed", "compressed", "compressed"]
+}>
+
+module {
+
+  func.func @reshape0(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<2x6xf64, #SparseMatrix> {
+    %shape = arith.constant dense <[ 2, 6 ]> : tensor<2xi32>
+    %0 = tensor.reshape %arg0(%shape) : (tensor<3x4xf64, #SparseMatrix>, tensor<2xi32>) -> tensor<2x6xf64, #SparseMatrix>
+    return %0 : tensor<2x6xf64, #SparseMatrix>
+  }
+
+  func.func @reshape1(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> {
+    %shape = arith.constant dense <[ 12 ]> : tensor<1xi32>
+    %0 = tensor.reshape %arg0(%shape) : (tensor<3x4xf64, #SparseMatrix>, tensor<1xi32>) -> tensor<12xf64, #SparseVector>
+    return %0 : tensor<12xf64, #SparseVector>
+  }
+
+  func.func @reshape2(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<2x3x2xf64, #Sparse3dTensor> {
+    %shape = arith.constant dense <[ 2, 3, 2 ]> : tensor<3xi32>
+    %0 = tensor.reshape %arg0(%shape) : (tensor<3x4xf64, #SparseMatrix>, tensor<3xi32>) -> tensor<2x3x2xf64, #Sparse3dTensor>
+    return %0 : tensor<2x3x2xf64, #Sparse3dTensor>
+  }
+
+
+  func.func @entry() {
+    %m = arith.constant dense <[ [ 1.1,  0.0,  1.3,  0.0 ],
+                                 [ 2.1,  0.0,  2.3,  0.0 ],
+                                 [ 3.1,  0.0,  3.3,  0.0 ]]> : tensor<3x4xf64>
+    %sm = sparse_tensor.convert %m : tensor<3x4xf64> to tensor<3x4xf64, #SparseMatrix>
+
+    %reshaped0 = call @reshape0(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<2x6xf64, #SparseMatrix>
+    %reshaped1 = call @reshape1(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector>
+    %reshaped2 = call @reshape2(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<2x3x2xf64, #Sparse3dTensor>
+
+    %c0 = arith.constant 0 : index
+    %df = arith.constant -1.0 : f64
+
+    // CHECK: ( 1.1, 1.3, 2.1, 2.3, 3.1, 3.3
+    %b0 = sparse_tensor.values %reshaped0: tensor<2x6xf64, #SparseMatrix> to memref<?xf64>
+    %v0 = vector.transfer_read %b0[%c0], %df: memref<?xf64>, vector<12xf64>
+    vector.print %v0 : vector<12xf64>
+
+    // CHECK: ( 1.1, 1.3, 2.1, 2.3, 3.1, 3.3
+    %b1 = sparse_tensor.values %reshaped1: tensor<12xf64, #SparseVector> to memref<?xf64>
+    %v1 = vector.transfer_read %b1[%c0], %df: memref<?xf64>, vector<12xf64>
+    vector.print %v1 : vector<12xf64>
+
+    // CHECK: ( 1.1, 1.3, 2.1, 2.3, 3.1, 3.3
+    %b2 = sparse_tensor.values %reshaped2: tensor<2x3x2xf64, #Sparse3dTensor> to memref<?xf64>
+    %v2 = vector.transfer_read %b2[%c0], %df: memref<?xf64>, vector<12xf64>
+    vector.print %v2: vector<12xf64>
+
+    bufferization.dealloc_tensor %sm : tensor<3x4xf64, #SparseMatrix>
+    bufferization.dealloc_tensor %reshaped0 : tensor<2x6xf64, #SparseMatrix>
+    bufferization.dealloc_tensor %reshaped1 : tensor<12xf64, #SparseVector>
+    bufferization.dealloc_tensor %reshaped2 : tensor<2x3x2xf64, #Sparse3dTensor>
+
+    return
+  }
+
+}
\ No newline at end of file