/// traversed any further.
///
/// When reaching the end of a chain (BlockArgument or Value without aliasing
- /// OpOperands), also return the last Value of that chain.
+ /// OpOperands), also return the last Value of that chain if
+ /// `alwaysIncludeLeaves` is set.
///
/// Example:
///
/// { 2, 7, 8, 5 }
///
/// If `followEquivalentOnly` is set, only equivalent OpOperands are selected.
- SetVector<Value>
- findValueInReverseUseDefChain(Value value,
- llvm::function_ref<bool(Value)> condition,
- bool followEquivalentOnly = false) const;
-
- /// Find the Values of the last preceding write of a given Value.
+ SetVector<Value> findValueInReverseUseDefChain(
+ Value value, llvm::function_ref<bool(Value)> condition,
+ bool followEquivalentOnly = false, bool alwaysIncludeLeaves = true) const;
+
+ /// Find the values that may define the contents of the given value at
+ /// runtime. A block argument is always a definition. An OpResult is a
+ /// definition if it bufferizes to memory write. If it does not bufferize to
+ /// a memory write but has aliasing operands, we continue the lookup on these
+ /// values.
+ ///
+ /// Example: %r = tensor.insert %f into %t[%c0] : tensor<?xf32>
+ /// findDefinitions(%r) = {%r} because %r bufferizes to memory write.
+ ///
+ /// Example: %r = tensor.empty() : tensor<10xf32>
+ /// findDefinitions(%r) = {} because tensor.empty does not the define the
+ /// contents of its result (i.e., it does not bufferize to a memory write)
+ /// and it has no aliasing OpOperands.
+ ///
+ /// Example:
+ /// %a = arith.constant ... : tensor<10xf32>
+ /// %b1 = tensor.insert %f into %t : tensor<50xf32>
+ /// %b2 = tensor.extract_slice %b1[0][10][1] : tensor<50xf32> tensor<10xf32>
+ /// %r = arith.select %cond, %a, %b : tensor<10xf32>
+ /// findDefinitions(%r) = {%a, %b1}. %r and %b2 are skipped (lookup continues
+ /// in the operands) because their defining ops do not define the contents of
+ /// the tensor.
///
- /// Note: Unknown ops are handled conservatively and assumed to be writes.
- /// Furthermore, BlockArguments are also assumed to be writes. There is no
- /// analysis across block boundaries.
+ /// Note: OpResults of unknown ops are handled conservatively and assumed to
+ /// be definitions.
///
/// Note: When reaching an end of the reverse SSA use-def chain, that value
- /// is returned regardless of whether it is a memory write or not.
- SetVector<Value> findLastPrecedingWrite(Value value) const;
+ /// is included regardless of whether it is a definition or not unless
+ /// `alwaysIncludeLeaves` is unset.
+ SetVector<Value> findDefinitions(Value value,
+ bool alwaysIncludeLeaves = true) const;
/// Return `true` if the given OpResult has been decided to bufferize inplace.
virtual bool isInPlace(OpOperand &opOperand) const;
if (!opResult.getType().isa<TensorType>())
continue;
- // If there is no preceding memory write, the tensor contents are
+ // If there is no preceding definition, the tensor contents are
// undefined.
- // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
- // use-def chain, it returns that value, regardless of whether it is a
- // memory write or not.
- SetVector<Value> lastWrites = findLastPrecedingWrite(opResult);
- bool isUndefined = llvm::none_of(lastWrites, [&](Value lastWrite) {
- return this->bufferizesToMemoryWrite(lastWrite);
- });
- if (isUndefined)
+ if (findDefinitions(opResult, /*alwaysIncludeLeaves=*/false).empty())
for (OpOperand &use : opResult.getUses())
undefinedTensorUses.insert(&use);
}
/// Annotate IR with details about the detected RaW conflict.
static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
- Value lastWrite) {
+ Value definition) {
static uint64_t counter = 0;
Operation *readingOp = uRead->getOwner();
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
readingOp->setAttr(readAttr, b.getUnitAttr());
- if (auto opResult = lastWrite.dyn_cast<OpResult>()) {
- std::string lastWriteAttr = id + "[LAST-WRITE: result " +
- std::to_string(opResult.getResultNumber()) +
- "]";
- opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr());
+ if (auto opResult = definition.dyn_cast<OpResult>()) {
+ std::string defAttr =
+ id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]";
+ opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr());
} else {
- auto bbArg = lastWrite.cast<BlockArgument>();
- std::string lastWriteAttr =
- id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
- bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr());
+ auto bbArg = definition.cast<BlockArgument>();
+ std::string defAttr =
+ id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
+ bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr());
}
}
/// all given writes bufferize inplace.
///
/// A conflict is: According to SSA use-def chains, a read R is supposed to read
-/// the result of a write W1. But because of bufferization decisions, R actually
-/// reads another write W2.
+/// the result of a definition W1. But because of bufferization decisions, R
+/// actually reads another definition W2.
static bool hasReadAfterWriteInterference(
const DenseSet<OpOperand *> &usesRead,
const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
// %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
// %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
//
- // In the above example, if uRead is the OpOperand of reading_op, lastWrite
- // is %0. Note that operations that create an alias but do not write (such
- // as ExtractSliceOp) are skipped.
- SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get());
+ // In the above example, if uRead is the OpOperand of reading_op, the
+ // definition is %0. Note that operations that create an alias but do not
+ // bufferize to a memory write (such as ExtractSliceOp) are skipped.
+ SetVector<Value> definitions = state.findDefinitions(uRead->get());
// Look for conflicting memory writes. Potential conflicts are writes to an
// alias that have been decided to bufferize inplace.
}
}
- // Check all possible last writes.
- for (Value lastWrite : lastWrites) {
- LLVM_DEBUG(llvm::dbgs() << " * lastWrite = " << lastWrite << "\n");
+ // Check all possible definitions.
+ for (Value definition : definitions) {
+ LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n");
- // No conflict if the conflicting write happens before the last
- // write.
- if (Operation *writingOp = lastWrite.getDefiningOp()) {
+ // No conflict if the conflicting write happens before the definition.
+ if (Operation *writingOp = definition.getDefiningOp()) {
if (happensBefore(conflictingWritingOp, writingOp, domInfo)) {
// conflictingWritingOp happens before writingOp. No conflict.
LLVM_DEBUG(llvm::dbgs()
- << " no conflict: write happens before last write\n");
+ << " no conflict: write happens before definition\n");
continue;
}
// No conflict if conflictingWritingOp is contained in writingOp.
if (writingOp->isProperAncestor(conflictingWritingOp)) {
LLVM_DEBUG(
llvm::dbgs()
- << " no conflict: write is contained in last write\n");
+ << " no conflict: write is contained in definition\n");
continue;
}
} else {
- auto bbArg = lastWrite.cast<BlockArgument>();
+ auto bbArg = definition.cast<BlockArgument>();
Block *block = bbArg.getOwner();
if (!block->findAncestorOpInBlock(*conflictingWritingOp)) {
- LLVM_DEBUG(llvm::dbgs() << " no conflict: last write is bbArg "
+ LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg "
"and write happens outside of block\n");
// conflictingWritingOp happens outside of the block. No
// conflict.
}
}
- // No conflict if the conflicting write and the last write are the same
+ // No conflict if the conflicting write and the definition are the same
// use.
SmallVector<OpResult> aliasingOpResult =
state.getAliasingOpResult(*uConflictingWrite);
- if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite) {
+ if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == definition) {
LLVM_DEBUG(llvm::dbgs()
- << " no conflict: last write and write are same\n");
+ << " no conflict: definition and write are same\n");
continue;
}
// All requirements are met. Conflict found!
if (options.printConflicts)
- annotateConflict(uRead, uConflictingWrite, lastWrite);
+ annotateConflict(uRead, uConflictingWrite, definition);
LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n");
return true;
}
/// conflict because:
/// * According to SSA use-def chains, we expect to read the result of %1.
/// * However, adding an alias {%0, %t} would mean that the second
-/// TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
-/// would no longer be reading the result of %1.
+/// TransferWriteOp overwrites the result of the first one. Therefore, the
+/// TransferReadOp would no longer be reading the result of %1.
///
/// If `checkConsistencyOnly` is true, this function checks if there is a
/// read-after-write conflict without bufferizing `operand` inplace. This would