[mlir][sparse] Implement the rewrite for sparse_tensor.push_back a value n times.
authorbixia1 <bixia@google.com>
Fri, 28 Oct 2022 17:26:54 +0000 (10:26 -0700)
committerbixia1 <bixia@google.com>
Mon, 31 Oct 2022 15:19:12 +0000 (08:19 -0700)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D136654

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir

index 09593c2..421706e 100644 (file)
@@ -193,6 +193,7 @@ def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> {
   let constructor = "mlir::createSparseBufferRewritePass()";
   let dependentDialects = [
     "arith::ArithDialect",
+    "linalg::LinalgDialect",
     "memref::MemRefDialect",
     "scf::SCFDialect",
     "sparse_tensor::SparseTensorDialect",
index bc09772..929d4a4 100644 (file)
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -498,13 +499,17 @@ public:
   using OpRewritePattern<PushBackOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(PushBackOp op,
                                 PatternRewriter &rewriter) const override {
-    // Rewrite push_back(buffer, value) to:
-    // if (size(buffer) >= capacity(buffer))
-    //    new_capacity = capacity(buffer)*2
+    // Rewrite push_back(buffer, value, n) to:
+    // new_size = size(buffer) + n
+    // if (new_size > capacity(buffer))
+    //    while new_size > new_capacity
+    //      new_capacity = new_capacity*2
     //    new_buffer = realloc(buffer, new_capacity)
     // buffer = new_buffer
-    // store(buffer, value)
-    // size(buffer)++
+    // subBuffer = subviewof(buffer)
+    // linalg.fill subBuffer value
+    //
+    // size(buffer) += n
     //
     // The capacity check is skipped when the attribute inbounds is presented.
     Location loc = op->getLoc();
@@ -516,18 +521,50 @@ public:
     Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
     Value value = op.getValue();
 
+    Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
+    Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
+    auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
+    bool nIsOne = (nValue && nValue.value() == 1);
+
     if (!op.getInbounds()) {
       Value cond = rewriter.create<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::uge, size, capacity);
+          loc, arith::CmpIPredicate::ugt, newSize, capacity);
 
+      Value c2 = constantIndex(rewriter, loc, 2);
       auto bufferType =
           MemRefType::get({ShapedType::kDynamicSize}, value.getType());
       scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
                                                   /*else=*/true);
       // True branch.
       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
-      Value c2 = constantIndex(rewriter, loc, 2);
-      capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
+      if (nIsOne) {
+        capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
+      } else {
+        // Use a do-while loop to calculate the new capacity as follows:
+        //   do { new_capacity *= 2 } while (size > new_capacity)
+        scf::WhileOp whileOp =
+            rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
+
+        // The before-region of the WhileOp.
+        Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
+                                             {capacity.getType()}, {loc});
+        rewriter.setInsertionPointToEnd(before);
+
+        capacity =
+            rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
+        cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
+                                              newSize, capacity);
+        rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
+        // The after-region of the WhileOp.
+        Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
+                                            {capacity.getType()}, {loc});
+        rewriter.setInsertionPointToEnd(after);
+        rewriter.create<scf::YieldOp>(loc, after->getArguments());
+
+        rewriter.setInsertionPointAfter(whileOp);
+        capacity = whileOp.getResult(0);
+      }
+
       Value newBuffer =
           rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
       rewriter.create<scf::YieldOp>(loc, newBuffer);
@@ -542,13 +579,17 @@ public:
     }
 
     // Add the value to the end of the buffer.
-    rewriter.create<memref::StoreOp>(loc, value, buffer, size);
-
-    // Increment the size of the buffer by 1.
-    Value c1 = constantIndex(rewriter, loc, 1);
-    size = rewriter.create<arith::AddIOp>(loc, size, c1);
-    rewriter.create<memref::StoreOp>(loc, size, bufferSizes, idx);
+    if (nIsOne) {
+      rewriter.create<memref::StoreOp>(loc, value, buffer, size);
+    } else {
+      Value subBuffer = rewriter.create<memref::SubViewOp>(
+          loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
+          /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
+      rewriter.create<linalg::FillOp>(loc, value, subBuffer);
+    }
 
+    // Update the buffer size.
+    rewriter.create<memref::StoreOp>(loc, newSize, bufferSizes, idx);
     rewriter.replaceOp(op, buffer);
     return success();
   }
index 31c6ad5..114bfd8 100644 (file)
@@ -7,19 +7,19 @@
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-//       CHECK: %[[S:.*]] = memref.dim %[[B]], %[[C0]]
-//       CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
-//       CHECK: %[[T:.*]] = arith.cmpi uge, %[[P]], %[[S]]
+//       CHECK: %[[P1:.*]] = memref.dim %[[B]], %[[C0]]
+//       CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
+//       CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[C1]] : index
+//       CHECK: %[[T:.*]] = arith.cmpi ugt, %[[S2]], %[[P1]]
 //       CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref<?xf64>) {
-//       CHECK:  %[[P1:.*]] = arith.muli %[[S]], %[[C2]]
-//       CHECK:  %[[M2:.*]] = memref.realloc %[[B]](%[[P1]])
+//       CHECK:  %[[P2:.*]] = arith.muli %[[P1]], %[[C2]]
+//       CHECK:  %[[M2:.*]] = memref.realloc %[[B]](%[[P2]])
 //       CHECK:  scf.yield %[[M2]] : memref<?xf64>
 //       CHECK: } else {
 //       CHECK:  scf.yield %[[B]] : memref<?xf64>
 //       CHECK: }
