[mlir][sparse] Move the implementation of sparse_tensor.push_back to the buffer rewriter.
authorbixia1 <bixia@google.com>
Thu, 29 Sep 2022 21:09:04 +0000 (14:09 -0700)
committerbixia1 <bixia@google.com>
Thu, 29 Sep 2022 22:06:00 +0000 (15:06 -0700)
Reviewed By: aartbik, Peiming

Differential Revision: https://reviews.llvm.org/D134777

mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir [moved from mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir with 100% similarity]

index 137f82f..7a1c6fa 100644 (file)
@@ -337,6 +337,60 @@ static void createSortFunc(OpBuilder &builder, ModuleOp module,
 
 namespace {
 
+/// Sparse rewriting rule for the push_back operator.
+struct PushBackRewriter : OpRewritePattern<PushBackOp> {
+public:
+  using OpRewritePattern<PushBackOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(PushBackOp op,
+                                PatternRewriter &rewriter) const override {
+    // Rewrite push_back(buffer, value) to:
+    // if (size(buffer) >= capacity(buffer))
+    //    new_capacity = capacity(buffer)*2
+    //    new_buffer = realloc(buffer, new_capacity)
+    // buffer = new_buffer
+    // store(buffer, value)
+    // size(buffer)++
+    Location loc = op->getLoc();
+    Value c0 = constantIndex(rewriter, loc, 0);
+    Value buffer = op.getInBuffer();
+    Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
+    Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue());
+    Value bufferSizes = op.getBufferSizes();
+    Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
+    Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
+                                                size, capacity);
+    Value value = op.getValue();
+    auto bufferType =
+        MemRefType::get({ShapedType::kDynamicSize}, value.getType());
+    scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
+                                                /*else=*/true);
+    // True branch.
+    rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+    Value c2 = constantIndex(rewriter, loc, 2);
+    capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
+    Value newBuffer =
+        rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
+    rewriter.create<scf::YieldOp>(loc, newBuffer);
+
+    // False branch.
+    rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+    rewriter.create<scf::YieldOp>(loc, buffer);
+
+    // Add the value to the end of the buffer.
+    rewriter.setInsertionPointAfter(ifOp);
+    buffer = ifOp.getResult(0);
+    rewriter.create<memref::StoreOp>(loc, value, buffer, size);
+
+    // Increment the size of the buffer by 1.
+    Value c1 = constantIndex(rewriter, loc, 1);
+    size = rewriter.create<arith::AddIOp>(loc, size, c1);
+    rewriter.create<memref::StoreOp>(loc, size, bufferSizes, idx);
+
+    rewriter.replaceOp(op, buffer);
+    return success();
+  }
+};
+
 /// Sparse rewriting rule for the sort operator.
 struct SortRewriter : public OpRewritePattern<SortOp> {
 public:
@@ -378,5 +432,5 @@ public:
 //===---------------------------------------------------------------------===//
 
 void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns) {
-  patterns.add<SortRewriter>(patterns.getContext());
+  patterns.add<PushBackRewriter, SortRewriter>(patterns.getContext());
 }
index d8e5d20..e40c806 100644 (file)
@@ -564,61 +564,6 @@ public:
   }
 };
 
-/// Sparse codegen rule for the push_back operator.
-class SparsePushBackConverter : public OpConversionPattern<PushBackOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(PushBackOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // Lower push_back(buffer, value) to:
-    // if (size(buffer) >= capacity(buffer))
-    //    new_capacity = capacity(buffer)*2
-    //    new_buffer = realloc(buffer, new_capacity)
-    // buffer = new_buffer
-    // store(buffer, value)
-    // size(buffer)++
-    Location loc = op->getLoc();
-    Value c0 = constantIndex(rewriter, loc, 0);
-    Value buffer = adaptor.getInBuffer();
-    Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
-    Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue());
-    Value bufferSizes = adaptor.getBufferSizes();
-    Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
-    Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
-                                                size, capacity);
-    Value value = adaptor.getValue();
-    auto bufferType =
-        MemRefType::get({ShapedType::kDynamicSize}, value.getType());
-    scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
-                                                /*else=*/true);
-    // True branch.
-    rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
-    Value c2 = constantIndex(rewriter, loc, 2);
-    capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
-    Value newBuffer =
-        rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
-    rewriter.create<scf::YieldOp>(loc, newBuffer);
-
-    // False branch.
-    rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
-    rewriter.create<scf::YieldOp>(loc, buffer);
-
-    // Add the value to the end of the buffer.
-    rewriter.setInsertionPointAfter(ifOp);
-    buffer = ifOp.getResult(0);
-    rewriter.create<memref::StoreOp>(loc, value, buffer, size);
-
-    // Increment the size of the buffer by 1.
-    Value c1 = constantIndex(rewriter, loc, 1);
-    size = rewriter.create<arith::AddIOp>(loc, size, c1);
-    rewriter.create<memref::StoreOp>(loc, size, bufferSizes, idx);
-
-    rewriter.replaceOp(op, buffer);
-    return success();
-  }
-};
-
 /// Base class for getter-like operations, e.g., to_indices, to_pointers.
 template <typename SourceOp, typename Base>
 class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
