From 248e113e9f6e583ed93e52de621a89d098c6d79e Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 11 May 2022 13:55:58 +0200 Subject: [PATCH] [mlir][bufferize][NFC] Move helper functions to BufferizationOptions Move helper functions for creating allocs/deallocs/memcpys to BufferizationOptions. Differential Revision: https://reviews.llvm.org/D125375 --- .../Bufferization/IR/BufferizableOpInterface.h | 22 +++++++----- .../Bufferization/IR/BufferizableOpInterface.cpp | 40 ++++++++++------------ .../Transforms/FuncBufferizableOpInterfaceImpl.cpp | 2 +- .../SCF/Transforms/BufferizableOpInterfaceImpl.cpp | 4 +-- .../Transforms/BufferizableOpInterfaceImpl.cpp | 8 ++--- 5 files changed, 38 insertions(+), 38 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 7d80e47..421db92 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -161,6 +161,19 @@ struct BufferizationOptions { Optional deallocationFn; Optional memCpyFn; + /// Create a memref allocation with the given type and dynamic extents. + FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, + ValueRange dynShape) const; + + /// Creates a memref deallocation. The given memref buffer must have been + /// allocated using `createAlloc`. + LogicalResult createDealloc(OpBuilder &b, Location loc, + Value allocatedBuffer) const; + + /// Creates a memcpy between two given buffers. + LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, + Value to) const; + /// Specifies whether not bufferizable ops are allowed in the input. If so, /// bufferization.to_memref and bufferization.to_tensor ops are inserted at /// the boundaries. @@ -514,15 +527,6 @@ BaseMemRefType getMemRefType(TensorType tensorType, MemRefLayoutAttrInterface layout = {}, Attribute memorySpace = {}); -/// Creates a memref deallocation. The given memref buffer must have been -/// allocated using `createAlloc`. -LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, - const BufferizationOptions &options); - -/// Creates a memcpy between two given buffers. -LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to, - const BufferizationOptions &options); - /// Try to hoist all new buffer allocations until the next hoisting barrier. LogicalResult hoistBufferAllocations(Operation *op, const BufferizationOptions &options); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 8a3bc5c..29d983b 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -327,8 +327,7 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand, // The copy happens right before the op that is bufferized. rewriter.setInsertionPoint(op); } - if (failed( - createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options))) + if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer))) return failure(); return resultBuffer; @@ -418,26 +417,24 @@ bool AlwaysCopyAnalysisState::isTensorYielded(Value tensor) const { //===----------------------------------------------------------------------===// /// Create a memref allocation with the given type and dynamic extents. -static FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, - ValueRange dynShape, - const BufferizationOptions &options) { - if (options.allocationFn) - return (*options.allocationFn)(b, loc, type, dynShape, - options.bufferAlignment); +FailureOr BufferizationOptions::createAlloc(OpBuilder &b, Location loc, + MemRefType type, + ValueRange dynShape) const { + if (allocationFn) + return (*allocationFn)(b, loc, type, dynShape, bufferAlignment); // Default bufferallocation via AllocOp. Value allocated = b.create( - loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment)); + loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment)); return allocated; } /// Creates a memref deallocation. The given memref buffer must have been /// allocated using `createAlloc`. -LogicalResult -bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, - const BufferizationOptions &options) { - if (options.deallocationFn) - return (*options.deallocationFn)(b, loc, allocatedBuffer); +LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc, + Value allocatedBuffer) const { + if (deallocationFn) + return (*deallocationFn)(b, loc, allocatedBuffer); // Default buffer deallocation via DeallocOp. b.create(loc, allocatedBuffer); @@ -523,11 +520,10 @@ FailureOr BufferizationState::createAlloc(OpBuilder &b, Location loc, } /// Create a memory copy between two memref buffers. -LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc, - Value from, Value to, - const BufferizationOptions &options) { - if (options.memCpyFn) - return (*options.memCpyFn)(b, loc, from, to); +LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, + Value from, Value to) const { + if (memCpyFn) + return (*memCpyFn)(b, loc, from, to); b.create(loc, from, to); return success(); @@ -557,8 +553,8 @@ bufferization::createAllocDeallocOps(Operation *op, Block *block = allocaOp->getBlock(); rewriter.setInsertionPoint(allocaOp); FailureOr alloc = - createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(), - allocaOp.dynamicSizes(), options); + options.createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(), + allocaOp.dynamicSizes()); if (failed(alloc)) return WalkResult::interrupt(); rewriter.replaceOp(allocaOp, *alloc); @@ -571,7 +567,7 @@ bufferization::createAllocDeallocOps(Operation *op, // Create dealloc. rewriter.setInsertionPoint(block->getTerminator()); - if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options))) + if (failed(options.createDealloc(rewriter, alloc->getLoc(), *alloc))) return WalkResult::interrupt(); return WalkResult::advance(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index 25d3df2..dda4614 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -495,7 +495,7 @@ struct FuncOpInterface // Note: This copy will fold away. It must be inserted here to ensure // that `returnVal` still has at least one use and does not fold away. if (failed( - createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options))) + options.createMemCpy(rewriter, loc, toMemrefOp, equivBbArg))) return funcOp->emitError("could not generate copy for bbArg"); continue; } diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 39af7d3..337b0aa 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -363,8 +363,8 @@ static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor, // TODO: We should rollback, but for now just assume that this always // succeeds. assert(yieldedAlloc.hasValue() && "could not create alloc"); - LogicalResult copyStatus = bufferization::createMemCpy( - rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc, state.getOptions()); + LogicalResult copyStatus = state.getOptions().createMemCpy( + rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc); (void)copyStatus; assert(succeeded(copyStatus) && "could not create memcpy"); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index efd2de7..3ec52e9 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -320,8 +320,8 @@ struct ExtractSliceOpInterface if (!inplace) { // Do not copy if the copied data is never read. if (state.getAnalysisState().isValueRead(extractSliceOp.result())) - if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, - alloc, state.getOptions()))) + if (failed(state.getOptions().createMemCpy( + rewriter, extractSliceOp.getLoc(), subView, alloc))) return failure(); subView = alloc; } @@ -718,8 +718,8 @@ struct InsertSliceOpInterface // tensor.extract_slice, the copy operation will eventually fold away. auto srcMemref = state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); - if (failed(srcMemref) || failed(createMemCpy(rewriter, loc, *srcMemref, - subView, state.getOptions()))) + if (failed(srcMemref) || failed(state.getOptions().createMemCpy( + rewriter, loc, *srcMemref, subView))) return failure(); replaceOpWithBufferizedValues(rewriter, op, *dstMemref); -- 2.7.4