[mlir][bufferization][NFC] Cache definitions of read tensors
authorMatthias Springer <springerm@google.com>
Thu, 9 Feb 2023 08:10:48 +0000 (09:10 +0100)
committerMatthias Springer <springerm@google.com>
Thu, 9 Feb 2023 08:27:39 +0000 (09:27 +0100)
This is to avoid unnecessary traversals of the IR.

Differential Revision: https://reviews.llvm.org/D143408

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

index f1f43e5..8a7d8f0 100644 (file)
@@ -122,6 +122,13 @@ public:
   /// Return true if the buffer of the given tensor value is writable.
   bool isWritable(Value value) const;
 
+  /// Find the definitions of the given tensor value or retrieve them from the
+  /// cache.
+  const SetVector<Value> &findDefinitionsCached(Value value);
+
+  /// Reset cached data structures.
+  void resetCache();
+
   /// Union the alias sets of `v1` and `v2`.
   void unionAliasSets(Value v1, Value v2);
 
@@ -226,6 +233,9 @@ private:
   /// Check that aliasInfo for `v` exists and return a reference to it.
   EquivalenceClassRangeType getAliases(Value v) const;
 
+  /// Cache definitions of tensor values.
+  DenseMap<Value, SetVector<Value>> cachedDefinitions;
+
   /// Set of all OpResults that were decided to bufferize in-place.
   llvm::DenseSet<OpOperand *> inplaceBufferized;
 
index 44986a6..d964be9 100644 (file)
@@ -16,6 +16,7 @@ namespace mlir {
 namespace bufferization {
 class AnalysisState;
 struct BufferizationStatistics;
+class OneShotAnalysisState;
 struct OneShotBufferizationOptions;
 
 /// A function that matches anchor OpOperands for tensor::EmptyOp elimination.
@@ -36,7 +37,7 @@ using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
 ///   following the aliasing  OpOperand, that eventually ends at a single
 ///   tensor::EmptyOp.
 LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
-                                    bufferization::AnalysisState &state,
+                                    OneShotAnalysisState &state,
                                     AnchorMatchFn anchorMatchFunc,
                                     RewriteFn rewriteFunc);
 
@@ -44,7 +45,7 @@ LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
 /// InsertSliceOp, i.e., if it is eventually inserted into another tensor
 /// (and some other conditions are met).
 LogicalResult insertSliceAnchoredEmptyTensorEliminationStep(
-    RewriterBase &rewriter, Operation *op, bufferization::AnalysisState &state);
+    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state);
 
 /// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
 /// After applying this transform, the IR can be bufferized without inserting
index 76dfa20..1579cfd 100644 (file)
@@ -105,7 +105,7 @@ findValidInsertionPoint(Operation *emptyTensorOp,
 /// chain, starting from the OpOperand and always following the aliasing
 /// OpOperand, that eventually ends at a single tensor::EmptyOp.
 LogicalResult mlir::bufferization::eliminateEmptyTensors(
-    RewriterBase &rewriter, Operation *op, AnalysisState &state,
+    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
     AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
   OpBuilder::InsertionGuard g(rewriter);
 
@@ -153,6 +153,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
 
       // Replace the tensor::EmptyOp.
       rewriter.replaceOp(emptyTensor.getDefiningOp(), replacement);
+      state.resetCache();
     }
 
     // Advance to the next operation.
@@ -189,7 +190,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
 ///   tensor::EmptyOp.
 template <typename OpTy>
 static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep(
-    RewriterBase &rewriter, Operation *op, AnalysisState &state) {
+    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
   return eliminateEmptyTensors(
       rewriter, op, state,
       /*anchorMatchFunc=*/
@@ -224,7 +225,7 @@ static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep(
 
 LogicalResult
 mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
-    RewriterBase &rewriter, Operation *op, AnalysisState &state) {
+    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
   if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep<
              tensor::InsertSliceOp>(rewriter, op, state)))
     return failure();
index 8a7d660..02ef3a6 100644 (file)
@@ -222,7 +222,7 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
 
       // If there is no preceding definition, the tensor contents are
       // undefined.
-      if (findDefinitions(opResult).empty())
+      if (findDefinitionsCached(opResult).empty())
         for (OpOperand &use : opResult.getUses())
           undefinedTensorUses.insert(&use);
     }
@@ -473,7 +473,8 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
     // In the above example, if uRead is the OpOperand of reading_op, the
     // definition is %0. Note that operations that create an alias but do not
     // bufferize to a memory write (such as ExtractSliceOp) are skipped.
-    SetVector<Value> definitions = state.findDefinitions(uRead->get());
+    const SetVector<Value> &definitions =
+        state.findDefinitionsCached(uRead->get());
     if (definitions.empty()) {
       // Fast path: No conflict if there are no definitions.
       LLVM_DEBUG(llvm::dbgs()
@@ -769,6 +770,19 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
 // Bufferization analyses.
 //===----------------------------------------------------------------------===//
 
+// Find the values that define the contents of the given value.
+const llvm::SetVector<Value> &
+OneShotAnalysisState::findDefinitionsCached(Value value) {
+  if (!cachedDefinitions.count(value)) {
+    cachedDefinitions[value] = findValueInReverseUseDefChain(
+        value, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
+        /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false);
+  }
+  return cachedDefinitions[value];
+}
+
+void OneShotAnalysisState::resetCache() { cachedDefinitions.clear(); }
+
 /// Determine if `operand` can be bufferized in-place.
 static LogicalResult
 bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,