NFC: Convert CmpIPredicate in StandardOps to use EnumAttr
authorLei Zhang <antiagainst@google.com>
Fri, 15 Nov 2019 18:16:33 +0000 (10:16 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 15 Nov 2019 18:17:31 +0000 (10:17 -0800)
This turns several hand-written functions to auto-generated ones.

PiperOrigin-RevId: 280684326

mlir/bindings/python/pybind.cpp
mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt
mlir/include/mlir/Dialect/StandardOps/Ops.h
mlir/include/mlir/Dialect/StandardOps/Ops.td
mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/lib/EDSC/Builders.cpp
mlir/lib/Transforms/Utils/LoopUtils.cpp
mlir/test/IR/invalid-ops.mlir

index 61f42af..a458837 100644 (file)
@@ -870,37 +870,37 @@ PYBIND11_MODULE(pybind, m) {
         .def("__lt__",
              [](PythonValueHandle lhs,
                 PythonValueHandle rhs) -> PythonValueHandle {
-               return ValueHandle::create<CmpIOp>(CmpIPredicate::SLT, lhs.value,
+               return ValueHandle::create<CmpIOp>(CmpIPredicate::slt, lhs.value,
                                                   rhs.value);
              })
         .def("__le__",
              [](PythonValueHandle lhs,
                 PythonValueHandle rhs) -> PythonValueHandle {
-               return ValueHandle::create<CmpIOp>(CmpIPredicate::SLE, lhs.value,
+               return ValueHandle::create<CmpIOp>(CmpIPredicate::sle, lhs.value,
                                                   rhs.value);
              })
         .def("__gt__",
              [](PythonValueHandle lhs,
                 PythonValueHandle rhs) -> PythonValueHandle {
-               return ValueHandle::create<CmpIOp>(CmpIPredicate::SGT, lhs.value,
+               return ValueHandle::create<CmpIOp>(CmpIPredicate::sgt, lhs.value,
                                                   rhs.value);
              })
         .def("__ge__",
              [](PythonValueHandle lhs,
                 PythonValueHandle rhs) -> PythonValueHandle {
-               return ValueHandle::create<CmpIOp>(CmpIPredicate::SGE, lhs.value,
+               return ValueHandle::create<CmpIOp>(CmpIPredicate::sge, lhs.value,
                                                   rhs.value);
              })
         .def("__eq__",
              [](PythonValueHandle lhs,
                 PythonValueHandle rhs) -> PythonValueHandle {
-               return ValueHandle::create<CmpIOp>(CmpIPredicate::EQ, lhs.value,
+               return ValueHandle::create<CmpIOp>(CmpIPredicate::eq, lhs.value,
                                                   rhs.value);
              })
         .def("__ne__",
              [](PythonValueHandle lhs,
                 PythonValueHandle rhs) -> PythonValueHandle {
-               return ValueHandle::create<CmpIOp>(CmpIPredicate::NE, lhs.value,
+               return ValueHandle::create<CmpIOp>(CmpIPredicate::ne, lhs.value,
                                                   rhs.value);
              })
         .def("__invert__",
index 670676f..b653479 100644 (file)
@@ -1,4 +1,6 @@
 set(LLVM_TARGET_DEFINITIONS Ops.td)
 mlir_tablegen(Ops.h.inc -gen-op-decls)
 mlir_tablegen(Ops.cpp.inc -gen-op-defs)
+mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
 add_public_tablegen_target(MLIRStandardOpsIncGen)
index fd69534..7798162 100644 (file)
@@ -30,6 +30,9 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/StandardTypes.h"
 
+// Pull in all enum type definitions and utility function declarations.
+#include "mlir/Dialect/StandardOps/OpsEnums.h.inc"
+
 namespace mlir {
 class AffineMap;
 class Builder;
@@ -43,27 +46,6 @@ public:
 };
 
 /// The predicate indicates the type of the comparison to perform:
-/// (in)equality; (un)signed less/greater than (or equal to).
-enum class CmpIPredicate {
-  FirstValidValue,
-  // (In)equality comparisons.
-  EQ = FirstValidValue,
-  NE,
-  // Signed comparisons.
-  SLT,
-  SLE,
-  SGT,
-  SGE,
-  // Unsigned comparisons.
-  ULT,
-  ULE,
-  UGT,
-  UGE,
-  // Number of predicates.
-  NumPredicates
-};
-
-/// The predicate indicates the type of the comparison to perform:
 /// (un)orderedness, (in)equality and less/greater than (or equal to) as
 /// well as predicates that are always true or false.
 enum class CmpFPredicate {
index bfd3916..eb7ebbb 100644 (file)
@@ -345,6 +345,24 @@ def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> {
   let hasCanonicalizer = 1;
 }
 
+def CMPI_P_EQ  : I64EnumAttrCase<"eq", 0>;
+def CMPI_P_NE  : I64EnumAttrCase<"ne", 1>;
+def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>;
+def CMPI_P_SLE : I64EnumAttrCase<"sle", 3>;
+def CMPI_P_SGT : I64EnumAttrCase<"sgt", 4>;
+def CMPI_P_SGE : I64EnumAttrCase<"sge", 5>;
+def CMPI_P_ULT : I64EnumAttrCase<"ult", 6>;
+def CMPI_P_ULE : I64EnumAttrCase<"ule", 7>;
+def CMPI_P_UGT : I64EnumAttrCase<"ugt", 8>;
+def CMPI_P_UGE : I64EnumAttrCase<"uge", 9>;
+
+def CmpIPredicateAttr : I64EnumAttr<
+    "CmpIPredicate", "",
+    [CMPI_P_EQ, CMPI_P_NE, CMPI_P_SLT, CMPI_P_SLE, CMPI_P_SGT,
+     CMPI_P_SGE, CMPI_P_ULT, CMPI_P_ULE, CMPI_P_UGT, CMPI_P_UGE]> {
+  let cppNamespace = "::mlir";
+}
+
 def CmpIOp : Std_Op<"cmpi",
     [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> {
   let summary = "integer comparison operation";
@@ -369,7 +387,11 @@ def CmpIOp : Std_Op<"cmpi",
       %r3 = "std.cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1
   }];
 
-  let arguments = (ins IntegerLike:$lhs, IntegerLike:$rhs);
+  let arguments = (ins
+      CmpIPredicateAttr:$predicate,
+      IntegerLike:$lhs,
+      IntegerLike:$rhs
+  );
   let results = (outs BoolLike);
 
   let builders = [OpBuilder<
@@ -388,6 +410,8 @@ def CmpIOp : Std_Op<"cmpi",
     }
   }];
 
+  let verifier = [{ return success(); }];
+
   let hasFolder = 1;
 }
 
index 98d80ed..4935d2d 100644 (file)
@@ -97,7 +97,7 @@ public:
     Value *remainder = builder.create<RemISOp>(loc, lhs, rhs);
     Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
     Value *isRemainderNegative =
-        builder.create<CmpIOp>(loc, CmpIPredicate::SLT, remainder, zeroCst);
+        builder.create<CmpIOp>(loc, CmpIPredicate::slt, remainder, zeroCst);
     Value *correctedRemainder = builder.create<AddIOp>(loc, remainder, rhs);
     Value *result = builder.create<SelectOp>(loc, isRemainderNegative,
                                              correctedRemainder, remainder);
@@ -134,7 +134,7 @@ public:
     Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
     Value *noneCst = builder.create<ConstantIndexOp>(loc, -1);
     Value *negative =
-        builder.create<CmpIOp>(loc, CmpIPredicate::SLT, lhs, zeroCst);
+        builder.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, zeroCst);
     Value *negatedDecremented = builder.create<SubIOp>(loc, noneCst, lhs);
     Value *dividend =
         builder.create<SelectOp>(loc, negative, negatedDecremented, lhs);
@@ -173,7 +173,7 @@ public:
     Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
     Value *oneCst = builder.create<ConstantIndexOp>(loc, 1);
     Value *nonPositive =
-        builder.create<CmpIOp>(loc, CmpIPredicate::SLE, lhs, zeroCst);
+        builder.create<CmpIOp>(loc, CmpIPredicate::sle, lhs, zeroCst);
     Value *negated = builder.create<SubIOp>(loc, zeroCst, lhs);
     Value *decremented = builder.create<SubIOp>(loc, lhs, oneCst);
     Value *dividend =
@@ -277,7 +277,7 @@ Value *mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
                                   boundOperands);
   if (!lbValues)
     return nullptr;
-  return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SGT, *lbValues,
+  return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::sgt, *lbValues,
                                  builder);
 }
 
@@ -290,7 +290,7 @@ Value *mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
                                   boundOperands);
   if (!ubValues)
     return nullptr;
-  return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SLT, *ubValues,
+  return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::slt, *ubValues,
                                  builder);
 }
 
