[mlir][bufferize][NFC] OneShotAnalysis: Expose analysis hooks from AnalysisState
authorMatthias Springer <springerm@google.com>
Wed, 8 Feb 2023 08:22:48 +0000 (09:22 +0100)
committerMatthias Springer <springerm@google.com>
Wed, 8 Feb 2023 08:28:08 +0000 (09:28 +0100)
This is in preparation of reusing the same AnalysisState for tensor.empty elimination and One-Shot Bufferize (to address performance bottlenecks).

Differential Revision: https://reviews.llvm.org/D143379

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

index 3dc045d..f1f43e5 100644 (file)
@@ -14,6 +14,8 @@
 #include <string>
 
 namespace mlir {
+class DominanceInfo;
+
 namespace bufferization {
 
 struct OneShotBufferizationOptions;
@@ -63,6 +65,12 @@ public:
         AnalysisState::getOptions());
   }
 
+  /// Analyze the given op and its nested ops.
+  LogicalResult analyzeOp(Operation *op, const DominanceInfo &domInfo);
+
+  /// Analyze a single op (without nested ops).
+  LogicalResult analyzeSingleOp(Operation *op, const DominanceInfo &domInfo);
+
   /// Apply `fun` to all the members of the equivalence class of `v`.
   void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
 
index 4231b85..6420feb 100644 (file)
@@ -812,61 +812,13 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,
   return success();
 }
 
-/// Analyze the `ops` to determine which OpOperands are inplaceable. Walk ops in
-/// reverse and bufferize ops greedily. This is a good starter heuristic.
-///
-/// Even if an op does not read or write, it may still create an alias when
-/// bufferized in-place. An example of such ops is tensor.extract_slice.
-///
-/// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace:
-///
-/// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This
-/// cannot change the flow of information for either the source or the
-/// result buffers.
-///
-/// When bufferized inplace, an ExtractSliceOp does not by itself create any
-/// read or write from memory. Instead, it has the effect of merging the alias
-/// sets of the source and the result buffers.
-///
-/// An analysis is required to ensure inplace bufferization would not result in
-/// RaW dependence violations.
-static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
-                                     OneShotAnalysisState &state,
-                                     const DominanceInfo &domInfo,
-                                     unsigned analysisFuzzerSeed = 0) {
-  if (analysisFuzzerSeed) {
-    // This is a fuzzer. For testing purposes only. Randomize the order in which
-    // operations are analyzed. The bufferization quality is likely worse, but
-    // we want to make sure that no assertions are triggered anywhere.
-    std::mt19937 g(analysisFuzzerSeed);
-    llvm::shuffle(ops.begin(), ops.end(), g);
-  }
-
-  // Analyze a single op.
-  auto analyzeOp = [&](Operation *op) {
-    for (OpOperand &opOperand : op->getOpOperands())
-      if (opOperand.get().getType().isa<TensorType>())
-        if (failed(bufferizableInPlaceAnalysisImpl(opOperand, state, domInfo)))
-          return failure();
-    return success();
-  };
-
-  OneShotBufferizationOptions::AnalysisHeuristic heuristic =
-      state.getOptions().analysisHeuristic;
-  if (heuristic == OneShotBufferizationOptions::AnalysisHeuristic::BottomUp) {
-    // Default: Walk ops in reverse for better interference analysis.
-    for (Operation *op : reverse(ops))
-      if (failed(analyzeOp(op)))
-        return failure();
-  } else if (heuristic ==
-             OneShotBufferizationOptions::AnalysisHeuristic::TopDown) {
-    for (Operation *op : ops)
-      if (failed(analyzeOp(op)))
+LogicalResult
+OneShotAnalysisState::analyzeSingleOp(Operation *op,
+                                      const DominanceInfo &domInfo) {
+  for (OpOperand &opOperand : op->getOpOperands())
+    if (opOperand.get().getType().isa<TensorType>())
+      if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo)))
         return failure();
