[mlir][bufferization][NFC] Rename: "last-write" -> "definition"
authorMatthias Springer <springerm@google.com>
Mon, 30 Jan 2023 08:41:31 +0000 (09:41 +0100)
committerMatthias Springer <springerm@google.com>
Mon, 30 Jan 2023 08:51:53 +0000 (09:51 +0100)
The previous lingo was confusing. There are no writes on tensors. There are only definitions.

Also some minor cleanup and better documentation.

Differential Revision: https://reviews.llvm.org/D141790

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

index ccd2711..061f280 100644 (file)
@@ -355,7 +355,8 @@ public:
   /// traversed any further.
   ///
   /// When reaching the end of a chain (BlockArgument or Value without aliasing
-  /// OpOperands), also return the last Value of that chain.
+  /// OpOperands), also return the last Value of that chain if
+  /// `alwaysIncludeLeaves` is set.
   ///
   /// Example:
   ///
@@ -374,20 +375,41 @@ public:
   /// { 2, 7, 8, 5 }
   ///
   /// If `followEquivalentOnly` is set, only equivalent OpOperands are selected.
-  SetVector<Value>
-  findValueInReverseUseDefChain(Value value,
-                                llvm::function_ref<bool(Value)> condition,
-                                bool followEquivalentOnly = false) const;
-
-  /// Find the Values of the last preceding write of a given Value.
+  SetVector<Value> findValueInReverseUseDefChain(
+      Value value, llvm::function_ref<bool(Value)> condition,
+      bool followEquivalentOnly = false, bool alwaysIncludeLeaves = true) const;
+
+  /// Find the values that may define the contents of the given value at
+  /// runtime. A block argument is always a definition. An OpResult is a
+  /// definition if it bufferizes to memory write. If it does not bufferize to
+  /// a memory write but has aliasing operands, we continue the lookup on these
+  /// values.
+  ///
+  /// Example: %r = tensor.insert %f into %t[%c0] : tensor<?xf32>
+  /// findDefinitions(%r) = {%r} because %r bufferizes to memory write.
+  ///
+  /// Example: %r = tensor.empty() : tensor<10xf32>
+  /// findDefinitions(%r) = {} because tensor.empty does not the define the
+  /// contents of its result (i.e., it does not bufferize to a memory write)
+  /// and it has no aliasing OpOperands.
+  ///
+  /// Example:
+  /// %a = arith.constant ... : tensor<10xf32>
+  /// %b1 = tensor.insert %f into %t : tensor<50xf32>
+  /// %b2 = tensor.extract_slice %b1[0][10][1] : tensor<50xf32> tensor<10xf32>
+  /// %r = arith.select %cond, %a, %b : tensor<10xf32>
+  /// findDefinitions(%r) = {%a, %b1}. %r and %b2 are skipped (lookup continues
+  /// in the operands) because their defining ops do not define the contents of
+  /// the tensor.
   ///
-  /// Note: Unknown ops are handled conservatively and assumed to be writes.
-  /// Furthermore, BlockArguments are also assumed to be writes. There is no
-  /// analysis across block boundaries.
+  /// Note: OpResults of unknown ops are handled conservatively and assumed to
+  /// be definitions.
   ///
   /// Note: When reaching an end of the reverse SSA use-def chain, that value
-  /// is returned regardless of whether it is a memory write or not.
-  SetVector<Value> findLastPrecedingWrite(Value value) const;
+  /// is included regardless of whether it is a definition or not unless
+  /// `alwaysIncludeLeaves` is unset.
+  SetVector<Value> findDefinitions(Value value,
+                                   bool alwaysIncludeLeaves = true) const;
 
   /// Return `true` if the given OpResult has been decided to bufferize inplace.
   virtual bool isInPlace(OpOperand &opOperand) const;
