This is to avoid unnecessary traversals of the IR.
Differential Revision: https://reviews.llvm.org/D143408
/// Return true if the buffer of the given tensor value is writable.
bool isWritable(Value value) const;
+ /// Find the definitions of the given tensor value or retrieve them from the
+ /// cache.
+ const SetVector<Value> &findDefinitionsCached(Value value);
+
+ /// Reset cached data structures.
+ void resetCache();
+
/// Union the alias sets of `v1` and `v2`.
void unionAliasSets(Value v1, Value v2);
/// Check that aliasInfo for `v` exists and return a reference to it.
EquivalenceClassRangeType getAliases(Value v) const;
+ /// Cache definitions of tensor values.
+ DenseMap<Value, SetVector<Value>> cachedDefinitions;
+
/// Set of all OpResults that were decided to bufferize in-place.
llvm::DenseSet<OpOperand *> inplaceBufferized;
namespace bufferization {
class AnalysisState;
struct BufferizationStatistics;
+class OneShotAnalysisState;
struct OneShotBufferizationOptions;
/// A function that matches anchor OpOperands for tensor::EmptyOp elimination.
/// following the aliasing OpOperand, that eventually ends at a single
/// tensor::EmptyOp.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
- bufferization::AnalysisState &state,
+ OneShotAnalysisState &state,
AnchorMatchFn anchorMatchFunc,
RewriteFn rewriteFunc);
/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
/// (and some other conditions are met).
LogicalResult insertSliceAnchoredEmptyTensorEliminationStep(
- RewriterBase &rewriter, Operation *op, bufferization::AnalysisState &state);
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state);
/// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
/// After applying this transform, the IR can be bufferized without inserting
/// chain, starting from the OpOperand and always following the aliasing
/// OpOperand, that eventually ends at a single tensor::EmptyOp.
LogicalResult mlir::bufferization::eliminateEmptyTensors(
- RewriterBase &rewriter, Operation *op, AnalysisState &state,
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
OpBuilder::InsertionGuard g(rewriter);
// Replace the tensor::EmptyOp.
rewriter.replaceOp(emptyTensor.getDefiningOp(), replacement);
+ state.resetCache();
}
// Advance to the next operation.
/// tensor::EmptyOp.
template <typename OpTy>
static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep(
- RewriterBase &rewriter, Operation *op, AnalysisState &state) {
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
return eliminateEmptyTensors(
rewriter, op, state,
/*anchorMatchFunc=*/
LogicalResult
mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
- RewriterBase &rewriter, Operation *op, AnalysisState &state) {
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep<
tensor::InsertSliceOp>(rewriter, op, state)))
return failure();
// If there is no preceding definition, the tensor contents are
// undefined.
- if (findDefinitions(opResult).empty())
+ if (findDefinitionsCached(opResult).empty())
for (OpOperand &use : opResult.getUses())
undefinedTensorUses.insert(&use);
}
// In the above example, if uRead is the OpOperand of reading_op, the
// definition is %0. Note that operations that create an alias but do not
// bufferize to a memory write (such as ExtractSliceOp) are skipped.
- SetVector<Value> definitions = state.findDefinitions(uRead->get());
+ const SetVector<Value> &definitions =
+ state.findDefinitionsCached(uRead->get());
if (definitions.empty()) {
// Fast path: No conflict if there are no definitions.
LLVM_DEBUG(llvm::dbgs()
// Bufferization analyses.
//===----------------------------------------------------------------------===//
+// Find the values that define the contents of the given value.
+const llvm::SetVector<Value> &
+OneShotAnalysisState::findDefinitionsCached(Value value) {
+ if (!cachedDefinitions.count(value)) {
+ cachedDefinitions[value] = findValueInReverseUseDefChain(
+ value, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
+ /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false);
+ }
+ return cachedDefinitions[value];
+}
+
+void OneShotAnalysisState::resetCache() { cachedDefinitions.clear(); }
+
/// Determine if `operand` can be bufferized in-place.
static LogicalResult
bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,