Move the definitions of CmpIOp, CmpFOp, and SelectOp to the ODG framework.
authorRiver Riddle <riverriddle@google.com>
Sat, 25 May 2019 01:01:20 +0000 (18:01 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:01:42 +0000 (20:01 -0700)
--

PiperOrigin-RevId: 249928953

mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/StandardOps/Ops.h
mlir/include/mlir/StandardOps/Ops.td
mlir/lib/IR/Operation.cpp
mlir/lib/StandardOps/Ops.cpp
mlir/test/IR/invalid-ops.mlir

index 780d92f..99be357 100644 (file)
@@ -459,6 +459,12 @@ class NestedTupleOf<list<Type> allowedTypes> :
 // Common type constraints
 //===----------------------------------------------------------------------===//
 
+// Type constraint for bool-like types: bools, vectors of bools, tensors of
+// bools.
+def BoolLike : TypeConstraint<Or<[I1.predicate, VectorOf<[I1]>.predicate,
+                                  TensorOf<[I1]>.predicate]>,
+    "bool-like">;
+
 // Type constraint for integer-like types: integers, indices, vectors of
 // integers, tensors of integers.
 def IntegerLike : TypeConstraint<Or<[AnyInteger.predicate, Index.predicate,
@@ -824,6 +830,8 @@ def Commutative      : NativeOpTrait<"IsCommutative">;
 def FloatLikeResults : NativeOpTrait<"ResultsAreFloatLike">;
 // Op has no side effect.
 def NoSideEffect     : NativeOpTrait<"HasNoSideEffect">;
+// Op has the same operand type.
+def SameOperandType  : NativeOpTrait<"SameTypeOperands">;
 // Op has same operand and result shape.
 def SameValueShape   : NativeOpTrait<"SameOperandsAndResultShape">;
 // Op has the same operand and result type.
index 6db6fe0..18008f2 100644 (file)
@@ -37,9 +37,6 @@ public:
   StandardOpsDialect(MLIRContext *context);
 };
 
-#define GET_OP_CLASSES
-#include "mlir/StandardOps/Ops.h.inc"
-
 /// The predicate indicates the type of the comparison to perform:
 /// (in)equality; (un)signed less/greater than (or equal to).
 enum class CmpIPredicate {
@@ -61,49 +58,6 @@ enum class CmpIPredicate {
   NumPredicates
 };
 
-/// The "cmpi" operation compares its two operands according to the integer
-/// comparison rules and the predicate specified by the respective attribute.
-/// The predicate defines the type of comparison: (in)equality, (un)signed
-/// less/greater than (or equal to).  The operands must have the same type, and
-/// this type must be an integer type, a vector or a tensor thereof.  The result
-/// is an i1, or a vector/tensor thereof having the same shape as the inputs.
-/// Since integers are signless, the predicate also explicitly indicates
-/// whether to interpret the operands as signed or unsigned integers for
-/// less/greater than comparisons.  For the sake of readability by humans,
-/// custom assembly form for the operation uses a string-typed attribute for
-/// the predicate.  The value of this attribute corresponds to lower-cased name
-/// of the predicate constant, e.g., "slt" means "signed less than".  The string
-/// representation of the attribute is merely a syntactic sugar and is converted
-/// to an integer attribute by the parser.
-///
-///   %r1 = cmpi "eq" %0, %1 : i32
-///   %r2 = cmpi "slt" %0, %1 : tensor<42x42xi64>
-///   %r3 = "std.cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1
-class CmpIOp
-    : public Op<CmpIOp, OpTrait::OperandsAreIntegerLike,
-                OpTrait::SameTypeOperands, OpTrait::NOperands<2>::Impl,
-                OpTrait::OneResult, OpTrait::ResultsAreBoolLike,
-                OpTrait::SameOperandsAndResultShape, OpTrait::HasNoSideEffect> {
-public:
-  using Op::Op;
-
-  CmpIPredicate getPredicate() {
-    return (CmpIPredicate)getAttrOfType<IntegerAttr>(getPredicateAttrName())
-        .getInt();
-  }
-
-  static StringRef getOperationName() { return "std.cmpi"; }
-  static StringRef getPredicateAttrName() { return "predicate"; }
-  static CmpIPredicate getPredicateByName(StringRef name);
-
-  static void build(Builder *builder, OperationState *result, CmpIPredicate,
-                    Value *lhs, Value *rhs);
-  static ParseResult parse(OpAsmParser *parser, OperationState *result);
-  void print(OpAsmPrinter *p);
-  LogicalResult verify();
-  OpFoldResult fold(ArrayRef<Attribute> operands);
-};
-
 /// The predicate indicates the type of the comparison to perform:
 /// (un)orderedness, (in)equality and signed less/greater than (or equal to) as
 /// well as predicates that are always true or false.
@@ -135,49 +89,8 @@ enum class CmpFPredicate {
   NumPredicates
 };
 
-/// The "cmpf" operation compares its two operands according to the float
-/// comparison rules and the predicate specified by the respective attribute.
-/// The predicate defines the type of comparison: (un)orderedness, (in)equality
-/// and signed less/greater than (or equal to) as well as predicates that are
-/// always true or false.  The operands must have the same type, and this type
-/// must be a float type, or a vector or tensor thereof.  The result is an i1,
-/// or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi,
-/// the operands are always treated as signed. The u prefix indicates
-/// *unordered* comparison, not unsigned comparison, so "une" means unordered or
-/// not equal. For the sake of readability by humans, custom assembly form for
-/// the operation uses a string-typed attribute for the predicate.  The value of
-/// this attribute corresponds to lower-cased name of the predicate constant,
-/// e.g., "one" means "ordered not equal".  The string representation of the
-/// attribute is merely a syntactic sugar and is converted to an integer
-/// attribute by the parser.
-///
-///   %r1 = cmpf "oeq" %0, %1 : f32
-///   %r2 = cmpf "ult" %0, %1 : tensor<42x42xf64>
-///   %r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1
-class CmpFOp
-    : public Op<CmpFOp, OpTrait::OperandsAreFloatLike,
-                OpTrait::SameTypeOperands, OpTrait::NOperands<2>::Impl,
-                OpTrait::OneResult, OpTrait::ResultsAreBoolLike,
-                OpTrait::SameOperandsAndResultShape, OpTrait::HasNoSideEffect> {
-public:
-  using Op::Op;
-
-  CmpFPredicate getPredicate() {
-    return (CmpFPredicate)getAttrOfType<IntegerAttr>(getPredicateAttrName())
-        .getInt();
-  }
-
-  static StringRef getOperationName() { return "std.cmpf"; }
-  static StringRef getPredicateAttrName() { return "predicate"; }
-  static CmpFPredicate getPredicateByName(StringRef name);
-
-  static void build(Builder *builder, OperationState *result, CmpFPredicate,
-                    Value *lhs, Value *rhs);
-  static ParseResult parse(OpAsmParser *parser, OperationState *result);
-  void print(OpAsmPrinter *p);
-  LogicalResult verify();
-  OpFoldResult fold(ArrayRef<Attribute> operands);
-};
+#define GET_OP_CLASSES
+#include "mlir/StandardOps/Ops.h.inc"
 
 /// The "cond_br" operation represents a conditional branch operation in a
 /// function. The operation takes variable number of operands and produces
@@ -573,36 +486,6 @@ public:
                                           MLIRContext *context);
 };
 
