[mlir][bufferize][NFC] Merge AnalysisState and BufferizationAliasInfo
authorMatthias Springer <springerm@google.com>
Wed, 8 Feb 2023 07:59:15 +0000 (08:59 +0100)
committerMatthias Springer <springerm@google.com>
Wed, 8 Feb 2023 08:12:09 +0000 (09:12 +0100)
There is no longer a need to keep the two separate. This is in preparation of reusing the same AnalysisState for tensor.empty elimination and One-Shot Bufferize (to address performance bottlenecks).

Differential Revision: https://reviews.llvm.org/D143313

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

index 63a8b38..3dc045d 100644 (file)
@@ -17,7 +17,6 @@ namespace mlir {
 namespace bufferization {
 
 struct OneShotBufferizationOptions;
-class BufferizationAliasInfo;
 struct BufferizationStatistics;
 class OneShotAnalysisState;
 
@@ -40,108 +39,11 @@ struct OneShotBufferizationOptions : public BufferizationOptions {
   llvm::ArrayRef<std::string> noAnalysisFuncFilter;
 };
 
-/// The BufferizationAliasInfo class maintains a list of buffer aliases and
-/// equivalence classes to support bufferization.
-class BufferizationAliasInfo {
-public:
-  explicit BufferizationAliasInfo(Operation *rootOp);
-
-  // BufferizationAliasInfo should be passed as a reference.
-  BufferizationAliasInfo(const BufferizationAliasInfo &) = delete;
-
-  /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
-  /// beginning the alias and equivalence sets only contain `v` itself.
-  void createAliasInfoEntry(Value v);
-
-  /// Insert an info entry for `newValue` and merge its alias set with that of
-  /// `alias`.
-  void insertNewBufferAlias(Value newValue, Value alias);
-
-  /// Insert an info entry for `newValue` and merge its alias set with that of
-  /// `alias`. Additionally, merge their equivalence classes.
-  void insertNewBufferEquivalence(Value newValue, Value alias);
-
-  /// Set the inPlace bufferization spec to true.
-  /// Merge result's and operand's aliasing sets and iterate to a fixed point.
-  void bufferizeInPlace(OpOperand &operand, AnalysisState &state);
-
-  /// Set the inPlace bufferization spec to false.
-  void bufferizeOutOfPlace(OpOperand &operand);
-
-  /// Return true if `v1` and `v2` may bufferize to aliasing buffers.
-  bool areAliasingBufferizedValues(Value v1, Value v2) const {
-    return aliasInfo.isEquivalent(v1, v2);
-  }
-
-  /// Return true if `v1` and `v2` bufferize to equivalent buffers.
-  bool areEquivalentBufferizedValues(Value v1, Value v2) const {
-    return equivalentInfo.isEquivalent(v1, v2);
-  }
-
-  /// Union the alias sets of `v1` and `v2`.
-  void unionAliasSets(Value v1, Value v2) { aliasInfo.unionSets(v1, v2); }
-
-  /// Union the equivalence classes of `v1` and `v2`.
-  void unionEquivalenceClasses(Value v1, Value v2) {
-    equivalentInfo.unionSets(v1, v2);
-  }
-
-  /// Apply `fun` to all the members of the equivalence class of `v`.
-  void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
-
-  /// Apply `fun` to all aliases of `v`.
-  void applyOnAliases(Value v, function_ref<void(Value)> fun) const;
-
-  /// Mark a value as in-place bufferized.
-  void markInPlace(OpOperand &o) { inplaceBufferized.insert(&o); }
-
-  /// Return `true` if a value was marked as in-place bufferized.
-  bool isInPlace(OpOperand &opOperand) const;
-
-  int64_t getStatNumTensorOutOfPlace() const { return statNumTensorOutOfPlace; }
-  int64_t getStatNumTensorInPlace() const { return statNumTensorInPlace; }
-
-private:
-  /// llvm::EquivalenceClasses wants comparable elements. This comparator uses
-  /// uses pointer comparison on the defining op. This is a poor man's
-  /// comparison but it's not like UnionFind needs ordering anyway.
-  struct ValueComparator {
-    bool operator()(const Value &lhs, const Value &rhs) const {
-      return lhs.getImpl() < rhs.getImpl();
-    }
-  };
-
-  using EquivalenceClassRangeType = llvm::iterator_range<
-      llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
-  /// Check that aliasInfo for `v` exists and return a reference to it.
-  EquivalenceClassRangeType getAliases(Value v) const;
-
-  /// Set of all OpResults that were decided to bufferize in-place.
-  llvm::DenseSet<OpOperand *> inplaceBufferized;
-
-  /// Auxiliary structure to store all the values a given value may alias with.
-  /// Alias information is "may be" conservative: In the presence of branches, a
-  /// value may alias with one of multiple other values. The concrete aliasing
-  /// value may not even be known at compile time. All such values are
-  /// considered to be aliases.
-  llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
-
-  /// Auxiliary structure to store all the equivalent buffer classes. Equivalent
-  /// buffer information is "must be" conservative: Only if two values are
-  /// guaranteed to be equivalent at runtime, they said to be equivalent. It is
-  /// possible that, in the presence of branches, it cannot be determined
-  /// statically if two values are equivalent. In that case, the values are
-  /// considered to be not equivalent.
-  llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
-
-  // Bufferization statistics.
-  int64_t statNumTensorOutOfPlace = 0;
-  int64_t statNumTensorInPlace = 0;
-};
-
 /// State for analysis-enabled bufferization. This class keeps track of alias
-/// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize
-/// in-place.
+/// sets, equivalence sets, in-place OpOperands and other things.
+///
+/// Note: Modifying the IR generally invalidates the result of the analysis.
+/// Adding new operations is safe if they are analyzed subsequently.
 class OneShotAnalysisState : public AnalysisState {
 public:
   OneShotAnalysisState(Operation *op,
@@ -161,11 +63,11 @@ public:
         AnalysisState::getOptions());
   }
 
-  /// Return a reference to the BufferizationAliasInfo.
-  BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
+  /// Apply `fun` to all the members of the equivalence class of `v`.
+  void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
 
-  /// Return `true` if the given OpResult has been decided to bufferize inplace.
-  bool isInPlace(OpOperand &opOperand) const override;
+  /// Apply `fun` to all aliases of `v`.
+  void applyOnAliases(Value v, function_ref<void(Value)> fun) const;
 
   /// Return true if `v1` and `v2` bufferize to equivalent buffers.
   bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
@@ -173,12 +75,16 @@ public:
   /// Return true if `v1` and `v2` may bufferize to aliasing buffers.
   bool areAliasingBufferizedValues(Value v1, Value v2) const override;
 
-  /// Return `true` if the given tensor has undefined contents.
-  bool hasUndefinedContents(OpOperand *opOperand) const override;
+  /// Mark the given OpOperand as in-place and merge the results' and operand's
+  /// aliasing sets.
+  void bufferizeInPlace(OpOperand &operand);
 
-  /// Return true if the given tensor (or an aliasing tensor) is yielded from
-  /// the containing block. Also include all aliasing tensors in the same block.
-  bool isTensorYielded(Value tensor) const override;
+  /// Mark the given OpOperand as out-of-place.
+  void bufferizeOutOfPlace(OpOperand &operand);
+
+  /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
+  /// beginning the alias and equivalence sets only contain `v` itself.
+  void createAliasInfoEntry(Value v);
 
   /// Find all tensor values in the given operation that have undefined contents
   /// and store them in `undefinedTensorUses`.
@@ -188,6 +94,19 @@ public:
   /// `yieldedTensors`. Also include all aliasing tensors in the same block.
   void gatherYieldedTensors(Operation *op);
 
+  int64_t getStatNumTensorOutOfPlace() const { return statNumTensorOutOfPlace; }
+  int64_t getStatNumTensorInPlace() const { return statNumTensorInPlace; }
+
+  /// Return `true` if the given tensor has undefined contents.
+  bool hasUndefinedContents(OpOperand *opOperand) const override;
+
+  /// Return `true` if the given OpResult has been decided to bufferize inplace.
+  bool isInPlace(OpOperand &opOperand) const override;
+
+  /// Return true if the given tensor (or an aliasing tensor) is yielded from
+  /// the containing block. Also include all aliasing tensors in the same block.
+  bool isTensorYielded(Value tensor) const override;
+
   /// Return true if the buffer of the given tensor value is written to. Must
   /// not be called for values inside not yet analyzed functions.
   bool isValueWritten(Value value) const;
@@ -195,6 +114,12 @@ public:
   /// Return true if the buffer of the given tensor value is writable.
   bool isWritable(Value value) const;
 
+  /// Union the alias sets of `v1` and `v2`.
+  void unionAliasSets(Value v1, Value v2);
+
+  /// Union the equivalence classes of `v1` and `v2`.
+  void unionEquivalenceClasses(Value v1, Value v2);
+
   /// Base class for OneShotAnalysisState extensions that allow
   /// OneShotAnalysisState to contain user-specified information in the state
   /// object. Clients are expected to derive this class, add the desired fields,
@@ -279,9 +204,41 @@ public:
   }
 
 private:
-  /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
-  /// functions and `runOneShotBufferize` may access this object.
-  BufferizationAliasInfo aliasInfo;
+  /// llvm::EquivalenceClasses wants comparable elements. This comparator uses
+  /// pointer comparison on the defining op. This is a poor man's comparison
+  /// but it's not like UnionFind needs ordering anyway.
+  struct ValueComparator {
+    bool operator()(const Value &lhs, const Value &rhs) const {
+      return lhs.getImpl() < rhs.getImpl();
+    }
+  };
+
+  using EquivalenceClassRangeType = llvm::iterator_range<
+      llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
+  /// Check that aliasInfo for `v` exists and return a reference to it.
+  EquivalenceClassRangeType getAliases(Value v) const;
+
+  /// Set of all OpResults that were decided to bufferize in-place.
+  llvm::DenseSet<OpOperand *> inplaceBufferized;
+
+  /// Auxiliary structure to store all the values a given value may alias with.
+  /// Alias information is "may be" conservative: In the presence of branches, a
+  /// value may alias with one of multiple other values. The concrete aliasing
+  /// value may not even be known at compile time. All such values are
+  /// considered to be aliases.
+  llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
+
+  /// Auxiliary structure to store all the equivalent buffer classes. Equivalent
+  /// buffer information is "must be" conservative: Only if two values are
+  /// guaranteed to be equivalent at runtime, they said to be equivalent. It is
+  /// possible that, in the presence of branches, it cannot be determined
+  /// statically if two values are equivalent. In that case, the values are
+  /// considered to be not equivalent.
+  llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
+
+  // Bufferization statistics.
+  int64_t statNumTensorOutOfPlace = 0;
+  int64_t statNumTensorInPlace = 0;
 
   /// A set of all tensors (and maybe aliasing tensors) that yielded from a
   /// block.
index 3cd84c5..4231b85 100644 (file)
@@ -71,7 +71,7 @@ static bool isaTensor(Type t) { return t.isa<TensorType>(); }
 //===----------------------------------------------------------------------===//
 // Bufferization-specific attribute manipulation.
 // These are for testing and debugging only. Bufferization information is stored
-// in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR is
+// in OneShotBufferizationState. When run with `testAnalysisOnly`, the IR is
 // annotated with the results of the analysis, so that they can be checked in
 // tests.
 //===----------------------------------------------------------------------===//
@@ -98,11 +98,14 @@ static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
 }
 
 //===----------------------------------------------------------------------===//