-//       CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[P]]]
-//       CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]]
-//       CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
+//       CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[S1]]]
+//       CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]]
 //       CHECK: return %[[M]] : memref<?xf64>
 func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
   %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
@@ -28,16 +28,52 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
 
 // -----
 
+// CHECK-LABEL: func @sparse_push_back_n(
+//  CHECK-SAME: %[[A:.*]]: memref<?xindex>,
+//  CHECK-SAME: %[[B:.*]]: memref<?xf64>,
+//  CHECK-SAME: %[[C:.*]]: f64,
+//  CHECK-SAME: %[[D:.*]]: index) -> memref<?xf64> {
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//       CHECK: %[[P1:.*]] = memref.dim %[[B]], %[[C0]]
+//       CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
+//       CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[D]] : index
+//       CHECK: %[[T:.*]] = arith.cmpi ugt, %[[S2]], %[[P1]]
+//       CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref<?xf64>) {
+//       CHECK:   %[[P2:.*]] = scf.while (%[[I:.*]] = %[[P1]]) : (index) -> index {
+//       CHECK:     %[[P3:.*]] = arith.muli %[[I]], %[[C2]] : index
+//       CHECK:     %[[T2:.*]] = arith.cmpi ugt, %[[S2]], %[[P3]] : index
+//       CHECK:     scf.condition(%[[T2]]) %[[P3]] : index
+//       CHECK:   } do {
+//       CHECK:     ^bb0(%[[I2:.*]]: index):
+//       CHECK:     scf.yield %[[I2]] : index
+//       CHECK:   }  
+//       CHECK:  %[[M2:.*]] = memref.realloc %[[B]](%[[P2]])
+//       CHECK:  scf.yield %[[M2]] : memref<?xf64>
+//       CHECK: } else {
+//       CHECK:  scf.yield %[[B]] : memref<?xf64>
+//       CHECK: }
+//       CHECK: %[[S:.*]] = memref.subview %[[M]]{{\[}}%[[S1]]] {{\[}}%[[D]]] [1]
+//       CHECK: linalg.fill ins(%[[C]] : f64) outs(%[[S]]
+//       CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]]
+//       CHECK: return %[[M]] : memref<?xf64>
+func.func @sparse_push_back_n(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64, %arg3: index) -> memref<?xf64> {
+  %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64, index
+  return %0 : memref<?xf64>
+}
+
+// -----
+
 // CHECK-LABEL: func @sparse_push_back_inbound(
 //  CHECK-SAME: %[[A:.*]]: memref<?xindex>,
 //  CHECK-SAME: %[[B:.*]]: memref<?xf64>,
 //  CHECK-SAME: %[[C:.*]]: f64) -> memref<?xf64> {
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-//       CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
-//       CHECK: memref.store %[[C]], %[[B]]{{\[}}%[[P]]]
-//       CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]]
-//       CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]]
+//       CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]]
+//       CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[C1]]
+//       CHECK: memref.store %[[C]], %[[B]]{{\[}}%[[S1]]]
+//       CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]]
 //       CHECK: return %[[B]] : memref<?xf64>
 func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2: f64) -> memref<?xf64> {
   %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref<?xindex>, memref<?xf64>, f64
index 90d1b37..6a88e29 100644 (file)
@@ -8,6 +8,7 @@ module {
   func.func @entry() {
     %c0 = arith.constant 0 : index
     %c1 = arith.constant 1 : index
+    %c10 = arith.constant 10 : index
     %d0 = arith.constant 0.0 : f32
     %d1 = arith.constant 1.0 : f32
     %d2 = arith.constant 2.0 : f32
@@ -17,15 +18,19 @@ module {
 
     memref.store %c0, %bufferSizes[%c0] : memref<?xindex>
     %buffer2 = sparse_tensor.push_back %bufferSizes, %buffer, %d2 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32
-    %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32
+    %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1, %c10 {idx=0 : index} : memref<?xindex>, memref<?xf32>, f32, index
 
-    // CHECK: ( 2 )
-    %sizeValue = vector.transfer_read %bufferSizes[%c0], %c0: memref<?xindex>, vector<1xindex>
-    vector.print %sizeValue : vector<1xindex>
+    // CHECK: 16
+    %capacity = memref.dim %buffer3, %c0 : memref<?xf32>
+    vector.print %capacity : index
 
-    // CHECK ( 2, 1 )
-    %bufferValue = vector.transfer_read %buffer3[%c0], %d0: memref<?xf32>, vector<2xf32>
-    vector.print %bufferValue : vector<2xf32>
+    // CHECK: ( 11 )
+    %size = vector.transfer_read %bufferSizes[%c0], %c0: memref<?xindex>, vector<1xindex>
+    vector.print %size : vector<1xindex>
+
+    // CHECK (  2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+    %values = vector.transfer_read %buffer3[%c0], %d0: memref<?xf32>, vector<11xf32>
+    vector.print %values : vector<11xf32>
 
     // Release the buffers.
     memref.dealloc %bufferSizes : memref<?xindex>