-/// The "select" operation chooses one value based on a binary condition
-/// supplied as its first operand. If the value of the first operand is 1, the
-/// second operand is chosen, otherwise the third operand is chosen. The second
-/// and the third operand must have the same type. The operation applies
-/// elementwise to vectors and tensors.  The shape of all arguments must be
-/// identical. For example, the maximum operation is obtained by combining
-/// "select" with "cmpi" as follows.
-///
-///   %2 = cmpi "gt" %0, %1 : i32         // %2 is i1
-///   %3 = select %2, %0, %1 : i32
-///
-class SelectOp : public Op<SelectOp, OpTrait::NOperands<3>::Impl,
-                           OpTrait::OneResult, OpTrait::HasNoSideEffect> {
-public:
-  using Op::Op;
-
-  static StringRef getOperationName() { return "std.select"; }
-  static void build(Builder *builder, OperationState *result, Value *condition,
-                    Value *trueValue, Value *falseValue);
-  static ParseResult parse(OpAsmParser *parser, OperationState *result);
-  void print(OpAsmPrinter *p);
-  LogicalResult verify();
-
-  Value *getCondition() { return getOperand(0); }
-  Value *getTrueValue() { return getOperand(1); }
-  Value *getFalseValue() { return getOperand(2); }
-
-  OpFoldResult fold(ArrayRef<Attribute> operands);
-};
-
 /// The "store" op writes an element to a memref specified by an index list.
 /// The arity of indices is the rank of the memref (i.e. if the memref being
 /// stored to is of rank 3, then 3 indices are required for the store following
