[mlir][linalg][bufferize][NFC] Reduce code duplication around bufferizableInPlaceAnalysis
authorMatthias Springer <springerm@google.com>
Wed, 13 Oct 2021 00:08:20 +0000 (09:08 +0900)
committerMatthias Springer <springerm@google.com>
Wed, 13 Oct 2021 00:08:58 +0000 (09:08 +0900)
Differential Revision: https://reviews.llvm.org/D111380

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

index 7acff26..ead888f 100644 (file)
@@ -2237,6 +2237,37 @@ static LogicalResult bufferize(OpBuilder &b, tensor::ExtractOp extractOp,
 // Bufferization analyses.
 //===----------------------------------------------------------------------===//
 
+/// Determine if `operand` can be bufferized in-place with `result`. If so, set
+/// InPlaceSpec::True on the result. Otherwise, set InPlaceSpec::False on the
+/// result.
+static LogicalResult
+bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result,
+                                BufferizationAliasInfo &aliasInfo,
+                                const DominanceInfo &domInfo) {
+  assert(getAliasingOpOperand(result) == &operand &&
+         "operand and result do not match");
+
+  int64_t resultNumber = result.getResultNumber();
+  (void)resultNumber;
+  LDBG('\n');
+  LDBG("Inplace analysis for <- #" << resultNumber << " -> #"
+                                   << operand.getOperandNumber() << " in "
+                                   << printValueInfo(result) << '\n');
+
+  bool foundInterference =
+      aliasInfo.wouldCreateWriteToNonWritableBuffer(result) ||
+      aliasInfo.wouldCreateReadAfterWriteInterference(result, domInfo);
+
+  if (foundInterference)
+    aliasInfo.bufferizeOutOfPlace(result);
+  else
+    aliasInfo.bufferizeInPlace(result, operand);
+
+  LDBG("Done inplace analysis for result #" << resultNumber << '\n');
+
+  return success();
+}
+
 ///
 /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace.
 /// ===========================================================
@@ -2255,27 +2286,9 @@ static LogicalResult
 bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp,
                             BufferizationAliasInfo &aliasInfo,
                             const DominanceInfo &domInfo) {
-  LDBG('\n');
-  LDBG("Inplace analysis for extract_slice: "
-       << printOperationInfo(extractSliceOp) << '\n');
-
-  OpResult r = extractSliceOp->getResult(0);
-  OpOperand &s = extractSliceOp->getOpOperand(0);
-  bool foundInterference =
-      /* If `extractSliceOp` were to be bufferized inplace, it cannot end up
-         aliasing a write into a non-writable buffer.*/
-      aliasInfo.wouldCreateWriteToNonWritableBuffer(r) ||
-      /* In any of extractSliceOp.result's aliases, can we find 2 such that we
-         hit an interfering write? */
-      aliasInfo.wouldCreateReadAfterWriteInterference(r, domInfo);
-  if (foundInterference)
-    aliasInfo.bufferizeOutOfPlace(r);
-  else
-    aliasInfo.bufferizeInPlace(r, s);
-
-  LDBG("Done inplace analysis for extract_slice\n");
-
-  return success();
+  return bufferizableInPlaceAnalysisImpl(extractSliceOp->getOpOperand(0),
+                                         extractSliceOp->getOpResult(0),
+                                         aliasInfo, domInfo);
 }
 
 /// Determine if `operand` can be bufferized in-place with one of the op's
@@ -2288,33 +2301,7 @@ bufferizableInPlaceAnalysis(OpOperand &operand,
   OpResult result = getInplaceableOpResult(operand);
   if (!result)
     return success();
-
-  Operation *op = result.getDefiningOp();
-  assert(result && !isa<ExtractSliceOp>(op) &&
-         "expected OpResult not coming from a ExtractSliceOp");
-  (void)op;
-
-  int64_t resultNumber = result.getResultNumber();
-  (void)resultNumber;
-  LDBG('\n');
-  LDBG("Inplace analysis for <- #" << resultNumber << " -> #"
-                                   << operand.getOperandNumber() << " in "
-                                   << printValueInfo(result) << '\n');
-
-  bool foundInterference =
-      aliasInfo.wouldCreateWriteToNonWritableBuffer(result) ||
-      aliasInfo.wouldCreateReadAfterWriteInterference(result, domInfo);
-
-  if (foundInterference)
-    aliasInfo.bufferizeOutOfPlace(result);
-  else
-    // TODO: Atm, all inplace bufferizations yield equivalent tensors. Support
-    // more cases on a per-need basis.
-    aliasInfo.bufferizeInPlace(result, operand);
-
-  LDBG("Done inplace analysis for result #" << resultNumber << '\n');
-
-  return success();
+  return bufferizableInPlaceAnalysisImpl(operand, result, aliasInfo, domInfo);
 }
 
 /// Analyze the `ops` to determine which OpResults are inplaceable: