[mlir][Linalg] NFC - Add support to specify that a tensor value is known to bufferize...
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 5 Oct 2021 06:58:53 +0000 (06:58 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 5 Oct 2021 08:37:34 +0000 (08:37 +0000)
This change allows better interop with external clients of comprehensive bufferization functions
but is otherwise NFC for the MLIR pass itself.

Differential Revision: https://reviews.llvm.org/D111121

mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

index e713fd9..82ada9f 100644 (file)
@@ -57,9 +57,9 @@ public:
   void insertNewBufferEquivalence(Value newValue, Value alias);
 
   /// Return true if the buffer to which `operand` would bufferize aliases a
-  /// buffer that is known to not be writeable. This implies that the matching
+  /// buffer that is known to not be writable. This implies that the matching
   /// OpResult cannot be bufferized inplace.
-  bool aliasesNonWriteableBuffer(OpOperand &operand) const;
+  bool aliasesNonWritableBuffer(OpOperand &operand) const;
 
   /// Return true if the buffer to which `operand` would bufferize is equivalent
   /// to some buffer write.
@@ -124,6 +124,12 @@ public:
   /// Apply `fun` to all the members of the equivalence class of `v`.
   void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
 
+  /// Return true if the value is known to bufferize to writable memory.
+  bool bufferizesToWritableMemory(Value v) const;
+
+  /// Specify that the value is known to bufferize to writable memory.
+  void setBufferizesToWritableMemory(Value v);
+
   /// Print to `os`.
   void printAliases(raw_ostream &os) const;
   void printEquivalences(raw_ostream &os) const;
@@ -210,6 +216,9 @@ private:
                                   OpOperand &aliasingWrite,
                                   const DominanceInfo &domInfo) const;
 
+  /// Set of tensors that are known to bufferize to writable memory.
+  llvm::DenseSet<Value> bufferizeToWritableMemory;
+
   /// Auxiliary structure to store all the values a given value aliases with.
   /// These are the conservative cases that can further decompose into
   /// "equivalent" buffer relationships.
index 79fa998..6b5f624 100644 (file)
@@ -726,15 +726,21 @@ void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
 }
 
 /// Return true if the buffer to which `operand` would bufferize aliases a
-/// buffer that is known to not be writeable. This implies that the matching
+/// buffer that is known to not be writable. This implies that the matching
 /// OpResult cannot be bufferized inplace.
