}];
}
+// The predicate indicates the type of the comparison to perform:
+// (un)orderedness, (in)equality and less/greater than (or equal to) as
+// well as predicates that are always true or false.
+def CMPF_P_FALSE : I64EnumAttrCase<"AlwaysFalse", 0, "false">;
+def CMPF_P_OEQ : I64EnumAttrCase<"OEQ", 1, "oeq">;
+def CMPF_P_OGT : I64EnumAttrCase<"OGT", 2, "ogt">;
+def CMPF_P_OGE : I64EnumAttrCase<"OGE", 3, "oge">;
+def CMPF_P_OLT : I64EnumAttrCase<"OLT", 4, "olt">;
+def CMPF_P_OLE : I64EnumAttrCase<"OLE", 5, "ole">;
+def CMPF_P_ONE : I64EnumAttrCase<"ONE", 6, "one">;
+def CMPF_P_ORD : I64EnumAttrCase<"ORD", 7, "ord">;
+def CMPF_P_UEQ : I64EnumAttrCase<"UEQ", 8, "ueq">;
+def CMPF_P_UGT : I64EnumAttrCase<"UGT", 9, "ugt">;
+def CMPF_P_UGE : I64EnumAttrCase<"UGE", 10, "uge">;
+def CMPF_P_ULT : I64EnumAttrCase<"ULT", 11, "ult">;
+def CMPF_P_ULE : I64EnumAttrCase<"ULE", 12, "ule">;
+def CMPF_P_UNE : I64EnumAttrCase<"UNE", 13, "une">;
+def CMPF_P_UNO : I64EnumAttrCase<"UNO", 14, "uno">;
+def CMPF_P_TRUE : I64EnumAttrCase<"AlwaysTrue", 15, "true">;
+
+def CmpFPredicateAttr : I64EnumAttr<
+ "CmpFPredicate", "",
+ [CMPF_P_FALSE, CMPF_P_OEQ, CMPF_P_OGT, CMPF_P_OGE, CMPF_P_OLT, CMPF_P_OLE,
+ CMPF_P_ONE, CMPF_P_ORD, CMPF_P_UEQ, CMPF_P_UGT, CMPF_P_UGE, CMPF_P_ULT,
+ CMPF_P_ULE, CMPF_P_UNE, CMPF_P_UNO, CMPF_P_TRUE]> {
+ let cppNamespace = "::mlir";
+}
+
def CmpFOp : Std_Op<"cmpf",
[NoSideEffect, SameTypeOperands, SameOperandsAndResultShape,
TypesMatchWith<
%r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1
}];
- let arguments = (ins FloatLike:$lhs, FloatLike:$rhs);
+ let arguments = (ins
+ CmpFPredicateAttr:$predicate,
+ FloatLike:$lhs,
+ FloatLike:$rhs
+ );
let results = (outs BoolLike:$result);
let builders = [OpBuilder<
}
}];
+ let verifier = [{ return success(); }];
+
let hasFolder = 1;
+
+ let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
}
def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>;
// CmpFOp
//===----------------------------------------------------------------------===//
-// Returns an array of mnemonics for CmpFPredicates indexed by values thereof.
-static inline const char *const *getCmpFPredicateNames() {
- static const char *predicateNames[] = {
- /*AlwaysFalse*/ "false",
- /*OEQ*/ "oeq",
- /*OGT*/ "ogt",
- /*OGE*/ "oge",
- /*OLT*/ "olt",
- /*OLE*/ "ole",
- /*ONE*/ "one",
- /*ORD*/ "ord",
- /*UEQ*/ "ueq",
- /*UGT*/ "ugt",
- /*UGE*/ "uge",
- /*ULT*/ "ult",
- /*ULE*/ "ule",
- /*UNE*/ "une",
- /*UNO*/ "uno",
- /*AlwaysTrue*/ "true",
- };
- static_assert(std::extent<decltype(predicateNames)>::value ==
- (size_t)CmpFPredicate::NumPredicates,
- "wrong number of predicate names");
- return predicateNames;
-}
-
-// Returns a value of the predicate corresponding to the given mnemonic.
-// Returns NumPredicates (one-past-end) if there is no such mnemonic.
-CmpFPredicate CmpFOp::getPredicateByName(StringRef name) {
- return llvm::StringSwitch<CmpFPredicate>(name)
- .Case("false", CmpFPredicate::AlwaysFalse)
- .Case("oeq", CmpFPredicate::OEQ)
- .Case("ogt", CmpFPredicate::OGT)
- .Case("oge", CmpFPredicate::OGE)
- .Case("olt", CmpFPredicate::OLT)
- .Case("ole", CmpFPredicate::OLE)
- .Case("one", CmpFPredicate::ONE)
- .Case("ord", CmpFPredicate::ORD)
- .Case("ueq", CmpFPredicate::UEQ)
- .Case("ugt", CmpFPredicate::UGT)
- .Case("uge", CmpFPredicate::UGE)
- .Case("ult", CmpFPredicate::ULT)
- .Case("ule", CmpFPredicate::ULE)
- .Case("une", CmpFPredicate::UNE)
- .Case("uno", CmpFPredicate::UNO)
- .Case("true", CmpFPredicate::AlwaysTrue)
- .Default(CmpFPredicate::NumPredicates);
-}
-
static void buildCmpFOp(Builder *build, OperationState &result,
CmpFPredicate predicate, Value lhs, Value rhs) {
result.addOperands({lhs, rhs});
build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
}
-static ParseResult parseCmpFOp(OpAsmParser &parser, OperationState &result) {
- SmallVector<OpAsmParser::OperandType, 2> ops;
- SmallVector<NamedAttribute, 4> attrs;
- Attribute predicateNameAttr;
- Type type;
- if (parser.parseAttribute(predicateNameAttr, CmpFOp::getPredicateAttrName(),
- attrs) ||
- parser.parseComma() || parser.parseOperandList(ops, 2) ||
- parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) ||
- parser.resolveOperands(ops, type, result.operands))
- return failure();
-
- if (!predicateNameAttr.isa<StringAttr>())
- return parser.emitError(parser.getNameLoc(),
- "expected string comparison predicate attribute");
-
- // Rewrite string attribute to an enum value.
- StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
- auto predicate = CmpFOp::getPredicateByName(predicateName);
- if (predicate == CmpFPredicate::NumPredicates)
- return parser.emitError(parser.getNameLoc(),
- "unknown comparison predicate \"" + predicateName +
- "\"");
-
- auto builder = parser.getBuilder();
- Type i1Type = getCheckedI1SameShape(type);
- if (!i1Type)
- return parser.emitError(parser.getNameLoc(),
- "expected type with valid i1 shape");
-
- attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
- result.attributes = attrs;
-
- result.addTypes({i1Type});
- return success();
-}
-
-static void print(OpAsmPrinter &p, CmpFOp op) {
- p << "cmpf ";
-
- auto predicateValue =
- op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()).getInt();
- assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) &&
- predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) &&
- "unknown predicate index");
- p << '"' << getCmpFPredicateNames()[predicateValue] << '"' << ", " << op.lhs()
- << ", " << op.rhs();
- p.printOptionalAttrDict(op.getAttrs(),
- /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()});
- p << " : " << op.lhs().getType();
-}
-
-static LogicalResult verify(CmpFOp op) {
- auto predicateAttr =
- op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName());
- if (!predicateAttr)
- return op.emitOpError("requires an integer attribute named 'predicate'");
- auto predicate = predicateAttr.getInt();
- if (predicate < (int64_t)CmpFPredicate::FirstValidValue ||
- predicate >= (int64_t)CmpFPredicate::NumPredicates)
- return op.emitOpError("'predicate' attribute value out of range");
-
- return success();
-}
-
-// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
-// comparison predicates.
+/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
+/// comparison predicates.
static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
const APFloat &rhs) {
auto cmpResult = lhs.compare(rhs);
// -----
func @cmpf_generic_invalid_predicate_value(%a : f32) {
- // expected-error@+1 {{'predicate' attribute value out of range}}
+ // expected-error@+1 {{attribute 'predicate' failed to satisfy constraint: allowed 64-bit integer cases}}
%r = "std.cmpf"(%a, %a) {predicate = 42} : (f32, f32) -> i1
}
// -----
func @cmpf_canonical_invalid_predicate_value(%a : f32) {
- // expected-error@+1 {{unknown comparison predicate "foo"}}
+ // expected-error@+1 {{invalid predicate attribute specification: "foo"}}
%r = cmpf "foo", %a, %a : f32
}
// -----
func @cmpf_canonical_invalid_predicate_value_signed(%a : f32) {
- // expected-error@+1 {{unknown comparison predicate "sge"}}
+ // expected-error@+1 {{invalid predicate attribute specification: "sge"}}
%r = cmpf "sge", %a, %a : f32
}
// -----
func @cmpf_canonical_invalid_predicate_value_no_order(%a : f32) {
- // expected-error@+1 {{unknown comparison predicate "eq"}}
+ // expected-error@+1 {{invalid predicate attribute specification: "eq"}}
%r = cmpf "eq", %a, %a : f32
}
// -----
func @cmpf_generic_no_predicate_attr(%a : f32, %b : f32) {
- // expected-error@+1 {{requires an integer attribute named 'predicate'}}
+ // expected-error@+1 {{requires attribute 'predicate'}}
%r = "std.cmpf"(%a, %b) {foo = 1} : (f32, f32) -> i1
}
// -----
func @cmpf_wrong_type(%a : i32, %b : i32) {
- %r = cmpf "oeq", %a, %b : i32 // expected-error {{operand #0 must be floating-point-like}}
+ %r = cmpf "oeq", %a, %b : i32 // expected-error {{must be floating-point-like}}
}
// -----