[mlir][bufferize][NFC] Optimize read-only tensor detection
authorMatthias Springer <springerm@google.com>
Thu, 9 Feb 2023 08:00:52 +0000 (09:00 +0100)
committerMatthias Springer <springerm@google.com>
Thu, 9 Feb 2023 08:07:14 +0000 (09:07 +0100)
Check alias sets instead of traversing the IR.

Differential Revision: https://reviews.llvm.org/D143500

mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

index 6420feb..8a7d660 100644 (file)
@@ -724,62 +724,42 @@ static void annotateNonWritableTensor(Value value) {
   }
 }
 
-/// Check the reverse SSA use-def chain (following aliasing OpOperands) for
-/// non-writable tensor values. Stop searching when an out-of-place bufferized
-/// OpOperand was found (or when the OpOperand was not bufferized yet).
-/// `currentOpOperand` is assumed to be in-place, even if that decision was not
-/// materialized in `aliasInfo` yet.
-static bool
-hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
-                                      const OneShotAnalysisState &state) {
-  SmallVector<Value> worklist;
-  worklist.push_back(value);
-  while (!worklist.empty()) {
-    Value nextVal = worklist.pop_back_val();
-    if (!state.isWritable(nextVal)) {
-      if (state.getOptions().printConflicts)
-        annotateNonWritableTensor(nextVal);
-      return true;
-    }
-
-    // If `nextVal` is not a BlockArgument: End of use-def chain reached.
-    auto opResult = nextVal.dyn_cast<OpResult>();
-    if (!opResult)
-      continue;
-
-    // Follow reverse SSA use-def chain.
-    AliasingOpOperandList aliasingOpOperands =
-        state.getAliasingOpOperands(opResult);
-    for (OpOperand *opOperand : aliasingOpOperands)
-      if (state.isInPlace(*opOperand) || currentOpOperand == opOperand)
-        worklist.push_back(opOperand->get());
-  }
-  return false;
-}
-
 /// Return true if bufferizing `operand` inplace would create a write to a
 /// non-writable buffer.
 static bool
 wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
                                     OneShotAnalysisState &state,
                                     bool checkConsistencyOnly = false) {
-  // Collect writes of all aliases of OpOperand and OpResult.
-  DenseSet<OpOperand *> usesWrite;
-  getAliasingInplaceWrites(usesWrite, operand.get(), state);
-  for (OpResult result : state.getAliasingOpResults(operand)) {
-    getAliasingInplaceWrites(usesWrite, result, state);
+  bool foundWrite =
+      !checkConsistencyOnly && state.bufferizesToMemoryWrite(operand);
+
+  if (!foundWrite) {
+    // Collect writes of all aliases of OpOperand and OpResult.
+    DenseSet<OpOperand *> usesWrite;
+    getAliasingInplaceWrites(usesWrite, operand.get(), state);
+    for (OpResult result : state.getAliasingOpResults(operand))
+      getAliasingInplaceWrites(usesWrite, result, state);
+    foundWrite = !usesWrite.empty();
   }
-  if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
-    usesWrite.insert(&operand);
 
-  // Assuming that `operand` bufferizes in-place: For each write (to each
-  // alias), check if there is a non-writable tensor in the reverse SSA use-def
-  // chain.
-  for (OpOperand *uWrite : usesWrite) {
-    if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand, state)) {
-      LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
-      return true;
+  if (!foundWrite)
+    return false;
+
+  // Look for a read-only tensor among all aliases.
+  bool foundReadOnly = false;
+  auto checkReadOnly = [&](Value v) {
+    if (!state.isWritable(v)) {
+      foundReadOnly = true;
+      if (state.getOptions().printConflicts)
+        annotateNonWritableTensor(v);
     }
+  };
+  state.applyOnAliases(operand.get(), checkReadOnly);
+  for (OpResult result : state.getAliasingOpResults(operand))
+    state.applyOnAliases(result, checkReadOnly);
+  if (foundReadOnly) {
+    LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
+    return true;
   }
 
   return false;