[mlir][sparse] Add rewriting rules for the sparse_tensor.new operator.
authorbixia1 <bixia@google.com>
Tue, 11 Oct 2022 20:47:45 +0000 (13:47 -0700)
committerbixia1 <bixia@google.com>
Tue, 11 Oct 2022 22:23:11 +0000 (15:23 -0700)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D135692

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir [new file with mode: 0644]

index 9dfcaec..1d02b37 100644 (file)
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -514,6 +515,97 @@ public:
   }
 };
 
+/// Sparse rewriting rule for the new operator.
+struct NewRewriter : public OpRewritePattern<NewOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(NewOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
+    SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
+    if (!encDst) {
+      return failure();
+    }
+
+    // Create a sparse tensor reader.
+    Value fileName = op.getSource();
+    Type opaqueTp = getOpaquePointerType(rewriter);
+    Value reader = createFuncCall(rewriter, loc, "createSparseTensorReader",
+                                  {opaqueTp}, {fileName}, EmitCInterface::Off)
+                       .getResult(0);
+
+    // Allocate a buffer for storing dimension sizes and indices.
+    Type indexTp = rewriter.getIndexType();
+    auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp);
+    uint64_t rank = dstTp.getRank();
+    Value dimSizes = rewriter.create<memref::AllocOp>(
+        loc, memTp, ValueRange{constantIndex(rewriter, loc, rank)});
+
+    // If the result tensor has dynamic dimensions, get the dynamic sizes from
+    // the sparse tensor reader.
+    SmallVector<Value, 4> dynSizesArray;
+    if (!dstTp.hasStaticShape()) {
+      createFuncCall(rewriter, loc, "getSparseTensorReaderDimSizes", {},
+                     {reader, dimSizes}, EmitCInterface::On)
+          .getResult(0);
+      ArrayRef<int64_t> dstShape = dstTp.getShape();
+      for (auto &d : llvm::enumerate(dstShape)) {
+        if (d.value() == ShapedType::kDynamicSize) {
+          dynSizesArray.push_back(rewriter.create<memref::LoadOp>(
+              loc, dimSizes, constantIndex(rewriter, loc, d.index())));
+        }
+      }
+    }
+
+    // Implement the NewOp as follows:
+    //   %tmp = bufferization.alloc_tensor : an unordered COO with identity
+    //                                       storage ordering
+    //   for i = 0 to nnz
+    //     get the next element from the input file
+    //     insert the element to %tmp
+    //   %t = sparse_tensor.ConvertOp %tmp
+    RankedTensorType cooTp = getUnorderedCOOFromType(dstTp);
+    auto cooBuffer =
+        rewriter.create<AllocTensorOp>(loc, cooTp, dynSizesArray).getResult();
+
+    Value c0 = constantIndex(rewriter, loc, 0);
+    Value c1 = constantIndex(rewriter, loc, 1);
+    Value nnz = createFuncCall(rewriter, loc, "getSparseTensorReaderNNZ",
+                               {indexTp}, {reader}, EmitCInterface::Off)
+                    .getResult(0);
+    scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, c0, nnz, c1);
+    rewriter.setInsertionPointToStart(forOp.getBody());
+
+    Type eltTp = dstTp.getElementType();
+    SmallString<18> getNextFuncName{"getSparseTensorReaderNext",
+                                    primaryTypeFunctionSuffix(eltTp)};
+    Value indices = dimSizes; // Reuse the indices memref to store indices.
+    Value value = createFuncCall(rewriter, loc, getNextFuncName, {eltTp},
+                                 {reader, indices}, EmitCInterface::On)
+                      .getResult(0);
+    SmallVector<Value, 4> indicesArray;
+    for (int64_t i = 0; i < rank; i++) {
+      indicesArray.push_back(rewriter.create<memref::LoadOp>(
+          loc, indices, constantIndex(rewriter, loc, i)));
+    }
+    rewriter.create<InsertOp>(loc, value, cooBuffer, indicesArray);
+    rewriter.setInsertionPointAfter(forOp);
+
+    // Release the indices buffer and the sparse tensor reader.
+    rewriter.create<memref::DeallocOp>(loc, indices);
+    createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
+                   EmitCInterface::Off);
+
+    Value newOp = rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, cooBuffer);
+
+    // Release the unordered COO tensor buffer.
+    rewriter.setInsertionPointAfterValue(newOp);
+    rewriter.create<DeallocTensorOp>(loc, cooBuffer);
+
+    return success();
+  }
+};
+
 } // namespace
 
 //===---------------------------------------------------------------------===//
@@ -527,7 +619,7 @@ void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
       patterns.getContext());
   // TODO: If RT not enabled, rewrite concatenate ops, etc here.
   if (!enableRT)
-    patterns.add<ConcatenateRewriter,
+    patterns.add<ConcatenateRewriter, NewRewriter,
                  Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
                  Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>>(
         patterns.getContext());
diff --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
new file mode 100644 (file)
index 0000000..d77f8ed
--- /dev/null
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s -sparse-tensor-rewrite=enable-runtime-library=false  | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{
+  dimLevelType = ["dense", "compressed"]
+}>
+
+// CHECK-LABEL:   func.func @sparse_new(
+// CHECK-SAME:    %[[A:.*]]: !llvm.ptr<i8>) -> tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> {
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK:         %[[R:.*]] = call @createSparseTensorReader(%[[A]])
+// CHECK:         %[[DS:.*]] = memref.alloc(%[[C2]]) : memref<?xindex>
+// CHECK:         call @getSparseTensorReaderDimSizes(%[[R]], %[[DS]])
+// CHECK:         %[[D0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]]
+// CHECK:         %[[D1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]]
+// CHECK:         %[[T:.*]] = bufferization.alloc_tensor(%[[D0]], %[[D1]])
+// CHECK:         %[[N:.*]] = call @getSparseTensorReaderNNZ(%[[R]])
+// CHECK:         scf.for %{{.*}} = %[[C0]] to %[[N]] step %[[C1]] {
+// CHECK:           %[[V:.*]] = func.call @getSparseTensorReaderNextF32(%[[R]], %[[DS]])
+// CHECK:           %[[E0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]]
+// CHECK:           %[[E1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]]
+// CHECK:           sparse_tensor.insert %[[V]] into %[[T]]{{\[}}%[[E0]], %[[E1]]]
+// CHECK:         }
+// CHECK:         memref.dealloc %[[DS]]
+// CHECK:         call @delSparseTensorReader(%[[R]])
+// CHECK:         %[[R:.*]] = sparse_tensor.convert %[[T]]
+// CHECK:         bufferization.dealloc_tensor %[[T]]
+// CHECK:         return %[[R]]
+// CHECK:         }
+func.func @sparse_new(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
+  %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #CSR>
+  return %0 : tensor<?x?xf32, #CSR>
+}