From c3728d28821e212bd3658261e58e744421668720 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Sat, 21 Jan 2023 01:50:40 -0500 Subject: [PATCH] [mlir] support !shape.value_shape when replace WithOp in OutlineShapeComputationPass. Fixes #60069 https://github.com/llvm/llvm-project/issues/60069 In case like: %1 = shape.with_shape %arg1, %0 : !shape.value_shape, !shape.shape %2 = shape.value_of %1 : tensor cannot replace %2 with %arg1. Transform it into %2 = shape.value_of %arg1 : tensor Differential Revision: https://reviews.llvm.org/D142275 --- .../Shape/Transforms/OutlineShapeComputation.cpp | 21 ++++++++++++++++++--- mlir/test/Dialect/Shape/arg_with_shape.mlir | 16 ++++++++++++++++ .../Dialect/Shape/outline-shape-computation.mlir | 10 ++++++++++ 3 files changed, 44 insertions(+), 3 deletions(-) create mode 100644 mlir/test/Dialect/Shape/arg_with_shape.mlir diff --git a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp index 372ec80..f23a090 100644 --- a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp @@ -232,9 +232,24 @@ void OutlineShapeComputationPass::runOnOperation() { for (shape::WithOp withOp : allWithOps) { Value value = withOp.getOperand(); - for (Operation *user : withOp.getResult().getUsers()) { - if (Value valueOf = llvm::dyn_cast(user)) - valueOf.replaceAllUsesExcept(value, withOp); + for (Operation *user : + llvm::make_early_inc_range(withOp.getResult().getUsers())) { + if (auto valueOf = llvm::dyn_cast(user)) { + // For pattern like + // %1 = shape.with_shape %arg1, %0 + // %2 = shape.value_of %1 + // because shape.value doesn't care the shape, the shape.with_shape is + // redundant. + // If type of %arg1 and %2 has same type, just + // replaced %2 with %arg1. + // If type of %arg1 has different type like !shape.value_shape, + // transform into + // %2 = shape.value_of %arg1 + if (valueOf.getType() == value.getType()) + valueOf.replaceAllUsesWith(value); + else + valueOf.setOperand(value); + } } } diff --git a/mlir/test/Dialect/Shape/arg_with_shape.mlir b/mlir/test/Dialect/Shape/arg_with_shape.mlir new file mode 100644 index 0000000..089c503 --- /dev/null +++ b/mlir/test/Dialect/Shape/arg_with_shape.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt -outline-shape-computation -split-input-file %s 2>%t | FileCheck %s + +func.func @func1(%arg0: !shape.value_shape, %arg1: !shape.value_shape) -> !shape.shape { + %0 = shape.shape_of %arg0 : !shape.value_shape -> !shape.shape + %1 = shape.shape_of %arg1 : !shape.value_shape -> !shape.shape + %2 = shape.meet %0, %1 : !shape.shape, !shape.shape -> !shape.shape + return %2 : !shape.shape +} +// Make sure with_shape used by call not crash. +// CHECK-LABEL:func.func @func +func.func @func(%arg0: !shape.value_shape, %arg1: !shape.value_shape) -> !shape.shape { + %0 = shape.shape_of %arg0 : !shape.value_shape -> !shape.shape + %1 = shape.with_shape %arg1, %0 : !shape.value_shape, !shape.shape + %2 = call @func1(%arg0, %1) : (!shape.value_shape, !shape.value_shape) -> !shape.shape + return %2 : !shape.shape +} diff --git a/mlir/test/Dialect/Shape/outline-shape-computation.mlir b/mlir/test/Dialect/Shape/outline-shape-computation.mlir index 73f6e3a..4aa15a9 100644 --- a/mlir/test/Dialect/Shape/outline-shape-computation.mlir +++ b/mlir/test/Dialect/Shape/outline-shape-computation.mlir @@ -207,3 +207,13 @@ func.func @multiple_reused(%arg0: tensor, %arg1: tensor) -> (t // CHECK-DAG: %[[V5:.*]] = from_extents %[[V4]], %c4 : index, index // CHECK-DAG: return %[[V5]] : !shape.shape +// Make sure redundant with_shape is removed when with_shape input is !shape.value_shape. +func.func @value_shape_with_shape(%arg0: !shape.value_shape, %arg1: !shape.value_shape) -> tensor { + %1 = shape.shape_of %arg0 : !shape.value_shape -> !shape.shape + %2 = shape.with_shape %arg1, %1 : !shape.value_shape, !shape.shape + %3 = shape.value_of %2 : tensor + return %3 : tensor +} +// CHECK-LABEL:func.func @value_shape_with_shape +// CHECK-NEXT:%0 = shape.value_of %arg1 : tensor +// CHECK-NEXT:return %0 : tensor -- 2.7.4