From 6247988e0751422fa10d70e64939c987dd3b81d9 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Thu, 8 Sep 2022 09:38:50 +0200 Subject: [PATCH] One-shot-bufferize: fix for inconsistent while arg types in before/after. Currently, if the `before` and `after` regions of a while op have tensor args in different indices, this leads to a crash. This moves the pass-through check for args to the handling of the condition block, since that is where the results are produced, so it's also where copies must be made. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D133477 --- .../BufferizableOpInterfaceImpl.cpp | 33 ++----------------- ...erize-allow-return-allocs-no-deallocs.mlir | 20 +++++++++++ ...-shot-bufferize-tensor-copy-insertion.mlir | 4 +-- mlir/test/Dialect/SCF/one-shot-bufferize.mlir | 20 ++++------- 4 files changed, 31 insertions(+), 46 deletions(-) create mode 100644 mlir/test/Dialect/SCF/one-shot-bufferize-allow-return-allocs-no-deallocs.mlir diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index fb8c8dd3e2b8..27be21430b17 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -762,13 +762,6 @@ struct WhileOpInterface OpBuilder::InsertionGuard g(rewriter); auto whileOp = cast(op); auto conditionOp = whileOp.getConditionOp(); - auto yieldOp = whileOp.getYieldOp(); - - // Indices of all bbArgs that have tensor type. These are the ones that - // are bufferized. The "before" and "after" regions may have different args. - DenseSet indicesBefore = getTensorIndices(whileOp.getInits()); - DenseSet indicesAfter = - getTensorIndices(whileOp.getAfterArguments()); // For every yielded value, is the value equivalent to its corresponding // bbArg? @@ -783,8 +776,9 @@ struct WhileOpInterface for (int64_t idx = 0; idx < static_cast(conditionOp.getArgs().size()); ++idx) { Value value = conditionOp.getArgs()[idx]; - if (!indicesBefore.contains(idx) || - equivalentYieldsBefore.contains(idx)) { + if (!value.getType().isa() || + (equivalentYieldsAfter.contains(idx) && + equivalentYieldsBefore.contains(idx))) { beforeYieldValues.push_back(value); continue; } @@ -799,27 +793,6 @@ struct WhileOpInterface conditionOp.getArgsMutable().assign(beforeYieldValues); }); - // Update "after" region. - rewriter.setInsertionPoint(yieldOp); - SmallVector afterYieldValues; - for (int64_t idx = 0; - idx < static_cast(yieldOp.getResults().size()); ++idx) { - Value value = yieldOp.getResults()[idx]; - if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) { - afterYieldValues.push_back(value); - continue; - } - FailureOr alloc = - allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value, - /*escape=*/true, state.getOptions()); - if (failed(alloc)) - return failure(); - afterYieldValues.push_back(*alloc); - } - rewriter.updateRootInPlace(yieldOp, [&]() { - yieldOp.getResultsMutable().assign(afterYieldValues); - }); - return success(); } diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-allow-return-allocs-no-deallocs.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-allow-return-allocs-no-deallocs.mlir new file mode 100644 index 000000000000..7e894b775fe0 --- /dev/null +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-allow-return-allocs-no-deallocs.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt %s \ +// RUN: -one-shot-bufferize="allow-return-allocs create-deallocs=0" \ +// RUN: -split-input-file | \ +// RUN: FileCheck %s --dump-input=always + +// A regression test to check that different before and after argument types are +// bufferized successfully. +func.func @different_before_after_args() -> tensor { + %true = arith.constant true + %cst = arith.constant dense<0.0> : tensor + %0 = scf.while (%arg4 = %true) : (i1) -> (tensor) { + scf.condition(%true) %cst : tensor + } do { + ^bb0(%arg4: tensor): + scf.yield %true : i1 + } + return %0 : tensor +} + +// CHECK-LABEL: @different_before_after_args \ No newline at end of file diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir index ec0ffa657d87..a5337831a0a5 100644 --- a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir @@ -98,9 +98,7 @@ func.func @scf_while_non_equiv_condition_and_body(%A: tensor<5xi1>, ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>): // CHECK: } do { // CHECK: ^bb0(%[[b0:.*]]: tensor<5xi1>, %[[b1:.*]]: tensor<5xi1>): - // CHECK-DAG: %[[yield2:.*]] = bufferization.alloc_tensor() copy(%[[b1]]) {bufferization.escape = [true]} : tensor<5xi1> - // CHECK-DAG: %[[yield3:.*]] = bufferization.alloc_tensor() copy(%[[b0]]) {bufferization.escape = [true]} : tensor<5xi1> - // CHECK: scf.yield %[[yield2]], %[[yield3]] + // CHECK: scf.yield %[[b1]], %[[b0]] // CHECK: } scf.yield %b1, %b0 : tensor<5xi1>, tensor<5xi1> } diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir index dab4331a2586..b37999315208 100644 --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -430,8 +430,8 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>, %idx: index) -> (tensor<5xi1>, tensor<5xi1>) { - // CHECK: %[[clone1:.*]] = bufferization.clone %[[arg1]] - // CHECK: %[[clone0:.*]] = bufferization.clone %[[arg0]] + // CHECK-DAG: %[[clone1:.*]] = bufferization.clone %[[arg1]] + // CHECK-DAG: %[[clone0:.*]] = bufferization.clone %[[arg0]] // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[clone0]], %[[w1:.*]] = %[[clone1]]) {{.*}} { %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1) : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) { @@ -454,19 +454,13 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>, // CHECK: } do { // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1>, %[[b1:.*]]: memref<5xi1>): // CHECK: memref.store %{{.*}}, %[[b0]] - // CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1> - // CHECK: memref.copy %[[b1]], %[[a3]] + // CHECK: %[[casted1:.*]] = memref.cast %[[b1]] + // CHECK: %[[casted0:.*]] = memref.cast %[[b0]] + // CHECK: %[[cloned1:.*]] = bufferization.clone %[[casted1]] // CHECK: memref.dealloc %[[b1]] - // CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1> - // CHECK: memref.copy %[[b0]], %[[a2]] + // CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]] // CHECK: memref.dealloc %[[b0]] - // CHECK: %[[casted3:.*]] = memref.cast %[[a3]] - // CHECK: %[[casted2:.*]] = memref.cast %[[a2]] - // CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted2]] - // CHECK: memref.dealloc %[[a2]] - // CHECK: %[[cloned3:.*]] = bufferization.clone %[[casted3]] - // CHECK: memref.dealloc %[[a3]] - // CHECK: scf.yield %[[cloned3]], %[[cloned2]] + // CHECK: scf.yield %[[cloned1]], %[[cloned0]] // CHECK: } %pos = "dummy.some_op"() : () -> (index) %val = "dummy.another_op"() : () -> (i1) -- 2.34.1