@@ -352,7 +352,7 @@ public:
                                           operandsRef.drop_front(numDims));
       if (!affResult)
         return matchFailure();
-      auto pred = isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE;
+      auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge;
       Value *cmpVal =
           rewriter.create<CmpIOp>(loc, pred, affResult, zeroConstant);
       cond =
index e8ab53b..08ee320 100644 (file)
@@ -205,7 +205,7 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
   // With the body block done, we can fill in the condition block.
   rewriter.setInsertionPointToEnd(conditionBlock);
   auto comparison =
-      rewriter.create<CmpIOp>(loc, CmpIPredicate::SLT, iv, upperBound);
+      rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, iv, upperBound);
 
   rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
                                 ArrayRef<Value *>(), endBlock,
index 2fd6e75..74d1352 100644 (file)
@@ -90,12 +90,12 @@ public:
         cmpIOpOperands.rhs());                                                 \
     return matchSuccess();
 
-      DISPATCH(CmpIPredicate::EQ, spirv::IEqualOp);
-      DISPATCH(CmpIPredicate::NE, spirv::INotEqualOp);
-      DISPATCH(CmpIPredicate::SLT, spirv::SLessThanOp);
-      DISPATCH(CmpIPredicate::SLE, spirv::SLessThanEqualOp);
-      DISPATCH(CmpIPredicate::SGT, spirv::SGreaterThanOp);
-      DISPATCH(CmpIPredicate::SGE, spirv::SGreaterThanEqualOp);
+      DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
+      DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
+      DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp);
+      DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
+      DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
+      DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
 
 #undef DISPATCH
 
