[mlir][linalg][bufferize][NFC] Simplify getAliasingOpResult()
authorMatthias Springer <springerm@google.com>
Thu, 7 Oct 2021 13:39:52 +0000 (22:39 +0900)
committerMatthias Springer <springerm@google.com>
Thu, 7 Oct 2021 13:41:21 +0000 (22:41 +0900)
The signature of this function was confusing. Check for hasKnownBufferizationAliasingBehavior separately when needed.

Differential Revision: https://reviews.llvm.org/D110916

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

index 8d4ef27..0c4ded7 100644 (file)
@@ -583,39 +583,34 @@ static Optional<OpOperand *> getAliasingOpOperand(OpResult result) {
       });
 }
 
+/// If the an ExtractSliceOp is bufferized in-place, the source operand will
+/// alias with the result.
+static OpResult getAliasingOpResult(ExtractSliceOp op, OpOperand &opOperand) {
+  if (op.source() == opOperand.get())
+    return op->getResult(0);
+  return OpResult();
+}
+
 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
 /// in place. This is a superset of `getInplaceableOpResult`.
-/// Return None if the owner of `opOperand` does not have known
-/// bufferization aliasing behavior, which indicates that the op must allocate
-/// all of its tensor results.
 /// TODO: in the future this may need to evolve towards a list of OpResult.
-static Optional<OpResult> getAliasingOpResult(OpOperand &opOperand) {
-  if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner()))
-    return None;
+static OpResult getAliasingOpResult(OpOperand &opOperand) {
   return TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
-      // These terminators legitimately have no result.
-      .Case<ReturnOp, linalg::InitTensorOp, linalg::YieldOp, scf::YieldOp>(
-          [&](auto op) { return OpResult(); })
-      // DimOp has no tensor result.
-      .Case<tensor::DimOp>([&](auto op) { return None; })
-      // ConstantOp is never inplaceable.
-      .Case([&](ConstantOp op) { return op->getResult(0); })
       // ExtractSliceOp is different: its result is not inplaceable on op.source
       // but when bufferized inplace, the result is an aliasing subregion of
       // op.source.
-      .Case([&](ExtractSliceOp op) { return op->getResult(0); })
-      // All other ops, including scf::ForOp, return the result of
-      // `getInplaceableOpResult`.
+      .Case(
+          [&](ExtractSliceOp op) { return getAliasingOpResult(op, opOperand); })
+      // All other ops, return the result of `getInplaceableOpResult`.
       .Default(
           [&](Operation *op) { return getInplaceableOpResult(opOperand); });
 }
 
 /// Return true if `opOperand` bufferizes to a memory read.
 static bool bufferizesToMemoryRead(OpOperand &opOperand) {
-  Optional<OpResult> maybeOpResult = getAliasingOpResult(opOperand);
   // Unknown op that returns a tensor. The inplace analysis does not support
   // it. Conservatively return true.
-  if (!maybeOpResult)
+  if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner()))
     return true;
   // ExtractSliceOp alone doesn't bufferize to a memory read, one of its uses
   // may.
@@ -672,19 +667,19 @@ bufferizesToMemoryWrite(OpOperand &opOperand,
   // conservative.
   if (auto callOp = dyn_cast<CallOpInterface>(opOperand.getOwner()))
     return true;
-  Optional<OpResult> maybeOpResult = getAliasingOpResult(opOperand);
   // Unknown op that returns a tensor. The inplace analysis does not support
   // it. Conservatively return true.
-  if (!maybeOpResult)
+  if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner()))
     return true;
+  OpResult opResult = getAliasingOpResult(opOperand);
   // Supported op without a matching result for opOperand (e.g. ReturnOp).
   // This does not bufferize to a write.
-  if (!*maybeOpResult)
+  if (!opResult)
     return false;
   // If we have a matching OpResult, this is a write.
   // Additionally allow to restrict to only inPlace write, if so specified.
   return inPlaceSpec == InPlaceSpec::None ||
-         getInPlace(*maybeOpResult) == inPlaceSpec;
+         getInPlace(opResult) == inPlaceSpec;
 }
 
 //===----------------------------------------------------------------------===//