createAliasInfoEntry(bbArg);
});
- // The return value of an scf::IfOp aliases with both yield values.
- rootOp->walk([&](scf::IfOp ifOp) {
- if (ifOp->getNumResults() > 0) {
- for (auto it : llvm::zip(ifOp.thenYield().results(),
- ifOp.elseYield().results(), ifOp.results())) {
- aliasInfo.unionSets(std::get<0>(it), std::get<1>(it));
- aliasInfo.unionSets(std::get<0>(it), std::get<2>(it));
- }
-
- // scf::IfOp always bufferizes in-place.
- for (OpResult opResult : ifOp->getResults())
- setInPlaceOpResult(opResult, InPlaceSpec::True);
+ // Set up alias sets for OpResults that must bufferize in-place. This should
+ // be done before making any other bufferization decisions.
+ rootOp->walk([&](BufferizableOpInterface bufferizableOp) {
+ for (OpResult opResult : bufferizableOp->getOpResults()) {
+ if (opResult.getType().isa<TensorType>())
+ if (bufferizableOp.mustBufferizeInPlace(opResult)) {
+ SmallVector<OpOperand *> operands =
+ bufferizableOp.getAliasingOpOperand(opResult);
+ assert(!operands.empty() &&
+ "expected that OpResult has aliasing OpOperand");
+ for (OpOperand *operand : operands)
+ aliasInfo.unionSets(operand->get(), opResult);
+ setInPlaceOpResult(opResult, InPlaceSpec::True);
+ }
}
});
}
/// * However, adding an alias {%0, %t} would mean that the second
/// TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
/// would no longer be reading the result of %1.
+///
+/// If `checkConsistencyOnly` is true, this function checks if there is a
+/// read-after-write conflict without bufferizing `operand` inplace. This would
+/// indicate a problem with the current inplace bufferization decisions.
bool wouldCreateReadAfterWriteInterference(
OpOperand &operand, OpResult result, const DominanceInfo &domInfo,
- const BufferizationAliasInfo &aliasInfo) {
+ const BufferizationAliasInfo &aliasInfo,
+ bool checkConsistencyOnly = false) {
#ifndef NDEBUG
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
assert(llvm::find(opOperands, &operand) != opOperands.end() &&
getAliasingReads(usesRead, result);
getAliasingInplaceWrites(usesWrite, operand.get());
getAliasingInplaceWrites(usesWrite, result);
- if (bufferizesToMemoryWrite(operand))
+ if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand))
usesWrite.insert(&operand);
return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, aliasInfo);
});
}
+/// Assert that the current bufferization decisions are consistent.
+static void checkAliasInfoConsistency(FuncOp funcOp,
+ const DominanceInfo &domInfo,
+ const BufferizationAliasInfo &aliasInfo) {
+ funcOp.walk([&](Operation *op) {
+ if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
+ for (OpOperand &opOperand : op->getOpOperands())
+ if (opOperand.get().getType().isa<TensorType>())
+ if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand))
+ // If this assertion fails, there is probably an inconsistent
+ // combination of "mustBufferizeInPlace" decisions.
+ assert(!wouldCreateReadAfterWriteInterference(
+ opOperand, opResult, domInfo, aliasInfo,
+ /*checkConsistencyOnly=*/true) &&
+ "found read after write conflict before running analysis");
+ });
+}
+
LogicalResult
mlir::linalg::runComprehensiveBufferize(ModuleOp moduleOp,
const BufferizationOptions &options) {
DominanceInfo domInfo(moduleOp);
BufferizationAliasInfo aliasInfo(moduleOp);
+
// Interestingly, all function args that are not visible outside of a module
// can be fully bufferized inplace by guaranteeing the CallOp is bufferized
// inplace. Therefore, we just bufferize funcOp as if none of its results were
if (bbArg.getType().isa<TensorType>())
setInPlaceFuncArgument(bbArg);
+#ifndef NDEBUG
+ checkAliasInfoConsistency(funcOp, domInfo, aliasInfo);
+#endif // NDEBUG
+
// If the analysis fails, just return.
if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo,
options.analysisFuzzerSeed)))
return true;
}
+ bool mustBufferizeInPlace(Operation *op, OpResult opResult) const {
+ // IfOp results always bufferize in-place. Since they have no OpOperands,
+ // they are mostly ignored by the analysis once alias sets are set up.
+ return true;
+ }
+
LogicalResult bufferize(Operation *op, OpBuilder &b,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,