//===----------------------------------------------------------------------===//
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
+ auto trueVal = getTrueValue();
+ auto falseVal = getFalseValue();
+ if (trueVal == falseVal)
+ return trueVal;
+
auto condition = getCondition();
// select true, %0, %1 => %0
if (matchPattern(condition, m_One()))
- return getTrueValue();
+ return trueVal;
// select false, %0, %1 => %1
if (matchPattern(condition, m_Zero()))
- return getFalseValue();
+ return falseVal;
+
+ if (auto cmp = dyn_cast_or_null<CmpIOp>(condition.getDefiningOp())) {
+ auto pred = cmp.predicate();
+ if (pred == mlir::CmpIPredicate::eq || pred == mlir::CmpIPredicate::ne) {
+ auto cmpLhs = cmp.lhs();
+ auto cmpRhs = cmp.rhs();
+
+ // %0 = cmpi eq, %arg0, %arg1
+ // %1 = select %0, %arg0, %arg1 => %arg1
+
+ // %0 = cmpi ne, %arg0, %arg1
+ // %1 = select %0, %arg0, %arg1 => %arg0
+
+ if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
+ (cmpRhs == trueVal && cmpLhs == falseVal))
+ return pred == mlir::CmpIPredicate::ne ? trueVal : falseVal;
+ }
+ }
return nullptr;
}
// CHECK: %[[GENERATE:.+]] = tensor.generate
// CHECK: %[[RESULT:.+]] = subtensor_insert %[[ARG0]] into %[[GENERATE]]
// CHECK: return %[[RESULT]]
+
+// -----
+
+// CHECK-LABEL: @select_same_val
+// CHECK: return %arg1
+func @select_same_val(%arg0: i1, %arg1: i64) -> i64 {
+ %0 = select %arg0, %arg1, %arg1 : i64
+ return %0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @select_cmp_eq_select
+// CHECK: return %arg1
+func @select_cmp_eq_select(%arg0: i64, %arg1: i64) -> i64 {
+ %0 = cmpi eq, %arg0, %arg1 : i64
+ %1 = select %0, %arg0, %arg1 : i64
+ return %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @select_cmp_ne_select
+// CHECK: return %arg0
+func @select_cmp_ne_select(%arg0: i64, %arg1: i64) -> i64 {
+ %0 = cmpi ne, %arg0, %arg1 : i64
+ %1 = select %0, %arg0, %arg1 : i64
+ return %1 : i64
+}