From: bixia1 Date: Fri, 28 Oct 2022 17:26:54 +0000 (-0700) Subject: [mlir][sparse] Implement the rewrite for sparse_tensor.push_back a value n times. X-Git-Tag: upstream/17.0.6~28967 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d45be8873628ce39e76dba6f4533bf96aa9f1985;p=platform%2Fupstream%2Fllvm.git [mlir][sparse] Implement the rewrite for sparse_tensor.push_back a value n times. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D136654 --- diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index 09593c2..421706e 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -193,6 +193,7 @@ def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> { let constructor = "mlir::createSparseBufferRewritePass()"; let dependentDialects = [ "arith::ArithDialect", + "linalg::LinalgDialect", "memref::MemRefDialect", "scf::SCFDialect", "sparse_tensor::SparseTensorDialect", diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index bc09772..929d4a4 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" @@ -498,13 +499,17 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PushBackOp op, PatternRewriter &rewriter) const override { - // Rewrite push_back(buffer, value) to: - // if (size(buffer) >= capacity(buffer)) - // new_capacity = capacity(buffer)*2 + // Rewrite push_back(buffer, value, n) to: + // new_size = size(buffer) + n + // if (new_size > capacity(buffer)) + // while new_size > new_capacity + // new_capacity = new_capacity*2 // new_buffer = realloc(buffer, new_capacity) // buffer = new_buffer - // store(buffer, value) - // size(buffer)++ + // subBuffer = subviewof(buffer) + // linalg.fill subBuffer value + // + // size(buffer) += n // // The capacity check is skipped when the attribute inbounds is presented. Location loc = op->getLoc(); @@ -516,18 +521,50 @@ public: Value size = rewriter.create(loc, bufferSizes, idx); Value value = op.getValue(); + Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); + Value newSize = rewriter.create(loc, size, n); + auto nValue = dyn_cast_or_null(n.getDefiningOp()); + bool nIsOne = (nValue && nValue.value() == 1); + if (!op.getInbounds()) { Value cond = rewriter.create( - loc, arith::CmpIPredicate::uge, size, capacity); + loc, arith::CmpIPredicate::ugt, newSize, capacity); + Value c2 = constantIndex(rewriter, loc, 2); auto bufferType = MemRefType::get({ShapedType::kDynamicSize}, value.getType()); scf::IfOp ifOp = rewriter.create(loc, bufferType, cond, /*else=*/true); // True branch. rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - Value c2 = constantIndex(rewriter, loc, 2); - capacity = rewriter.create(loc, capacity, c2); + if (nIsOne) { + capacity = rewriter.create(loc, capacity, c2); + } else { + // Use a do-while loop to calculate the new capacity as follows: + // do { new_capacity *= 2 } while (size > new_capacity) + scf::WhileOp whileOp = + rewriter.create(loc, capacity.getType(), capacity); + + // The before-region of the WhileOp. + Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, + {capacity.getType()}, {loc}); + rewriter.setInsertionPointToEnd(before); + + capacity = + rewriter.create(loc, before->getArgument(0), c2); + cond = rewriter.create(loc, arith::CmpIPredicate::ugt, + newSize, capacity); + rewriter.create(loc, cond, ValueRange{capacity}); + // The after-region of the WhileOp. + Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, + {capacity.getType()}, {loc}); + rewriter.setInsertionPointToEnd(after); + rewriter.create(loc, after->getArguments()); + + rewriter.setInsertionPointAfter(whileOp); + capacity = whileOp.getResult(0); + } + Value newBuffer = rewriter.create(loc, bufferType, buffer, capacity); rewriter.create(loc, newBuffer); @@ -542,13 +579,17 @@ public: } // Add the value to the end of the buffer. - rewriter.create(loc, value, buffer, size); - - // Increment the size of the buffer by 1. - Value c1 = constantIndex(rewriter, loc, 1); - size = rewriter.create(loc, size, c1); - rewriter.create(loc, size, bufferSizes, idx); + if (nIsOne) { + rewriter.create(loc, value, buffer, size); + } else { + Value subBuffer = rewriter.create( + loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n}, + /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); + rewriter.create(loc, value, subBuffer); + } + // Update the buffer size. + rewriter.create(loc, newSize, bufferSizes, idx); rewriter.replaceOp(op, buffer); return success(); } diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir index 31c6ad5..114bfd8 100644 --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -7,19 +7,19 @@ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[S:.*]] = memref.dim %[[B]], %[[C0]] -// CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] -// CHECK: %[[T:.*]] = arith.cmpi uge, %[[P]], %[[S]] +// CHECK: %[[P1:.*]] = memref.dim %[[B]], %[[C0]] +// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] +// CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[C1]] : index +// CHECK: %[[T:.*]] = arith.cmpi ugt, %[[S2]], %[[P1]] // CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref) { -// CHECK: %[[P1:.*]] = arith.muli %[[S]], %[[C2]] -// CHECK: %[[M2:.*]] = memref.realloc %[[B]](%[[P1]]) +// CHECK: %[[P2:.*]] = arith.muli %[[P1]], %[[C2]] +// CHECK: %[[M2:.*]] = memref.realloc %[[B]](%[[P2]]) // CHECK: scf.yield %[[M2]] : memref // CHECK: } else { // CHECK: scf.yield %[[B]] : memref // CHECK: } -// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[P]]] -// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]] -// CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]] +// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[S1]]] +// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]] // CHECK: return %[[M]] : memref func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 @@ -28,16 +28,52 @@ func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: // ----- +// CHECK-LABEL: func @sparse_push_back_n( +// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[B:.*]]: memref, +// CHECK-SAME: %[[C:.*]]: f64, +// CHECK-SAME: %[[D:.*]]: index) -> memref { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[P1:.*]] = memref.dim %[[B]], %[[C0]] +// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] +// CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[D]] : index +// CHECK: %[[T:.*]] = arith.cmpi ugt, %[[S2]], %[[P1]] +// CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref) { +// CHECK: %[[P2:.*]] = scf.while (%[[I:.*]] = %[[P1]]) : (index) -> index { +// CHECK: %[[P3:.*]] = arith.muli %[[I]], %[[C2]] : index +// CHECK: %[[T2:.*]] = arith.cmpi ugt, %[[S2]], %[[P3]] : index +// CHECK: scf.condition(%[[T2]]) %[[P3]] : index +// CHECK: } do { +// CHECK: ^bb0(%[[I2:.*]]: index): +// CHECK: scf.yield %[[I2]] : index +// CHECK: } +// CHECK: %[[M2:.*]] = memref.realloc %[[B]](%[[P2]]) +// CHECK: scf.yield %[[M2]] : memref +// CHECK: } else { +// CHECK: scf.yield %[[B]] : memref +// CHECK: } +// CHECK: %[[S:.*]] = memref.subview %[[M]]{{\[}}%[[S1]]] {{\[}}%[[D]]] [1] +// CHECK: linalg.fill ins(%[[C]] : f64) outs(%[[S]] +// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]] +// CHECK: return %[[M]] : memref +func.func @sparse_push_back_n(%arg0: memref, %arg1: memref, %arg2: f64, %arg3: index) -> memref { + %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 {idx = 2 : index} : memref, memref, f64, index + return %0 : memref +} + +// ----- + // CHECK-LABEL: func @sparse_push_back_inbound( // CHECK-SAME: %[[A:.*]]: memref, // CHECK-SAME: %[[B:.*]]: memref, // CHECK-SAME: %[[C:.*]]: f64) -> memref { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] -// CHECK: memref.store %[[C]], %[[B]]{{\[}}%[[P]]] -// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]] -// CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]] +// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] +// CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[C1]] +// CHECK: memref.store %[[C]], %[[B]]{{\[}}%[[S1]]] +// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]] // CHECK: return %[[B]] : memref func.func @sparse_push_back_inbound(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir index 90d1b37..6a88e29 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir @@ -8,6 +8,7 @@ module { func.func @entry() { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index %d0 = arith.constant 0.0 : f32 %d1 = arith.constant 1.0 : f32 %d2 = arith.constant 2.0 : f32 @@ -17,15 +18,19 @@ module { memref.store %c0, %bufferSizes[%c0] : memref %buffer2 = sparse_tensor.push_back %bufferSizes, %buffer, %d2 {idx=0 : index} : memref, memref, f32 - %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1 {idx=0 : index} : memref, memref, f32 + %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1, %c10 {idx=0 : index} : memref, memref, f32, index - // CHECK: ( 2 ) - %sizeValue = vector.transfer_read %bufferSizes[%c0], %c0: memref, vector<1xindex> - vector.print %sizeValue : vector<1xindex> + // CHECK: 16 + %capacity = memref.dim %buffer3, %c0 : memref + vector.print %capacity : index - // CHECK ( 2, 1 ) - %bufferValue = vector.transfer_read %buffer3[%c0], %d0: memref, vector<2xf32> - vector.print %bufferValue : vector<2xf32> + // CHECK: ( 11 ) + %size = vector.transfer_read %bufferSizes[%c0], %c0: memref, vector<1xindex> + vector.print %size : vector<1xindex> + + // CHECK ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) + %values = vector.transfer_read %buffer3[%c0], %d0: memref, vector<11xf32> + vector.print %values : vector<11xf32> // Release the buffers. memref.dealloc %bufferSizes : memref