.def("__lt__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
- return ValueHandle::create<CmpIOp>(CmpIPredicate::SLT, lhs.value,
+ return ValueHandle::create<CmpIOp>(CmpIPredicate::slt, lhs.value,
rhs.value);
})
.def("__le__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
- return ValueHandle::create<CmpIOp>(CmpIPredicate::SLE, lhs.value,
+ return ValueHandle::create<CmpIOp>(CmpIPredicate::sle, lhs.value,
rhs.value);
})
.def("__gt__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
- return ValueHandle::create<CmpIOp>(CmpIPredicate::SGT, lhs.value,
+ return ValueHandle::create<CmpIOp>(CmpIPredicate::sgt, lhs.value,
rhs.value);
})
.def("__ge__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
- return ValueHandle::create<CmpIOp>(CmpIPredicate::SGE, lhs.value,
+ return ValueHandle::create<CmpIOp>(CmpIPredicate::sge, lhs.value,
rhs.value);
})
.def("__eq__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
- return ValueHandle::create<CmpIOp>(CmpIPredicate::EQ, lhs.value,
+ return ValueHandle::create<CmpIOp>(CmpIPredicate::eq, lhs.value,
rhs.value);
})
.def("__ne__",
[](PythonValueHandle lhs,
PythonValueHandle rhs) -> PythonValueHandle {
- return ValueHandle::create<CmpIOp>(CmpIPredicate::NE, lhs.value,
+ return ValueHandle::create<CmpIOp>(CmpIPredicate::ne, lhs.value,
rhs.value);
})
.def("__invert__",
set(LLVM_TARGET_DEFINITIONS Ops.td)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
+mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRStandardOpsIncGen)
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
+// Pull in all enum type definitions and utility function declarations.
+#include "mlir/Dialect/StandardOps/OpsEnums.h.inc"
+
namespace mlir {
class AffineMap;
class Builder;
};
/// The predicate indicates the type of the comparison to perform:
-/// (in)equality; (un)signed less/greater than (or equal to).
-enum class CmpIPredicate {
- FirstValidValue,
- // (In)equality comparisons.
- EQ = FirstValidValue,
- NE,
- // Signed comparisons.
- SLT,
- SLE,
- SGT,
- SGE,
- // Unsigned comparisons.
- ULT,
- ULE,
- UGT,
- UGE,
- // Number of predicates.
- NumPredicates
-};
-
-/// The predicate indicates the type of the comparison to perform:
/// (un)orderedness, (in)equality and less/greater than (or equal to) as
/// well as predicates that are always true or false.
enum class CmpFPredicate {
let hasCanonicalizer = 1;
}
+def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>;
+def CMPI_P_NE : I64EnumAttrCase<"ne", 1>;
+def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>;
+def CMPI_P_SLE : I64EnumAttrCase<"sle", 3>;
+def CMPI_P_SGT : I64EnumAttrCase<"sgt", 4>;
+def CMPI_P_SGE : I64EnumAttrCase<"sge", 5>;
+def CMPI_P_ULT : I64EnumAttrCase<"ult", 6>;
+def CMPI_P_ULE : I64EnumAttrCase<"ule", 7>;
+def CMPI_P_UGT : I64EnumAttrCase<"ugt", 8>;
+def CMPI_P_UGE : I64EnumAttrCase<"uge", 9>;
+
+def CmpIPredicateAttr : I64EnumAttr<
+ "CmpIPredicate", "",
+ [CMPI_P_EQ, CMPI_P_NE, CMPI_P_SLT, CMPI_P_SLE, CMPI_P_SGT,
+ CMPI_P_SGE, CMPI_P_ULT, CMPI_P_ULE, CMPI_P_UGT, CMPI_P_UGE]> {
+ let cppNamespace = "::mlir";
+}
+
def CmpIOp : Std_Op<"cmpi",
[NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> {
let summary = "integer comparison operation";
%r3 = "std.cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1
}];
- let arguments = (ins IntegerLike:$lhs, IntegerLike:$rhs);
+ let arguments = (ins
+ CmpIPredicateAttr:$predicate,
+ IntegerLike:$lhs,
+ IntegerLike:$rhs
+ );
let results = (outs BoolLike);
let builders = [OpBuilder<
}
}];
+ let verifier = [{ return success(); }];
+
let hasFolder = 1;
}
Value *remainder = builder.create<RemISOp>(loc, lhs, rhs);
Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
Value *isRemainderNegative =
- builder.create<CmpIOp>(loc, CmpIPredicate::SLT, remainder, zeroCst);
+ builder.create<CmpIOp>(loc, CmpIPredicate::slt, remainder, zeroCst);
Value *correctedRemainder = builder.create<AddIOp>(loc, remainder, rhs);
Value *result = builder.create<SelectOp>(loc, isRemainderNegative,
correctedRemainder, remainder);
Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
Value *noneCst = builder.create<ConstantIndexOp>(loc, -1);
Value *negative =
- builder.create<CmpIOp>(loc, CmpIPredicate::SLT, lhs, zeroCst);
+ builder.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, zeroCst);
Value *negatedDecremented = builder.create<SubIOp>(loc, noneCst, lhs);
Value *dividend =
builder.create<SelectOp>(loc, negative, negatedDecremented, lhs);
Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
Value *oneCst = builder.create<ConstantIndexOp>(loc, 1);
Value *nonPositive =
- builder.create<CmpIOp>(loc, CmpIPredicate::SLE, lhs, zeroCst);
+ builder.create<CmpIOp>(loc, CmpIPredicate::sle, lhs, zeroCst);
Value *negated = builder.create<SubIOp>(loc, zeroCst, lhs);
Value *decremented = builder.create<SubIOp>(loc, lhs, oneCst);
Value *dividend =
boundOperands);
if (!lbValues)
return nullptr;
- return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SGT, *lbValues,
+ return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::sgt, *lbValues,
builder);
}
boundOperands);
if (!ubValues)
return nullptr;
- return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SLT, *ubValues,
+ return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::slt, *ubValues,
builder);
}
operandsRef.drop_front(numDims));
if (!affResult)
return matchFailure();
- auto pred = isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE;
+ auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge;
Value *cmpVal =
rewriter.create<CmpIOp>(loc, pred, affResult, zeroConstant);
cond =
// With the body block done, we can fill in the condition block.
rewriter.setInsertionPointToEnd(conditionBlock);
auto comparison =
- rewriter.create<CmpIOp>(loc, CmpIPredicate::SLT, iv, upperBound);
+ rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, iv, upperBound);
rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
ArrayRef<Value *>(), endBlock,
cmpIOpOperands.rhs()); \
return matchSuccess();
- DISPATCH(CmpIPredicate::EQ, spirv::IEqualOp);
- DISPATCH(CmpIPredicate::NE, spirv::INotEqualOp);
- DISPATCH(CmpIPredicate::SLT, spirv::SLessThanOp);
- DISPATCH(CmpIPredicate::SLE, spirv::SLessThanEqualOp);
- DISPATCH(CmpIPredicate::SGT, spirv::SGreaterThanOp);
- DISPATCH(CmpIPredicate::SGE, spirv::SGreaterThanEqualOp);
+ DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
+ DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
+ DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp);
+ DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
+ DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
+ DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
#undef DISPATCH
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
+// Pull in all enum type definitions and utility function declarations.
+#include "mlir/Dialect/StandardOps/OpsEnums.cpp.inc"
+
using namespace mlir;
//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//
-// Returns an array of mnemonics for CmpIPredicates indexed by values thereof.
-static inline const char *const *getCmpIPredicateNames() {
- static const char *predicateNames[]{
- /*EQ*/ "eq",
- /*NE*/ "ne",
- /*SLT*/ "slt",
- /*SLE*/ "sle",
- /*SGT*/ "sgt",
- /*SGE*/ "sge",
- /*ULT*/ "ult",
- /*ULE*/ "ule",
- /*UGT*/ "ugt",
- /*UGE*/ "uge",
- };
- static_assert(std::extent<decltype(predicateNames)>::value ==
- (size_t)CmpIPredicate::NumPredicates,
- "wrong number of predicate names");
- return predicateNames;
-}
-
-// Returns a value of the predicate corresponding to the given mnemonic.
-// Returns NumPredicates (one-past-end) if there is no such mnemonic.
-CmpIPredicate CmpIOp::getPredicateByName(StringRef name) {
- return llvm::StringSwitch<CmpIPredicate>(name)
- .Case("eq", CmpIPredicate::EQ)
- .Case("ne", CmpIPredicate::NE)
- .Case("slt", CmpIPredicate::SLT)
- .Case("sle", CmpIPredicate::SLE)
- .Case("sgt", CmpIPredicate::SGT)
- .Case("sge", CmpIPredicate::SGE)
- .Case("ult", CmpIPredicate::ULT)
- .Case("ule", CmpIPredicate::ULE)
- .Case("ugt", CmpIPredicate::UGT)
- .Case("uge", CmpIPredicate::UGE)
- .Default(CmpIPredicate::NumPredicates);
-}
-
static void buildCmpIOp(Builder *build, OperationState &result,
CmpIPredicate predicate, Value *lhs, Value *rhs) {
result.addOperands({lhs, rhs});
// Rewrite string attribute to an enum value.
StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
- auto predicate = CmpIOp::getPredicateByName(predicateName);
- if (predicate == CmpIPredicate::NumPredicates)
+ Optional<CmpIPredicate> predicate = symbolizeCmpIPredicate(predicateName);
+ if (!predicate.hasValue())
return parser.emitError(parser.getNameLoc())
<< "unknown comparison predicate \"" << predicateName << "\"";
return parser.emitError(parser.getNameLoc(),
"expected type with valid i1 shape");
- attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
+ attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(*predicate));
result.attributes = attrs;
result.addTypes({i1Type});
static void print(OpAsmPrinter &p, CmpIOp op) {
p << "cmpi ";
+ Builder b(op.getContext());
auto predicateValue =
op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
- assert(predicateValue >= static_cast<int>(CmpIPredicate::FirstValidValue) &&
- predicateValue < static_cast<int>(CmpIPredicate::NumPredicates) &&
- "unknown predicate index");
- Builder b(op.getContext());
- auto predicateStringAttr =
- b.getStringAttr(getCmpIPredicateNames()[predicateValue]);
- p.printAttribute(predicateStringAttr);
+ p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue))
+ << '"';
p << ", ";
p.printOperand(op.lhs());
p << " : " << op.lhs()->getType();
}
-static LogicalResult verify(CmpIOp op) {
- auto predicateAttr =
- op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName());
- if (!predicateAttr)
- return op.emitOpError("requires an integer attribute named 'predicate'");
- auto predicate = predicateAttr.getInt();
- if (predicate < (int64_t)CmpIPredicate::FirstValidValue ||
- predicate >= (int64_t)CmpIPredicate::NumPredicates)
- return op.emitOpError("'predicate' attribute value out of range");
-
- return success();
-}
-
// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
// comparison predicates.
static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
const APInt &rhs) {
switch (predicate) {
- case CmpIPredicate::EQ:
+ case CmpIPredicate::eq:
return lhs.eq(rhs);
- case CmpIPredicate::NE:
+ case CmpIPredicate::ne:
return lhs.ne(rhs);
- case CmpIPredicate::SLT:
+ case CmpIPredicate::slt:
return lhs.slt(rhs);
- case CmpIPredicate::SLE:
+ case CmpIPredicate::sle:
return lhs.sle(rhs);
- case CmpIPredicate::SGT:
+ case CmpIPredicate::sgt:
return lhs.sgt(rhs);
- case CmpIPredicate::SGE:
+ case CmpIPredicate::sge:
return lhs.sge(rhs);
- case CmpIPredicate::ULT:
+ case CmpIPredicate::ult:
return lhs.ult(rhs);
- case CmpIPredicate::ULE:
+ case CmpIPredicate::ule:
return lhs.ule(rhs);
- case CmpIPredicate::UGT:
+ case CmpIPredicate::ugt:
return lhs.ugt(rhs);
- case CmpIPredicate::UGE:
+ case CmpIPredicate::uge:
return lhs.uge(rhs);
default:
llvm_unreachable("unknown comparison predicate");
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs)
- : createIComparisonExpr(CmpIPredicate::EQ, lhs, rhs);
+ : createIComparisonExpr(CmpIPredicate::eq, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs)
- : createIComparisonExpr(CmpIPredicate::NE, lhs, rhs);
+ : createIComparisonExpr(CmpIPredicate::ne, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
:
// TODO(ntv,zinenko): signed by default, how about unsigned?
- createIComparisonExpr(CmpIPredicate::SLT, lhs, rhs);
+ createIComparisonExpr(CmpIPredicate::slt, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
- : createIComparisonExpr(CmpIPredicate::SLE, lhs, rhs);
+ : createIComparisonExpr(CmpIPredicate::sle, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
- : createIComparisonExpr(CmpIPredicate::SGT, lhs, rhs);
+ : createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
- : createIComparisonExpr(CmpIPredicate::SGE, lhs, rhs);
+ : createIComparisonExpr(CmpIPredicate::sge, lhs, rhs);
}
// Insert newForOp before the terminator of `t`.
OpBuilder b(t.getBodyBuilder());
Value *stepped = b.create<AddIOp>(t.getLoc(), iv, forOp.step());
- Value *less = b.create<CmpIOp>(t.getLoc(), CmpIPredicate::SLT,
+ Value *less = b.create<CmpIOp>(t.getLoc(), CmpIPredicate::slt,
forOp.upperBound(), stepped);
Value *ub =
b.create<SelectOp>(t.getLoc(), less, forOp.upperBound(), stepped);
func @func_with_ops(i32) {
^bb0(%a : i32):
- // expected-error@+1 {{'predicate' attribute value out of range}}
+ // expected-error@+1 {{failed to satisfy constraint: allowed 64-bit integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}
%r = "std.cmpi"(%a, %a) {predicate = 42} : (i32, i32) -> i1
}
func @func_with_ops(i32, i32) {
^bb0(%a : i32, %b : i32):
- // expected-error@+1 {{requires an integer attribute named 'predicate'}}
+ // expected-error@+1 {{requires attribute 'predicate'}}
%r = "std.cmpi"(%a, %b) {foo = 1} : (i32, i32) -> i1
}