From: Matthias Springer Date: Fri, 6 May 2022 08:22:44 +0000 (+0900) Subject: [mlir][scf] Implement BufferizableOpInterface for scf::WhileOp X-Git-Tag: upstream/15.0.7~8472 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a5d09c637261252393a015e7858efd85c9166e32;p=platform%2Fupstream%2Fllvm.git [mlir][scf] Implement BufferizableOpInterface for scf::WhileOp This follows the same implementation strategy as scf::ForOp and common functionality is extracted into helper functions. This implementation works well in cases where each yielded value (from either body/condition region) is equivalent to the corresponding bbArg of the parent block. In that case, each OpResult of the loop may be aliasing with the corresponding OpOperand of the loop (and with no other OpOperand). In the absence of said equivalence relationship, new buffer copies must be inserted, so that the aliasing OpOperand/OpResult contract of scf::WhileOp is honored. In essence, by yielding a newly allocated buffer, we can enforce the specified may-alias relationship. (Newly allocated buffers cannot alias with any OpOperands of the loop.) Differential Revision: https://reviews.llvm.org/D124929 --- diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index b6d2001..39af7d3 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -271,7 +271,7 @@ static DenseSet getTensorIndices(ValueRange values) { /// Helper function for loop bufferization. Return the indices of all /// bbArg/yielded value pairs who's buffer relation is "Equivalent". -DenseSet getEquivalentBuffers(ValueRange bbArgs, +DenseSet getEquivalentBuffers(Block::BlockArgListType bbArgs, ValueRange yieldedValues, const AnalysisState &state) { DenseSet result; @@ -403,6 +403,18 @@ SmallVector getYieldedValues(RewriterBase &rewriter, ValueRange values, }); } +/// Helper function for loop bufferization. Given a list of bbArgs of the new +/// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into +/// ToTensorOps, so that the block body can be moved over to the new op. +SmallVector +getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, + const DenseSet &tensorIndices) { + return convertTensorValues( + bbArgs, tensorIndices, [&](Value val, int64_t index) { + return rewriter.create(val.getLoc(), val); + }); +} + /// Bufferization of scf.for. Replace with a new scf.for that operates on /// memrefs. struct ForOpInterface @@ -486,10 +498,8 @@ struct ForOpInterface // Set up new iter_args. The loop body uses tensors, so wrap the (memref) // iter_args of the new loop in ToTensorOps. rewriter.setInsertionPointToStart(loopBody); - SmallVector iterArgs = convertTensorValues( - newForOp.getRegionIterArgs(), indices, [&](Value val, int64_t index) { - return rewriter.create(val.getLoc(), val); - }); + SmallVector iterArgs = + getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices); iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); // Erase terminator if present. @@ -546,6 +556,187 @@ struct ForOpInterface } }; +/// Bufferization of scf.while. Replace with a new scf.while that operates on +/// memrefs. +struct WhileOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Tensor iter_args of scf::WhileOps are always considered as a read. + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Tensor iter_args of scf::WhileOps are always considered as a write. + return true; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + auto whileOp = cast(op); + return {whileOp->getResult(opOperand.getOperandNumber())}; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + // WhileOp results are equivalent to their corresponding init_args if the + // corresponding iter_args and yield values are equivalent (for both the + // "before" and the "after" block). + unsigned int resultNumber = opResult.getResultNumber(); + auto whileOp = cast(op); + + auto conditionOp = whileOp.getConditionOp(); + BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; + Value conditionOperand = conditionOp.getArgs()[resultNumber]; + bool equivCondition = + state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand); + + auto yieldOp = whileOp.getYieldOp(); + BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; + Value yieldOperand = yieldOp.getOperand(resultNumber); + bool equivYield = + state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand); + + return equivCondition && equivYield ? BufferRelation::Equivalent + : BufferRelation::None; + } + + bool isWritable(Operation *op, Value value, + const AnalysisState &state) const { + // Interestingly, scf::WhileOp's bbArg can **always** be viewed + // inplace from the perspective of ops nested under: + // 1. Either the matching iter operand is not bufferized inplace and an + // alloc + optional copy makes the bbArg itself inplaceable. + // 2. Or the matching iter operand is bufferized inplace and bbArg just + // bufferizes to that too. + return true; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + BufferizationState &state) const { + auto whileOp = cast(op); + + assert(whileOp.getBefore().getBlocks().size() == 1 && + "regions with multiple blocks not supported"); + Block *beforeBody = &whileOp.getBefore().front(); + assert(whileOp.getAfter().getBlocks().size() == 1 && + "regions with multiple blocks not supported"); + Block *afterBody = &whileOp.getAfter().front(); + + // Indices of all iter_args that have tensor type. These are the ones that + // are bufferized. + DenseSet indices = getTensorIndices(whileOp.getInits()); + // For every yielded value, is the value equivalent to its corresponding + // bbArg? + DenseSet equivalentYieldsBefore = getEquivalentBuffers( + whileOp.getBeforeArguments(), whileOp.getConditionOp().getArgs(), + state.getAnalysisState()); + DenseSet equivalentYieldsAfter = getEquivalentBuffers( + whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), + state.getAnalysisState()); + + // The new memref init_args of the loop. + SmallVector initArgs = + getBuffers(rewriter, whileOp->getOpOperands(), state); + if (initArgs.size() != indices.size()) + return failure(); + + // Construct a new scf.while op with memref instead of tensor values. + ValueRange argsRange(initArgs); + TypeRange argsTypes(argsRange); + auto newWhileOp = + rewriter.create(whileOp.getLoc(), argsTypes, initArgs); + // Add before/after regions to the new op. + SmallVector bbArgLocs(initArgs.size(), whileOp.getLoc()); + Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); + newWhileOp.getBefore().addArguments(argsTypes, bbArgLocs); + Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); + newWhileOp.getAfter().addArguments(argsTypes, bbArgLocs); + + // Set up new iter_args and move the loop condition block to the new op. + // The old block uses tensors, so wrap the (memref) bbArgs of the new block + // in ToTensorOps. + rewriter.setInsertionPointToStart(newBeforeBody); + SmallVector newBeforeArgs = getBbArgReplacements( + rewriter, newWhileOp.getBeforeArguments(), indices); + rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs); + + // Update scf.condition of new loop. + auto newConditionOp = newWhileOp.getConditionOp(); + rewriter.setInsertionPoint(newConditionOp); + SmallVector newConditionArgs = + getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypes, indices, + equivalentYieldsBefore, state); + newConditionOp.getArgsMutable().assign(newConditionArgs); + + // Set up new iter_args and move the loop body block to the new op. + // The old block uses tensors, so wrap the (memref) bbArgs of the new block + // in ToTensorOps. + rewriter.setInsertionPointToStart(newAfterBody); + SmallVector newAfterArgs = + getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(), indices); + rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs); + + // Update scf.yield of the new loop. + auto newYieldOp = newWhileOp.getYieldOp(); + rewriter.setInsertionPoint(newYieldOp); + SmallVector newYieldValues = + getYieldedValues(rewriter, newYieldOp.getResults(), argsTypes, indices, + equivalentYieldsAfter, state); + newYieldOp.getResultsMutable().assign(newYieldValues); + + // Replace loop results. + replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); + + return success(); + } + + /// Assert that yielded values of an scf.while op are equivalent to their + /// corresponding bbArgs. In that case, the buffer relations of the + /// corresponding OpResults are "Equivalent". + /// + /// If this is not the case, allocs+copies are inserted and yielded from + /// the loop. This could be a performance problem, so it must be explicitly + /// activated with `alloc-return-allocs`. + /// + /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the + /// equivalence condition must be checked for both. + LogicalResult verifyAnalysis(Operation *op, + const AnalysisState &state) const { + auto whileOp = cast(op); + const auto &options = + static_cast(state.getOptions()); + if (options.allowReturnAllocs) + return success(); + + auto conditionOp = whileOp.getConditionOp(); + for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { + if (!it.value().getType().isa()) + continue; + if (!state.areEquivalentBufferizedValues( + it.value(), conditionOp->getBlock()->getArgument(it.index()))) + return conditionOp->emitError() + << "Condition arg #" << it.index() + << " is not equivalent to the corresponding iter bbArg"; + } + + auto yieldOp = whileOp.getYieldOp(); + for (const auto &it : llvm::enumerate(yieldOp.getResults())) { + if (!it.value().getType().isa()) + continue; + if (!state.areEquivalentBufferizedValues( + it.value(), yieldOp->getBlock()->getArgument(it.index()))) + return yieldOp->emitError() + << "Yield operand #" << it.index() + << " is not equivalent to the corresponding iter bbArg"; + } + + return success(); + } +}; + /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so /// this is for analysis only. struct YieldOpInterface @@ -581,7 +772,7 @@ struct YieldOpInterface LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto yieldOp = cast(op); - if (!isa( + if (!isa( yieldOp->getParentOp())) return yieldOp->emitError("unsupported scf::YieldOp parent"); return success(); @@ -598,6 +789,7 @@ void mlir::scf::registerBufferizableOpInterfaceExternalModels( ExecuteRegionOp::attachInterface(*ctx); ForOp::attachInterface(*ctx); IfOp::attachInterface(*ctx); + WhileOp::attachInterface(*ctx); YieldOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir index d92eaff..8ab2877 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -110,6 +110,54 @@ func.func @scf_for(%A : tensor, // ----- +func.func @scf_while_non_equiv_condition(%arg0: tensor<5xi1>, + %arg1: tensor<5xi1>, + %idx: index) -> (i1, i1) +{ + %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1) + : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) { + %condition = tensor.extract %w0[%idx] : tensor<5xi1> + // expected-error @+1 {{Condition arg #0 is not equivalent to the corresponding iter bbArg}} + scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1> + } do { + ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>): + %pos = "dummy.some_op"() : () -> (index) + %val = "dummy.another_op"() : () -> (i1) + %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1> + scf.yield %1, %b1 : tensor<5xi1>, tensor<5xi1> + } + + %v0 = tensor.extract %r0[%idx] : tensor<5xi1> + %v1 = tensor.extract %r1[%idx] : tensor<5xi1> + return %v0, %v1 : i1, i1 +} + +// ----- + +func.func @scf_while_non_equiv_yield(%arg0: tensor<5xi1>, + %arg1: tensor<5xi1>, + %idx: index) -> (i1, i1) +{ + %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1) + : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) { + %condition = tensor.extract %w0[%idx] : tensor<5xi1> + scf.condition(%condition) %w0, %w1 : tensor<5xi1>, tensor<5xi1> + } do { + ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>): + %pos = "dummy.some_op"() : () -> (index) + %val = "dummy.another_op"() : () -> (i1) + %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1> + // expected-error @+1 {{Yield operand #0 is not equivalent to the corresponding iter bbArg}} + scf.yield %b1, %1 : tensor<5xi1>, tensor<5xi1> + } + + %v0 = tensor.extract %r0[%idx] : tensor<5xi1> + %v1 = tensor.extract %r1[%idx] : tensor<5xi1> + return %v0, %v1 : i1, i1 +} + +// ----- + func.func private @fun_with_side_effects(%A: tensor {bufferization.writable = true}) func.func @foo(%A: tensor {bufferization.writable = true}) -> (tensor) { diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir index 1b6fd99..22b5e41 100644 --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -1,12 +1,12 @@ -// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries" -split-input-file | FileCheck %s // Run fuzzer with different seeds. -// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=23 bufferize-function-boundaries" -split-input-file -o /dev/null -// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=59 bufferize-function-boundaries" -split-input-file -o /dev/null -// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=91 bufferize-function-boundaries" -split-input-file -o /dev/null +// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=23 bufferize-function-boundaries" -split-input-file -o /dev/null +// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=59 bufferize-function-boundaries" -split-input-file -o /dev/null +// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=91 bufferize-function-boundaries" -split-input-file -o /dev/null // Test bufferization using memref types that have no layout map. -// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs fully-dynamic-layout-maps=0 bufferize-function-boundaries" -split-input-file -o /dev/null +// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs fully-dynamic-layout-maps=0 bufferize-function-boundaries" -split-input-file -o /dev/null // CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> @@ -328,3 +328,124 @@ func.func @scf_for_swapping_yields( // CHECK: return %[[r0]], %[[r1]] return %f0, %f1: f32, f32 } + +// ----- + +// CHECK-LABEL: func @scf_while( +// CHECK-SAME: %[[arg0:.*]]: memref +func.func @scf_while(%arg0: tensor, %idx: index) -> tensor { + // CHECK: scf.while : () -> () { + %res = scf.while (%arg1 = %arg0) : (tensor) -> tensor { + // CHECK: %[[condition:.*]] = memref.load %[[arg0]] + // CHECK: scf.condition(%[[condition]]) + %condition = tensor.extract %arg1[%idx] : tensor + scf.condition(%condition) %arg1 : tensor + } do { + ^bb0(%arg2: tensor): + // CHECK: } do { + // CHECK: memref.store %{{.*}}, %[[arg0]] + // CHECK: scf.yield + // CHECK: } + %pos = "dummy.some_op"() : () -> (index) + %val = "dummy.another_op"() : () -> (i1) + %1 = tensor.insert %val into %arg2[%pos] : tensor + scf.yield %1 : tensor + } + + // CHECK: return + return %res : tensor +} + +// ----- + +// The loop condition yields non-equivalent buffers. + +// CHECK-LABEL: func @scf_while_non_equiv_condition( +// CHECK-SAME: %[[arg0:.*]]: memref<5xi1, #{{.*}}>, %[[arg1:.*]]: memref<5xi1, #{{.*}}> +func.func @scf_while_non_equiv_condition(%arg0: tensor<5xi1>, + %arg1: tensor<5xi1>, + %idx: index) + -> (tensor<5xi1>, tensor<5xi1>) +{ + // These allocation used to be inside the scf.while loop, but they were + // hoisted. + // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1> + // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1> + // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[arg0]], %[[w1:.*]] = %[[arg1]]) {{.*}} { + %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1) + : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) { + // CHECK: %[[condition:.*]] = memref.load %[[w0]] + // CHECK: memref.copy %[[w1]], %[[a1]] + // CHECK: %[[casted1:.*]] = memref.cast %[[a1]] + // CHECK: memref.copy %[[w0]], %[[a0]] + // CHECK: %[[casted0:.*]] = memref.cast %[[a0]] + // CHECK: scf.condition(%[[condition]]) %[[casted1]], %[[casted0]] + %condition = tensor.extract %w0[%idx] : tensor<5xi1> + scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1> + } do { + ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>): + // CHECK: } do { + // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1, #{{.*}}>, %[[b1:.*]]: memref<5xi1, #{{.*}}): + // CHECK: memref.store %{{.*}}, %[[b0]] + // CHECK: scf.yield %[[b0]], %[[b1]] + // CHECK: } + %pos = "dummy.some_op"() : () -> (index) + %val = "dummy.another_op"() : () -> (i1) + %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1> + scf.yield %1, %b1 : tensor<5xi1>, tensor<5xi1> + } + + // CHECK: return %[[loop]]#0, %[[loop]]#1 + return %r0, %r1 : tensor<5xi1>, tensor<5xi1> +} + +// ----- + +// Both the loop condition and the loop buffer yield non-equivalent buffers. + +// CHECK-LABEL: func @scf_while_non_equiv_condition_and_body( +// CHECK-SAME: %[[arg0:.*]]: memref<5xi1, #{{.*}}>, %[[arg1:.*]]: memref<5xi1, #{{.*}}> +func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>, + %arg1: tensor<5xi1>, + %idx: index) + -> (tensor<5xi1>, tensor<5xi1>) +{ + // These allocation used to be inside the scf.while loop, but they were + // hoisted. + // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1> + // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1> + // CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1> + // CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1> + // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[arg0]], %[[w1:.*]] = %[[arg1]]) {{.*}} { + %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1) + : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) { + // CHECK: %[[condition:.*]] = memref.load %[[w0]] + // CHECK: memref.copy %[[w1]], %[[a3]] + // CHECK: %[[casted3:.*]] = memref.cast %[[a3]] + // CHECK: memref.copy %[[w0]], %[[a2]] + // CHECK: %[[casted2:.*]] = memref.cast %[[a2]] + // CHECK: scf.condition(%[[condition]]) %[[casted3]], %[[casted2]] + %condition = tensor.extract %w0[%idx] : tensor<5xi1> + scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1> + } do { + ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>): + // CHECK: } do { + // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1, #{{.*}}>, %[[b1:.*]]: memref<5xi1, #{{.*}}): + // CHECK: memref.store %{{.*}}, %[[b0]] + // CHECK: memref.copy %[[b1]], %[[a1]] + // CHECK: %[[casted1:.*]] = memref.cast %[[a1]] + // CHECK: memref.copy %[[b0]], %[[a0]] + // CHECK: %[[casted0:.*]] = memref.cast %[[a0]] + // CHECK: scf.yield %[[casted1]], %[[casted0]] + // CHECK: } + %pos = "dummy.some_op"() : () -> (index) + %val = "dummy.another_op"() : () -> (i1) + %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1> + scf.yield %b1, %1 : tensor<5xi1>, tensor<5xi1> + } + + // CHECK-DAG: memref.dealloc %[[a0]] + // CHECK-DAG: memref.dealloc %[[a1]] + // CHECK: return %[[loop]]#0, %[[loop]]#1 + return %r0, %r1 : tensor<5xi1>, tensor<5xi1> +}