-// BufferizationAliasInfo
+// OneShotAnalysisState
 //===----------------------------------------------------------------------===//
 
-BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
-  rootOp->walk([&](Operation *op) {
+OneShotAnalysisState::OneShotAnalysisState(
+    Operation *op, const OneShotBufferizationOptions &options)
+    : AnalysisState(options, TypeID::get<OneShotAnalysisState>()) {
+  // Set up alias sets.
+  op->walk([&](Operation *op) {
     for (Value v : op->getResults())
       if (v.getType().isa<TensorType>())
         createAliasInfoEntry(v);
@@ -112,55 +115,20 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
           if (bbArg.getType().isa<TensorType>())
             createAliasInfoEntry(bbArg);
   });
-}
-
-/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
-/// beginning the alias and equivalence sets only contain `v` itself.
-void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
-  aliasInfo.insert(v);
-  equivalentInfo.insert(v);
-}
-
-/// Insert an info entry for `newValue` and merge its alias set with that of
-/// `alias`.
-void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
-  createAliasInfoEntry(newValue);
-  aliasInfo.unionSets(newValue, alias);
-}
-
-/// Insert an info entry for `newValue` and merge its alias set with that of
-/// `alias`. Additionally, merge their equivalence classes.
-void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
-                                                        Value alias) {
-  insertNewBufferAlias(newValue, alias);
-  equivalentInfo.unionSets(newValue, alias);
-}
-
-/// Return `true` if a value was marked as in-place bufferized.
-bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
-  return inplaceBufferized.contains(&operand);
-}
-
-/// Set the inPlace bufferization spec to true.
-void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
-                                              AnalysisState &state) {
-  if (inplaceBufferized.contains(&operand))
-    return;
-  markInPlace(operand);
-  for (OpResult result : state.getAliasingOpResults(operand))
-    aliasInfo.unionSets(result, operand.get());
-  ++statNumTensorInPlace;
-}
 
