[mlir][linalg][bufferize][NFC] Remove InSpaceSpec from bufferizesToMemoryWrite
authorMatthias Springer <springerm@google.com>
Wed, 13 Oct 2021 00:50:43 +0000 (09:50 +0900)
committerMatthias Springer <springerm@google.com>
Wed, 13 Oct 2021 00:51:22 +0000 (09:51 +0900)
Move functionality into a separate function `isInplaceMemoryWrite`.

Differential Revision: https://reviews.llvm.org/D111040

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

index 0255b3a..a0412a9 100644 (file)
@@ -654,11 +654,7 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) {
 }
 
 /// Return true if `opOperand` bufferizes to a memory write.
-/// If inPlaceSpec is different from InPlaceSpec::None, additionally require the
-/// write to match the inplace specification.
-static bool
-bufferizesToMemoryWrite(OpOperand &opOperand,
-                        InPlaceSpec inPlaceSpec = InPlaceSpec::None) {
+static bool bufferizesToMemoryWrite(OpOperand &opOperand) {
   // These terminators are not writes.
   if (isa<ReturnOp, linalg::YieldOp, scf::YieldOp>(opOperand.getOwner()))
     return false;
@@ -677,14 +673,9 @@ bufferizesToMemoryWrite(OpOperand &opOperand,
   if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner()))
     return true;
   OpResult opResult = getAliasingOpResult(opOperand);
-  // Supported op without a matching result for opOperand (e.g. ReturnOp).
-  // This does not bufferize to a write.
-  if (!opResult)
-    return false;
-  // If we have a matching OpResult, this is a write.
-  // Additionally allow to restrict to only inPlace write, if so specified.
-  return inPlaceSpec == InPlaceSpec::None ||
-         getInPlace(opResult) == inPlaceSpec;
+  // Only supported op with a matching result for opOperand bufferize to a
+  // write. E.g., ReturnOp does not bufferize to a write.
+  return static_cast<bool>(opResult);
 }
 
 /// Returns the relationship between the operand and the its corresponding
@@ -701,6 +692,15 @@ static BufferRelation bufferRelation(OpOperand &operand) {
 // Bufferization-specific alias analysis.
 //===----------------------------------------------------------------------===//
 
+/// Return true if opOperand has been decided to bufferize in-place.
+static bool isInplaceMemoryWrite(OpOperand &opOperand) {
+  // Ops that do not bufferize to a memory write, cannot be write in-place.
+  if (!bufferizesToMemoryWrite(opOperand))
+    return false;
+  OpResult opResult = getAliasingOpResult(opOperand);
+  return opResult && getInPlace(opResult) == InPlaceSpec::True;
+}
+
 BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
   rootOp->walk([&](Operation *op) {
     for (Value v : op->getResults())
@@ -785,7 +785,7 @@ bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const {
   LDBG("-------for : " << printValueInfo(value) << '\n');
   for (Value v : getAliases(value)) {
     for (auto &use : v.getUses()) {
-      if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) {
+      if (isInplaceMemoryWrite(use)) {
         LDBG("-----------wants to bufferize to inPlace write: "
              << printOperationInfo(use.getOwner()) << '\n');
         return true;
@@ -914,7 +914,7 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
     for (Value alias : getAliases(root)) {
       for (auto &use : alias.getUses()) {
         // Inplace write to a value that aliases root.
-        if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) {
+        if (isInplaceMemoryWrite(use)) {
           LDBG("------------bufferizesToMemoryWrite: "
                << use.getOwner()->getName().getStringRef() << "\n");
           res.insert(&use);
@@ -1135,6 +1135,7 @@ bool BufferizationAliasInfo::existsInterleavedValueClobber(
       SmallVector<OpOperand *> operands =
           getAliasingOpOperand(mit->v.cast<OpResult>());
       assert(operands.size() <= 1 && "more than 1 OpOperand not supported yet");
+      // TODO: Should we check for isInplaceMemoryWrite instead?
       if (operands.empty() || !bufferizesToMemoryWrite(*operands.front()))
         continue;
       LDBG("---->clobbering candidate: " << printOperationInfo(candidateOp)