From 416679615d8349a4cf17e57f7bea1f8111d699e5 Mon Sep 17 00:00:00 2001 From: thomasraoux Date: Fri, 17 Sep 2021 10:06:57 -0700 Subject: [PATCH] [mlir] Linalg hoisting should ignore uses outside the loop Differential Revision: https://reviews.llvm.org/D109859 --- mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp | 2 +- mlir/test/Dialect/Linalg/hoisting.mlir | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 2c487f0..4229887 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -457,7 +457,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) return WalkResult::advance(); for (auto &use : transferRead.source().getUses()) { - if (!dom.properlyDominates(loop, use.getOwner())) + if (!loop->isAncestor(use.getOwner())) continue; if (use.getOwner() == transferRead.getOperation() || use.getOwner() == transferWrite.getOperation()) diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir index 9c03c46..959f254 100644 --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -39,9 +39,11 @@ func @hoist_vector_transfer_pairs( // CHECK: scf.yield {{.*}} : vector<1xf32>, vector<2xf32> // CHECK: } // CHECK: vector.transfer_write %{{.*}} : vector<2xf32>, memref +// CHECK: "unrelated_use"(%[[MEMREF0]]) : (memref) -> () // CHECK: scf.yield {{.*}} : vector<1xf32> // CHECK: } // CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref +// CHECK: "unrelated_use"(%[[MEMREF1]]) : (memref) -> () scf.for %i = %lb to %ub step %step { scf.for %j = %lb to %ub step %step { %r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref, vector<1xf32> @@ -66,7 +68,9 @@ func @hoist_vector_transfer_pairs( vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref "some_crippling_use"(%memref3) : (memref) -> () } + "unrelated_use"(%memref0) : (memref) -> () } + "unrelated_use"(%memref1) : (memref) -> () return } -- 2.7.4