index 058750f..817079a 100644 (file)
@@ -278,6 +278,97 @@ def CallIndirectOp : Std_Op<"call_indirect"> {
   let hasCanonicalizer = 1;
 }
 
+def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameOperandType, SameValueShape]> {
+  let summary = "integer comparison operation";
+  let description = [{
+    The "cmpi" operation compares its two operands according to the integer
+    comparison rules and the predicate specified by the respective attribute.
+    The predicate defines the type of comparison: (in)equality, (un)signed
+    less/greater than (or equal to).  The operands must have the same type, and
+    this type must be an integer type, a vector or a tensor thereof.  The result
+    is an i1, or a vector/tensor thereof having the same shape as the inputs.
+    Since integers are signless, the predicate also explicitly indicates
+    whether to interpret the operands as signed or unsigned integers for
+    less/greater than comparisons.  For the sake of readability by humans,
+    custom assembly form for the operation uses a string-typed attribute for
+    the predicate.  The value of this attribute corresponds to lower-cased name
+    of the predicate constant, e.g., "slt" means "signed less than".  The string
+    representation of the attribute is merely a syntactic sugar and is converted
+    to an integer attribute by the parser.
+
+      %r1 = cmpi "eq" %0, %1 : i32
+      %r2 = cmpi "slt" %0, %1 : tensor<42x42xi64>
+      %r3 = "std.cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1
+  }];
+
+  let arguments = (ins IntegerLike:$lhs, IntegerLike:$rhs);
+  let results = (outs BoolLike);
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, CmpIPredicate predicate,"
+    "Value *lhs, Value *rhs", [{
+      ::buildCmpIOp(builder, result, predicate, lhs, rhs);
+  }]>];
+
+  let extraClassDeclaration = [{
+    static StringRef getPredicateAttrName() { return "predicate"; }
+    static CmpIPredicate getPredicateByName(StringRef name);
+
+    CmpIPredicate getPredicate() {
+      return (CmpIPredicate)getAttrOfType<IntegerAttr>(getPredicateAttrName())
+          .getInt();
+    }
+  }];
+
+  let hasFolder = 1;
+}
+
+def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameOperandType, SameValueShape]> {
+  let summary = "floating-point comparison operation";
+  let description = [{
+    The "cmpf" operation compares its two operands according to the float
+    comparison rules and the predicate specified by the respective attribute.
+    The predicate defines the type of comparison: (un)orderedness, (in)equality
+    and signed less/greater than (or equal to) as well as predicates that are
+    always true or false.  The operands must have the same type, and this type
+    must be a float type, or a vector or tensor thereof.  The result is an i1,
+    or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi,
+    the operands are always treated as signed. The u prefix indicates
+    *unordered* comparison, not unsigned comparison, so "une" means unordered or
+    not equal. For the sake of readability by humans, custom assembly form for
+    the operation uses a string-typed attribute for the predicate.  The value of
+    this attribute corresponds to lower-cased name of the predicate constant,
+    e.g., "one" means "ordered not equal".  The string representation of the
+    attribute is merely a syntactic sugar and is converted to an integer
+    attribute by the parser.
+
+      %r1 = cmpf "oeq" %0, %1 : f32
+      %r2 = cmpf "ult" %0, %1 : tensor<42x42xf64>
+      %r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1
+  }];
+
+  let arguments = (ins FloatLike:$lhs, FloatLike:$rhs);
+  let results = (outs BoolLike);
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, CmpFPredicate predicate,"
+    "Value *lhs, Value *rhs", [{
+      ::buildCmpFOp(builder, result, predicate, lhs, rhs);
+  }]>];
+
+  let extraClassDeclaration = [{
+    static StringRef getPredicateAttrName() { return "predicate"; }
+    static CmpFPredicate getPredicateByName(StringRef name);
+
+    CmpFPredicate getPredicate() {
+      return (CmpFPredicate)getAttrOfType<IntegerAttr>(getPredicateAttrName())
+          .getInt();
+    }
+  }];
+
+  let hasFolder = 1;
+}
+
 def ConstantOp : Std_Op<"constant", [NoSideEffect]> {
   let summary = "constant";
 
@@ -477,6 +568,40 @@ def ReturnOp : Std_Op<"return", [Terminator]> {
   >];
 }
 