@@ -703,7 +648,6 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
                SparseCastConverter, SparseTensorAllocConverter,
                SparseTensorDeallocConverter, SparseTensorLoadConverter,
                SparseExpandConverter, SparseCompressConverter,
-               SparsePushBackConverter, SparseToPointersConverter,
-               SparseToIndicesConverter, SparseToValuesConverter>(
-      typeConverter, patterns.getContext());
+               SparseToPointersConverter, SparseToIndicesConverter,
+               SparseToValuesConverter>(typeConverter, patterns.getContext());
 }
index 51bcc6a..b208dfe 100644 (file)
@@ -160,6 +160,7 @@ struct SparseTensorCodegenPass
     // Most ops in the sparse dialect must go!
     target.addIllegalDialect<SparseTensorDialect>();
     target.addLegalOp<SortOp>();
+    target.addLegalOp<PushBackOp>();
     // All dynamic rules below accept new function, call, return, and various
     // tensor and bufferization operations as legal output of the rewriting
     // provided that all sparse tensor types have been fully rewritten.
index e40064b..5aef2be 100644 (file)
@@ -1,5 +1,31 @@
 // RUN: mlir-opt %s --sparse-buffer-rewrite  --canonicalize --cse | FileCheck %s
 
+// CHECK-LABEL: func @sparse_push_back(
+//  CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+//  CHECK-SAME: %[[B:.*]]: memref<?xf64>,
+//  CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//       CHECK: %[[S:.*]] = memref.dim %[[B]], %[[C0]]
+//       CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
+//       CHECK: %[[T:.*]] = arith.cmpi uge, %[[P]], %[[S]]
+//       CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref<?xf64>) {
+//       CHECK:  %[[P1:.*]] = arith.muli %[[S]], %[[C2]]
+//       CHECK:  %[[M2:.*]] = memref.realloc %[[B]](%[[P1]])
+//       CHECK:  scf.yield %[[M2]] : memref<?xf64>
+//       CHECK: } else {
+//       CHECK:  scf.yield %[[B]] : memref<?xf64>
+//       CHECK: }
+//       CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[P]]]
+//       CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]]
+//       CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
+//       CHECK: return %[[M]] : memref<?xf64>
+func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
+  %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
+  return %0 : memref<?xf64>
+}
+
 // CHECK-LABEL:   func.func private @_sparse_less_than_1_i8(
 // CHECK-SAME:                                              %[[I:arg0]]: index,
 // CHECK-SAME:                                              %[[J:.*]]: index,
index 1263d3e..5bc7535 100644 (file)
@@ -425,29 +425,3 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
     : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #UCSR>
   return
 }
-
-// CHECK-LABEL: func @sparse_push_back(
-//  CHECK-SAME: %[[A:.*]]: memref<?xindex>,
-//  CHECK-SAME: %[[B:.*]]: memref<?xf64>,
-//  CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
-//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-//       CHECK: %[[S:.*]] = memref.dim %[[B]], %[[C0]]
-//       CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
-//       CHECK: %[[T:.*]] = arith.cmpi uge, %[[P]], %[[S]]
-//       CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref<?xf64>) {
-//       CHECK:  %[[P1:.*]] = arith.muli %[[S]], %[[C2]]
-//       CHECK:  %[[M2:.*]] = memref.realloc %[[B]](%[[P1]])
-//       CHECK:  scf.yield %[[M2]] : memref<?xf64>
-//       CHECK: } else {
-//       CHECK:  scf.yield %[[B]] : memref<?xf64>
-//       CHECK: }
-//       CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[P]]]
-//       CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]]
-//       CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
-//       CHECK: return %[[M]] : memref<?xf64>
-func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
-  %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64 to memref<?xf64>
-  return %0 : memref<?xf64>
-}