From: Butygin Date: Fri, 12 Mar 2021 14:39:43 +0000 (+0300) Subject: [mlir] Additional folding for SelectOp X-Git-Tag: llvmorg-14-init~11810 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=7219b31d40f14604c669d633b014d0cc8b707cf3;p=platform%2Fupstream%2Fllvm.git [mlir] Additional folding for SelectOp * Fold SelectOp when both true and false args are same SSA value * Fold some cmp + select patterns Differential Revision: https://reviews.llvm.org/D98576 --- diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index bd38e15..4830a51 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1360,15 +1360,38 @@ static LogicalResult verify(ReturnOp op) { //===----------------------------------------------------------------------===// OpFoldResult SelectOp::fold(ArrayRef operands) { + auto trueVal = getTrueValue(); + auto falseVal = getFalseValue(); + if (trueVal == falseVal) + return trueVal; + auto condition = getCondition(); // select true, %0, %1 => %0 if (matchPattern(condition, m_One())) - return getTrueValue(); + return trueVal; // select false, %0, %1 => %1 if (matchPattern(condition, m_Zero())) - return getFalseValue(); + return falseVal; + + if (auto cmp = dyn_cast_or_null(condition.getDefiningOp())) { + auto pred = cmp.predicate(); + if (pred == mlir::CmpIPredicate::eq || pred == mlir::CmpIPredicate::ne) { + auto cmpLhs = cmp.lhs(); + auto cmpRhs = cmp.rhs(); + + // %0 = cmpi eq, %arg0, %arg1 + // %1 = select %0, %arg0, %arg1 => %arg1 + + // %0 = cmpi ne, %arg0, %arg1 + // %1 = select %0, %arg0, %arg1 => %arg0 + + if ((cmpLhs == trueVal && cmpRhs == falseVal) || + (cmpRhs == trueVal && cmpLhs == falseVal)) + return pred == mlir::CmpIPredicate::ne ? trueVal : falseVal; + } + } return nullptr; } diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir index a6bf0c7..7702202 100644 --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -339,3 +339,32 @@ func @subtensor_insert_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : // CHECK: %[[GENERATE:.+]] = tensor.generate // CHECK: %[[RESULT:.+]] = subtensor_insert %[[ARG0]] into %[[GENERATE]] // CHECK: return %[[RESULT]] + +// ----- + +// CHECK-LABEL: @select_same_val +// CHECK: return %arg1 +func @select_same_val(%arg0: i1, %arg1: i64) -> i64 { + %0 = select %arg0, %arg1, %arg1 : i64 + return %0 : i64 +} + +// ----- + +// CHECK-LABEL: @select_cmp_eq_select +// CHECK: return %arg1 +func @select_cmp_eq_select(%arg0: i64, %arg1: i64) -> i64 { + %0 = cmpi eq, %arg0, %arg1 : i64 + %1 = select %0, %arg0, %arg1 : i64 + return %1 : i64 +} + +// ----- + +// CHECK-LABEL: @select_cmp_ne_select +// CHECK: return %arg0 +func @select_cmp_ne_select(%arg0: i64, %arg1: i64) -> i64 { + %0 = cmpi ne, %arg0, %arg1 : i64 + %1 = select %0, %arg0, %arg1 : i64 + return %1 : i64 +}