index bf0cb75..c4abee3 100644 (file)
@@ -35,6 +35,9 @@
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
 
+// Pull in all enum type definitions and utility function declarations.
+#include "mlir/Dialect/StandardOps/OpsEnums.cpp.inc"
+
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
@@ -699,43 +702,6 @@ static Type getI1SameShape(Builder *build, Type type) {
 // CmpIOp
 //===----------------------------------------------------------------------===//
 
-// Returns an array of mnemonics for CmpIPredicates indexed by values thereof.
-static inline const char *const *getCmpIPredicateNames() {
-  static const char *predicateNames[]{
-      /*EQ*/ "eq",
-      /*NE*/ "ne",
-      /*SLT*/ "slt",
-      /*SLE*/ "sle",
-      /*SGT*/ "sgt",
-      /*SGE*/ "sge",
-      /*ULT*/ "ult",
-      /*ULE*/ "ule",
-      /*UGT*/ "ugt",
-      /*UGE*/ "uge",
-  };
-  static_assert(std::extent<decltype(predicateNames)>::value ==
-                    (size_t)CmpIPredicate::NumPredicates,
-                "wrong number of predicate names");
-  return predicateNames;
-}
-
-// Returns a value of the predicate corresponding to the given mnemonic.
-// Returns NumPredicates (one-past-end) if there is no such mnemonic.
-CmpIPredicate CmpIOp::getPredicateByName(StringRef name) {
-  return llvm::StringSwitch<CmpIPredicate>(name)
-      .Case("eq", CmpIPredicate::EQ)
-      .Case("ne", CmpIPredicate::NE)
-      .Case("slt", CmpIPredicate::SLT)
-      .Case("sle", CmpIPredicate::SLE)
-      .Case("sgt", CmpIPredicate::SGT)
-      .Case("sge", CmpIPredicate::SGE)
-      .Case("ult", CmpIPredicate::ULT)
-      .Case("ule", CmpIPredicate::ULE)
-      .Case("ugt", CmpIPredicate::UGT)
-      .Case("uge", CmpIPredicate::UGE)
-      .Default(CmpIPredicate::NumPredicates);
-}
-
 static void buildCmpIOp(Builder *build, OperationState &result,
                         CmpIPredicate predicate, Value *lhs, Value *rhs) {
   result.addOperands({lhs, rhs});
@@ -763,8 +729,8 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
 
   // Rewrite string attribute to an enum value.
   StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
-  auto predicate = CmpIOp::getPredicateByName(predicateName);
-  if (predicate == CmpIPredicate::NumPredicates)
+  Optional<CmpIPredicate> predicate = symbolizeCmpIPredicate(predicateName);
+  if (!predicate.hasValue())
     return parser.emitError(parser.getNameLoc())
            << "unknown comparison predicate \"" << predicateName << "\"";
 
@@ -774,7 +740,7 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
     return parser.emitError(parser.getNameLoc(),
                             "expected type with valid i1 shape");
 
-  attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
+  attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(*predicate));
   result.attributes = attrs;
 
   result.addTypes({i1Type});
@@ -784,15 +750,11 @@ static ParseResult parseCmpIOp(OpAsmParser &parser, OperationState &result) {
 static void print(OpAsmPrinter &p, CmpIOp op) {
   p << "cmpi ";
 
+  Builder b(op.getContext());
   auto predicateValue =
       op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
-  assert(predicateValue >= static_cast<int>(CmpIPredicate::FirstValidValue) &&
-         predicateValue < static_cast<int>(CmpIPredicate::NumPredicates) &&
-         "unknown predicate index");
-  Builder b(op.getContext());
-  auto predicateStringAttr =
-      b.getStringAttr(getCmpIPredicateNames()[predicateValue]);
-  p.printAttribute(predicateStringAttr);
+  p << '"' << stringifyCmpIPredicate(static_cast<CmpIPredicate>(predicateValue))
+    << '"';
 
   p << ", ";
   p.printOperand(op.lhs());
@@ -803,43 +765,30 @@ static void print(OpAsmPrinter &p, CmpIOp op) {
   p << " : " << op.lhs()->getType();
 }
 
-static LogicalResult verify(CmpIOp op) {
-  auto predicateAttr =
-      op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName());
-  if (!predicateAttr)
-    return op.emitOpError("requires an integer attribute named 'predicate'");
-  auto predicate = predicateAttr.getInt();
-  if (predicate < (int64_t)CmpIPredicate::FirstValidValue ||
-      predicate >= (int64_t)CmpIPredicate::NumPredicates)
-    return op.emitOpError("'predicate' attribute value out of range");
-
-  return success();
-}
-
 // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
 // comparison predicates.
 static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
                               const APInt &rhs) {
   switch (predicate) {
-  case CmpIPredicate::EQ:
+  case CmpIPredicate::eq:
     return lhs.eq(rhs);
-  case CmpIPredicate::NE:
+  case CmpIPredicate::ne:
     return lhs.ne(rhs);
-  case CmpIPredicate::SLT:
+  case CmpIPredicate::slt:
     return lhs.slt(rhs);
-  case CmpIPredicate::SLE:
+  case CmpIPredicate::sle:
     return lhs.sle(rhs);
-  case CmpIPredicate::SGT:
+  case CmpIPredicate::sgt:
     return lhs.sgt(rhs);
-  case CmpIPredicate::SGE:
+  case CmpIPredicate::sge:
     return lhs.sge(rhs);
-  case CmpIPredicate::ULT:
+  case CmpIPredicate::ult:
     return lhs.ult(rhs);
-  case CmpIPredicate::ULE:
+  case CmpIPredicate::ule:
     return lhs.ule(rhs);
-  case CmpIPredicate::UGT:
+  case CmpIPredicate::ugt:
     return lhs.ugt(rhs);
-  case CmpIPredicate::UGE:
+  case CmpIPredicate::uge:
     return lhs.uge(rhs);
   default:
     llvm_unreachable("unknown comparison predicate");
index 4046a7c..9d7ca8c 100644 (file)
@@ -460,13 +460,13 @@ ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) {
   auto type = lhs.getType();
   return type.isa<FloatType>()
              ? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs)
-             : createIComparisonExpr(CmpIPredicate::EQ, lhs, rhs);
+             : createIComparisonExpr(CmpIPredicate::eq, lhs, rhs);
 }
 ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) {
   auto type = lhs.getType();
   return type.isa<FloatType>()
              ? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs)
