/// Helper function for loop bufferization. Return the indices of all
/// bbArg/yielded value pairs who's buffer relation is "Equivalent".
-DenseSet<int64_t> getEquivalentBuffers(ValueRange bbArgs,
+DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
ValueRange yieldedValues,
const AnalysisState &state) {
DenseSet<int64_t> result;
});
}
+/// Helper function for loop bufferization. Given a list of bbArgs of the new
+/// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
+/// ToTensorOps, so that the block body can be moved over to the new op.
+SmallVector<Value>
+getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
+ const DenseSet<int64_t> &tensorIndices) {
+ return convertTensorValues(
+ bbArgs, tensorIndices, [&](Value val, int64_t index) {
+ return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
+ });
+}
+
/// Bufferization of scf.for. Replace with a new scf.for that operates on
/// memrefs.
struct ForOpInterface
// Set up new iter_args. The loop body uses tensors, so wrap the (memref)
// iter_args of the new loop in ToTensorOps.
rewriter.setInsertionPointToStart(loopBody);
- SmallVector<Value> iterArgs = convertTensorValues(
- newForOp.getRegionIterArgs(), indices, [&](Value val, int64_t index) {
- return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
- });
+ SmallVector<Value> iterArgs =
+ getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
// Erase terminator if present.
}
};
+/// Bufferization of scf.while. Replace with a new scf.while that operates on
+/// memrefs.
+struct WhileOpInterface
+ : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
+ scf::WhileOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ // Tensor iter_args of scf::WhileOps are always considered as a read.
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ // Tensor iter_args of scf::WhileOps are always considered as a write.
+ return true;
+ }
+
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ auto whileOp = cast<scf::WhileOp>(op);
+ return {whileOp->getResult(opOperand.getOperandNumber())};
+ }
+
+ BufferRelation bufferRelation(Operation *op, OpResult opResult,
+ const AnalysisState &state) const {
+ // WhileOp results are equivalent to their corresponding init_args if the
+ // corresponding iter_args and yield values are equivalent (for both the
+ // "before" and the "after" block).
+ unsigned int resultNumber = opResult.getResultNumber();
+ auto whileOp = cast<scf::WhileOp>(op);
+
+ auto conditionOp = whileOp.getConditionOp();
+ BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
+ Value conditionOperand = conditionOp.getArgs()[resultNumber];
+ bool equivCondition =
+ state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
+
+ auto yieldOp = whileOp.getYieldOp();
+ BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
+ Value yieldOperand = yieldOp.getOperand(resultNumber);
+ bool equivYield =
+ state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
+
+ return equivCondition && equivYield ? BufferRelation::Equivalent
+ : BufferRelation::None;
+ }
+
+ bool isWritable(Operation *op, Value value,
+ const AnalysisState &state) const {
+ // Interestingly, scf::WhileOp's bbArg can **always** be viewed
+ // inplace from the perspective of ops nested under:
+ // 1. Either the matching iter operand is not bufferized inplace and an
+ // alloc + optional copy makes the bbArg itself inplaceable.
+ // 2. Or the matching iter operand is bufferized inplace and bbArg just
+ // bufferizes to that too.
+ return true;
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ BufferizationState &state) const {
+ auto whileOp = cast<scf::WhileOp>(op);
+
+ assert(whileOp.getBefore().getBlocks().size() == 1 &&
+ "regions with multiple blocks not supported");
+ Block *beforeBody = &whileOp.getBefore().front();
+ assert(whileOp.getAfter().getBlocks().size() == 1 &&
+ "regions with multiple blocks not supported");
+ Block *afterBody = &whileOp.getAfter().front();
+
+ // Indices of all iter_args that have tensor type. These are the ones that
+ // are bufferized.
+ DenseSet<int64_t> indices = getTensorIndices(whileOp.getInits());
+ // For every yielded value, is the value equivalent to its corresponding
+ // bbArg?
+ DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
+ whileOp.getBeforeArguments(), whileOp.getConditionOp().getArgs(),
+ state.getAnalysisState());
+ DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
+ whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(),
+ state.getAnalysisState());
+
+ // The new memref init_args of the loop.
+ SmallVector<Value> initArgs =
+ getBuffers(rewriter, whileOp->getOpOperands(), state);
+ if (initArgs.size() != indices.size())
+ return failure();
+
+ // Construct a new scf.while op with memref instead of tensor values.
+ ValueRange argsRange(initArgs);
+ TypeRange argsTypes(argsRange);
+ auto newWhileOp =
+ rewriter.create<scf::WhileOp>(whileOp.getLoc(), argsTypes, initArgs);
+ // Add before/after regions to the new op.
+ SmallVector<Location> bbArgLocs(initArgs.size(), whileOp.getLoc());
+ Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
+ newWhileOp.getBefore().addArguments(argsTypes, bbArgLocs);
+ Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
+ newWhileOp.getAfter().addArguments(argsTypes, bbArgLocs);
+
+ // Set up new iter_args and move the loop condition block to the new op.
+ // The old block uses tensors, so wrap the (memref) bbArgs of the new block
+ // in ToTensorOps.
+ rewriter.setInsertionPointToStart(newBeforeBody);
+ SmallVector<Value> newBeforeArgs = getBbArgReplacements(
+ rewriter, newWhileOp.getBeforeArguments(), indices);
+ rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
+
+ // Update scf.condition of new loop.
+ auto newConditionOp = newWhileOp.getConditionOp();
+ rewriter.setInsertionPoint(newConditionOp);
+ SmallVector<Value> newConditionArgs =
+ getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypes, indices,
+ equivalentYieldsBefore, state);
+ newConditionOp.getArgsMutable().assign(newConditionArgs);
+
+ // Set up new iter_args and move the loop body block to the new op.
+ // The old block uses tensors, so wrap the (memref) bbArgs of the new block
+ // in ToTensorOps.
+ rewriter.setInsertionPointToStart(newAfterBody);
+ SmallVector<Value> newAfterArgs =
+ getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(), indices);
+ rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
+
+ // Update scf.yield of the new loop.
+ auto newYieldOp = newWhileOp.getYieldOp();
+ rewriter.setInsertionPoint(newYieldOp);
+ SmallVector<Value> newYieldValues =
+ getYieldedValues(rewriter, newYieldOp.getResults(), argsTypes, indices,
+ equivalentYieldsAfter, state);
+ newYieldOp.getResultsMutable().assign(newYieldValues);
+
+ // Replace loop results.
+ replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
+
+ return success();
+ }
+
+ /// Assert that yielded values of an scf.while op are equivalent to their
+ /// corresponding bbArgs. In that case, the buffer relations of the
+ /// corresponding OpResults are "Equivalent".
+ ///
+ /// If this is not the case, allocs+copies are inserted and yielded from
+ /// the loop. This could be a performance problem, so it must be explicitly
+ /// activated with `alloc-return-allocs`.
+ ///
+ /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
+ /// equivalence condition must be checked for both.
+ LogicalResult verifyAnalysis(Operation *op,
+ const AnalysisState &state) const {
+ auto whileOp = cast<scf::WhileOp>(op);
+ const auto &options =
+ static_cast<const OneShotBufferizationOptions &>(state.getOptions());
+ if (options.allowReturnAllocs)
+ return success();
+
+ auto conditionOp = whileOp.getConditionOp();
+ for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
+ if (!it.value().getType().isa<TensorType>())
+ continue;
+ if (!state.areEquivalentBufferizedValues(
+ it.value(), conditionOp->getBlock()->getArgument(it.index())))
+ return conditionOp->emitError()
+ << "Condition arg #" << it.index()
+ << " is not equivalent to the corresponding iter bbArg";
+ }
+
+ auto yieldOp = whileOp.getYieldOp();
+ for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
+ if (!it.value().getType().isa<TensorType>())
+ continue;
+ if (!state.areEquivalentBufferizedValues(
+ it.value(), yieldOp->getBlock()->getArgument(it.index())))
+ return yieldOp->emitError()
+ << "Yield operand #" << it.index()
+ << " is not equivalent to the corresponding iter bbArg";
+ }
+
+ return success();
+ }
+};
+
/// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
/// this is for analysis only.
struct YieldOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto yieldOp = cast<scf::YieldOp>(op);
- if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(
+ if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
yieldOp->getParentOp()))
return yieldOp->emitError("unsupported scf::YieldOp parent");
return success();
ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
ForOp::attachInterface<ForOpInterface>(*ctx);
IfOp::attachInterface<IfOpInterface>(*ctx);
+ WhileOp::attachInterface<WhileOpInterface>(*ctx);
YieldOp::attachInterface<YieldOpInterface>(*ctx);
});
}
// -----
+func.func @scf_while_non_equiv_condition(%arg0: tensor<5xi1>,
+ %arg1: tensor<5xi1>,
+ %idx: index) -> (i1, i1)
+{
+ %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1)
+ : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) {
+ %condition = tensor.extract %w0[%idx] : tensor<5xi1>
+ // expected-error @+1 {{Condition arg #0 is not equivalent to the corresponding iter bbArg}}
+ scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1>
+ } do {
+ ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
+ %pos = "dummy.some_op"() : () -> (index)
+ %val = "dummy.another_op"() : () -> (i1)
+ %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1>
+ scf.yield %1, %b1 : tensor<5xi1>, tensor<5xi1>
+ }
+
+ %v0 = tensor.extract %r0[%idx] : tensor<5xi1>
+ %v1 = tensor.extract %r1[%idx] : tensor<5xi1>
+ return %v0, %v1 : i1, i1
+}
+
+// -----
+
+func.func @scf_while_non_equiv_yield(%arg0: tensor<5xi1>,
+ %arg1: tensor<5xi1>,
+ %idx: index) -> (i1, i1)
+{
+ %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1)
+ : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) {
+ %condition = tensor.extract %w0[%idx] : tensor<5xi1>
+ scf.condition(%condition) %w0, %w1 : tensor<5xi1>, tensor<5xi1>
+ } do {
+ ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
+ %pos = "dummy.some_op"() : () -> (index)
+ %val = "dummy.another_op"() : () -> (i1)
+ %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1>
+ // expected-error @+1 {{Yield operand #0 is not equivalent to the corresponding iter bbArg}}
+ scf.yield %b1, %1 : tensor<5xi1>, tensor<5xi1>
+ }
+
+ %v0 = tensor.extract %r0[%idx] : tensor<5xi1>
+ %v1 = tensor.extract %r1[%idx] : tensor<5xi1>
+ return %v0, %v1 : i1, i1
+}
+
+// -----
+
func.func private @fun_with_side_effects(%A: tensor<?xf32> {bufferization.writable = true})
func.func @foo(%A: tensor<?xf32> {bufferization.writable = true}) -> (tensor<?xf32>) {
-// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries" -split-input-file | FileCheck %s
// Run fuzzer with different seeds.
-// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=23 bufferize-function-boundaries" -split-input-file -o /dev/null
-// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=59 bufferize-function-boundaries" -split-input-file -o /dev/null
-// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=91 bufferize-function-boundaries" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=23 bufferize-function-boundaries" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=59 bufferize-function-boundaries" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=91 bufferize-function-boundaries" -split-input-file -o /dev/null
// Test bufferization using memref types that have no layout map.
-// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs fully-dynamic-layout-maps=0 bufferize-function-boundaries" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs fully-dynamic-layout-maps=0 bufferize-function-boundaries" -split-input-file -o /dev/null
// CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
// CHECK: return %[[r0]], %[[r1]]
return %f0, %f1: f32, f32
}
+
+// -----
+
+// CHECK-LABEL: func @scf_while(
+// CHECK-SAME: %[[arg0:.*]]: memref<?xi1, #{{.*}}>
+func.func @scf_while(%arg0: tensor<?xi1>, %idx: index) -> tensor<?xi1> {
+ // CHECK: scf.while : () -> () {
+ %res = scf.while (%arg1 = %arg0) : (tensor<?xi1>) -> tensor<?xi1> {
+ // CHECK: %[[condition:.*]] = memref.load %[[arg0]]
+ // CHECK: scf.condition(%[[condition]])
+ %condition = tensor.extract %arg1[%idx] : tensor<?xi1>
+ scf.condition(%condition) %arg1 : tensor<?xi1>
+ } do {
+ ^bb0(%arg2: tensor<?xi1>):
+ // CHECK: } do {
+ // CHECK: memref.store %{{.*}}, %[[arg0]]
+ // CHECK: scf.yield
+ // CHECK: }
+ %pos = "dummy.some_op"() : () -> (index)
+ %val = "dummy.another_op"() : () -> (i1)
+ %1 = tensor.insert %val into %arg2[%pos] : tensor<?xi1>
+ scf.yield %1 : tensor<?xi1>
+ }
+
+ // CHECK: return
+ return %res : tensor<?xi1>
+}
+
+// -----
+
+// The loop condition yields non-equivalent buffers.
+
+// CHECK-LABEL: func @scf_while_non_equiv_condition(
+// CHECK-SAME: %[[arg0:.*]]: memref<5xi1, #{{.*}}>, %[[arg1:.*]]: memref<5xi1, #{{.*}}>
+func.func @scf_while_non_equiv_condition(%arg0: tensor<5xi1>,
+ %arg1: tensor<5xi1>,
+ %idx: index)
+ -> (tensor<5xi1>, tensor<5xi1>)
+{
+ // These allocation used to be inside the scf.while loop, but they were
+ // hoisted.
+ // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
+ // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
+ // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[arg0]], %[[w1:.*]] = %[[arg1]]) {{.*}} {
+ %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1)
+ : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) {
+ // CHECK: %[[condition:.*]] = memref.load %[[w0]]
+ // CHECK: memref.copy %[[w1]], %[[a1]]
+ // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
+ // CHECK: memref.copy %[[w0]], %[[a0]]
+ // CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
+ // CHECK: scf.condition(%[[condition]]) %[[casted1]], %[[casted0]]
+ %condition = tensor.extract %w0[%idx] : tensor<5xi1>
+ scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1>
+ } do {
+ ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
+ // CHECK: } do {
+ // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1, #{{.*}}>, %[[b1:.*]]: memref<5xi1, #{{.*}}):
+ // CHECK: memref.store %{{.*}}, %[[b0]]
+ // CHECK: scf.yield %[[b0]], %[[b1]]
+ // CHECK: }
+ %pos = "dummy.some_op"() : () -> (index)
+ %val = "dummy.another_op"() : () -> (i1)
+ %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1>
+ scf.yield %1, %b1 : tensor<5xi1>, tensor<5xi1>
+ }
+
+ // CHECK: return %[[loop]]#0, %[[loop]]#1
+ return %r0, %r1 : tensor<5xi1>, tensor<5xi1>
+}
+
+// -----
+
+// Both the loop condition and the loop buffer yield non-equivalent buffers.
+
+// CHECK-LABEL: func @scf_while_non_equiv_condition_and_body(
+// CHECK-SAME: %[[arg0:.*]]: memref<5xi1, #{{.*}}>, %[[arg1:.*]]: memref<5xi1, #{{.*}}>
+func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
+ %arg1: tensor<5xi1>,
+ %idx: index)
+ -> (tensor<5xi1>, tensor<5xi1>)
+{
+ // These allocation used to be inside the scf.while loop, but they were
+ // hoisted.
+ // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
+ // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
+ // CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
+ // CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1>
+ // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[arg0]], %[[w1:.*]] = %[[arg1]]) {{.*}} {
+ %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1)
+ : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) {
+ // CHECK: %[[condition:.*]] = memref.load %[[w0]]
+ // CHECK: memref.copy %[[w1]], %[[a3]]
+ // CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
+ // CHECK: memref.copy %[[w0]], %[[a2]]
+ // CHECK: %[[casted2:.*]] = memref.cast %[[a2]]
+ // CHECK: scf.condition(%[[condition]]) %[[casted3]], %[[casted2]]
+ %condition = tensor.extract %w0[%idx] : tensor<5xi1>
+ scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1>
+ } do {
+ ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
+ // CHECK: } do {
+ // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1, #{{.*}}>, %[[b1:.*]]: memref<5xi1, #{{.*}}):
+ // CHECK: memref.store %{{.*}}, %[[b0]]
+ // CHECK: memref.copy %[[b1]], %[[a1]]
+ // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
+ // CHECK: memref.copy %[[b0]], %[[a0]]
+ // CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
+ // CHECK: scf.yield %[[casted1]], %[[casted0]]
+ // CHECK: }
+ %pos = "dummy.some_op"() : () -> (index)
+ %val = "dummy.another_op"() : () -> (i1)
+ %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1>
+ scf.yield %b1, %1 : tensor<5xi1>, tensor<5xi1>
+ }
+
+ // CHECK-DAG: memref.dealloc %[[a0]]
+ // CHECK-DAG: memref.dealloc %[[a1]]
+ // CHECK: return %[[loop]]#0, %[[loop]]#1
+ return %r0, %r1 : tensor<5xi1>, tensor<5xi1>
+}