From c34386e3e590438c3ae914659106d611d8035e70 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Mon, 6 May 2019 17:51:08 -0700 Subject: [PATCH] CmpFOp. Add float comparison op This closely mirrors the llvm fcmp instruction, defining 16 different predicates Constant folding is unsupported for NaN and Inf because there's no way to represent those as constants at the moment -- PiperOrigin-RevId: 246932358 --- mlir/include/mlir/IR/OpDefinition.h | 12 ++ mlir/include/mlir/StandardOps/Ops.h | 75 ++++++++++ mlir/lib/IR/Operation.cpp | 9 ++ mlir/lib/StandardOps/Ops.cpp | 233 ++++++++++++++++++++++++++++++-- mlir/test/IR/core-ops.mlir | 23 ++++ mlir/test/IR/invalid-ops.mlir | 90 ++++++++++++ mlir/test/Transforms/constant-fold.mlir | 61 +++++++++ 7 files changed, 491 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index e551140..a0f0d54 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -310,6 +310,7 @@ LogicalResult verifyZeroOperands(Operation *op); LogicalResult verifyOneOperand(Operation *op); LogicalResult verifyNOperands(Operation *op, unsigned numOperands); LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands); +LogicalResult verifyOperandsAreFloatLike(Operation *op); LogicalResult verifyOperandsAreIntegerLike(Operation *op); LogicalResult verifySameTypeOperands(Operation *op); LogicalResult verifyZeroResult(Operation *op); @@ -652,6 +653,17 @@ public: } }; +/// This class verifies that all operands of the specified op have a float type, +/// a vector thereof, or a tensor thereof. +template +class OperandsAreFloatLike + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyOperandsAreFloatLike(op); + } +}; + /// This class verifies that all operands of the specified op have an integer or /// index type, a vector thereof, or a tensor thereof. template diff --git a/mlir/include/mlir/StandardOps/Ops.h b/mlir/include/mlir/StandardOps/Ops.h index fd2cb88..4088108 100644 --- a/mlir/include/mlir/StandardOps/Ops.h +++ b/mlir/include/mlir/StandardOps/Ops.h @@ -255,6 +255,81 @@ public: Attribute constantFold(ArrayRef operands, MLIRContext *context); }; +/// The predicate indicates the type of the comparison to perform: +/// (un)orderedness, (in)equality and signed less/greater than (or equal to) as +/// well as predicates that are always true or false. +enum class CmpFPredicate { + FirstValidValue, + // Always false + FALSE = FirstValidValue, + // Ordered comparisons + OEQ, + OGT, + OGE, + OLT, + OLE, + ONE, + // Both ordered + ORD, + // Unordered comparisons + UEQ, + UGT, + UGE, + ULT, + ULE, + UNE, + // Any unordered + UNO, + // Always true + TRUE, + // Number of predicates. + NumPredicates +}; + +/// The "cmpf" operation compares its two operands according to the float +/// comparison rules and the predicate specified by the respective attribute. +/// The predicate defines the type of comparison: (un)orderedness, (in)equality +/// and signed less/greater than (or equal to) as well as predicates that are +/// always true or false. The operands must have the same type, and this type +/// must be a float type, or a vector or tensor thereof. The result is an i1, +/// or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi, +/// the operands are always treated as signed. The u prefix indicates +/// *unordered* comparison, not unsigned comparison, so "une" means unordered or +/// not equal. For the sake of readability by humans, custom assembly form for +/// the operation uses a string-typed attribute for the predicate. The value of +/// this attribute corresponds to lower-cased name of the predicate constant, +/// e.g., "one" means "ordered not equal". The string representation of the +/// attribute is merely a syntactic sugar and is converted to an integer +/// attribute by the parser. +/// +/// %r1 = cmpf "oeq" %0, %1 : f32 +/// %r2 = cmpf "ult" %0, %1 : tensor<42x42xf64> +/// %r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1 +class CmpFOp + : public Op::Impl, + OpTrait::OneResult, OpTrait::ResultsAreBoolLike, + OpTrait::SameOperandsAndResultShape, OpTrait::HasNoSideEffect> { +public: + using Op::Op; + + CmpFPredicate getPredicate() { + return (CmpFPredicate)getAttrOfType(getPredicateAttrName()) + .getInt(); + } + + static StringRef getOperationName() { return "std.cmpf"; } + static StringRef getPredicateAttrName() { return "predicate"; } + static CmpFPredicate getPredicateByName(StringRef name); + + static void build(Builder *builder, OperationState *result, CmpFPredicate, + Value *lhs, Value *rhs); + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + LogicalResult verify(); + Attribute constantFold(ArrayRef operands, MLIRContext *context); +}; + /// The "cond_br" operation represents a conditional branch operation in a /// function. The operation takes variable number of operands and produces /// no results. The operand number and types for each successor must match the diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 91b32de..6074fd5 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -725,6 +725,15 @@ LogicalResult OpTrait::impl::verifyOperandsAreIntegerLike(Operation *op) { return success(); } +LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { + for (auto *operand : op->getOperands()) { + auto type = getTensorOrVectorElementType(operand->getType()); + if (!type.isa()) + return op->emitOpError("requires a float type"); + } + return success(); +} + LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { // Zero or one operand always have the "same" type. unsigned nOperands = op->getNumOperands(); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 46baa74..969875f 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -61,9 +61,10 @@ void detail::printStandardBinaryOp(Operation *op, OpAsmPrinter *p) { StandardOpsDialect::StandardOpsDialect(MLIRContext *context) : Dialect(/*name=*/"std", context) { - addOperations(); @@ -577,7 +578,7 @@ void CallIndirectOp::getCanonicalizationPatterns( } //===----------------------------------------------------------------------===// -// CmpIOp +// General helpers for comparison ops //===----------------------------------------------------------------------===// // Return the type of the same shape (scalar, vector or tensor) containing i1. @@ -627,9 +628,13 @@ static bool checkI1SameShape(Type pattern, Type type) { llvm_unreachable("unsupported type"); } -// Returns an array of mnemonics for CmpIPredicates, indexed by values thereof. -static inline const char *const *getPredicateNames() { - static const char *predicateNames[(int)CmpIPredicate::NumPredicates]{ +//===----------------------------------------------------------------------===// +// CmpIOp +//===----------------------------------------------------------------------===// + +// Returns an array of mnemonics for CmpIPredicates indexed by values thereof. +static inline const char *const *getCmpIPredicateNames() { + static const char *predicateNames[]{ /*EQ*/ "eq", /*NE*/ "ne", /*SLT*/ "slt", @@ -639,7 +644,11 @@ static inline const char *const *getPredicateNames() { /*ULT*/ "ult", /*ULE*/ "ule", /*UGT*/ "ugt", - /*UGE*/ "uge"}; + /*UGE*/ "uge", + }; + static_assert(std::extent::value == + (size_t)CmpIPredicate::NumPredicates, + "wrong number of predicate names"); return predicateNames; } @@ -664,9 +673,9 @@ void CmpIOp::build(Builder *build, OperationState *result, CmpIPredicate predicate, Value *lhs, Value *rhs) { result->addOperands({lhs, rhs}); result->types.push_back(getI1SameShape(build, lhs->getType())); - result->addAttribute(getPredicateAttrName(), - build->getIntegerAttr(build->getIntegerType(64), - static_cast(predicate))); + result->addAttribute( + getPredicateAttrName(), + build->getI64IntegerAttr(static_cast(predicate))); } bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) { @@ -717,7 +726,7 @@ void CmpIOp::print(OpAsmPrinter *p) { "unknown predicate index"); Builder b(getContext()); auto predicateStringAttr = - b.getStringAttr(getPredicateNames()[predicateValue]); + b.getStringAttr(getCmpIPredicateNames()[predicateValue]); p->printAttribute(predicateStringAttr); *p << ", "; @@ -786,6 +795,206 @@ Attribute CmpIOp::constantFold(ArrayRef operands, } //===----------------------------------------------------------------------===// +// CmpFOp +//===----------------------------------------------------------------------===// + +// Returns an array of mnemonics for CmpFPredicates indexed by values thereof. +static inline const char *const *getCmpFPredicateNames() { + static const char *predicateNames[] = { + /*FALSE*/ "false", + /*OEQ*/ "oeq", + /*OGT*/ "ogt", + /*OGE*/ "oge", + /*OLT*/ "olt", + /*OLE*/ "ole", + /*ONE*/ "one", + /*ORD*/ "ord", + /*UEQ*/ "ueq", + /*UGT*/ "ugt", + /*UGE*/ "uge", + /*ULT*/ "ult", + /*ULE*/ "ule", + /*UNE*/ "une", + /*UNO*/ "uno", + /*TRUE*/ "true", + }; + static_assert(std::extent::value == + (size_t)CmpFPredicate::NumPredicates, + "wrong number of predicate names"); + return predicateNames; +} + +// Returns a value of the predicate corresponding to the given mnemonic. +// Returns NumPredicates (one-past-end) if there is no such mnemonic. +CmpFPredicate CmpFOp::getPredicateByName(StringRef name) { + return llvm::StringSwitch(name) + .Case("false", CmpFPredicate::FALSE) + .Case("oeq", CmpFPredicate::OEQ) + .Case("ogt", CmpFPredicate::OGT) + .Case("oge", CmpFPredicate::OGE) + .Case("olt", CmpFPredicate::OLT) + .Case("ole", CmpFPredicate::OLE) + .Case("one", CmpFPredicate::ONE) + .Case("ord", CmpFPredicate::ORD) + .Case("ueq", CmpFPredicate::UEQ) + .Case("ugt", CmpFPredicate::UGT) + .Case("uge", CmpFPredicate::UGE) + .Case("ult", CmpFPredicate::ULT) + .Case("ule", CmpFPredicate::ULE) + .Case("une", CmpFPredicate::UNE) + .Case("uno", CmpFPredicate::UNO) + .Case("true", CmpFPredicate::TRUE) + .Default(CmpFPredicate::NumPredicates); +} + +void CmpFOp::build(Builder *build, OperationState *result, + CmpFPredicate predicate, Value *lhs, Value *rhs) { + result->addOperands({lhs, rhs}); + result->types.push_back(getI1SameShape(build, lhs->getType())); + result->addAttribute( + getPredicateAttrName(), + build->getI64IntegerAttr(static_cast(predicate))); +} + +bool CmpFOp::parse(OpAsmParser *parser, OperationState *result) { + SmallVector ops; + SmallVector attrs; + Attribute predicateNameAttr; + Type type; + if (parser->parseAttribute(predicateNameAttr, getPredicateAttrName(), + attrs) || + parser->parseComma() || parser->parseOperandList(ops, 2) || + parser->parseOptionalAttributeDict(attrs) || + parser->parseColonType(type) || + parser->resolveOperands(ops, type, result->operands)) + return true; + + if (!predicateNameAttr.isa()) + return parser->emitError(parser->getNameLoc(), + "expected string comparison predicate attribute"); + + // Rewrite string attribute to an enum value. + StringRef predicateName = predicateNameAttr.cast().getValue(); + auto predicate = getPredicateByName(predicateName); + if (predicate == CmpFPredicate::NumPredicates) + return parser->emitError(parser->getNameLoc(), + "unknown comparison predicate \"" + predicateName + + "\""); + + auto builder = parser->getBuilder(); + Type i1Type = getCheckedI1SameShape(&builder, type); + if (!i1Type) + return parser->emitError(parser->getNameLoc(), + "expected type with valid i1 shape"); + + attrs[0].second = builder.getI64IntegerAttr(static_cast(predicate)); + result->attributes = attrs; + + result->addTypes({i1Type}); + return false; +} + +void CmpFOp::print(OpAsmPrinter *p) { + *p << "cmpf "; + + auto predicateValue = + getAttrOfType(getPredicateAttrName()).getInt(); + assert(predicateValue >= static_cast(CmpFPredicate::FirstValidValue) && + predicateValue < static_cast(CmpFPredicate::NumPredicates) && + "unknown predicate index"); + Builder b(getContext()); + auto predicateStringAttr = + b.getStringAttr(getCmpFPredicateNames()[predicateValue]); + p->printAttribute(predicateStringAttr); + + *p << ", "; + p->printOperand(getOperand(0)); + *p << ", "; + p->printOperand(getOperand(1)); + p->printOptionalAttrDict(getAttrs(), + /*elidedAttrs=*/{getPredicateAttrName()}); + *p << " : " << getOperand(0)->getType(); +} + +LogicalResult CmpFOp::verify() { + auto predicateAttr = getAttrOfType(getPredicateAttrName()); + if (!predicateAttr) + return emitOpError("requires an integer attribute named 'predicate'"); + auto predicate = predicateAttr.getInt(); + if (predicate < (int64_t)CmpFPredicate::FirstValidValue || + predicate >= (int64_t)CmpFPredicate::NumPredicates) + return emitOpError("'predicate' attribute value out of range"); + + return success(); +} + +// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point +// comparison predicates. +static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, + const APFloat &rhs) { + auto cmpResult = lhs.compare(rhs); + switch (predicate) { + case CmpFPredicate::FALSE: + return false; + case CmpFPredicate::OEQ: + return cmpResult == APFloat::cmpEqual; + case CmpFPredicate::OGT: + return cmpResult == APFloat::cmpGreaterThan; + case CmpFPredicate::OGE: + return cmpResult == APFloat::cmpGreaterThan || + cmpResult == APFloat::cmpEqual; + case CmpFPredicate::OLT: + return cmpResult == APFloat::cmpLessThan; + case CmpFPredicate::OLE: + return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; + case CmpFPredicate::ONE: + return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; + case CmpFPredicate::ORD: + return cmpResult != APFloat::cmpUnordered; + case CmpFPredicate::UEQ: + return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; + case CmpFPredicate::UGT: + return cmpResult == APFloat::cmpUnordered || + cmpResult == APFloat::cmpGreaterThan; + case CmpFPredicate::UGE: + return cmpResult == APFloat::cmpUnordered || + cmpResult == APFloat::cmpGreaterThan || + cmpResult == APFloat::cmpEqual; + case CmpFPredicate::ULT: + return cmpResult == APFloat::cmpUnordered || + cmpResult == APFloat::cmpLessThan; + case CmpFPredicate::ULE: + return cmpResult == APFloat::cmpUnordered || + cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; + case CmpFPredicate::UNE: + return cmpResult != APFloat::cmpEqual; + case CmpFPredicate::UNO: + return cmpResult == APFloat::cmpUnordered; + case CmpFPredicate::TRUE: + return true; + default: + llvm_unreachable("unknown comparison predicate"); + } +} + +// Constant folding hook for comparisons. +Attribute CmpFOp::constantFold(ArrayRef operands, + MLIRContext *context) { + assert(operands.size() == 2 && "cmpf takes two arguments"); + + auto lhs = operands.front().dyn_cast_or_null(); + auto rhs = operands.back().dyn_cast_or_null(); + if (!lhs || !rhs || + // TODO(b/122019992) Implement and test constant folding for nan/inf when + // it is possible to have constant nan/inf + !lhs.getValue().isFinite() || !rhs.getValue().isFinite()) + return {}; + + auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); + return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val)); +} + +//===----------------------------------------------------------------------===// // CondBranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index 67c03ce..fd9dc04 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -253,6 +253,29 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index) { // CHECK: %{{[0-9]+}} = xor %cst_4, %cst_4 : tensor<42xi32> %63 = xor %tci32, %tci32 : tensor<42 x i32> + %64 = constant splat, 0.> : vector<4 x f32> + %tcf32 = constant splat, 0.> : tensor<42 x f32> + %vcf32 = constant splat, 0.> : vector<4 x f32> + + // CHECK: %{{[0-9]+}} = cmpf "ogt", %{{[0-9]+}}, %{{[0-9]+}} : f32 + %65 = cmpf "ogt", %f3, %f4 : f32 + + // Predicate 0 means ordered equality comparison. + // CHECK: %{{[0-9]+}} = cmpf "oeq", %{{[0-9]+}}, %{{[0-9]+}} : f32 + %66 = "std.cmpf"(%f3, %f4) {predicate: 1} : (f32, f32) -> i1 + + // CHECK: %{{[0-9]+}} = cmpf "olt", %cst_8, %cst_8 : vector<4xf32> + %67 = cmpf "olt", %vcf32, %vcf32 : vector<4 x f32> + + // CHECK: %{{[0-9]+}} = cmpf "oeq", %cst_8, %cst_8 : vector<4xf32> + %68 = "std.cmpf"(%vcf32, %vcf32) {predicate: 1} : (vector<4 x f32>, vector<4 x f32>) -> vector<4 x i1> + + // CHECK: %{{[0-9]+}} = cmpf "oeq", %cst_7, %cst_7 : tensor<42xf32> + %69 = cmpf "oeq", %tcf32, %tcf32 : tensor<42 x f32> + + // CHECK: %{{[0-9]+}} = cmpf "oeq", %cst_8, %cst_8 : vector<4xf32> + %70 = cmpf "oeq", %vcf32, %vcf32 : vector<4 x f32> + return } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 7d49b7f..b2bd89b 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -528,3 +528,93 @@ func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) { func @invalid_cmp_attr(%idx : i32) { // expected-error@+1 {{expected string comparison predicate attribute}} %cmp = cmpi i1, %idx, %idx : i32 + +// ----- + +func @cmpf_generic_invalid_predicate_value(%a : f32) { + // expected-error@+1 {{'predicate' attribute value out of range}} + %r = "std.cmpf"(%a, %a) {predicate: 42} : (f32, f32) -> i1 +} + +// ----- + +func @cmpf_canonical_invalid_predicate_value(%a : f32) { + // expected-error@+1 {{unknown comparison predicate "foo"}} + %r = cmpf "foo", %a, %a : f32 +} + +// ----- + +func @cmpf_canonical_invalid_predicate_value_signed(%a : f32) { + // expected-error@+1 {{unknown comparison predicate "sge"}} + %r = cmpf "sge", %a, %a : f32 +} + +// ----- + +func @cmpf_canonical_invalid_predicate_value_no_order(%a : f32) { + // expected-error@+1 {{unknown comparison predicate "eq"}} + %r = cmpf "eq", %a, %a : f32 +} + +// ----- + +func @cmpf_canonical_no_predicate_attr(%a : f32, %b : f32) { + %r = cmpf %a, %b : f32 // expected-error {{}} +} + +// ----- + +func @cmpf_generic_no_predicate_attr(%a : f32, %b : f32) { + // expected-error@+1 {{requires an integer attribute named 'predicate'}} + %r = "std.cmpf"(%a, %b) {foo: 1} : (f32, f32) -> i1 +} + +// ----- + +func @cmpf_wrong_type(%a : i32, %b : i32) { + %r = cmpf "oeq", %a, %b : i32 // expected-error {{op requires a float type}} +} + +// ----- + +func @cmpf_generic_wrong_result_type(%a : f32, %b : f32) { + // expected-error@+1 {{op requires a bool result type}} + %r = "std.cmpf"(%a, %b) {predicate: 0} : (f32, f32) -> f32 +} + +// ----- + +func @cmpf_canonical_wrong_result_type(%a : f32, %b : f32) -> f32 { + %r = cmpf "oeq", %a, %b : f32 // expected-error {{prior use here}} + // expected-error@+1 {{use of value '%r' expects different type than prior uses}} + return %r : f32 +} + +// ----- + +func @cmpf_result_shape_mismatch(%a : vector<42xf32>) { + // expected-error@+1 {{op requires the same shape for all operands and results}} + %r = "std.cmpf"(%a, %a) {predicate: 0} : (vector<42 x f32>, vector<42 x f32>) -> vector<41 x i1> +} + +// ----- + +func @cmpf_operand_shape_mismatch(%a : vector<42xf32>, %b : vector<41xf32>) { + // expected-error@+1 {{op requires all operands to have the same type}} + %r = "std.cmpf"(%a, %b) {predicate: 0} : (vector<42 x f32>, vector<41 x f32>) -> vector<42 x i1> +} + +// ----- + +func @cmpf_generic_operand_type_mismatch(%a : f32, %b : f64) { + // expected-error@+1 {{op requires all operands to have the same type}} + %r = "std.cmpf"(%a, %b) {predicate: 0} : (f32, f64) -> i1 +} + +// ----- + +func @cmpf_canonical_type_mismatch(%a : f32, %b : f64) { // expected-error {{prior use here}} + // expected-error@+1 {{use of value '%b' expects different type than prior uses}} + %r = cmpf "oeq", %a, %b : f32 +} diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir index 198fd82..f249e99 100644 --- a/mlir/test/Transforms/constant-fold.mlir +++ b/mlir/test/Transforms/constant-fold.mlir @@ -330,6 +330,67 @@ func @cmpi() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) { // ----- +// CHECK-LABEL: func @cmpf_normal_numbers +func @cmpf_normal_numbers() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) { + %c42 = constant 42. : f32 + %cm1 = constant -1. : f32 + // CHECK-DAG: %false = constant 0 : i1 + // CHECK-DAG: %true = constant 1 : i1 + // CHECK-NEXT: return %false, + %0 = cmpf "false", %c42, %cm1 : f32 + // CHECK-SAME: %false, + %1 = cmpf "oeq", %c42, %cm1 : f32 + // CHECK-SAME: %true, + %2 = cmpf "ogt", %c42, %cm1 : f32 + // CHECK-SAME: %true, + %3 = cmpf "oge", %c42, %cm1 : f32 + // CHECK-SAME: %false, + %4 = cmpf "olt", %c42, %cm1 : f32 + // CHECK-SAME: %false, + %5 = cmpf "ole", %c42, %cm1 : f32 + // CHECK-SAME: %true, + %6 = cmpf "one", %c42, %cm1 : f32 + // CHECK-SAME: %true, + %7 = cmpf "ord", %c42, %cm1 : f32 + // CHECK-SAME: %false, + %8 = cmpf "ueq", %c42, %cm1 : f32 + // CHECK-SAME: %true, + %9 = cmpf "ugt", %c42, %cm1 : f32 + // CHECK-SAME: %true, + %10 = cmpf "uge", %c42, %cm1 : f32 + // CHECK-SAME: %false, + %11 = cmpf "ult", %c42, %cm1 : f32 + // CHECK-SAME: %false, + %12 = cmpf "ule", %c42, %cm1 : f32 + // CHECK-SAME: %true, + %13 = cmpf "une", %c42, %cm1 : f32 + // CHECK-SAME: %false, + %14 = cmpf "uno", %c42, %cm1 : f32 + // CHECK-SAME: %true + %15 = cmpf "true", %c42, %cm1 : f32 + return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 +} + +// ----- + +// CHECK-LABEL: func @cmpf_nan +func @cmpf_nans() { + // TODO(b/122019992) Add tests for nan constant folding when it's possible to + // have nan constants + return +} + +// ----- + +// CHECK-LABEL: func @cmpf_inf +func @cmpf_inf() { + // TODO(b/122019992) Add tests for inf constant folding when it's possible to + // have inf constants + return +} + +// ----- + // CHECK-LABEL: func @fold_extract_element func @fold_extract_element(%arg0 : index) -> (f32, f16, f16, i32) { %const_0 = constant 0 : index -- 2.7.4