-             : createIComparisonExpr(CmpIPredicate::NE, lhs, rhs);
+             : createIComparisonExpr(CmpIPredicate::ne, lhs, rhs);
 }
 ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) {
   auto type = lhs.getType();
@@ -474,23 +474,23 @@ ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) {
              ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
              :
              // TODO(ntv,zinenko): signed by default, how about unsigned?
-             createIComparisonExpr(CmpIPredicate::SLT, lhs, rhs);
+             createIComparisonExpr(CmpIPredicate::slt, lhs, rhs);
 }
 ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) {
   auto type = lhs.getType();
   return type.isa<FloatType>()
              ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
-             : createIComparisonExpr(CmpIPredicate::SLE, lhs, rhs);
+             : createIComparisonExpr(CmpIPredicate::sle, lhs, rhs);
 }
 ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) {
   auto type = lhs.getType();
   return type.isa<FloatType>()
              ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
-             : createIComparisonExpr(CmpIPredicate::SGT, lhs, rhs);
+             : createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs);
 }
 ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) {
   auto type = lhs.getType();
   return type.isa<FloatType>()
              ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
-             : createIComparisonExpr(CmpIPredicate::SGE, lhs, rhs);
+             : createIComparisonExpr(CmpIPredicate::sge, lhs, rhs);
 }
index 405116e..0ee1220 100644 (file)
@@ -747,7 +747,7 @@ static Loops stripmineSink(loop::ForOp forOp, Value *factor,
     // Insert newForOp before the terminator of `t`.
     OpBuilder b(t.getBodyBuilder());
     Value *stepped = b.create<AddIOp>(t.getLoc(), iv, forOp.step());
-    Value *less = b.create<CmpIOp>(t.getLoc(), CmpIPredicate::SLT,
+    Value *less = b.create<CmpIOp>(t.getLoc(), CmpIPredicate::slt,
                                    forOp.upperBound(), stepped);
     Value *ub =
         b.create<SelectOp>(t.getLoc(), less, forOp.upperBound(), stepped);
index 74dd412..c8fffc9 100644 (file)
@@ -201,7 +201,7 @@ func @func_with_ops(i32) {
 
 func @func_with_ops(i32) {
 ^bb0(%a : i32):
-  // expected-error@+1 {{'predicate' attribute value out of range}}
+  // expected-error@+1 {{failed to satisfy constraint: allowed 64-bit integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}
   %r = "std.cmpi"(%a, %a) {predicate = 42} : (i32, i32) -> i1
 }
 
@@ -241,7 +241,7 @@ func @func_with_ops(i32, i32) {
 
 func @func_with_ops(i32, i32) {
 ^bb0(%a : i32, %b : i32):
-  // expected-error@+1 {{requires an integer attribute named 'predicate'}}
+  // expected-error@+1 {{requires attribute 'predicate'}}
   %r = "std.cmpi"(%a, %b) {foo = 1} : (i32, i32) -> i1
 }