+def SelectOp : Std_Op<"select", [NoSideEffect, SameValueShape]> {
+  let summary = "select operation";
+  let description = [{
+    The "select" operation chooses one value based on a binary condition
+    supplied as its first operand. If the value of the first operand is 1, the
+    second operand is chosen, otherwise the third operand is chosen. The second
+    and the third operand must have the same type. The operation applies
+    elementwise to vectors and tensors.  The shape of all arguments must be
+    identical. For example, the maximum operation is obtained by combining
+    "select" with "cmpi" as follows.
+
+      %2 = cmpi "gt" %0, %1 : i32         // %2 is i1
+      %3 = select %2, %0, %1 : i32
+  }];
+
+  let arguments = (ins BoolLike:$condition, AnyType:$true_value,
+                       AnyType:$false_value);
+  let results = (outs AnyType);
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Value *condition,"
+    "Value *trueValue, Value *falseValue", [{
+      result->addOperands({condition, trueValue, falseValue});
+      result->addTypes(trueValue->getType());
+  }]>];
+
+  let extraClassDeclaration = [{
+      Value *getCondition() { return condition(); }
+      Value *getTrueValue() { return true_value(); }
+      Value *getFalseValue() { return false_value(); }
+  }];
+
+  let hasFolder = 1;
+}
 def ShlISOp : IntArithmeticOp<"shlis"> {
   let summary = "signed integer shift left";
 }