index 8abb9c3..c24653d 100644 (file)
@@ -444,7 +444,7 @@ bool AnalysisState::isValueRead(Value value) const {
 // further.
 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
     Value value, llvm::function_ref<bool(Value)> condition,
-    bool followEquivalentOnly) const {
+    bool followEquivalentOnly, bool alwaysIncludeLeaves) const {
   llvm::SetVector<Value> result, workingSet;
   workingSet.insert(value);
 
@@ -469,7 +469,8 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
         (followEquivalentOnly &&
          bufferizableOp.bufferRelation(opResult, *this) !=
              BufferRelation::Equivalent)) {
-      result.insert(value);
+      if (alwaysIncludeLeaves)
+        result.insert(value);
       continue;
     }
 
@@ -480,11 +481,12 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
   return result;
 }
 
-// Find the Values of the last preceding write of a given Value.
+// Find the values that define the contents of the given value.
 llvm::SetVector<Value>
-AnalysisState::findLastPrecedingWrite(Value value) const {
+AnalysisState::findDefinitions(Value value, bool alwaysIncludeLeaves) const {
   return findValueInReverseUseDefChain(
-      value, [&](Value v) { return this->bufferizesToMemoryWrite(v); });
+      value, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
+      /*followEquivalentOnly=*/false, alwaysIncludeLeaves);
 }
 
 AnalysisState::AnalysisState(const BufferizationOptions &options)
index 2e48067..8570352 100644 (file)
@@ -270,16 +270,9 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
       if (!opResult.getType().isa<TensorType>())
         continue;
 
-      // If there is no preceding memory write, the tensor contents are
+      // If there is no preceding definition, the tensor contents are
       // undefined.
-      // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
-      // use-def chain, it returns that value, regardless of whether it is a
-      // memory write or not.
-      SetVector<Value> lastWrites = findLastPrecedingWrite(opResult);
-      bool isUndefined = llvm::none_of(lastWrites, [&](Value lastWrite) {
-        return this->bufferizesToMemoryWrite(lastWrite);
-      });
-      if (isUndefined)
+      if (findDefinitions(opResult, /*alwaysIncludeLeaves=*/false).empty())
         for (OpOperand &use : opResult.getUses())
           undefinedTensorUses.insert(&use);
     }
@@ -471,7 +464,7 @@ bool canUseOpDominance(const DenseSet<OpOperand *> &usesRead,
 
 /// Annotate IR with details about the detected RaW conflict.
 static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
-                             Value lastWrite) {
+                             Value definition) {
   static uint64_t counter = 0;
   Operation *readingOp = uRead->getOwner();
   Operation *conflictingWritingOp = uConflictingWrite->getOwner();
@@ -489,16 +482,15 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
       id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
   readingOp->setAttr(readAttr, b.getUnitAttr());
 
-  if (auto opResult = lastWrite.dyn_cast<OpResult>()) {
-    std::string lastWriteAttr = id + "[LAST-WRITE: result " +
-                                std::to_string(opResult.getResultNumber()) +
-                                "]";
-    opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr());
+  if (auto opResult = definition.dyn_cast<OpResult>()) {
+    std::string defAttr =
+        id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]";
+    opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr());
   } else {
-    auto bbArg = lastWrite.cast<BlockArgument>();
-    std::string lastWriteAttr =
-        id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
-    bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr());
+    auto bbArg = definition.cast<BlockArgument>();
+    std::string defAttr =
+        id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
+    bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr());
   }
 }
 
@@ -507,8 +499,8 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
 /// all given writes bufferize inplace.
 ///
 /// A conflict is: According to SSA use-def chains, a read R is supposed to read
-/// the result of a write W1. But because of bufferization decisions, R actually
-/// reads another write W2.
+/// the result of a definition W1. But because of bufferization decisions, R
+/// actually reads another definition W2.
 static bool hasReadAfterWriteInterference(
     const DenseSet<OpOperand *> &usesRead,
     const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
@@ -529,10 +521,10 @@ static bool hasReadAfterWriteInterference(
     // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
     // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
     //
-    // In the above example, if uRead is the OpOperand of reading_op, lastWrite
-    // is %0. Note that operations that create an alias but do not write (such
-    // as ExtractSliceOp) are skipped.
-    SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get());
+    // In the above example, if uRead is the OpOperand of reading_op, the
+    // definition is %0. Note that operations that create an alias but do not
+    // bufferize to a memory write (such as ExtractSliceOp) are skipped.
+    SetVector<Value> definitions = state.findDefinitions(uRead->get());
 
     // Look for conflicting memory writes. Potential conflicts are writes to an
     // alias that have been decided to bufferize inplace.
@@ -611,31 +603,30 @@ static bool hasReadAfterWriteInterference(
         }
       }
 