-/// Set the inPlace bufferization spec to false.
-void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) {
-  assert(!inplaceBufferized.contains(&operand) &&
-         "OpOperand was already decided to bufferize inplace");
-  ++statNumTensorOutOfPlace;
+  // Mark OpOperands in-place that must bufferize in-place.
+  op->walk([&](BufferizableOpInterface bufferizableOp) {
+    if (!options.isOpAllowed(bufferizableOp))
+      return WalkResult::skip();
+    for (OpOperand &opOperand : bufferizableOp->getOpOperands())
+      if (opOperand.get().getType().isa<TensorType>())
+        if (bufferizableOp.mustBufferizeInPlace(opOperand, *this))
+          bufferizeInPlace(opOperand);
+    return WalkResult::advance();
+  });
 }
 
-/// Apply `fun` to all the members of the equivalence class of `v`.
-void BufferizationAliasInfo::applyOnEquivalenceClass(
+void OneShotAnalysisState::applyOnEquivalenceClass(
     Value v, function_ref<void(Value)> fun) const {
   auto leaderIt = equivalentInfo.findLeader(v);
   for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
@@ -169,66 +137,48 @@ void BufferizationAliasInfo::applyOnEquivalenceClass(
   }
 }
 
-/// Apply `fun` to all aliases of `v`.
-void BufferizationAliasInfo::applyOnAliases(
-    Value v, function_ref<void(Value)> fun) const {
+void OneShotAnalysisState::applyOnAliases(Value v,
+                                          function_ref<void(Value)> fun) const {
   auto leaderIt = aliasInfo.findLeader(v);
   for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
     fun(*mit);
   }
 }
 
-BufferizationAliasInfo::EquivalenceClassRangeType
-BufferizationAliasInfo::getAliases(Value v) const {
-  DenseSet<Value> res;
-  auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
-  for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
-       mit != meit; ++mit) {
-    res.insert(static_cast<Value>(*mit));
-  }
-  return BufferizationAliasInfo::EquivalenceClassRangeType(
-      aliasInfo.member_begin(it), aliasInfo.member_end());
+bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
+                                                         Value v2) const {
+  return equivalentInfo.isEquivalent(v1, v2);
 }
 
