Currently, if the `before` and `after` regions of a while op have
tensor args in different indices, this leads to a crash.
This moves the pass-through check for args to the handling of the
condition block, since that is where the results are produced, so
it's also where copies must be made.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/
D133477
OpBuilder::InsertionGuard g(rewriter);
auto whileOp = cast<scf::WhileOp>(op);
auto conditionOp = whileOp.getConditionOp();
- auto yieldOp = whileOp.getYieldOp();
-
- // Indices of all bbArgs that have tensor type. These are the ones that
- // are bufferized. The "before" and "after" regions may have different args.
- DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
- DenseSet<int64_t> indicesAfter =
- getTensorIndices(whileOp.getAfterArguments());
// For every yielded value, is the value equivalent to its corresponding
// bbArg?
for (int64_t idx = 0;
idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
Value value = conditionOp.getArgs()[idx];
- if (!indicesBefore.contains(idx) ||
- equivalentYieldsBefore.contains(idx)) {
+ if (!value.getType().isa<TensorType>() ||
+ (equivalentYieldsAfter.contains(idx) &&
+ equivalentYieldsBefore.contains(idx))) {
beforeYieldValues.push_back(value);
continue;
}
conditionOp.getArgsMutable().assign(beforeYieldValues);
});
- // Update "after" region.
- rewriter.setInsertionPoint(yieldOp);
- SmallVector<Value> afterYieldValues;
- for (int64_t idx = 0;
- idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
- Value value = yieldOp.getResults()[idx];
- if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) {
- afterYieldValues.push_back(value);
- continue;
- }
- FailureOr<Value> alloc =
- allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
- /*escape=*/true, state.getOptions());
- if (failed(alloc))
- return failure();
- afterYieldValues.push_back(*alloc);
- }
- rewriter.updateRootInPlace(yieldOp, [&]() {
- yieldOp.getResultsMutable().assign(afterYieldValues);
- });
-
return success();
}
--- /dev/null
+// RUN: mlir-opt %s \
+// RUN: -one-shot-bufferize="allow-return-allocs create-deallocs=0" \
+// RUN: -split-input-file | \
+// RUN: FileCheck %s --dump-input=always
+
+// A regression test to check that different before and after argument types are
+// bufferized successfully.
+func.func @different_before_after_args() -> tensor<f32> {
+ %true = arith.constant true
+ %cst = arith.constant dense<0.0> : tensor<f32>
+ %0 = scf.while (%arg4 = %true) : (i1) -> (tensor<f32>) {
+ scf.condition(%true) %cst : tensor<f32>
+ } do {
+ ^bb0(%arg4: tensor<f32>):
+ scf.yield %true : i1
+ }
+ return %0 : tensor<f32>
+}
+
+// CHECK-LABEL: @different_before_after_args
\ No newline at end of file
^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
// CHECK: } do {
// CHECK: ^bb0(%[[b0:.*]]: tensor<5xi1>, %[[b1:.*]]: tensor<5xi1>):
- // CHECK-DAG: %[[yield2:.*]] = bufferization.alloc_tensor() copy(%[[b1]]) {bufferization.escape = [true]} : tensor<5xi1>
- // CHECK-DAG: %[[yield3:.*]] = bufferization.alloc_tensor() copy(%[[b0]]) {bufferization.escape = [true]} : tensor<5xi1>
- // CHECK: scf.yield %[[yield2]], %[[yield3]]
+ // CHECK: scf.yield %[[b1]], %[[b0]]
// CHECK: }
scf.yield %b1, %b0 : tensor<5xi1>, tensor<5xi1>
}
%idx: index)
-> (tensor<5xi1>, tensor<5xi1>)
{
- // CHECK: %[[clone1:.*]] = bufferization.clone %[[arg1]]
- // CHECK: %[[clone0:.*]] = bufferization.clone %[[arg0]]
+ // CHECK-DAG: %[[clone1:.*]] = bufferization.clone %[[arg1]]
+ // CHECK-DAG: %[[clone0:.*]] = bufferization.clone %[[arg0]]
// CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[clone0]], %[[w1:.*]] = %[[clone1]]) {{.*}} {
%r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1)
: (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) {
// CHECK: } do {
// CHECK: ^bb0(%[[b0:.*]]: memref<5xi1>, %[[b1:.*]]: memref<5xi1>):
// CHECK: memref.store %{{.*}}, %[[b0]]
- // CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1>
- // CHECK: memref.copy %[[b1]], %[[a3]]
+ // CHECK: %[[casted1:.*]] = memref.cast %[[b1]]
+ // CHECK: %[[casted0:.*]] = memref.cast %[[b0]]
+ // CHECK: %[[cloned1:.*]] = bufferization.clone %[[casted1]]
// CHECK: memref.dealloc %[[b1]]
- // CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
- // CHECK: memref.copy %[[b0]], %[[a2]]
+ // CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]]
// CHECK: memref.dealloc %[[b0]]
- // CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
- // CHECK: %[[casted2:.*]] = memref.cast %[[a2]]
- // CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted2]]
- // CHECK: memref.dealloc %[[a2]]
- // CHECK: %[[cloned3:.*]] = bufferization.clone %[[casted3]]
- // CHECK: memref.dealloc %[[a3]]
- // CHECK: scf.yield %[[cloned3]], %[[cloned2]]
+ // CHECK: scf.yield %[[cloned1]], %[[cloned0]]
// CHECK: }
%pos = "dummy.some_op"() : () -> (index)
%val = "dummy.another_op"() : () -> (i1)