From 8096759519f2930dc9ba2ceddada1100fee34a0b Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 5 Oct 2021 06:58:53 +0000 Subject: [PATCH] [mlir][Linalg] NFC - Add support to specify that a tensor value is known to bufferize to writeable memory This change allows better interop with external clients of comprehensive bufferization functions but is otherwise NFC for the MLIR pass itself. Differential Revision: https://reviews.llvm.org/D111121 --- .../Linalg/Transforms/ComprehensiveBufferize.h | 13 ++++- .../Linalg/Transforms/ComprehensiveBufferize.cpp | 62 +++++++++++++++------- 2 files changed, 54 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h index e713fd9..82ada9f 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -57,9 +57,9 @@ public: void insertNewBufferEquivalence(Value newValue, Value alias); /// Return true if the buffer to which `operand` would bufferize aliases a - /// buffer that is known to not be writeable. This implies that the matching + /// buffer that is known to not be writable. This implies that the matching /// OpResult cannot be bufferized inplace. - bool aliasesNonWriteableBuffer(OpOperand &operand) const; + bool aliasesNonWritableBuffer(OpOperand &operand) const; /// Return true if the buffer to which `operand` would bufferize is equivalent /// to some buffer write. @@ -124,6 +124,12 @@ public: /// Apply `fun` to all the members of the equivalence class of `v`. void applyOnEquivalenceClass(Value v, function_ref fun) const; + /// Return true if the value is known to bufferize to writable memory. + bool bufferizesToWritableMemory(Value v) const; + + /// Specify that the value is known to bufferize to writable memory. + void setBufferizesToWritableMemory(Value v); + /// Print to `os`. void printAliases(raw_ostream &os) const; void printEquivalences(raw_ostream &os) const; @@ -210,6 +216,9 @@ private: OpOperand &aliasingWrite, const DominanceInfo &domInfo) const; + /// Set of tensors that are known to bufferize to writable memory. + llvm::DenseSet bufferizeToWritableMemory; + /// Auxiliary structure to store all the values a given value aliases with. /// These are the conservative cases that can further decompose into /// "equivalent" buffer relationships. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp index 79fa998..6b5f624 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -726,15 +726,21 @@ void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue, } /// Return true if the buffer to which `operand` would bufferize aliases a -/// buffer that is known to not be writeable. This implies that the matching +/// buffer that is known to not be writable. This implies that the matching /// OpResult cannot be bufferized inplace. -bool BufferizationAliasInfo::aliasesNonWriteableBuffer( +bool BufferizationAliasInfo::aliasesNonWritableBuffer( OpOperand &operand) const { - LDBG("----Start aliasesNonWriteableBuffer\n"); + LDBG("----Start aliasesNonWritableBuffer\n"); LDBG("-------for -> #" << operand.getOperandNumber() << ": " << printOperationInfo(operand.getOwner()) << '\n'); for (Value v : getAliases(operand.get())) { LDBG("-----------examine: " << printValueInfo(v) << '\n'); + if (bufferizesToWritableMemory(v)) { + LDBG("-----------Value is known to be writeable -> skip: " + << printValueInfo(v) << '\n'); + continue; + } + if (auto bbArg = v.dyn_cast()) { if (getInPlace(bbArg) == InPlaceSpec::True) { LDBG("-----------bbArg is writeable -> skip: " << printValueInfo(bbArg) @@ -747,15 +753,24 @@ bool BufferizationAliasInfo::aliasesNonWriteableBuffer( if (Operation *op = v.getDefiningOp()) { if (isa(op) || !hasKnownBufferizationAliasingBehavior(op)) { - LDBG("-----------notWriteable\n"); + LDBG("-----------notWritable\n"); return true; } } } - LDBG("---->operand is writeable\n"); + LDBG("---->operand is writable\n"); return false; } +bool BufferizationAliasInfo::bufferizesToWritableMemory(Value v) const { + return bufferizeToWritableMemory.count(v) > 0; +} + +/// Specify that the value is known to bufferize to writable memory. +void BufferizationAliasInfo::setBufferizesToWritableMemory(Value v) { + bufferizeToWritableMemory.insert(v); +} + /// Return true if the buffer to which `operand` would bufferize is equivalent /// to some buffer write. bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const { @@ -2184,22 +2199,22 @@ bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp, << printOperationInfo(extractSliceOp) << '\n'); // If `extractSliceOp` were to be bufferized inplace, it cannot end up - // aliasing a write into a non-writeable buffer. - bool wouldCreateAliasingWriteToNonWriteableBuffer = + // aliasing a write into a non-writable buffer. + bool wouldCreateAliasingWriteToNonWritableBuffer = aliasInfo.aliasesInPlaceWrite(extractSliceOp.result()) && - aliasInfo.aliasesNonWriteableBuffer(extractSliceOp->getOpOperand(0)); + aliasInfo.aliasesNonWritableBuffer(extractSliceOp->getOpOperand(0)); - if (wouldCreateAliasingWriteToNonWriteableBuffer) - LDBG("->the corresponding buffer is not writeable\n"); + if (wouldCreateAliasingWriteToNonWritableBuffer) + LDBG("->the corresponding buffer is not writable\n"); else - LDBG("->bufferizes to writeable inplace buffer\n"); + LDBG("->bufferizes to writable inplace buffer\n"); // In any of extractSliceOp.result's aliases, can we find 2 such that we hit // an interfering write? OpResult r = extractSliceOp->getResult(0); OpOperand &s = extractSliceOp->getOpOperand(0); bool foundInterference = - wouldCreateAliasingWriteToNonWriteableBuffer || + wouldCreateAliasingWriteToNonWritableBuffer || aliasInfo.wouldCreateReadAfterWriteInterference(r, domInfo); if (foundInterference) aliasInfo.bufferizeOutOfPlace(r); @@ -2230,21 +2245,21 @@ bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result, << operand.getOperandNumber() << " in " << printValueInfo(result) << '\n'); - // `result` must bufferize to a writeable buffer to be a candidate. + // `result` must bufferize to a writable buffer to be a candidate. // This means the operand must not alias either: // 1. a function bbArg that is not inplaceable or // 2. a constant op. // to be considered for inplace bufferization - bool wouldCreateAliasingWriteToNonWriteableBuffer = - aliasInfo.aliasesNonWriteableBuffer(operand); - if (wouldCreateAliasingWriteToNonWriteableBuffer) - LDBG("->the corresponding buffer is not writeable\n"); + bool wouldCreateAliasingWriteToNonWritableBuffer = + aliasInfo.aliasesNonWritableBuffer(operand); + if (wouldCreateAliasingWriteToNonWritableBuffer) + LDBG("->the corresponding buffer is not writable\n"); else - LDBG("->bufferizes to writeable inplace buffer\n"); + LDBG("->bufferizes to writable inplace buffer\n"); assert(result == getInplaceableOpResult(operand)); bool foundInterference = - wouldCreateAliasingWriteToNonWriteableBuffer || + wouldCreateAliasingWriteToNonWritableBuffer || aliasInfo.wouldCreateReadAfterWriteInterference(result, domInfo); if (foundInterference) @@ -2312,6 +2327,15 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, ops.push_back(op); }); + // Set the function arguments marked with inplaceable to be known as + // bufferizing to a writeable memory. + for (BlockArgument bbArg : funcOp.getArguments()) { + BoolAttr inplaceAttr = funcOp.getArgAttrOfType( + bbArg.getArgNumber(), LinalgDialect::kInplaceableAttrName); + if (inplaceAttr && inplaceAttr.getValue()) + aliasInfo.setBufferizesToWritableMemory(bbArg); + } + LogicalResult res = inPlaceAnalysis(ops, aliasInfo, domInfo); LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n'); -- 2.7.4