LogicalResult verifyOneOperand(Operation *op);
LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
+LogicalResult verifyOperandsAreFloatLike(Operation *op);
LogicalResult verifyOperandsAreIntegerLike(Operation *op);
LogicalResult verifySameTypeOperands(Operation *op);
LogicalResult verifyZeroResult(Operation *op);
}
};
+/// This class verifies that all operands of the specified op have a float type,
+/// a vector thereof, or a tensor thereof.
+template <typename ConcreteType>
+class OperandsAreFloatLike
+ : public TraitBase<ConcreteType, OperandsAreFloatLike> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ return impl::verifyOperandsAreFloatLike(op);
+ }
+};
+
/// This class verifies that all operands of the specified op have an integer or
/// index type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
};
+/// The predicate indicates the type of the comparison to perform:
+/// (un)orderedness, (in)equality and signed less/greater than (or equal to) as
+/// well as predicates that are always true or false.
+enum class CmpFPredicate {
+ FirstValidValue,
+ // Always false
+ FALSE = FirstValidValue,
+ // Ordered comparisons
+ OEQ,
+ OGT,
+ OGE,
+ OLT,
+ OLE,
+ ONE,
+ // Both ordered
+ ORD,
+ // Unordered comparisons
+ UEQ,
+ UGT,
+ UGE,
+ ULT,
+ ULE,
+ UNE,
+ // Any unordered
+ UNO,
+ // Always true
+ TRUE,
+ // Number of predicates.
+ NumPredicates
+};
+
+/// The "cmpf" operation compares its two operands according to the float
+/// comparison rules and the predicate specified by the respective attribute.
+/// The predicate defines the type of comparison: (un)orderedness, (in)equality
+/// and signed less/greater than (or equal to) as well as predicates that are
+/// always true or false. The operands must have the same type, and this type
+/// must be a float type, or a vector or tensor thereof. The result is an i1,
+/// or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi,
+/// the operands are always treated as signed. The u prefix indicates
+/// *unordered* comparison, not unsigned comparison, so "une" means unordered or
+/// not equal. For the sake of readability by humans, custom assembly form for
+/// the operation uses a string-typed attribute for the predicate. The value of
+/// this attribute corresponds to lower-cased name of the predicate constant,
+/// e.g., "one" means "ordered not equal". The string representation of the
+/// attribute is merely a syntactic sugar and is converted to an integer
+/// attribute by the parser.
+///
+/// %r1 = cmpf "oeq" %0, %1 : f32
+/// %r2 = cmpf "ult" %0, %1 : tensor<42x42xf64>
+/// %r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1
+class CmpFOp
+ : public Op<CmpFOp, OpTrait::OperandsAreFloatLike,
+ OpTrait::SameTypeOperands, OpTrait::NOperands<2>::Impl,
+ OpTrait::OneResult, OpTrait::ResultsAreBoolLike,
+ OpTrait::SameOperandsAndResultShape, OpTrait::HasNoSideEffect> {
+public:
+ using Op::Op;
+
+ CmpFPredicate getPredicate() {
+ return (CmpFPredicate)getAttrOfType<IntegerAttr>(getPredicateAttrName())
+ .getInt();
+ }
+
+ static StringRef getOperationName() { return "std.cmpf"; }
+ static StringRef getPredicateAttrName() { return "predicate"; }
+ static CmpFPredicate getPredicateByName(StringRef name);
+
+ static void build(Builder *builder, OperationState *result, CmpFPredicate,
+ Value *lhs, Value *rhs);
+ static bool parse(OpAsmParser *parser, OperationState *result);
+ void print(OpAsmPrinter *p);
+ LogicalResult verify();
+ Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
+};
+
/// The "cond_br" operation represents a conditional branch operation in a
/// function. The operation takes variable number of operands and produces
/// no results. The operand number and types for each successor must match the
return success();
}
+LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) {
+ for (auto *operand : op->getOperands()) {
+ auto type = getTensorOrVectorElementType(operand->getType());
+ if (!type.isa<FloatType>())
+ return op->emitOpError("requires a float type");
+ }
+ return success();
+}
+
LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) {
// Zero or one operand always have the "same" type.
unsigned nOperands = op->getNumOperands();
StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
: Dialect(/*name=*/"std", context) {
- addOperations<AllocOp, BranchOp, CallOp, CallIndirectOp, CmpIOp, CondBranchOp,
- DeallocOp, DimOp, DmaStartOp, DmaWaitOp, ExtractElementOp,
- LoadOp, MemRefCastOp, ReturnOp, SelectOp, StoreOp, TensorCastOp,
+ addOperations<AllocOp, BranchOp, CallOp, CallIndirectOp, CmpFOp, CmpIOp,
+ CondBranchOp, DeallocOp, DimOp, DmaStartOp, DmaWaitOp,
+ ExtractElementOp, LoadOp, MemRefCastOp, ReturnOp, SelectOp,
+ StoreOp, TensorCastOp,
#define GET_OP_LIST
#include "mlir/StandardOps/Ops.cpp.inc"
>();
}
//===----------------------------------------------------------------------===//
-// CmpIOp
+// General helpers for comparison ops
//===----------------------------------------------------------------------===//
// Return the type of the same shape (scalar, vector or tensor) containing i1.
llvm_unreachable("unsupported type");
}
-// Returns an array of mnemonics for CmpIPredicates, indexed by values thereof.
-static inline const char *const *getPredicateNames() {
- static const char *predicateNames[(int)CmpIPredicate::NumPredicates]{
+//===----------------------------------------------------------------------===//
+// CmpIOp
+//===----------------------------------------------------------------------===//
+
+// Returns an array of mnemonics for CmpIPredicates indexed by values thereof.
+static inline const char *const *getCmpIPredicateNames() {
+ static const char *predicateNames[]{
/*EQ*/ "eq",
/*NE*/ "ne",
/*SLT*/ "slt",
/*ULT*/ "ult",
/*ULE*/ "ule",
/*UGT*/ "ugt",
- /*UGE*/ "uge"};
+ /*UGE*/ "uge",
+ };
+ static_assert(std::extent<decltype(predicateNames)>::value ==
+ (size_t)CmpIPredicate::NumPredicates,
+ "wrong number of predicate names");
return predicateNames;
}
CmpIPredicate predicate, Value *lhs, Value *rhs) {
result->addOperands({lhs, rhs});
result->types.push_back(getI1SameShape(build, lhs->getType()));
- result->addAttribute(getPredicateAttrName(),
- build->getIntegerAttr(build->getIntegerType(64),
- static_cast<int64_t>(predicate)));
+ result->addAttribute(
+ getPredicateAttrName(),
+ build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
}
bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
"unknown predicate index");
Builder b(getContext());
auto predicateStringAttr =
- b.getStringAttr(getPredicateNames()[predicateValue]);
+ b.getStringAttr(getCmpIPredicateNames()[predicateValue]);
p->printAttribute(predicateStringAttr);
*p << ", ";
}
//===----------------------------------------------------------------------===//
+// CmpFOp
+//===----------------------------------------------------------------------===//
+
+// Returns an array of mnemonics for CmpFPredicates indexed by values thereof.
+static inline const char *const *getCmpFPredicateNames() {
+ static const char *predicateNames[] = {
+ /*FALSE*/ "false",
+ /*OEQ*/ "oeq",
+ /*OGT*/ "ogt",
+ /*OGE*/ "oge",
+ /*OLT*/ "olt",
+ /*OLE*/ "ole",
+ /*ONE*/ "one",
+ /*ORD*/ "ord",
+ /*UEQ*/ "ueq",
+ /*UGT*/ "ugt",
+ /*UGE*/ "uge",
+ /*ULT*/ "ult",
+ /*ULE*/ "ule",
+ /*UNE*/ "une",
+ /*UNO*/ "uno",
+ /*TRUE*/ "true",
+ };
+ static_assert(std::extent<decltype(predicateNames)>::value ==
+ (size_t)CmpFPredicate::NumPredicates,
+ "wrong number of predicate names");
+ return predicateNames;
+}
+
+// Returns a value of the predicate corresponding to the given mnemonic.
+// Returns NumPredicates (one-past-end) if there is no such mnemonic.
+CmpFPredicate CmpFOp::getPredicateByName(StringRef name) {
+ return llvm::StringSwitch<CmpFPredicate>(name)
+ .Case("false", CmpFPredicate::FALSE)
+ .Case("oeq", CmpFPredicate::OEQ)
+ .Case("ogt", CmpFPredicate::OGT)
+ .Case("oge", CmpFPredicate::OGE)
+ .Case("olt", CmpFPredicate::OLT)
+ .Case("ole", CmpFPredicate::OLE)
+ .Case("one", CmpFPredicate::ONE)
+ .Case("ord", CmpFPredicate::ORD)
+ .Case("ueq", CmpFPredicate::UEQ)
+ .Case("ugt", CmpFPredicate::UGT)
+ .Case("uge", CmpFPredicate::UGE)
+ .Case("ult", CmpFPredicate::ULT)
+ .Case("ule", CmpFPredicate::ULE)
+ .Case("une", CmpFPredicate::UNE)
+ .Case("uno", CmpFPredicate::UNO)
+ .Case("true", CmpFPredicate::TRUE)
+ .Default(CmpFPredicate::NumPredicates);
+}
+
+void CmpFOp::build(Builder *build, OperationState *result,
+ CmpFPredicate predicate, Value *lhs, Value *rhs) {
+ result->addOperands({lhs, rhs});
+ result->types.push_back(getI1SameShape(build, lhs->getType()));
+ result->addAttribute(
+ getPredicateAttrName(),
+ build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
+}
+
+bool CmpFOp::parse(OpAsmParser *parser, OperationState *result) {
+ SmallVector<OpAsmParser::OperandType, 2> ops;
+ SmallVector<NamedAttribute, 4> attrs;
+ Attribute predicateNameAttr;
+ Type type;
+ if (parser->parseAttribute(predicateNameAttr, getPredicateAttrName(),
+ attrs) ||
+ parser->parseComma() || parser->parseOperandList(ops, 2) ||
+ parser->parseOptionalAttributeDict(attrs) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperands(ops, type, result->operands))
+ return true;
+
+ if (!predicateNameAttr.isa<StringAttr>())
+ return parser->emitError(parser->getNameLoc(),
+ "expected string comparison predicate attribute");
+
+ // Rewrite string attribute to an enum value.
+ StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
+ auto predicate = getPredicateByName(predicateName);
+ if (predicate == CmpFPredicate::NumPredicates)
+ return parser->emitError(parser->getNameLoc(),
+ "unknown comparison predicate \"" + predicateName +
+ "\"");
+
+ auto builder = parser->getBuilder();
+ Type i1Type = getCheckedI1SameShape(&builder, type);
+ if (!i1Type)
+ return parser->emitError(parser->getNameLoc(),
+ "expected type with valid i1 shape");
+
+ attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
+ result->attributes = attrs;
+
+ result->addTypes({i1Type});
+ return false;
+}
+
+void CmpFOp::print(OpAsmPrinter *p) {
+ *p << "cmpf ";
+
+ auto predicateValue =
+ getAttrOfType<IntegerAttr>(getPredicateAttrName()).getInt();
+ assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) &&
+ predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) &&
+ "unknown predicate index");
+ Builder b(getContext());
+ auto predicateStringAttr =
+ b.getStringAttr(getCmpFPredicateNames()[predicateValue]);
+ p->printAttribute(predicateStringAttr);
+
+ *p << ", ";
+ p->printOperand(getOperand(0));
+ *p << ", ";
+ p->printOperand(getOperand(1));
+ p->printOptionalAttrDict(getAttrs(),
+ /*elidedAttrs=*/{getPredicateAttrName()});
+ *p << " : " << getOperand(0)->getType();
+}
+
+LogicalResult CmpFOp::verify() {
+ auto predicateAttr = getAttrOfType<IntegerAttr>(getPredicateAttrName());
+ if (!predicateAttr)
+ return emitOpError("requires an integer attribute named 'predicate'");
+ auto predicate = predicateAttr.getInt();
+ if (predicate < (int64_t)CmpFPredicate::FirstValidValue ||
+ predicate >= (int64_t)CmpFPredicate::NumPredicates)
+ return emitOpError("'predicate' attribute value out of range");
+
+ return success();
+}
+
+// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
+// comparison predicates.
+static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
+ const APFloat &rhs) {
+ auto cmpResult = lhs.compare(rhs);
+ switch (predicate) {
+ case CmpFPredicate::FALSE:
+ return false;
+ case CmpFPredicate::OEQ:
+ return cmpResult == APFloat::cmpEqual;
+ case CmpFPredicate::OGT:
+ return cmpResult == APFloat::cmpGreaterThan;
+ case CmpFPredicate::OGE:
+ return cmpResult == APFloat::cmpGreaterThan ||
+ cmpResult == APFloat::cmpEqual;
+ case CmpFPredicate::OLT:
+ return cmpResult == APFloat::cmpLessThan;
+ case CmpFPredicate::OLE:
+ return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
+ case CmpFPredicate::ONE:
+ return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
+ case CmpFPredicate::ORD:
+ return cmpResult != APFloat::cmpUnordered;
+ case CmpFPredicate::UEQ:
+ return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
+ case CmpFPredicate::UGT:
+ return cmpResult == APFloat::cmpUnordered ||
+ cmpResult == APFloat::cmpGreaterThan;
+ case CmpFPredicate::UGE:
+ return cmpResult == APFloat::cmpUnordered ||
+ cmpResult == APFloat::cmpGreaterThan ||
+ cmpResult == APFloat::cmpEqual;
+ case CmpFPredicate::ULT:
+ return cmpResult == APFloat::cmpUnordered ||
+ cmpResult == APFloat::cmpLessThan;
+ case CmpFPredicate::ULE:
+ return cmpResult == APFloat::cmpUnordered ||
+ cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
+ case CmpFPredicate::UNE:
+ return cmpResult != APFloat::cmpEqual;
+ case CmpFPredicate::UNO:
+ return cmpResult == APFloat::cmpUnordered;
+ case CmpFPredicate::TRUE:
+ return true;
+ default:
+ llvm_unreachable("unknown comparison predicate");
+ }
+}
+
+// Constant folding hook for comparisons.
+Attribute CmpFOp::constantFold(ArrayRef<Attribute> operands,
+ MLIRContext *context) {
+ assert(operands.size() == 2 && "cmpf takes two arguments");
+
+ auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
+ auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
+ if (!lhs || !rhs ||
+ // TODO(b/122019992) Implement and test constant folding for nan/inf when
+ // it is possible to have constant nan/inf
+ !lhs.getValue().isFinite() || !rhs.getValue().isFinite())
+ return {};
+
+ auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
+ return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val));
+}
+
+//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//
// CHECK: %{{[0-9]+}} = xor %cst_4, %cst_4 : tensor<42xi32>
%63 = xor %tci32, %tci32 : tensor<42 x i32>
+ %64 = constant splat<vector<4 x f32>, 0.> : vector<4 x f32>
+ %tcf32 = constant splat<tensor<42 x f32>, 0.> : tensor<42 x f32>
+ %vcf32 = constant splat<vector<4 x f32>, 0.> : vector<4 x f32>
+
+ // CHECK: %{{[0-9]+}} = cmpf "ogt", %{{[0-9]+}}, %{{[0-9]+}} : f32
+ %65 = cmpf "ogt", %f3, %f4 : f32
+
+ // Predicate 0 means ordered equality comparison.
+ // CHECK: %{{[0-9]+}} = cmpf "oeq", %{{[0-9]+}}, %{{[0-9]+}} : f32
+ %66 = "std.cmpf"(%f3, %f4) {predicate: 1} : (f32, f32) -> i1
+
+ // CHECK: %{{[0-9]+}} = cmpf "olt", %cst_8, %cst_8 : vector<4xf32>
+ %67 = cmpf "olt", %vcf32, %vcf32 : vector<4 x f32>
+
+ // CHECK: %{{[0-9]+}} = cmpf "oeq", %cst_8, %cst_8 : vector<4xf32>
+ %68 = "std.cmpf"(%vcf32, %vcf32) {predicate: 1} : (vector<4 x f32>, vector<4 x f32>) -> vector<4 x i1>
+
+ // CHECK: %{{[0-9]+}} = cmpf "oeq", %cst_7, %cst_7 : tensor<42xf32>
+ %69 = cmpf "oeq", %tcf32, %tcf32 : tensor<42 x f32>
+
+ // CHECK: %{{[0-9]+}} = cmpf "oeq", %cst_8, %cst_8 : vector<4xf32>
+ %70 = cmpf "oeq", %vcf32, %vcf32 : vector<4 x f32>
+
return
}
func @invalid_cmp_attr(%idx : i32) {
// expected-error@+1 {{expected string comparison predicate attribute}}
%cmp = cmpi i1, %idx, %idx : i32
+
+// -----
+
+func @cmpf_generic_invalid_predicate_value(%a : f32) {
+ // expected-error@+1 {{'predicate' attribute value out of range}}
+ %r = "std.cmpf"(%a, %a) {predicate: 42} : (f32, f32) -> i1
+}
+
+// -----
+
+func @cmpf_canonical_invalid_predicate_value(%a : f32) {
+ // expected-error@+1 {{unknown comparison predicate "foo"}}
+ %r = cmpf "foo", %a, %a : f32
+}
+
+// -----
+
+func @cmpf_canonical_invalid_predicate_value_signed(%a : f32) {
+ // expected-error@+1 {{unknown comparison predicate "sge"}}
+ %r = cmpf "sge", %a, %a : f32
+}
+
+// -----
+
+func @cmpf_canonical_invalid_predicate_value_no_order(%a : f32) {
+ // expected-error@+1 {{unknown comparison predicate "eq"}}
+ %r = cmpf "eq", %a, %a : f32
+}
+
+// -----
+
+func @cmpf_canonical_no_predicate_attr(%a : f32, %b : f32) {
+ %r = cmpf %a, %b : f32 // expected-error {{}}
+}
+
+// -----
+
+func @cmpf_generic_no_predicate_attr(%a : f32, %b : f32) {
+ // expected-error@+1 {{requires an integer attribute named 'predicate'}}
+ %r = "std.cmpf"(%a, %b) {foo: 1} : (f32, f32) -> i1
+}
+
+// -----
+
+func @cmpf_wrong_type(%a : i32, %b : i32) {
+ %r = cmpf "oeq", %a, %b : i32 // expected-error {{op requires a float type}}
+}
+
+// -----
+
+func @cmpf_generic_wrong_result_type(%a : f32, %b : f32) {
+ // expected-error@+1 {{op requires a bool result type}}
+ %r = "std.cmpf"(%a, %b) {predicate: 0} : (f32, f32) -> f32
+}
+
+// -----
+
+func @cmpf_canonical_wrong_result_type(%a : f32, %b : f32) -> f32 {
+ %r = cmpf "oeq", %a, %b : f32 // expected-error {{prior use here}}
+ // expected-error@+1 {{use of value '%r' expects different type than prior uses}}
+ return %r : f32
+}
+
+// -----
+
+func @cmpf_result_shape_mismatch(%a : vector<42xf32>) {
+ // expected-error@+1 {{op requires the same shape for all operands and results}}
+ %r = "std.cmpf"(%a, %a) {predicate: 0} : (vector<42 x f32>, vector<42 x f32>) -> vector<41 x i1>
+}
+
+// -----
+
+func @cmpf_operand_shape_mismatch(%a : vector<42xf32>, %b : vector<41xf32>) {
+ // expected-error@+1 {{op requires all operands to have the same type}}
+ %r = "std.cmpf"(%a, %b) {predicate: 0} : (vector<42 x f32>, vector<41 x f32>) -> vector<42 x i1>
+}
+
+// -----
+
+func @cmpf_generic_operand_type_mismatch(%a : f32, %b : f64) {
+ // expected-error@+1 {{op requires all operands to have the same type}}
+ %r = "std.cmpf"(%a, %b) {predicate: 0} : (f32, f64) -> i1
+}
+
+// -----
+
+func @cmpf_canonical_type_mismatch(%a : f32, %b : f64) { // expected-error {{prior use here}}
+ // expected-error@+1 {{use of value '%b' expects different type than prior uses}}
+ %r = cmpf "oeq", %a, %b : f32
+}
// -----
+// CHECK-LABEL: func @cmpf_normal_numbers
+func @cmpf_normal_numbers() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
+ %c42 = constant 42. : f32
+ %cm1 = constant -1. : f32
+ // CHECK-DAG: %false = constant 0 : i1
+ // CHECK-DAG: %true = constant 1 : i1
+ // CHECK-NEXT: return %false,
+ %0 = cmpf "false", %c42, %cm1 : f32
+ // CHECK-SAME: %false,
+ %1 = cmpf "oeq", %c42, %cm1 : f32
+ // CHECK-SAME: %true,
+ %2 = cmpf "ogt", %c42, %cm1 : f32
+ // CHECK-SAME: %true,
+ %3 = cmpf "oge", %c42, %cm1 : f32
+ // CHECK-SAME: %false,
+ %4 = cmpf "olt", %c42, %cm1 : f32
+ // CHECK-SAME: %false,
+ %5 = cmpf "ole", %c42, %cm1 : f32
+ // CHECK-SAME: %true,
+ %6 = cmpf "one", %c42, %cm1 : f32
+ // CHECK-SAME: %true,
+ %7 = cmpf "ord", %c42, %cm1 : f32
+ // CHECK-SAME: %false,
+ %8 = cmpf "ueq", %c42, %cm1 : f32
+ // CHECK-SAME: %true,
+ %9 = cmpf "ugt", %c42, %cm1 : f32
+ // CHECK-SAME: %true,
+ %10 = cmpf "uge", %c42, %cm1 : f32
+ // CHECK-SAME: %false,
+ %11 = cmpf "ult", %c42, %cm1 : f32
+ // CHECK-SAME: %false,
+ %12 = cmpf "ule", %c42, %cm1 : f32
+ // CHECK-SAME: %true,
+ %13 = cmpf "une", %c42, %cm1 : f32
+ // CHECK-SAME: %false,
+ %14 = cmpf "uno", %c42, %cm1 : f32
+ // CHECK-SAME: %true
+ %15 = cmpf "true", %c42, %cm1 : f32
+ return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
+}
+
+// -----
+
+// CHECK-LABEL: func @cmpf_nan
+func @cmpf_nans() {
+ // TODO(b/122019992) Add tests for nan constant folding when it's possible to
+ // have nan constants
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @cmpf_inf
+func @cmpf_inf() {
+ // TODO(b/122019992) Add tests for inf constant folding when it's possible to
+ // have inf constants
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @fold_extract_element
func @fold_extract_element(%arg0 : index) -> (f32, f16, f16, i32) {
%const_0 = constant 0 : index