index f6fc3ee..5d50db1 100644 (file)
@@ -782,6 +782,8 @@ static LogicalResult verifyShapeMatch(Type type1, Type type2) {
   // Either both or neither type should be shaped.
   if (!sType1)
     return success(!sType2);
+  if (!sType2)
+    return failure();
 
   if (!sType1.hasRank() || !sType2.hasRank())
     return success();
index fd9d57f..6b559b3 100644 (file)
@@ -81,8 +81,7 @@ template <typename T> static LogicalResult verifyCastOp(T op) {
 
 StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
     : Dialect(/*name=*/"std", context) {
-  addOperations<CmpFOp, CmpIOp, CondBranchOp, DmaStartOp, DmaWaitOp, LoadOp,
-                SelectOp, StoreOp,
+  addOperations<CondBranchOp, DmaStartOp, DmaWaitOp, LoadOp, StoreOp,
 #define GET_OP_LIST
 #include "mlir/StandardOps/Ops.cpp.inc"
                 >();
@@ -597,33 +596,6 @@ static Type getI1SameShape(Builder *build, Type type) {
   return res;
 }
 
-static inline bool isI1(Type type) {
-  return type.isa<IntegerType>() && type.cast<IntegerType>().getWidth() == 1;
-}
-
-template <typename Ty>
-static inline bool implCheckI1SameShape(Ty pattern, Type type) {
-  auto specificType = type.dyn_cast<Ty>();
-  if (!specificType)
-    return true;
-  if (specificType.getShape() != pattern.getShape())
-    return true;
-  return !isI1(specificType.getElementType());
-}
-
-// Checks if "type" has the same shape (scalar, vector or tensor) as "pattern"
-// and contains i1.
-static bool checkI1SameShape(Type pattern, Type type) {
-  if (pattern.isIntOrIndexOrFloat())
-    return !isI1(type);
-  if (auto patternTensorType = pattern.dyn_cast<TensorType>())
-    return implCheckI1SameShape(patternTensorType, type);
-  if (auto patternVectorType = pattern.dyn_cast<VectorType>())
-    return implCheckI1SameShape(patternVectorType, type);
-
-  llvm_unreachable("unsupported type");
-}
-
 //===----------------------------------------------------------------------===//
 // CmpIOp
 //===----------------------------------------------------------------------===//
@@ -665,21 +637,21 @@ CmpIPredicate CmpIOp::getPredicateByName(StringRef name) {
       .Default(CmpIPredicate::NumPredicates);
 }
 
-void CmpIOp::build(Builder *build, OperationState *result,
-                   CmpIPredicate predicate, Value *lhs, Value *rhs) {
+static void buildCmpIOp(Builder *build, OperationState *result,
+                        CmpIPredicate predicate, Value *lhs, Value *rhs) {
   result->addOperands({lhs, rhs});
   result->types.push_back(getI1SameShape(build, lhs->getType()));
   result->addAttribute(
-      getPredicateAttrName(),
+      CmpIOp::getPredicateAttrName(),
       build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
 }
 
-ParseResult CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
+static ParseResult parseCmpIOp(OpAsmParser *parser, OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 2> ops;
   SmallVector<NamedAttribute, 4> attrs;
   Attribute predicateNameAttr;
   Type type;
-  if (parser->parseAttribute(predicateNameAttr, getPredicateAttrName(),
+  if (parser->parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(),
                              attrs) ||
       parser->parseComma() || parser->parseOperandList(ops, 2) ||
       parser->parseOptionalAttributeDict(attrs) ||
@@ -693,7 +665,7 @@ ParseResult CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
 
   // Rewrite string attribute to an enum value.
   StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
-  auto predicate = getPredicateByName(predicateName);
+  auto predicate = CmpIOp::getPredicateByName(predicateName);
   if (predicate == CmpIPredicate::NumPredicates)
     return parser->emitError(parser->getNameLoc())
            << "unknown comparison predicate \"" << predicateName << "\"";
@@ -711,36 +683,37 @@ ParseResult CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
   return success();
 }
 
-void CmpIOp::print(OpAsmPrinter *p) {
+static void print(OpAsmPrinter *p, CmpIOp op) {
   *p << "cmpi ";
 
   auto predicateValue =
-      getAttrOfType<IntegerAttr>(getPredicateAttrName()).getInt();
+      op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
   assert(predicateValue >= static_cast<int>(CmpIPredicate::FirstValidValue) &&
          predicateValue < static_cast<int>(CmpIPredicate::NumPredicates) &&
          "unknown predicate index");
-  Builder b(getContext());
+  Builder b(op.getContext());
   auto predicateStringAttr =
       b.getStringAttr(getCmpIPredicateNames()[predicateValue]);
   p->printAttribute(predicateStringAttr);
 
   *p << ", ";
-  p->printOperand(getOperand(0));
+  p->printOperand(op.lhs());
   *p << ", ";
-  p->printOperand(getOperand(1));
-  p->printOptionalAttrDict(getAttrs(),
-                           /*elidedAttrs=*/{getPredicateAttrName()});
-  *p << " : " << getOperand(0)->getType();
+  p->printOperand(op.rhs());
+  p->printOptionalAttrDict(op.getAttrs(),
+                           /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
+  *p << " : " << op.lhs()->getType();
 }
 
-LogicalResult CmpIOp::verify() {
-  auto predicateAttr = getAttrOfType<IntegerAttr>(getPredicateAttrName());
+static LogicalResult verify(CmpIOp op) {
+  auto predicateAttr =
+      op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName());
   if (!predicateAttr)
-    return emitOpError("requires an integer attribute named 'predicate'");
+    return op.emitOpError("requires an integer attribute named 'predicate'");
   auto predicate = predicateAttr.getInt();
   if (predicate < (int64_t)CmpIPredicate::FirstValidValue ||
       predicate >= (int64_t)CmpIPredicate::NumPredicates)
-    return emitOpError("'predicate' attribute value out of range");
+    return op.emitOpError("'predicate' attribute value out of range");
 
   return success();
 }
@@ -841,21 +814,21 @@ CmpFPredicate CmpFOp::getPredicateByName(StringRef name) {
       .Default(CmpFPredicate::NumPredicates);
 }
 
-void CmpFOp::build(Builder *build, OperationState *result,
-                   CmpFPredicate predicate, Value *lhs, Value *rhs) {
+static void buildCmpFOp(Builder *build, OperationState *result,
+                        CmpFPredicate predicate, Value *lhs, Value *rhs) {
   result->addOperands({lhs, rhs});
   result->types.push_back(getI1SameShape(build, lhs->getType()));
   result->addAttribute(
-      getPredicateAttrName(),
+      CmpFOp::getPredicateAttrName(),
       build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
 }
 
-ParseResult CmpFOp::parse(OpAsmParser *parser, OperationState *result) {
+static ParseResult parseCmpFOp(OpAsmParser *parser, OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 2> ops;
   SmallVector<NamedAttribute, 4> attrs;
   Attribute predicateNameAttr;
   Type type;
-  if (parser->parseAttribute(predicateNameAttr, getPredicateAttrName(),
+  if (parser->parseAttribute(predicateNameAttr, CmpFOp::getPredicateAttrName(),
                              attrs) ||
       parser->parseComma() || parser->parseOperandList(ops, 2) ||
       parser->parseOptionalAttributeDict(attrs) ||
@@ -869,7 +842,7 @@ ParseResult CmpFOp::parse(OpAsmParser *parser, OperationState *result) {
 
   // Rewrite string attribute to an enum value.
   StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
-  auto predicate = getPredicateByName(predicateName);
+  auto predicate = CmpFOp::getPredicateByName(predicateName);
   if (predicate == CmpFPredicate::NumPredicates)
     return parser->emitError(parser->getNameLoc(),
                              "unknown comparison predicate \"" + predicateName +
@@ -888,36 +861,37 @@ ParseResult CmpFOp::parse(OpAsmParser *parser, OperationState *result) {
   return success();
 }
 
-void CmpFOp::print(OpAsmPrinter *p) {
+static void print(OpAsmPrinter *p, CmpFOp op) {
   *p << "cmpf ";
 
   auto predicateValue =
-      getAttrOfType<IntegerAttr>(getPredicateAttrName()).getInt();
+      op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()).getInt();
   assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) &&
          predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) &&
          "unknown predicate index");
-  Builder b(getContext());
+  Builder b(op.getContext());
   auto predicateStringAttr =
       b.getStringAttr(getCmpFPredicateNames()[predicateValue]);
   p->printAttribute(predicateStringAttr);
 
   *p << ", ";
-  p->printOperand(getOperand(0));
+  p->printOperand(op.lhs());
   *p << ", ";
-  p->printOperand(getOperand(1));
-  p->printOptionalAttrDict(getAttrs(),
-                           /*elidedAttrs=*/{getPredicateAttrName()});
-  *p << " : " << getOperand(0)->getType();
+  p->printOperand(op.rhs());
+  p->printOptionalAttrDict(op.getAttrs(),
+                           /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()});
+  *p << " : " << op.lhs()->getType();
 }
 
-LogicalResult CmpFOp::verify() {
-  auto predicateAttr = getAttrOfType<IntegerAttr>(getPredicateAttrName());
+static LogicalResult verify(CmpFOp op) {
+  auto predicateAttr =
+      op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName());
   if (!predicateAttr)
-    return emitOpError("requires an integer attribute named 'predicate'");
+    return op.emitOpError("requires an integer attribute named 'predicate'");
   auto predicate = predicateAttr.getInt();
   if (predicate < (int64_t)CmpFPredicate::FirstValidValue ||
       predicate >= (int64_t)CmpFPredicate::NumPredicates)
-    return emitOpError("'predicate' attribute value out of range");
+    return op.emitOpError("'predicate' attribute value out of range");
 
   return success();
 }
@@ -1952,13 +1926,7 @@ static LogicalResult verify(ReturnOp op) {
 // SelectOp
 //===----------------------------------------------------------------------===//
 
-void SelectOp::build(Builder *builder, OperationState *result, Value *condition,
-                     Value *trueValue, Value *falseValue) {
-  result->addOperands({condition, trueValue, falseValue});
-  result->addTypes(trueValue->getType());
-}
-
-ParseResult SelectOp::parse(OpAsmParser *parser, OperationState *result) {
+static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 3> ops;
   SmallVector<NamedAttribute, 4> attrs;
   Type type;
@@ -1979,26 +1947,21 @@ ParseResult SelectOp::parse(OpAsmParser *parser, OperationState *result) {
                  parser->addTypeToList(type, result->types));
 }
 
-void SelectOp::print(OpAsmPrinter *p) {
+static void print(OpAsmPrinter *p, SelectOp op) {
   *p << "select ";
-  p->printOperands(getOperation()->getOperands());
-  *p << " : " << getTrueValue()->getType();
-  p->printOptionalAttrDict(getAttrs());
+  p->printOperands(op.getOperands());
+  *p << " : " << op.getTrueValue()->getType();
+  p->printOptionalAttrDict(op.getAttrs());
 }
 
-LogicalResult SelectOp::verify() {
-  auto conditionType = getCondition()->getType();
-  auto trueType = getTrueValue()->getType();
-  auto falseType = getFalseValue()->getType();
+static LogicalResult verify(SelectOp op) {
+  auto trueType = op.getTrueValue()->getType();
+  auto falseType = op.getFalseValue()->getType();
 
   if (trueType != falseType)
-    return emitOpError(
+    return op.emitOpError(
         "requires 'true' and 'false' arguments to be of the same type");
 
-  if (checkI1SameShape(trueType, conditionType))
-    return emitOpError("requires the condition to have the same shape as "
-                       "arguments with elemental type i1");
-
   return success();
 }
 
index a975808..80c3787 100644 (file)
@@ -218,7 +218,7 @@ func @func_with_ops(i32, i32) {
 // Integer comparisons are not recognized for float types.
 func @func_with_ops(f32, f32) {
 ^bb0(%a : f32, %b : f32):
-  %r = cmpi "eq", %a, %b : f32 // expected-error {{op requires an integer or index type}}
+  %r = cmpi "eq", %a, %b : f32 // expected-error {{operand #0 must be integer-like}}
 }
 
 // -----
@@ -226,7 +226,7 @@ func @func_with_ops(f32, f32) {
 // Result type must be boolean like.
 func @func_with_ops(i32, i32) {
 ^bb0(%a : i32, %b : i32):
-  %r = "std.cmpi"(%a, %b) {predicate: 0} : (i32, i32) -> i32 // expected-error {{op requires a bool result type}}
+  %r = "std.cmpi"(%a, %b) {predicate: 0} : (i32, i32) -> i32 // expected-error {{op result #0 must be bool-like}}
 }
 
 // -----
@@ -259,7 +259,7 @@ func @func_with_ops(i32, i32, i32) {
 
 func @func_with_ops(i32, i32, i32) {
 ^bb0(%cond : i32, %t : i32, %f : i32):
-  // expected-error@+1 {{elemental type i1}}
+  // expected-error@+1 {{op operand #0 must be bool-like}}
   %r = "std.select"(%cond, %t, %f) : (i32, i32, i32) -> i32
 }
 
@@ -275,7 +275,7 @@ func @func_with_ops(i1, i32, i64) {
 
 func @func_with_ops(i1, vector<42xi32>, vector<42xi32>) {
 ^bb0(%cond : i1, %t : vector<42xi32>, %f : vector<42xi32>):
-  // expected-error@+1 {{requires the condition to have the same shape as arguments}}
+  // expected-error@+1 {{requires the same shape for all operands and results}}
   %r = "std.select"(%cond, %t, %f) : (i1, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
 }
 
@@ -283,7 +283,7 @@ func @func_with_ops(i1, vector<42xi32>, vector<42xi32>) {
 
 func @func_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
 ^bb0(%cond : i1, %t : tensor<42xi32>, %f : tensor<?xi32>):
-  // expected-error@+1 {{'true' and 'false' arguments to be of the same type}}
+  // expected-error@+1 {{ op requires the same shape for all operands and results}}
   %r = "std.select"(%cond, %t, %f) : (i1, tensor<42xi32>, tensor<?xi32>) -> tensor<42xi32>
 }
 
@@ -291,7 +291,7 @@ func @func_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
 
 func @func_with_ops(tensor<?xi1>, tensor<42xi32>, tensor<42xi32>) {
 ^bb0(%cond : tensor<?xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
-  // expected-error@+1 {{requires the condition to have the same shape as arguments}}
+  // expected-error@+1 {{requires the same shape for all operands and results}}
   %r = "std.select"(%cond, %t, %f) : (tensor<?xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
 }
 
@@ -566,13 +566,13 @@ func @cmpf_generic_no_predicate_attr(%a : f32, %b : f32) {
 // -----
 
 func @cmpf_wrong_type(%a : i32, %b : i32) {
-  %r = cmpf "oeq", %a, %b : i32 // expected-error {{op requires a float type}}
+  %r = cmpf "oeq", %a, %b : i32 // expected-error {{operand #0 must be floating-point-like}}
 }
 
 // -----
 
 func @cmpf_generic_wrong_result_type(%a : f32, %b : f32) {
-  // expected-error@+1 {{op requires a bool result type}}
+  // expected-error@+1 {{result #0 must be bool-like}}
   %r = "std.cmpf"(%a, %b) {predicate: 0} : (f32, f32) -> f32
 }