-  } else {
-    llvm_unreachable("unsupported heuristic");
-  }
-
   return success();
 }
 
@@ -877,23 +829,6 @@ static bool hasTensorSemantics(Operation *op) {
   return hasTensorResult || hasTensorOperand;
 }
 
-/// Analyze all ops that are contained in `op`.
-static LogicalResult inPlaceAnalysis(Operation *op,
-                                     OneShotAnalysisState &state,
-                                     const DominanceInfo &domInfo,
-                                     unsigned analysisFuzzerSeed = 0) {
-  // Collect ops so we can build our own reverse traversal.
-  SmallVector<Operation *> ops;
-  op->walk([&](Operation *op) {
-    // No tensors => no buffers.
-    if (!hasTensorSemantics(op))
-      return;
-    ops.push_back(op);
-  });
-
-  return inPlaceAnalysis(ops, state, domInfo, analysisFuzzerSeed);
-}
-
 /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
 static void equivalenceAnalysis(SmallVector<Operation *> &ops,
                                 OneShotAnalysisState &state) {
@@ -924,6 +859,45 @@ static void equivalenceAnalysis(Operation *op, OneShotAnalysisState &state) {
   equivalenceAnalysis(ops, state);
 }
 
+LogicalResult OneShotAnalysisState::analyzeOp(Operation *op,
+                                              const DominanceInfo &domInfo) {
+  // Collect ops so we can build our own reverse traversal.
+  SmallVector<Operation *> ops;
+  op->walk([&](Operation *op) {
+    // No tensors => no buffers.
+    if (!hasTensorSemantics(op))
+      return;
+    ops.push_back(op);
+  });
+
+  if (getOptions().analysisFuzzerSeed) {
+    // This is a fuzzer. For testing purposes only. Randomize the order in which
+    // operations are analyzed. The bufferization quality is likely worse, but
+    // we want to make sure that no assertions are triggered anywhere.
+    std::mt19937 g(getOptions().analysisFuzzerSeed);
+    llvm::shuffle(ops.begin(), ops.end(), g);
+  }
+
+  OneShotBufferizationOptions::AnalysisHeuristic heuristic =
+      getOptions().analysisHeuristic;
+  if (heuristic == OneShotBufferizationOptions::AnalysisHeuristic::BottomUp) {
+    // Default: Walk ops in reverse for better interference analysis.
+    for (Operation *op : reverse(ops))
+      if (failed(analyzeSingleOp(op, domInfo)))
+        return failure();
+  } else if (heuristic ==
+             OneShotBufferizationOptions::AnalysisHeuristic::TopDown) {
+    for (Operation *op : ops)
+      if (failed(analyzeSingleOp(op, domInfo)))
+        return failure();
+  } else {
+    llvm_unreachable("unsupported heuristic");
+  }
+
+  equivalenceAnalysis(op, *this);
+  return success();
+}
+
 /// Assert that the current bufferization decisions are consistent.
 static LogicalResult checkAliasInfoConsistency(Operation *op,
                                                const DominanceInfo &domInfo,
@@ -1060,7 +1034,7 @@ LogicalResult bufferization::analyzeOp(Operation *op,
     return failure();
 
   // If the analysis fails, just return.
-  if (failed(inPlaceAnalysis(op, state, domInfo, options.analysisFuzzerSeed)))
+  if (failed(state.analyzeOp(op, domInfo)))
     return failure();
 
   if (statistics) {
@@ -1068,8 +1042,6 @@ LogicalResult bufferization::analyzeOp(Operation *op,
     statistics->numTensorOutOfPlace = state.getStatNumTensorOutOfPlace();
   }
 
-  equivalenceAnalysis(op, state);
-
   bool failedAnalysis = false;
   if (!options.allowReturnAllocs)
     failedAnalysis |= failed(assertNoAllocsReturned(op, state));