buffer-deallocation: consider aliases introduced by arith.select.
authorJohannes Reifferscheid <jreiffers@google.com>
Tue, 23 Aug 2022 12:21:40 +0000 (14:21 +0200)
committerJohannes Reifferscheid <jreiffers@google.com>
Tue, 23 Aug 2022 12:37:02 +0000 (14:37 +0200)
Currently, buffer deallocation considers arith.select to be
non-aliasing, which results in deallocs being inserted incorrectly. Since
arith.select doesn't implement any useful interfaces, this change just handles
it explicitly. Eventually this should probably be fixed properly, if this pass
is going to be used long term.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D132460

mlir/lib/Analysis/BufferViewFlowAnalysis.cpp
mlir/lib/Analysis/CMakeLists.txt
mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 5b2b31d..80f538d 100644 (file)
@@ -8,6 +8,7 @@
 
 #include "mlir/Analysis/BufferViewFlowAnalysis.h"
 
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/SetOperations.h"
@@ -51,9 +52,9 @@ void BufferViewFlowAnalysis::remove(const SmallPtrSetImpl<Value> &aliasValues) {
 /// successor regions and branch-like return operations from nested regions.
 void BufferViewFlowAnalysis::build(Operation *op) {
   // Registers all dependencies of the given values.
-  auto registerDependencies = [&](auto values, auto dependencies) {
-    for (auto entry : llvm::zip(values, dependencies))
-      this->dependencies[std::get<0>(entry)].insert(std::get<1>(entry));
+  auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
+    for (auto [value, dep] : llvm::zip(values, dependencies))
+      this->dependencies[value].insert(dep);
   };
 
   // Add additional dependencies created by view changes to the alias list.
@@ -119,4 +120,10 @@ void BufferViewFlowAnalysis::build(Operation *op) {
       }
     }
   });
+
+  // TODO: This should be an interface.
+  op->walk([&](arith::SelectOp selectOp) {
+    registerDependencies({selectOp.getOperand(1)}, {selectOp.getResult()});
+    registerDependencies({selectOp.getOperand(2)}, {selectOp.getResult()});
+  });
 }
index 4ead4c7..701584c 100644 (file)
@@ -40,6 +40,7 @@ add_mlir_library(MLIRAnalysis
   mlir-headers
 
   LINK_LIBS PUBLIC
+  MLIRArithmeticDialect
   MLIRCallInterfaces
   MLIRControlFlowInterfaces
   MLIRDataLayoutInterfaces
index 240cc2a..61493d9 100644 (file)
@@ -1298,3 +1298,19 @@ func.func @while_three_arg(%arg0: index) {
 // CHECK-NEXT: return
   return
 }
+
+// -----
+
+func.func @select_aliases(%arg0: index, %arg1: memref<?xi8>, %arg2: i1) {
+  // CHECK: memref.alloc
+  // CHECK: memref.alloc
+  // CHECK: arith.select
+  // CHECK: test.copy
+  // CHECK: memref.dealloc
+  // CHECK: memref.dealloc
+  %0 = memref.alloc(%arg0) : memref<?xi8>
+  %1 = memref.alloc(%arg0) : memref<?xi8>
+  %2 = arith.select %arg2, %0, %1 : memref<?xi8>
+  test.copy(%2, %arg1) : (memref<?xi8>, memref<?xi8>)
+  return
+}
index 7d45b2e..cefa2a9 100644 (file)
@@ -5879,6 +5879,7 @@ cc_library(
     ),
     includes = ["include"],
     deps = [
+        ":ArithmeticDialect",
         ":CallOpInterfaces",
         ":ControlFlowInterfaces",
         ":DataLayoutInterfaces",