return isWritten;
}
+bool OneShotAnalysisState::isWritable(Value value) const {
+ // TODO: Out-of-place bufferized value could be considered writable.
+ if (auto bufferizableOp = getOptions().dynCastBufferizableOp(value))
+ return bufferizableOp.isWritable(value, *this);
+
+ // Query BufferizableOpInterface to see if the BlockArgument is writable.
+ if (auto bbArg = value.dyn_cast<BlockArgument>())
+ if (auto bufferizableOp =
+ getOptions().dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
+ return bufferizableOp.isWritable(bbArg, *this);
+
+ // Not a bufferizable op: The conservative answer is "not writable".
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// Bufferization-specific alias analysis.
//===----------------------------------------------------------------------===//
/// Return true if opOperand has been decided to bufferize in-place.
static bool isInplaceMemoryWrite(OpOperand &opOperand,
const BufferizationAliasInfo &aliasInfo,
- AnalysisState &state) {
+ const AnalysisState &state) {
// OpOperands that do not bufferize to a memory write do not write in-place.
if (!state.bufferizesToMemoryWrite(opOperand))
return false;
return aliasInfo.isInPlace(opOperand);
}
-/// Return true if, under current bufferization decisions, the buffer of `value`
-/// is not writable.
-static bool aliasesNonWritableBuffer(Value value,
- const BufferizationAliasInfo &aliasInfo,
- AnalysisState &state) {
- bool foundNonWritableBuffer = false;
- aliasInfo.applyOnAliases(value, [&](Value v) {
- // Query BufferizableOpInterface to see if the value is writable.
- // TODO: Out-of-place bufferized value could be considered writable.
- if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(v))
- if (bufferizableOp && bufferizableOp.isWritable(v, state))
- return;
-
- // Query BufferizableOpInterface to see if the BlockArgument is writable.
- if (auto bbArg = v.dyn_cast<BlockArgument>())
- if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(
- bbArg.getOwner()->getParentOp()))
- if (bufferizableOp.isWritable(bbArg, state))
- return;
-
- foundNonWritableBuffer = true;
- });
-
- return foundNonWritableBuffer;
-}
-
-/// Return true if the buffer to which `operand` would bufferize is equivalent
-/// to some buffer write.
-static bool aliasesInPlaceWrite(Value value,
- const BufferizationAliasInfo &aliasInfo,
- AnalysisState &state) {
- bool foundInplaceWrite = false;
- aliasInfo.applyOnAliases(value, [&](Value v) {
- for (auto &use : v.getUses()) {
- if (isInplaceMemoryWrite(use, aliasInfo, state)) {
- foundInplaceWrite = true;
- return;
- }
- }
- });
- return foundInplaceWrite;
-}
-
/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
/// properly dominates `b` and `b` is not inside `a`.
static bool happensBefore(Operation *a, Operation *b,
return false;
}
+// Helper function to iterate on aliases of `root` and capture the writes.
+static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root,
+ const BufferizationAliasInfo &aliasInfo,
+ const AnalysisState &state) {
+ aliasInfo.applyOnAliases(root, [&](Value alias) {
+ for (auto &use : alias.getUses())
+ // Inplace write to a value that aliases root.
+ if (isInplaceMemoryWrite(use, aliasInfo, state))
+ res.insert(&use);
+ });
+}
+
+// Helper function to iterate on aliases of `root` and capture the reads.
+static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
+ const BufferizationAliasInfo &aliasInfo,
+ const AnalysisState &state) {
+ aliasInfo.applyOnAliases(root, [&](Value alias) {
+ for (auto &use : alias.getUses())
+ // Read to a value that aliases root.
+ if (state.bufferizesToMemoryRead(use))
+ res.insert(&use);
+ });
+}
+
/// Return true if bufferizing `operand` inplace would create a conflict. A read
/// R and a write W of the same alias set is a conflict if inplace bufferization
/// of W changes the value read by R to a value different from the one that
OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state,
const BufferizationAliasInfo &aliasInfo,
bool checkConsistencyOnly = false) {
- // Helper function to iterate on aliases of `root` and capture the reads.
- auto getAliasingReads = [&](DenseSet<OpOperand *> &res, Value root) {
- aliasInfo.applyOnAliases(root, [&](Value alias) {
- for (auto &use : alias.getUses())
- // Read to a value that aliases root.
- if (state.bufferizesToMemoryRead(use))
- res.insert(&use);
- });
- };
-
- // Helper function to iterate on aliases of `root` and capture the writes.
- auto getAliasingInplaceWrites = [&](DenseSet<OpOperand *> &res, Value root) {
- aliasInfo.applyOnAliases(root, [&](Value alias) {
- for (auto &use : alias.getUses())
- // Inplace write to a value that aliases root.
- if (isInplaceMemoryWrite(use, aliasInfo, state))
- res.insert(&use);
- });
- };
-
// Collect reads and writes of all aliases of OpOperand and OpResult.
DenseSet<OpOperand *> usesRead, usesWrite;
- getAliasingReads(usesRead, operand.get());
- getAliasingInplaceWrites(usesWrite, operand.get());
+ getAliasingReads(usesRead, operand.get(), aliasInfo, state);
+ getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
for (OpResult result : state.getAliasingOpResult(operand)) {
- getAliasingReads(usesRead, result);
- getAliasingInplaceWrites(usesWrite, result);
+ getAliasingReads(usesRead, result, aliasInfo, state);
+ getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
}
if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
usesWrite.insert(&operand);
aliasInfo);
}
-/// Return true if bufferizing `opOperand` inplace would create a write to a
-/// non-writable buffer.
+/// Check the reverse SSA use-def chain (following aliasing OpOperands) for
+/// non-writable tensor values. Stop searching when an out-of-place bufferized
+/// OpOperand was found (or when the OpOperand was not bufferized yet).
+/// `currentOpOperand` is assumed to be in-place, even if that decision was not
+/// materialized in `aliasInfo` yet.
static bool
-wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
- const BufferizationAliasInfo &aliasInfo,
- AnalysisState &state) {
- // Certain buffers are not writeable:
- // 1. A function bbArg that is not inplaceable or
- // 2. A constant op.
- bool nonWritable =
- aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state);
- if (!nonWritable)
- return false;
+hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
+ const BufferizationAliasInfo &aliasInfo,
+ const OneShotAnalysisState &state) {
+ SmallVector<Value> worklist;
+ worklist.push_back(value);
+ while (!worklist.empty()) {
+ Value nextVal = worklist.pop_back_val();
+ if (!state.isWritable(nextVal))
+ return true;
+
+ // If `nextVal` is not a BlockArgument: End of use-def chain reached.
+ auto opResult = nextVal.dyn_cast<OpResult>();
+ if (!opResult)
+ continue;
+
+ // Follow reverse SSA use-def chain.
+ SmallVector<OpOperand *> aliasingOpOperands =
+ state.getAliasingOpOperand(opResult);
+ for (OpOperand *opOperand : aliasingOpOperands)
+ if (aliasInfo.isInPlace(*opOperand) || currentOpOperand == opOperand)
+ worklist.push_back(opOperand->get());
+ }
+ return false;
+}
- // This is a problem only if the buffer is written to via some alias.
- bool hasWrite = aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) ||
- state.bufferizesToMemoryWrite(opOperand);
+/// Return true if bufferizing `operand` inplace would create a write to a
+/// non-writable buffer.
+static bool wouldCreateWriteToNonWritableBuffer(
+ OpOperand &operand, const BufferizationAliasInfo &aliasInfo,
+ OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
+ // Collect writes of all aliases of OpOperand and OpResult.
+ DenseSet<OpOperand *> usesWrite;
+ getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
+ for (OpResult result : state.getAliasingOpResult(operand)) {
+ getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
+ }
+ if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
+ usesWrite.insert(&operand);
- for (OpResult opResult : state.getAliasingOpResult(opOperand))
- hasWrite |= aliasesInPlaceWrite(opResult, aliasInfo, state);
+ // Assuming that `operand` bufferizes in-place: For each write (to each
+ // alias), check if there is a non-writable tensor in the reverse SSA use-def
+ // chain.
+ for (OpOperand *uWrite : usesWrite)
+ if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand,
+ aliasInfo, state))
+ return true;
- return hasWrite;
+ return false;
}
//===----------------------------------------------------------------------===//
/// Determine if `operand` can be bufferized in-place.
static LogicalResult bufferizableInPlaceAnalysisImpl(
- OpOperand &operand, BufferizationAliasInfo &aliasInfo, AnalysisState &state,
- const DominanceInfo &domInfo) {
+ OpOperand &operand, BufferizationAliasInfo &aliasInfo,
+ OneShotAnalysisState &state, const DominanceInfo &domInfo) {
bool foundInterference =
wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) ||
wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo);
/// RaW dependence violations.
static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
BufferizationAliasInfo &aliasInfo,
- AnalysisState &state,
+ OneShotAnalysisState &state,
const DominanceInfo &domInfo,
unsigned analysisFuzzerSeed = 0) {
if (analysisFuzzerSeed) {
/// Analyze all ops that are contained in `op`.
static LogicalResult inPlaceAnalysis(Operation *op,
BufferizationAliasInfo &aliasInfo,
- AnalysisState &state,
+ OneShotAnalysisState &state,
const DominanceInfo &domInfo,
unsigned analysisFuzzerSeed = 0) {
// Collect ops so we can build our own reverse traversal.