return BoolAttr::get(getContext(), val);
}
+class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
+public:
+ using OpRewritePattern<CmpFOp>::OpRewritePattern;
+
+ static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
+ bool isUnsigned) {
+ using namespace arith;
+ switch (pred) {
+ case CmpFPredicate::UEQ:
+ case CmpFPredicate::OEQ:
+ return CmpIPredicate::eq;
+ case CmpFPredicate::UGT:
+ case CmpFPredicate::OGT:
+ return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
+ case CmpFPredicate::UGE:
+ case CmpFPredicate::OGE:
+ return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
+ case CmpFPredicate::ULT:
+ case CmpFPredicate::OLT:
+ return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
+ case CmpFPredicate::ULE:
+ case CmpFPredicate::OLE:
+ return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
+ case CmpFPredicate::UNE:
+ case CmpFPredicate::ONE:
+ return CmpIPredicate::ne;
+ default:
+ llvm_unreachable("Unexpected predicate!");
+ }
+ }
+
+ LogicalResult matchAndRewrite(CmpFOp op,
+ PatternRewriter &rewriter) const override {
+ FloatAttr flt;
+ if (!matchPattern(op.getRhs(), m_Constant(&flt)))
+ return failure();
+
+ const APFloat &rhs = flt.getValue();
+
+ // Don't attempt to fold a nan.
+ if (rhs.isNaN())
+ return failure();
+
+ // Get the width of the mantissa. We don't want to hack on conversions that
+ // might lose information from the integer, e.g. "i64 -> float"
+ FloatType floatTy = op.getRhs().getType().cast<FloatType>();
+ int mantissaWidth = floatTy.getFPMantissaWidth();
+ if (mantissaWidth <= 0)
+ return failure();
+
+ bool isUnsigned;
+ Value intVal;
+
+ if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
+ isUnsigned = false;
+ intVal = si.getIn();
+ } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
+ isUnsigned = true;
+ intVal = ui.getIn();
+ } else {
+ return failure();
+ }
+
+ // Check to see that the input is converted from an integer type that is
+ // small enough that preserves all bits.
+ auto intTy = intVal.getType().cast<IntegerType>();
+ auto intWidth = intTy.getWidth();
+
+ // Number of bits representing values, as opposed to the sign
+ auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
+
+ // Following test does NOT adjust intWidth downwards for signed inputs,
+ // because the most negative value still requires all the mantissa bits
+ // to distinguish it from one less than that value.
+ if ((int)intWidth > mantissaWidth) {
+ // Conversion would lose accuracy. Check if loss can impact comparison.
+ int exponent = ilogb(rhs);
+ if (exponent == APFloat::IEK_Inf) {
+ int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
+ if (maxExponent < (int)valueBits) {
+ // Conversion could create infinity.
+ return failure();
+ }
+ } else {
+ // Note that if rhs is zero or NaN, then Exp is negative
+ // and first condition is trivially false.
+ if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
+ // Conversion could affect comparison.
+ return failure();
+ }
+ }
+ }
+
+ // Convert to equivalent cmpi predicate
+ CmpIPredicate pred;
+ switch (op.getPredicate()) {
+ case CmpFPredicate::ORD:
+ // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+ /*width=*/1);
+ return success();
+ case CmpFPredicate::UNO:
+ // Int to fp conversion doesn't create a nan (uno checks either is a nan)
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+ /*width=*/1);
+ return success();
+ default:
+ pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
+ break;
+ }
+
+ if (!isUnsigned) {
+ // If the rhs value is > SignedMax, fold the comparison. This handles
+ // +INF and large values.
+ APFloat signedMax(rhs.getSemantics());
+ signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
+ APFloat::rmNearestTiesToEven);
+ if (signedMax < rhs) { // smax < 13123.0
+ if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
+ pred == CmpIPredicate::sle)
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+ /*width=*/1);
+ else
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+ /*width=*/1);
+ return success();
+ }
+ } else {
+ // If the rhs value is > UnsignedMax, fold the comparison. This handles
+ // +INF and large values.
+ APFloat unsignedMax(rhs.getSemantics());
+ unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
+ APFloat::rmNearestTiesToEven);
+ if (unsignedMax < rhs) { // umax < 13123.0
+ if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
+ pred == CmpIPredicate::ule)
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+ /*width=*/1);
+ else
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+ /*width=*/1);
+ return success();
+ }
+ }
+
+ if (!isUnsigned) {
+ // See if the rhs value is < SignedMin.
+ APFloat signedMin(rhs.getSemantics());
+ signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
+ APFloat::rmNearestTiesToEven);
+ if (signedMin > rhs) { // smin > 12312.0
+ if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
+ pred == CmpIPredicate::sge)
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+ /*width=*/1);
+ else
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+ /*width=*/1);
+ return success();
+ }
+ } else {
+ // See if the rhs value is < UnsignedMin.
+ APFloat unsignedMin(rhs.getSemantics());
+ unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
+ APFloat::rmNearestTiesToEven);
+ if (unsignedMin > rhs) { // umin > 12312.0
+ if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
+ pred == CmpIPredicate::uge)
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+ /*width=*/1);
+ else
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+ /*width=*/1);
+ return success();
+ }
+ }
+
+ // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
+ // [0, UMAX], but it may still be fractional. See if it is fractional by
+ // casting the FP value to the integer value and back, checking for
+ // equality. Don't do this for zero, because -0.0 is not fractional.
+ bool ignored;
+ APSInt rhsInt(intWidth, isUnsigned);
+ if (APFloat::opInvalidOp ==
+ rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
+ // Undefined behavior invoked - the destination type can't represent
+ // the input constant.
+ return failure();
+ }
+
+ if (!rhs.isZero()) {
+ APFloat apf(floatTy.getFloatSemantics(),
+ APInt::getZero(floatTy.getWidth()));
+ apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
+
+ bool equal = apf == rhs;
+ if (!equal) {
+ // If we had a comparison against a fractional value, we have to adjust
+ // the compare predicate and sometimes the value. rhsInt is rounded
+ // towards zero at this point.
+ switch (pred) {
+ default:
+ llvm_unreachable("Unexpected integer comparison!");
+ case CmpIPredicate::ne: // (float)int != 4.4 --> true
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+ /*width=*/1);
+ return success();
+ case CmpIPredicate::eq: // (float)int == 4.4 --> false
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+ /*width=*/1);
+ return success();
+ case CmpIPredicate::ule:
+ // (float)int <= 4.4 --> int <= 4
+ // (float)int <= -4.4 --> false
+ if (rhs.isNegative()) {
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+ /*width=*/1);
+ return success();
+ }
+ break;
+ case CmpIPredicate::sle:
+ // (float)int <= 4.4 --> int <= 4
+ // (float)int <= -4.4 --> int < -4
+ if (rhs.isNegative())
+ pred = CmpIPredicate::slt;
+ break;
+ case CmpIPredicate::ult:
+ // (float)int < -4.4 --> false
+ // (float)int < 4.4 --> int <= 4
+ if (rhs.isNegative()) {
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+ /*width=*/1);
+ return success();
+ }
+ pred = CmpIPredicate::ule;
+ break;
+ case CmpIPredicate::slt:
+ // (float)int < -4.4 --> int < -4
+ // (float)int < 4.4 --> int <= 4
+ if (!rhs.isNegative())
+ pred = CmpIPredicate::sle;
+ break;
+ case CmpIPredicate::ugt:
+ // (float)int > 4.4 --> int > 4
+ // (float)int > -4.4 --> true
+ if (rhs.isNegative()) {
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+ /*width=*/1);
+ return success();
+ }
+ break;
+ case CmpIPredicate::sgt:
+ // (float)int > 4.4 --> int > 4
+ // (float)int > -4.4 --> int >= -4
+ if (rhs.isNegative())
+ pred = CmpIPredicate::sge;
+ break;
+ case CmpIPredicate::uge:
+ // (float)int >= -4.4 --> true
+ // (float)int >= 4.4 --> int > 4
+ if (rhs.isNegative()) {
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+ /*width=*/1);
+ return success();
+ }
+ pred = CmpIPredicate::ugt;
+ break;
+ case CmpIPredicate::sge:
+ // (float)int >= -4.4 --> int >= -4
+ // (float)int >= 4.4 --> int > 4
+ if (!rhs.isNegative())
+ pred = CmpIPredicate::sgt;
+ break;
+ }
+ }
+ }
+
+ // Lower this FP comparison into an appropriate integer version of the
+ // comparison.
+ rewriter.replaceOpWithNewOp<CmpIOp>(
+ op, pred, intVal,
+ rewriter.create<ConstantOp>(
+ op.getLoc(), intVal.getType(),
+ rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
+ return success();
+ }
+};
+
+void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.insert<CmpFIntToFPConst>(context);
+}
+
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
%res = arith.sitofp %c0 : i32 to f32
return %res : f32
}
+
+// -----
+
+// Tests rewritten from https://github.com/llvm/llvm-project/blob/main/llvm/test/Transforms/InstCombine/2008-11-08-FCmp.ll
+// When inst combining an FCMP with the LHS coming from a arith.uitofp instruction, we
+// can lower it to signed ICMP instructions.
+
+// CHECK-LABEL: @test1(
+// CHECK-SAME: %[[arg0:.+]]:
+func @test1(%arg0: i32) -> i1 {
+ %cst = arith.constant 0.000000e+00 : f64
+ %1 = arith.uitofp %arg0: i32 to f64
+ %2 = arith.cmpf ole, %1, %cst : f64
+ // CHECK: %[[c0:.+]] = arith.constant 0 : i32
+ // CHECK: arith.cmpi ule, %[[arg0]], %[[c0]] : i32
+ return %2 : i1
+}
+
+// CHECK-LABEL: @test2(
+// CHECK-SAME: %[[arg0:.+]]:
+func @test2(%arg0: i32) -> i1 {
+ %cst = arith.constant 0.000000e+00 : f64
+ %1 = arith.uitofp %arg0: i32 to f64
+ %2 = arith.cmpf olt, %1, %cst : f64
+ return %2 : i1
+ // CHECK: %[[c0:.+]] = arith.constant 0 : i32
+ // CHECK: arith.cmpi ult, %[[arg0]], %[[c0]] : i32
+}
+
+// CHECK-LABEL: @test3(
+// CHECK-SAME: %[[arg0:.+]]:
+func @test3(%arg0: i32) -> i1 {
+ %cst = arith.constant 0.000000e+00 : f64
+ %1 = arith.uitofp %arg0: i32 to f64
+ %2 = arith.cmpf oge, %1, %cst : f64
+ return %2 : i1
+ // CHECK: %[[c0:.+]] = arith.constant 0 : i32
+ // CHECK: arith.cmpi uge, %[[arg0]], %[[c0]] : i32
+}
+
+// CHECK-LABEL: @test4(
+// CHECK-SAME: %[[arg0:.+]]:
+func @test4(%arg0: i32) -> i1 {
+ %cst = arith.constant 0.000000e+00 : f64
+ %1 = arith.uitofp %arg0: i32 to f64
+ %2 = arith.cmpf ogt, %1, %cst : f64
+ // CHECK: %[[c0:.+]] = arith.constant 0 : i32
+ // CHECK: arith.cmpi ugt, %[[arg0]], %[[c0]] : i32
+ return %2 : i1
+}
+
+// CHECK-LABEL: @test5(
+func @test5(%arg0: i32) -> i1 {
+ %cst = arith.constant -4.400000e+00 : f64
+ %1 = arith.uitofp %arg0: i32 to f64
+ %2 = arith.cmpf ogt, %1, %cst : f64
+ return %2 : i1
+ // CHECK: %[[true:.+]] = arith.constant true
+ // CHECK: return %[[true]] : i1
+}
+
+// CHECK-LABEL: @test6(
+func @test6(%arg0: i32) -> i1 {
+ %cst = arith.constant -4.400000e+00 : f64
+ %1 = arith.uitofp %arg0: i32 to f64
+ %2 = arith.cmpf olt, %1, %cst : f64
+ return %2 : i1
+ // CHECK: %[[false:.+]] = arith.constant false
+ // CHECK: return %[[false]] : i1
+}
+
+// Check that optimizing unsigned >= comparisons correctly distinguishes
+// positive and negative constants.
+// CHECK-LABEL: @test7(
+// CHECK-SAME: %[[arg0:.+]]:
+func @test7(%arg0: i32) -> i1 {
+ %cst = arith.constant 3.200000e+00 : f64
+ %1 = arith.uitofp %arg0: i32 to f64
+ %2 = arith.cmpf oge, %1, %cst : f64
+ return %2 : i1
+ // CHECK: %[[c3:.+]] = arith.constant 3 : i32
+ // CHECK: arith.cmpi ugt, %[[arg0]], %[[c3]] : i32
+}