-//===----------------------------------------------------------------------===//
-// OneShotAnalysisState
-//===----------------------------------------------------------------------===//
-
-OneShotAnalysisState::OneShotAnalysisState(
-    Operation *op, const OneShotBufferizationOptions &options)
-    : AnalysisState(options, TypeID::get<OneShotAnalysisState>()),
-      aliasInfo(op) {
-  // Set up alias sets for OpResults that must bufferize in-place. This should
-  // be done before making any other bufferization decisions.
-  op->walk([&](BufferizableOpInterface bufferizableOp) {
-    if (!options.isOpAllowed(bufferizableOp))
-      return WalkResult::skip();
-    for (OpOperand &opOperand : bufferizableOp->getOpOperands())
-      if (opOperand.get().getType().isa<TensorType>())
-        if (bufferizableOp.mustBufferizeInPlace(opOperand, *this))
-          aliasInfo.bufferizeInPlace(opOperand, *this);
-    return WalkResult::advance();
-  });
+bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1,
+                                                       Value v2) const {
+  return aliasInfo.isEquivalent(v1, v2);
 }
 
-bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
-  return aliasInfo.isInPlace(opOperand);
+void OneShotAnalysisState::bufferizeInPlace(OpOperand &operand) {
+  if (inplaceBufferized.contains(&operand))
+    return;
+  inplaceBufferized.insert(&operand);
+  for (OpResult result : getAliasingOpResults(operand))
+    aliasInfo.unionSets(result, operand.get());
+  ++statNumTensorInPlace;
 }
 
-bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
-                                                         Value v2) const {
-  return aliasInfo.areEquivalentBufferizedValues(v1, v2);
+void OneShotAnalysisState::bufferizeOutOfPlace(OpOperand &operand) {
+  assert(!inplaceBufferized.contains(&operand) &&
+         "OpOperand was already decided to bufferize inplace");
+  ++statNumTensorOutOfPlace;
 }
 
-bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1,
-                                                       Value v2) const {
-  return aliasInfo.areAliasingBufferizedValues(v1, v2);
+void OneShotAnalysisState::createAliasInfoEntry(Value v) {
+  aliasInfo.insert(v);
+  equivalentInfo.insert(v);
 }
 
 // Gather yielded tensors in `yieldedTensors` by querying all aliases. This is
 // to ensure that such information is available during bufferization time.
-// Alias information can no longer be queried through BufferizationAliasInfo
-// once we have started modifying the IR.
+// Alias information can no longer be queried once we have started modifying
+// the IR.
 void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
   op->walk([&](Operation *returnOp) {
     if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp))
@@ -242,7 +192,7 @@ void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
 
       // Add all aliases of the returned value. But only the ones that are in
       // the same block.
-      aliasInfo.applyOnAliases(returnVal, [&](Value v) {
+      applyOnAliases(returnVal, [&](Value v) {
         if (auto bbArg = v.dyn_cast<BlockArgument>()) {
           if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp())
             yieldedTensors.insert(bbArg);
@@ -285,13 +235,17 @@ bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
   return undefinedTensorUses.contains(opOperand);
 }
 
+bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
+  return inplaceBufferized.contains(&opOperand);
+}
+
 bool OneShotAnalysisState::isTensorYielded(Value tensor) const {
   return yieldedTensors.contains(tensor);
 }
 
 bool OneShotAnalysisState::isValueWritten(Value value) const {
   bool isWritten = false;
-  aliasInfo.applyOnAliases(value, [&](Value val) {
+  applyOnAliases(value, [&](Value val) {
     for (OpOperand &use : val.getUses())
       if (isInPlace(use) && bufferizesToMemoryWrite(use))
         isWritten = true;
@@ -314,6 +268,14 @@ bool OneShotAnalysisState::isWritable(Value value) const {
   return false;
 }
 
+void OneShotAnalysisState::unionAliasSets(Value v1, Value v2) {
+  aliasInfo.unionSets(v1, v2);
+}
+
+void OneShotAnalysisState::unionEquivalenceClasses(Value v1, Value v2) {
+  equivalentInfo.unionSets(v1, v2);
+}
+
 OneShotAnalysisState::Extension::~Extension() = default;
 
 //===----------------------------------------------------------------------===//
@@ -322,13 +284,12 @@ OneShotAnalysisState::Extension::~Extension() = default;
 
 /// Return true if opOperand has been decided to bufferize in-place.
 static bool isInplaceMemoryWrite(OpOperand &opOperand,
-                                 const BufferizationAliasInfo &aliasInfo,
-                                 const AnalysisState &state) {
+                                 const OneShotAnalysisState &state) {
   // OpOperands that do not bufferize to a memory write do not write in-place.
   if (!state.bufferizesToMemoryWrite(opOperand))
     return false;
   // Check current bufferization decisions.
-  return aliasInfo.isInPlace(opOperand);
+  return state.isInPlace(opOperand);
 }
 
 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
@@ -489,10 +450,11 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
 /// A conflict is: According to SSA use-def chains, a read R is supposed to read
 /// the result of a definition W1. But because of bufferization decisions, R
 /// actually reads another definition W2.
-static bool hasReadAfterWriteInterference(
-    const DenseSet<OpOperand *> &usesRead,
-    const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
-    AnalysisState &state, const BufferizationAliasInfo &aliasInfo) {
+static bool
+hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
+                              const DenseSet<OpOperand *> &usesWrite,
+                              const DominanceInfo &domInfo,
+                              OneShotAnalysisState &state) {
   const BufferizationOptions &options = state.getOptions();
 
   for (OpOperand *uRead : usesRead) {
@@ -654,21 +616,19 @@ static bool hasReadAfterWriteInterference(
 
 // Helper function to iterate on aliases of `root` and capture the writes.
 static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root,
-                                     const BufferizationAliasInfo &aliasInfo,
-                                     const AnalysisState &state) {
-  aliasInfo.applyOnAliases(root, [&](Value alias) {
+                                     const OneShotAnalysisState &state) {
+  state.applyOnAliases(root, [&](Value alias) {
     for (auto &use : alias.getUses())
       // Inplace write to a value that aliases root.
-      if (isInplaceMemoryWrite(use, aliasInfo, state))
+      if (isInplaceMemoryWrite(use, state))
         res.insert(&use);
   });
 }
 
 // Helper function to iterate on aliases of `root` and capture the reads.
 static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
-                             const BufferizationAliasInfo &aliasInfo,
-                             const AnalysisState &state) {
-  aliasInfo.applyOnAliases(root, [&](Value alias) {
+                             const OneShotAnalysisState &state) {
+  state.applyOnAliases(root, [&](Value alias) {
     for (auto &use : alias.getUses()) {
       // Read of a value that aliases root.
       if (state.bufferizesToMemoryRead(use)) {
@@ -731,22 +691,20 @@ static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
 /// OpResult. In that case, only the consistency of bufferization decisions
 /// involving aliases of the given OpOperand are checked.
 static bool wouldCreateReadAfterWriteInterference(
-    OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state,
-    const BufferizationAliasInfo &aliasInfo,
-    bool checkConsistencyOnly = false) {
+    OpOperand &operand, const DominanceInfo &domInfo,
+    OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
   // Collect reads and writes of all aliases of OpOperand and OpResult.
   DenseSet<OpOperand *> usesRead, usesWrite;
-  getAliasingReads(usesRead, operand.get(), aliasInfo, state);
-  getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
+  getAliasingReads(usesRead, operand.get(), state);
+  getAliasingInplaceWrites(usesWrite, operand.get(), state);
   for (OpResult result : state.getAliasingOpResults(operand)) {
-    getAliasingReads(usesRead, result, aliasInfo, state);
-    getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
+    getAliasingReads(usesRead, result, state);
+    getAliasingInplaceWrites(usesWrite, result, state);
   }
   if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
     usesWrite.insert(&operand);
 
-  return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state,
-                                       aliasInfo);
+  return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state);
 }
 
 /// Annotate IR with details about the detected non-writability conflict.
@@ -773,7 +731,6 @@ static void annotateNonWritableTensor(Value value) {
 /// materialized in `aliasInfo` yet.
 static bool
 hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
-                                      const BufferizationAliasInfo &aliasInfo,
                                       const OneShotAnalysisState &state) {
   SmallVector<Value> worklist;
   worklist.push_back(value);
@@ -794,7 +751,7 @@ hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
     AliasingOpOperandList aliasingOpOperands =
         state.getAliasingOpOperands(opResult);
     for (OpOperand *opOperand : aliasingOpOperands)
-      if (aliasInfo.isInPlace(*opOperand) || currentOpOperand == opOperand)
+      if (state.isInPlace(*opOperand) || currentOpOperand == opOperand)
         worklist.push_back(opOperand->get());
   }
   return false;
@@ -802,14 +759,15 @@ hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
 
 /// Return true if bufferizing `operand` inplace would create a write to a
 /// non-writable buffer.
-static bool wouldCreateWriteToNonWritableBuffer(
-    OpOperand &operand, const BufferizationAliasInfo &aliasInfo,
-    OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
+static bool
+wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
+                                    OneShotAnalysisState &state,
+                                    bool checkConsistencyOnly = false) {
   // Collect writes of all aliases of OpOperand and OpResult.
   DenseSet<OpOperand *> usesWrite;
-  getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
+  getAliasingInplaceWrites(usesWrite, operand.get(), state);
   for (OpResult result : state.getAliasingOpResults(operand)) {
-    getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
+    getAliasingInplaceWrites(usesWrite, result, state);
   }
   if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
     usesWrite.insert(&operand);
@@ -818,8 +776,7 @@ static bool wouldCreateWriteToNonWritableBuffer(
   // alias), check if there is a non-writable tensor in the reverse SSA use-def
   // chain.
   for (OpOperand *uWrite : usesWrite) {
-    if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand,
-                                              aliasInfo, state)) {
+    if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand, state)) {
       LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
       return true;
     }
@@ -833,22 +790,22 @@ static bool wouldCreateWriteToNonWritableBuffer(
 //===----------------------------------------------------------------------===//
 
 /// Determine if `operand` can be bufferized in-place.
-static LogicalResult bufferizableInPlaceAnalysisImpl(
-    OpOperand &operand, BufferizationAliasInfo &aliasInfo,
-    OneShotAnalysisState &state, const DominanceInfo &domInfo) {
+static LogicalResult
+bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,
+                                const DominanceInfo &domInfo) {
   LLVM_DEBUG(
       llvm::dbgs() << "//===-------------------------------------------===//\n"
                    << "Analyzing operand #" << operand.getOperandNumber()
                    << " of " << *operand.getOwner() << "\n");
 
   bool foundInterference =
-      wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) ||
-      wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo);
+      wouldCreateWriteToNonWritableBuffer(operand, state) ||
+      wouldCreateReadAfterWriteInterference(operand, domInfo, state);
 
   if (foundInterference)
-    aliasInfo.bufferizeOutOfPlace(operand);
+    state.bufferizeOutOfPlace(operand);
   else
-    aliasInfo.bufferizeInPlace(operand, state);
+    state.bufferizeInPlace(operand);
 
   LLVM_DEBUG(llvm::dbgs()
              << "//===-------------------------------------------===//\n");
@@ -874,7 +831,6 @@ static LogicalResult bufferizableInPlaceAnalysisImpl(
 /// An analysis is required to ensure inplace bufferization would not result in
 /// RaW dependence violations.
 static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
-                                     BufferizationAliasInfo &aliasInfo,
                                      OneShotAnalysisState &state,
                                      const DominanceInfo &domInfo,
                                      unsigned analysisFuzzerSeed = 0) {
@@ -890,8 +846,7 @@ static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
   auto analyzeOp = [&](Operation *op) {
     for (OpOperand &opOperand : op->getOpOperands())
       if (opOperand.get().getType().isa<TensorType>())
-        if (failed(bufferizableInPlaceAnalysisImpl(opOperand, aliasInfo, state,
-                                                   domInfo)))
+        if (failed(bufferizableInPlaceAnalysisImpl(opOperand, state, domInfo)))
           return failure();
     return success();
   };
@@ -924,7 +879,6 @@ static bool hasTensorSemantics(Operation *op) {
 
 /// Analyze all ops that are contained in `op`.
 static LogicalResult inPlaceAnalysis(Operation *op,
-                                     BufferizationAliasInfo &aliasInfo,
                                      OneShotAnalysisState &state,
                                      const DominanceInfo &domInfo,
                                      unsigned analysisFuzzerSeed = 0) {
@@ -937,13 +891,12 @@ static LogicalResult inPlaceAnalysis(Operation *op,
     ops.push_back(op);
   });
 
-  return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed);
+  return inPlaceAnalysis(ops, state, domInfo, analysisFuzzerSeed);
 }
 
 /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
 static void equivalenceAnalysis(SmallVector<Operation *> &ops,
-                                BufferizationAliasInfo &aliasInfo,
-                                AnalysisState &state) {
+                                OneShotAnalysisState &state) {
   for (Operation *op : ops)
     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
       for (OpResult opResult : op->getOpResults())
@@ -953,14 +906,12 @@ static void equivalenceAnalysis(SmallVector<Operation *> &ops,
             if (state.isInPlace(*opOperand))
               if (bufferizableOp.bufferRelation(opResult, state) ==
                   BufferRelation::Equivalent)
-                aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
+                state.unionEquivalenceClasses(opResult, opOperand->get());
 }
 
 /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
 /// in `op`.
-static void equivalenceAnalysis(Operation *op,
-                                BufferizationAliasInfo &aliasInfo,
-                                AnalysisState &state) {
+static void equivalenceAnalysis(Operation *op, OneShotAnalysisState &state) {
   // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
   SmallVector<Operation *> ops;
   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
@@ -970,14 +921,13 @@ static void equivalenceAnalysis(Operation *op,
     ops.push_back(op);
   });
 
-  equivalenceAnalysis(ops, aliasInfo, state);
+  equivalenceAnalysis(ops, state);
 }
 
 /// Assert that the current bufferization decisions are consistent.
-static LogicalResult
-checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
-                          AnalysisState &state,
-                          const BufferizationAliasInfo &aliasInfo) {
+static LogicalResult checkAliasInfoConsistency(Operation *op,
+                                               const DominanceInfo &domInfo,
+                                               OneShotAnalysisState &state) {
   const BufferizationOptions &options = state.getOptions();
 
   WalkResult walkResult = op->walk([&](BufferizableOpInterface op) {
@@ -995,7 +945,7 @@ checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
     for (OpOperand &opOperand : op->getOpOperands()) {
       if (opOperand.get().getType().isa<TensorType>()) {
         if (wouldCreateReadAfterWriteInterference(
-                opOperand, domInfo, state, aliasInfo,
+                opOperand, domInfo, state,
                 /*checkConsistencyOnly=*/true)) {
           // This error can happen if certain "mustBufferizeInPlace" interface
           // methods are implemented incorrectly, such that the IR already has
@@ -1015,13 +965,12 @@ checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
 /// Annotate the IR with the result of the analysis. For testing/debugging only.
 static void
 annotateOpsWithBufferizationMarkers(Operation *op,
-                                    const BufferizationAliasInfo &aliasInfo,
-                                    const BufferizationOptions &options) {
+                                    const OneShotAnalysisState &state) {
   // Add __inplace_operands_attr__.
   op->walk([&](Operation *op) {
     for (OpOperand &opOperand : op->getOpOperands())
       if (opOperand.get().getType().isa<TensorType>())
-        setInPlaceOpOperand(opOperand, aliasInfo.isInPlace(opOperand));
+        setInPlaceOpOperand(opOperand, state.isInPlace(opOperand));
   });
 }
 
@@ -1056,12 +1005,12 @@ annotateOpsWithBufferizationMarkers(Operation *op,
 // TODO: Remove buffer deallocation from One-Shot Bufferize and fix the buffer
 // deallocation pass.
 static LogicalResult assertNoAllocsReturned(Operation *op,
-                                            const BufferizationOptions &options,
-                                            BufferizationAliasInfo &aliasInfo) {
+                                            const OneShotAnalysisState &state) {
   LogicalResult status = success();
   DominanceInfo domInfo(op);
   op->walk([&](Operation *returnOp) {
-    if (!isRegionReturnLike(returnOp) || !options.isOpAllowed(returnOp))
+    if (!isRegionReturnLike(returnOp) ||
+        !state.getOptions().isOpAllowed(returnOp))
       return WalkResult::advance();
 
     for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
@@ -1071,7 +1020,7 @@ static LogicalResult assertNoAllocsReturned(Operation *op,
         continue;
 
       bool foundEquivValue = false;
-      aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
+      state.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
         if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
           Operation *definingOp = bbArg.getOwner()->getParentOp();
           if (definingOp->isProperAncestor(returnOp))
@@ -1105,27 +1054,25 @@ LogicalResult bufferization::analyzeOp(Operation *op,
                                        OneShotAnalysisState &state,
                                        BufferizationStatistics *statistics) {
   DominanceInfo domInfo(op);
-  BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
   const OneShotBufferizationOptions &options = state.getOptions();
 
-  if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
+  if (failed(checkAliasInfoConsistency(op, domInfo, state)))
     return failure();
 
   // If the analysis fails, just return.
-  if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
-                             options.analysisFuzzerSeed)))
+  if (failed(inPlaceAnalysis(op, state, domInfo, options.analysisFuzzerSeed)))
     return failure();
 
   if (statistics) {
-    statistics->numTensorInPlace = aliasInfo.getStatNumTensorInPlace();
-    statistics->numTensorOutOfPlace = aliasInfo.getStatNumTensorOutOfPlace();
+    statistics->numTensorInPlace = state.getStatNumTensorInPlace();
+    statistics->numTensorOutOfPlace = state.getStatNumTensorOutOfPlace();
   }
 
-  equivalenceAnalysis(op, aliasInfo, state);
+  equivalenceAnalysis(op, state);
 
   bool failedAnalysis = false;
   if (!options.allowReturnAllocs)
-    failedAnalysis |= failed(assertNoAllocsReturned(op, options, aliasInfo));
+    failedAnalysis |= failed(assertNoAllocsReturned(op, state));
 
   // Gather some extra analysis data.
   state.gatherYieldedTensors(op);
@@ -1142,7 +1089,7 @@ LogicalResult bufferization::analyzeOp(Operation *op,
 
   // Annotate operations if we only want to report the analysis.
   if (options.testAnalysisOnly)
-    annotateOpsWithBufferizationMarkers(op, aliasInfo, options);
+    annotateOpsWithBufferizationMarkers(op, state);
 
   return success(!failedAnalysis);
 }
index 943efe8..9562ac5 100644 (file)
@@ -250,7 +250,6 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) {
 /// analyzed.
 // TODO: This does not handle cyclic function call graphs etc.
 static void equivalenceAnalysis(func::FuncOp funcOp,
-                                BufferizationAliasInfo &aliasInfo,
                                 OneShotAnalysisState &state,
                                 FuncAnalysisState &funcState) {
   funcOp->walk([&](func::CallOp callOp) {
@@ -268,7 +267,7 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
         continue;
       Value returnVal = callOp.getResult(returnIdx);
       Value argVal = callOp->getOperand(bbargIdx);
-      aliasInfo.unionEquivalenceClasses(returnVal, argVal);
+      state.unionEquivalenceClasses(returnVal, argVal);
     }
 
     return WalkResult::advance();
@@ -365,7 +364,6 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
   assert(state.getOptions().bufferizeFunctionBoundaries &&
          "expected that function boundary bufferization is activated");
   FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
-  BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
 
   // A list of functions in the order in which they are analyzed + bufferized.
   SmallVector<func::FuncOp> orderedFuncOps;
@@ -385,7 +383,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
     funcState.startFunctionAnalysis(funcOp);
 
     // Gather equivalence info for CallOps.
-    equivalenceAnalysis(funcOp, aliasInfo, state, funcState);
+    equivalenceAnalysis(funcOp, state, funcState);
 
     // Analyze funcOp.
     if (failed(analyzeOp(funcOp, state, statistics)))