From 9235e597a40b423a298ce415eb922462e7f0b765 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 20 Apr 2022 18:43:49 +0900 Subject: [PATCH] [mlir][bufferize] Fix missing copies when writing to a buffer in a loop Writes into tensors that are definied outside of a repetitive region, but with the write happening inside of the repetitive region were previously not considered conflicts. This was incorrect. E.g.: ``` %0 = ... : tensor scf.for ... { "reading_op"(%0) : tensor %1 = "writing_op"(%0) : tensor -> tensor ... } ``` In the above example, "writing_op" should be out-of-place. This commit fixes the bufferization for any op that declares its repetitive semantics via RegionBranchOpInterface. --- .../Bufferization/Transforms/OneShotAnalysis.cpp | 83 ++++++++++++++++++++-- .../comprehensive-module-bufferize-analysis.mlir | 55 ++++++++++++++ 2 files changed, 133 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index d88f36d..8ae5c1c 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -325,6 +325,20 @@ static bool happensBefore(Operation *a, Operation *b, return false; } +/// For each given value, find the closest enclosing repetitive region. If this +/// is the same region for each value, return it. Otherwise return None. +/// Note: If there is no enclosing repetitive region, return nullptr. +static Optional +getCommonEnclosingRepetitiveRegion(ArrayRef values) { + if (values.empty()) + return None; + Region *r = getEnclosingRepetitiveRegion(values.front()); + for (Value value : values.drop_front()) + if (getEnclosingRepetitiveRegion(value) != r) + return None; + return r; +} + /// Annotate IR with details about the detected RaW conflict. static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, Value lastWrite) { @@ -371,6 +385,15 @@ static bool hasReadAfterWriteInterference( AnalysisState &state, const BufferizationAliasInfo &aliasInfo) { const BufferizationOptions &options = state.getOptions(); + // Gather all written aliases. + SmallVector writtenAliases; + for (OpOperand *uWrite : usesWrite) + writtenAliases.push_back(uWrite->get()); + // Find the inner-most enclosing repetitive region of each alias. If this is + // the same region for every alias, save it in `repetitiveRegionOfWrites`. + Optional repetitiveRegionOfWrites = + getCommonEnclosingRepetitiveRegion(writtenAliases); + for (OpOperand *uRead : usesRead) { Operation *readingOp = uRead->getOwner(); @@ -393,15 +416,60 @@ static bool hasReadAfterWriteInterference( // met for uConflictingWrite to be an actual conflict. Operation *conflictingWritingOp = uConflictingWrite->getOwner(); + // Check if conflictingWritingOp is in the same repetitive region as all + // written aliases. If this is not the case, there is no meaningful + // `happensBefore` relationship because conflictingWritingOp may be + // executed multiple times. E.g.: + // + // %0 = ... : tensor + // scf.for ... { + // "reading_op"(%0) : tensor + // %1 = "writing_op"(%0) : tensor -> tensor + // ... + // } + // + // In the above example, reading_op happens before writing_op according to + // op dominance. However, both ops may happen multiple times; in + // particular, the second execution of reading_op happens after the first + // execution of writing_op. This is problematic if the tensor they operate + // on (%0) is defined outside of the loop. + // + // Counter example: + // + // scf.for ... { + // %0 = ... : tensor + // "reading_op"(%0) : tensor + // %1 = "writing_op"(%0) : tensor -> tensor + // ... + // } + // + // In this example, %0 is in the same repetitive region as + // conflictingWritingOp, so op dominance can be used to compute the + // `happensBefore` relationship. + // + // Note: iter_args of loops are not aliases of their respective block + // arguments, so op domanice can be used when analyzing ops that operate + // on them. + bool canUseOpDominance = + repetitiveRegionOfWrites == + getEnclosingRepetitiveRegion(conflictingWritingOp); + // No conflict if the readingOp dominates conflictingWritingOp, i.e., the // write is not visible when reading. - if (happensBefore(readingOp, conflictingWritingOp, domInfo)) + // + // Note: If ops are executed multiple times (e.g., because they are inside + // a loop), there may be no meaningful `happensBefore` relationship. + if (canUseOpDominance && + happensBefore(readingOp, conflictingWritingOp, domInfo)) continue; // No conflict if the reading use equals the use of the conflicting write. - // A use cannot conflict with itself. Note: Just being the same op is not - // enough. It has to be the same use. - if (uConflictingWrite == uRead) + // A use cannot conflict with itself. + // + // Note: Just being the same op is not enough. It has to be the same use. + // Note: If the op is executed multiple times (e.g., because it is inside + // a loop), it may be conflicting with itself. + if (canUseOpDominance && uConflictingWrite == uRead) continue; // No conflict if the op interface says so. @@ -416,7 +484,12 @@ static bool hasReadAfterWriteInterference( continue; // Ops are not conflicting if they are in mutually exclusive regions. - if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) + // + // Note: If ops are executed multiple times (e.g., because they are inside + // a loop), mutually exclusive regions may be executed multiple + // times. + if (canUseOpDominance && + insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) continue; // Check all possible last writes. diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir index b12eee1..ea088e9 100644 --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -1786,3 +1786,58 @@ func @write_after_select_no_conflict( return %f, %w : f32, tensor } + +// ----- + +// CHECK-LABEL: func @write_to_same_tensor_in_loop_out_of_place( +func @write_to_same_tensor_in_loop_out_of_place( + %A : tensor {linalg.inplaceable = true}, + %B : tensor {linalg.inplaceable = true}, + %lb : index, %ub : index, %step : index, %sz: index) + -> (tensor) +{ + // CHECK: scf.for {{.*}} { + %r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor) { + %i2 = arith.index_cast %i : index to i32 + %i3 = arith.sitofp %i2 : i32 to f32 + // The tensor.insert is out-of-place because the %B is written multiple + // times inside a loop. + // CHECK: tensor.insert + // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "none"]} + %B2 = tensor.insert %i3 into %B[%i] : tensor + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none", "none"]} + %A2 = tensor.insert_slice %B2 into %t[%i][%sz][1] : tensor into tensor + scf.yield %A2 : tensor + } + // CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "true"]} + + return %r0 : tensor +} + +// ----- + +// CHECK-LABEL: func @write_to_same_tensor_in_loop_in_place( +func @write_to_same_tensor_in_loop_in_place( + %A : tensor {linalg.inplaceable = true}, + %lb : index, %ub : index, %step : index, %sz: index) + -> (tensor) +{ + // CHECK: scf.for {{.*}} { + %r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor) { + %B = linalg.init_tensor [%sz] : tensor + %i2 = arith.index_cast %i : index to i32 + %i3 = arith.sitofp %i2 : i32 to f32 + // The tensor.insert is in-place because the %B is defined inside the loop. + // CHECK: tensor.insert + // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "none"]} + %B2 = tensor.insert %i3 into %B[%i] : tensor + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none", "none"]} + %A2 = tensor.insert_slice %B2 into %t[%i][%sz][1] : tensor into tensor + scf.yield %A2 : tensor + } + // CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "true"]} + + return %r0 : tensor +} -- 2.7.4