-bool BufferizationAliasInfo::aliasesNonWriteableBuffer(
+bool BufferizationAliasInfo::aliasesNonWritableBuffer(
     OpOperand &operand) const {
-  LDBG("----Start aliasesNonWriteableBuffer\n");
+  LDBG("----Start aliasesNonWritableBuffer\n");
   LDBG("-------for -> #" << operand.getOperandNumber() << ": "
                          << printOperationInfo(operand.getOwner()) << '\n');
   for (Value v : getAliases(operand.get())) {
     LDBG("-----------examine: " << printValueInfo(v) << '\n');
+    if (bufferizesToWritableMemory(v)) {
+      LDBG("-----------Value is known to be writeable -> skip: "
+           << printValueInfo(v) << '\n');
+      continue;
+    }
+
     if (auto bbArg = v.dyn_cast<BlockArgument>()) {
       if (getInPlace(bbArg) == InPlaceSpec::True) {
         LDBG("-----------bbArg is writeable -> skip: " << printValueInfo(bbArg)
@@ -747,15 +753,24 @@ bool BufferizationAliasInfo::aliasesNonWriteableBuffer(
 
     if (Operation *op = v.getDefiningOp()) {
       if (isa<ConstantOp>(op) || !hasKnownBufferizationAliasingBehavior(op)) {
-        LDBG("-----------notWriteable\n");
+        LDBG("-----------notWritable\n");
         return true;
       }
     }
   }
-  LDBG("---->operand is writeable\n");
+  LDBG("---->operand is writable\n");
   return false;
 }
 
+bool BufferizationAliasInfo::bufferizesToWritableMemory(Value v) const {
+  return bufferizeToWritableMemory.count(v) > 0;
+}
+
+/// Specify that the value is known to bufferize to writable memory.
+void BufferizationAliasInfo::setBufferizesToWritableMemory(Value v) {
+  bufferizeToWritableMemory.insert(v);
+}
+
 /// Return true if the buffer to which `operand` would bufferize is equivalent
 /// to some buffer write.
 bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const {
@@ -2184,22 +2199,22 @@ bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp,
        << printOperationInfo(extractSliceOp) << '\n');
 
   // If `extractSliceOp` were to be bufferized inplace, it cannot end up
-  // aliasing a write into a non-writeable buffer.
-  bool wouldCreateAliasingWriteToNonWriteableBuffer =
+  // aliasing a write into a non-writable buffer.
+  bool wouldCreateAliasingWriteToNonWritableBuffer =
       aliasInfo.aliasesInPlaceWrite(extractSliceOp.result()) &&
-      aliasInfo.aliasesNonWriteableBuffer(extractSliceOp->getOpOperand(0));
+      aliasInfo.aliasesNonWritableBuffer(extractSliceOp->getOpOperand(0));
 
-  if (wouldCreateAliasingWriteToNonWriteableBuffer)
-    LDBG("->the corresponding buffer is not writeable\n");
+  if (wouldCreateAliasingWriteToNonWritableBuffer)
+    LDBG("->the corresponding buffer is not writable\n");
   else
-    LDBG("->bufferizes to writeable inplace buffer\n");
+    LDBG("->bufferizes to writable inplace buffer\n");
 
   // In any of extractSliceOp.result's aliases, can we find 2 such that we hit
   // an interfering write?
   OpResult r = extractSliceOp->getResult(0);
   OpOperand &s = extractSliceOp->getOpOperand(0);
   bool foundInterference =
-      wouldCreateAliasingWriteToNonWriteableBuffer ||
+      wouldCreateAliasingWriteToNonWritableBuffer ||
       aliasInfo.wouldCreateReadAfterWriteInterference(r, domInfo);
   if (foundInterference)
     aliasInfo.bufferizeOutOfPlace(r);
@@ -2230,21 +2245,21 @@ bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result,
                                    << operand.getOperandNumber() << " in "
                                    << printValueInfo(result) << '\n');
 
-  // `result` must bufferize to a writeable buffer to be a candidate.
+  // `result` must bufferize to a writable buffer to be a candidate.
   // This means the operand  must not alias either:
   //   1. a function bbArg that is not inplaceable or
   //   2. a constant op.
   // to be considered for inplace bufferization
-  bool wouldCreateAliasingWriteToNonWriteableBuffer =
-      aliasInfo.aliasesNonWriteableBuffer(operand);
-  if (wouldCreateAliasingWriteToNonWriteableBuffer)
-    LDBG("->the corresponding buffer is not writeable\n");
+  bool wouldCreateAliasingWriteToNonWritableBuffer =
+      aliasInfo.aliasesNonWritableBuffer(operand);
+  if (wouldCreateAliasingWriteToNonWritableBuffer)
+    LDBG("->the corresponding buffer is not writable\n");
   else
-    LDBG("->bufferizes to writeable inplace buffer\n");
+    LDBG("->bufferizes to writable inplace buffer\n");
 
   assert(result == getInplaceableOpResult(operand));
   bool foundInterference =
-      wouldCreateAliasingWriteToNonWriteableBuffer ||
+      wouldCreateAliasingWriteToNonWritableBuffer ||
       aliasInfo.wouldCreateReadAfterWriteInterference(result, domInfo);
 
   if (foundInterference)
@@ -2312,6 +2327,15 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
     ops.push_back(op);
   });
 
+  // Set the function arguments marked with inplaceable to be known as
+  // bufferizing to a writeable memory.
+  for (BlockArgument bbArg : funcOp.getArguments()) {
+    BoolAttr inplaceAttr = funcOp.getArgAttrOfType<BoolAttr>(
+        bbArg.getArgNumber(), LinalgDialect::kInplaceableAttrName);
+    if (inplaceAttr && inplaceAttr.getValue())
+      aliasInfo.setBufferizesToWritableMemory(bbArg);
+  }
+
   LogicalResult res = inPlaceAnalysis(ops, aliasInfo, domInfo);
   LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n');