-      // Check all possible last writes.
-      for (Value lastWrite : lastWrites) {
-        LLVM_DEBUG(llvm::dbgs() << "  * lastWrite = " << lastWrite << "\n");
+      // Check all possible definitions.
+      for (Value definition : definitions) {
+        LLVM_DEBUG(llvm::dbgs() << "  * definition = " << definition << "\n");
 
-        // No conflict if the conflicting write happens before the last
-        // write.
-        if (Operation *writingOp = lastWrite.getDefiningOp()) {
+        // No conflict if the conflicting write happens before the definition.
+        if (Operation *writingOp = definition.getDefiningOp()) {
           if (happensBefore(conflictingWritingOp, writingOp, domInfo)) {
             // conflictingWritingOp happens before writingOp. No conflict.
             LLVM_DEBUG(llvm::dbgs()
-                       << "    no conflict: write happens before last write\n");
+                       << "    no conflict: write happens before definition\n");
             continue;
           }
           // No conflict if conflictingWritingOp is contained in writingOp.
           if (writingOp->isProperAncestor(conflictingWritingOp)) {
             LLVM_DEBUG(
                 llvm::dbgs()
-                << "    no conflict: write is contained in last write\n");
+                << "    no conflict: write is contained in definition\n");
             continue;
           }
         } else {
-          auto bbArg = lastWrite.cast<BlockArgument>();
+          auto bbArg = definition.cast<BlockArgument>();
           Block *block = bbArg.getOwner();
           if (!block->findAncestorOpInBlock(*conflictingWritingOp)) {
-            LLVM_DEBUG(llvm::dbgs() << "    no conflict: last write is bbArg "
+            LLVM_DEBUG(llvm::dbgs() << "    no conflict: definition is bbArg "
                                        "and write happens outside of block\n");
             // conflictingWritingOp happens outside of the block. No
             // conflict.
@@ -643,20 +634,20 @@ static bool hasReadAfterWriteInterference(
           }
         }
 
-        // No conflict if the conflicting write and the last write are the same
+        // No conflict if the conflicting write and the definition are the same
         // use.
         SmallVector<OpResult> aliasingOpResult =
             state.getAliasingOpResult(*uConflictingWrite);
-        if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite) {
+        if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == definition) {
           LLVM_DEBUG(llvm::dbgs()
-                     << "    no conflict: last write and write are same\n");
+                     << "    no conflict: definition and write are same\n");
           continue;
         }
 
         // All requirements are met. Conflict found!
 
         if (options.printConflicts)
-          annotateConflict(uRead, uConflictingWrite, lastWrite);
+          annotateConflict(uRead, uConflictingWrite, definition);
         LLVM_DEBUG(llvm::dbgs() << "  => RaW CONFLICT FOUND\n");
         return true;
       }
@@ -734,8 +725,8 @@ static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
 /// conflict because:
 /// * According to SSA use-def chains, we expect to read the result of %1.
 /// * However, adding an alias {%0, %t} would mean that the second
-///   TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
-///   would no longer be reading the result of %1.
+///   TransferWriteOp overwrites the result of the first one. Therefore, the
+///   TransferReadOp would no longer be reading the result of %1.
 ///
 /// If `checkConsistencyOnly` is true, this function checks if there is a
 /// read-after-write conflict without bufferizing `operand` inplace. This would
index 5d0d44f..2f57844 100644 (file)
@@ -712,7 +712,7 @@ static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
     // In the above example:
     // uRead             = OpOperand 0 (%1) of vector.transfer_read
     // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
-    // lastWrite         = %1
+    // definition        = %1
     //
     // This is not a conflict because the InsertSliceOp overwrites the
     // memory segment of %1 with the exact same data. (Effectively, there