From ca885b3c81015dd482a0c19445aeb3989004532d Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 24 May 2019 18:01:20 -0700 Subject: [PATCH] Move the definitions of CmpIOp, CmpFOp, and SelectOp to the ODG framework. -- PiperOrigin-RevId: 249928953 --- mlir/include/mlir/IR/OpBase.td | 8 +++ mlir/include/mlir/StandardOps/Ops.h | 121 +------------------------------ mlir/include/mlir/StandardOps/Ops.td | 125 ++++++++++++++++++++++++++++++++ mlir/lib/IR/Operation.cpp | 2 + mlir/lib/StandardOps/Ops.cpp | 133 +++++++++++++---------------------- mlir/test/IR/invalid-ops.mlir | 16 ++--- 6 files changed, 193 insertions(+), 212 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 780d92f..99be357 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -459,6 +459,12 @@ class NestedTupleOf allowedTypes> : // Common type constraints //===----------------------------------------------------------------------===// +// Type constraint for bool-like types: bools, vectors of bools, tensors of +// bools. +def BoolLike : TypeConstraint.predicate, + TensorOf<[I1]>.predicate]>, + "bool-like">; + // Type constraint for integer-like types: integers, indices, vectors of // integers, tensors of integers. def IntegerLike : TypeConstraint; def FloatLikeResults : NativeOpTrait<"ResultsAreFloatLike">; // Op has no side effect. def NoSideEffect : NativeOpTrait<"HasNoSideEffect">; +// Op has the same operand type. +def SameOperandType : NativeOpTrait<"SameTypeOperands">; // Op has same operand and result shape. def SameValueShape : NativeOpTrait<"SameOperandsAndResultShape">; // Op has the same operand and result type. diff --git a/mlir/include/mlir/StandardOps/Ops.h b/mlir/include/mlir/StandardOps/Ops.h index 6db6fe0..18008f2 100644 --- a/mlir/include/mlir/StandardOps/Ops.h +++ b/mlir/include/mlir/StandardOps/Ops.h @@ -37,9 +37,6 @@ public: StandardOpsDialect(MLIRContext *context); }; -#define GET_OP_CLASSES -#include "mlir/StandardOps/Ops.h.inc" - /// The predicate indicates the type of the comparison to perform: /// (in)equality; (un)signed less/greater than (or equal to). enum class CmpIPredicate { @@ -61,49 +58,6 @@ enum class CmpIPredicate { NumPredicates }; -/// The "cmpi" operation compares its two operands according to the integer -/// comparison rules and the predicate specified by the respective attribute. -/// The predicate defines the type of comparison: (in)equality, (un)signed -/// less/greater than (or equal to). The operands must have the same type, and -/// this type must be an integer type, a vector or a tensor thereof. The result -/// is an i1, or a vector/tensor thereof having the same shape as the inputs. -/// Since integers are signless, the predicate also explicitly indicates -/// whether to interpret the operands as signed or unsigned integers for -/// less/greater than comparisons. For the sake of readability by humans, -/// custom assembly form for the operation uses a string-typed attribute for -/// the predicate. The value of this attribute corresponds to lower-cased name -/// of the predicate constant, e.g., "slt" means "signed less than". The string -/// representation of the attribute is merely a syntactic sugar and is converted -/// to an integer attribute by the parser. -/// -/// %r1 = cmpi "eq" %0, %1 : i32 -/// %r2 = cmpi "slt" %0, %1 : tensor<42x42xi64> -/// %r3 = "std.cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1 -class CmpIOp - : public Op::Impl, - OpTrait::OneResult, OpTrait::ResultsAreBoolLike, - OpTrait::SameOperandsAndResultShape, OpTrait::HasNoSideEffect> { -public: - using Op::Op; - - CmpIPredicate getPredicate() { - return (CmpIPredicate)getAttrOfType(getPredicateAttrName()) - .getInt(); - } - - static StringRef getOperationName() { return "std.cmpi"; } - static StringRef getPredicateAttrName() { return "predicate"; } - static CmpIPredicate getPredicateByName(StringRef name); - - static void build(Builder *builder, OperationState *result, CmpIPredicate, - Value *lhs, Value *rhs); - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - LogicalResult verify(); - OpFoldResult fold(ArrayRef operands); -}; - /// The predicate indicates the type of the comparison to perform: /// (un)orderedness, (in)equality and signed less/greater than (or equal to) as /// well as predicates that are always true or false. @@ -135,49 +89,8 @@ enum class CmpFPredicate { NumPredicates }; -/// The "cmpf" operation compares its two operands according to the float -/// comparison rules and the predicate specified by the respective attribute. -/// The predicate defines the type of comparison: (un)orderedness, (in)equality -/// and signed less/greater than (or equal to) as well as predicates that are -/// always true or false. The operands must have the same type, and this type -/// must be a float type, or a vector or tensor thereof. The result is an i1, -/// or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi, -/// the operands are always treated as signed. The u prefix indicates -/// *unordered* comparison, not unsigned comparison, so "une" means unordered or -/// not equal. For the sake of readability by humans, custom assembly form for -/// the operation uses a string-typed attribute for the predicate. The value of -/// this attribute corresponds to lower-cased name of the predicate constant, -/// e.g., "one" means "ordered not equal". The string representation of the -/// attribute is merely a syntactic sugar and is converted to an integer -/// attribute by the parser. -/// -/// %r1 = cmpf "oeq" %0, %1 : f32 -/// %r2 = cmpf "ult" %0, %1 : tensor<42x42xf64> -/// %r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1 -class CmpFOp - : public Op::Impl, - OpTrait::OneResult, OpTrait::ResultsAreBoolLike, - OpTrait::SameOperandsAndResultShape, OpTrait::HasNoSideEffect> { -public: - using Op::Op; - - CmpFPredicate getPredicate() { - return (CmpFPredicate)getAttrOfType(getPredicateAttrName()) - .getInt(); - } - - static StringRef getOperationName() { return "std.cmpf"; } - static StringRef getPredicateAttrName() { return "predicate"; } - static CmpFPredicate getPredicateByName(StringRef name); - - static void build(Builder *builder, OperationState *result, CmpFPredicate, - Value *lhs, Value *rhs); - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - LogicalResult verify(); - OpFoldResult fold(ArrayRef operands); -}; +#define GET_OP_CLASSES +#include "mlir/StandardOps/Ops.h.inc" /// The "cond_br" operation represents a conditional branch operation in a /// function. The operation takes variable number of operands and produces @@ -573,36 +486,6 @@ public: MLIRContext *context); }; -/// The "select" operation chooses one value based on a binary condition -/// supplied as its first operand. If the value of the first operand is 1, the -/// second operand is chosen, otherwise the third operand is chosen. The second -/// and the third operand must have the same type. The operation applies -/// elementwise to vectors and tensors. The shape of all arguments must be -/// identical. For example, the maximum operation is obtained by combining -/// "select" with "cmpi" as follows. -/// -/// %2 = cmpi "gt" %0, %1 : i32 // %2 is i1 -/// %3 = select %2, %0, %1 : i32 -/// -class SelectOp : public Op::Impl, - OpTrait::OneResult, OpTrait::HasNoSideEffect> { -public: - using Op::Op; - - static StringRef getOperationName() { return "std.select"; } - static void build(Builder *builder, OperationState *result, Value *condition, - Value *trueValue, Value *falseValue); - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - LogicalResult verify(); - - Value *getCondition() { return getOperand(0); } - Value *getTrueValue() { return getOperand(1); } - Value *getFalseValue() { return getOperand(2); } - - OpFoldResult fold(ArrayRef operands); -}; - /// The "store" op writes an element to a memref specified by an index list. /// The arity of indices is the rank of the memref (i.e. if the memref being /// stored to is of rank 3, then 3 indices are required for the store following diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index 058750f..817079a 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -278,6 +278,97 @@ def CallIndirectOp : Std_Op<"call_indirect"> { let hasCanonicalizer = 1; } +def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameOperandType, SameValueShape]> { + let summary = "integer comparison operation"; + let description = [{ + The "cmpi" operation compares its two operands according to the integer + comparison rules and the predicate specified by the respective attribute. + The predicate defines the type of comparison: (in)equality, (un)signed + less/greater than (or equal to). The operands must have the same type, and + this type must be an integer type, a vector or a tensor thereof. The result + is an i1, or a vector/tensor thereof having the same shape as the inputs. + Since integers are signless, the predicate also explicitly indicates + whether to interpret the operands as signed or unsigned integers for + less/greater than comparisons. For the sake of readability by humans, + custom assembly form for the operation uses a string-typed attribute for + the predicate. The value of this attribute corresponds to lower-cased name + of the predicate constant, e.g., "slt" means "signed less than". The string + representation of the attribute is merely a syntactic sugar and is converted + to an integer attribute by the parser. + + %r1 = cmpi "eq" %0, %1 : i32 + %r2 = cmpi "slt" %0, %1 : tensor<42x42xi64> + %r3 = "std.cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1 + }]; + + let arguments = (ins IntegerLike:$lhs, IntegerLike:$rhs); + let results = (outs BoolLike); + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, CmpIPredicate predicate," + "Value *lhs, Value *rhs", [{ + ::buildCmpIOp(builder, result, predicate, lhs, rhs); + }]>]; + + let extraClassDeclaration = [{ + static StringRef getPredicateAttrName() { return "predicate"; } + static CmpIPredicate getPredicateByName(StringRef name); + + CmpIPredicate getPredicate() { + return (CmpIPredicate)getAttrOfType(getPredicateAttrName()) + .getInt(); + } + }]; + + let hasFolder = 1; +} + +def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameOperandType, SameValueShape]> { + let summary = "floating-point comparison operation"; + let description = [{ + The "cmpf" operation compares its two operands according to the float + comparison rules and the predicate specified by the respective attribute. + The predicate defines the type of comparison: (un)orderedness, (in)equality + and signed less/greater than (or equal to) as well as predicates that are + always true or false. The operands must have the same type, and this type + must be a float type, or a vector or tensor thereof. The result is an i1, + or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi, + the operands are always treated as signed. The u prefix indicates + *unordered* comparison, not unsigned comparison, so "une" means unordered or + not equal. For the sake of readability by humans, custom assembly form for + the operation uses a string-typed attribute for the predicate. The value of + this attribute corresponds to lower-cased name of the predicate constant, + e.g., "one" means "ordered not equal". The string representation of the + attribute is merely a syntactic sugar and is converted to an integer + attribute by the parser. + + %r1 = cmpf "oeq" %0, %1 : f32 + %r2 = cmpf "ult" %0, %1 : tensor<42x42xf64> + %r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1 + }]; + + let arguments = (ins FloatLike:$lhs, FloatLike:$rhs); + let results = (outs BoolLike); + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, CmpFPredicate predicate," + "Value *lhs, Value *rhs", [{ + ::buildCmpFOp(builder, result, predicate, lhs, rhs); + }]>]; + + let extraClassDeclaration = [{ + static StringRef getPredicateAttrName() { return "predicate"; } + static CmpFPredicate getPredicateByName(StringRef name); + + CmpFPredicate getPredicate() { + return (CmpFPredicate)getAttrOfType(getPredicateAttrName()) + .getInt(); + } + }]; + + let hasFolder = 1; +} + def ConstantOp : Std_Op<"constant", [NoSideEffect]> { let summary = "constant"; @@ -477,6 +568,40 @@ def ReturnOp : Std_Op<"return", [Terminator]> { >]; } +def SelectOp : Std_Op<"select", [NoSideEffect, SameValueShape]> { + let summary = "select operation"; + let description = [{ + The "select" operation chooses one value based on a binary condition + supplied as its first operand. If the value of the first operand is 1, the + second operand is chosen, otherwise the third operand is chosen. The second + and the third operand must have the same type. The operation applies + elementwise to vectors and tensors. The shape of all arguments must be + identical. For example, the maximum operation is obtained by combining + "select" with "cmpi" as follows. + + %2 = cmpi "gt" %0, %1 : i32 // %2 is i1 + %3 = select %2, %0, %1 : i32 + }]; + + let arguments = (ins BoolLike:$condition, AnyType:$true_value, + AnyType:$false_value); + let results = (outs AnyType); + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, Value *condition," + "Value *trueValue, Value *falseValue", [{ + result->addOperands({condition, trueValue, falseValue}); + result->addTypes(trueValue->getType()); + }]>]; + + let extraClassDeclaration = [{ + Value *getCondition() { return condition(); } + Value *getTrueValue() { return true_value(); } + Value *getFalseValue() { return false_value(); } + }]; + + let hasFolder = 1; +} def ShlISOp : IntArithmeticOp<"shlis"> { let summary = "signed integer shift left"; } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index f6fc3ee..5d50db1 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -782,6 +782,8 @@ static LogicalResult verifyShapeMatch(Type type1, Type type2) { // Either both or neither type should be shaped. if (!sType1) return success(!sType2); + if (!sType2) + return failure(); if (!sType1.hasRank() || !sType2.hasRank()) return success(); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index fd9d57f..6b559b3 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -81,8 +81,7 @@ template static LogicalResult verifyCastOp(T op) { StandardOpsDialect::StandardOpsDialect(MLIRContext *context) : Dialect(/*name=*/"std", context) { - addOperations(); @@ -597,33 +596,6 @@ static Type getI1SameShape(Builder *build, Type type) { return res; } -static inline bool isI1(Type type) { - return type.isa() && type.cast().getWidth() == 1; -} - -template -static inline bool implCheckI1SameShape(Ty pattern, Type type) { - auto specificType = type.dyn_cast(); - if (!specificType) - return true; - if (specificType.getShape() != pattern.getShape()) - return true; - return !isI1(specificType.getElementType()); -} - -// Checks if "type" has the same shape (scalar, vector or tensor) as "pattern" -// and contains i1. -static bool checkI1SameShape(Type pattern, Type type) { - if (pattern.isIntOrIndexOrFloat()) - return !isI1(type); - if (auto patternTensorType = pattern.dyn_cast()) - return implCheckI1SameShape(patternTensorType, type); - if (auto patternVectorType = pattern.dyn_cast()) - return implCheckI1SameShape(patternVectorType, type); - - llvm_unreachable("unsupported type"); -} - //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// @@ -665,21 +637,21 @@ CmpIPredicate CmpIOp::getPredicateByName(StringRef name) { .Default(CmpIPredicate::NumPredicates); } -void CmpIOp::build(Builder *build, OperationState *result, - CmpIPredicate predicate, Value *lhs, Value *rhs) { +static void buildCmpIOp(Builder *build, OperationState *result, + CmpIPredicate predicate, Value *lhs, Value *rhs) { result->addOperands({lhs, rhs}); result->types.push_back(getI1SameShape(build, lhs->getType())); result->addAttribute( - getPredicateAttrName(), + CmpIOp::getPredicateAttrName(), build->getI64IntegerAttr(static_cast(predicate))); } -ParseResult CmpIOp::parse(OpAsmParser *parser, OperationState *result) { +static ParseResult parseCmpIOp(OpAsmParser *parser, OperationState *result) { SmallVector ops; SmallVector attrs; Attribute predicateNameAttr; Type type; - if (parser->parseAttribute(predicateNameAttr, getPredicateAttrName(), + if (parser->parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(), attrs) || parser->parseComma() || parser->parseOperandList(ops, 2) || parser->parseOptionalAttributeDict(attrs) || @@ -693,7 +665,7 @@ ParseResult CmpIOp::parse(OpAsmParser *parser, OperationState *result) { // Rewrite string attribute to an enum value. StringRef predicateName = predicateNameAttr.cast().getValue(); - auto predicate = getPredicateByName(predicateName); + auto predicate = CmpIOp::getPredicateByName(predicateName); if (predicate == CmpIPredicate::NumPredicates) return parser->emitError(parser->getNameLoc()) << "unknown comparison predicate \"" << predicateName << "\""; @@ -711,36 +683,37 @@ ParseResult CmpIOp::parse(OpAsmParser *parser, OperationState *result) { return success(); } -void CmpIOp::print(OpAsmPrinter *p) { +static void print(OpAsmPrinter *p, CmpIOp op) { *p << "cmpi "; auto predicateValue = - getAttrOfType(getPredicateAttrName()).getInt(); + op.getAttrOfType(CmpIOp::getPredicateAttrName()).getInt(); assert(predicateValue >= static_cast(CmpIPredicate::FirstValidValue) && predicateValue < static_cast(CmpIPredicate::NumPredicates) && "unknown predicate index"); - Builder b(getContext()); + Builder b(op.getContext()); auto predicateStringAttr = b.getStringAttr(getCmpIPredicateNames()[predicateValue]); p->printAttribute(predicateStringAttr); *p << ", "; - p->printOperand(getOperand(0)); + p->printOperand(op.lhs()); *p << ", "; - p->printOperand(getOperand(1)); - p->printOptionalAttrDict(getAttrs(), - /*elidedAttrs=*/{getPredicateAttrName()}); - *p << " : " << getOperand(0)->getType(); + p->printOperand(op.rhs()); + p->printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()}); + *p << " : " << op.lhs()->getType(); } -LogicalResult CmpIOp::verify() { - auto predicateAttr = getAttrOfType(getPredicateAttrName()); +static LogicalResult verify(CmpIOp op) { + auto predicateAttr = + op.getAttrOfType(CmpIOp::getPredicateAttrName()); if (!predicateAttr) - return emitOpError("requires an integer attribute named 'predicate'"); + return op.emitOpError("requires an integer attribute named 'predicate'"); auto predicate = predicateAttr.getInt(); if (predicate < (int64_t)CmpIPredicate::FirstValidValue || predicate >= (int64_t)CmpIPredicate::NumPredicates) - return emitOpError("'predicate' attribute value out of range"); + return op.emitOpError("'predicate' attribute value out of range"); return success(); } @@ -841,21 +814,21 @@ CmpFPredicate CmpFOp::getPredicateByName(StringRef name) { .Default(CmpFPredicate::NumPredicates); } -void CmpFOp::build(Builder *build, OperationState *result, - CmpFPredicate predicate, Value *lhs, Value *rhs) { +static void buildCmpFOp(Builder *build, OperationState *result, + CmpFPredicate predicate, Value *lhs, Value *rhs) { result->addOperands({lhs, rhs}); result->types.push_back(getI1SameShape(build, lhs->getType())); result->addAttribute( - getPredicateAttrName(), + CmpFOp::getPredicateAttrName(), build->getI64IntegerAttr(static_cast(predicate))); } -ParseResult CmpFOp::parse(OpAsmParser *parser, OperationState *result) { +static ParseResult parseCmpFOp(OpAsmParser *parser, OperationState *result) { SmallVector ops; SmallVector attrs; Attribute predicateNameAttr; Type type; - if (parser->parseAttribute(predicateNameAttr, getPredicateAttrName(), + if (parser->parseAttribute(predicateNameAttr, CmpFOp::getPredicateAttrName(), attrs) || parser->parseComma() || parser->parseOperandList(ops, 2) || parser->parseOptionalAttributeDict(attrs) || @@ -869,7 +842,7 @@ ParseResult CmpFOp::parse(OpAsmParser *parser, OperationState *result) { // Rewrite string attribute to an enum value. StringRef predicateName = predicateNameAttr.cast().getValue(); - auto predicate = getPredicateByName(predicateName); + auto predicate = CmpFOp::getPredicateByName(predicateName); if (predicate == CmpFPredicate::NumPredicates) return parser->emitError(parser->getNameLoc(), "unknown comparison predicate \"" + predicateName + @@ -888,36 +861,37 @@ ParseResult CmpFOp::parse(OpAsmParser *parser, OperationState *result) { return success(); } -void CmpFOp::print(OpAsmPrinter *p) { +static void print(OpAsmPrinter *p, CmpFOp op) { *p << "cmpf "; auto predicateValue = - getAttrOfType(getPredicateAttrName()).getInt(); + op.getAttrOfType(CmpFOp::getPredicateAttrName()).getInt(); assert(predicateValue >= static_cast(CmpFPredicate::FirstValidValue) && predicateValue < static_cast(CmpFPredicate::NumPredicates) && "unknown predicate index"); - Builder b(getContext()); + Builder b(op.getContext()); auto predicateStringAttr = b.getStringAttr(getCmpFPredicateNames()[predicateValue]); p->printAttribute(predicateStringAttr); *p << ", "; - p->printOperand(getOperand(0)); + p->printOperand(op.lhs()); *p << ", "; - p->printOperand(getOperand(1)); - p->printOptionalAttrDict(getAttrs(), - /*elidedAttrs=*/{getPredicateAttrName()}); - *p << " : " << getOperand(0)->getType(); + p->printOperand(op.rhs()); + p->printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()}); + *p << " : " << op.lhs()->getType(); } -LogicalResult CmpFOp::verify() { - auto predicateAttr = getAttrOfType(getPredicateAttrName()); +static LogicalResult verify(CmpFOp op) { + auto predicateAttr = + op.getAttrOfType(CmpFOp::getPredicateAttrName()); if (!predicateAttr) - return emitOpError("requires an integer attribute named 'predicate'"); + return op.emitOpError("requires an integer attribute named 'predicate'"); auto predicate = predicateAttr.getInt(); if (predicate < (int64_t)CmpFPredicate::FirstValidValue || predicate >= (int64_t)CmpFPredicate::NumPredicates) - return emitOpError("'predicate' attribute value out of range"); + return op.emitOpError("'predicate' attribute value out of range"); return success(); } @@ -1952,13 +1926,7 @@ static LogicalResult verify(ReturnOp op) { // SelectOp //===----------------------------------------------------------------------===// -void SelectOp::build(Builder *builder, OperationState *result, Value *condition, - Value *trueValue, Value *falseValue) { - result->addOperands({condition, trueValue, falseValue}); - result->addTypes(trueValue->getType()); -} - -ParseResult SelectOp::parse(OpAsmParser *parser, OperationState *result) { +static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) { SmallVector ops; SmallVector attrs; Type type; @@ -1979,26 +1947,21 @@ ParseResult SelectOp::parse(OpAsmParser *parser, OperationState *result) { parser->addTypeToList(type, result->types)); } -void SelectOp::print(OpAsmPrinter *p) { +static void print(OpAsmPrinter *p, SelectOp op) { *p << "select "; - p->printOperands(getOperation()->getOperands()); - *p << " : " << getTrueValue()->getType(); - p->printOptionalAttrDict(getAttrs()); + p->printOperands(op.getOperands()); + *p << " : " << op.getTrueValue()->getType(); + p->printOptionalAttrDict(op.getAttrs()); } -LogicalResult SelectOp::verify() { - auto conditionType = getCondition()->getType(); - auto trueType = getTrueValue()->getType(); - auto falseType = getFalseValue()->getType(); +static LogicalResult verify(SelectOp op) { + auto trueType = op.getTrueValue()->getType(); + auto falseType = op.getFalseValue()->getType(); if (trueType != falseType) - return emitOpError( + return op.emitOpError( "requires 'true' and 'false' arguments to be of the same type"); - if (checkI1SameShape(trueType, conditionType)) - return emitOpError("requires the condition to have the same shape as " - "arguments with elemental type i1"); - return success(); } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index a975808..80c3787 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -218,7 +218,7 @@ func @func_with_ops(i32, i32) { // Integer comparisons are not recognized for float types. func @func_with_ops(f32, f32) { ^bb0(%a : f32, %b : f32): - %r = cmpi "eq", %a, %b : f32 // expected-error {{op requires an integer or index type}} + %r = cmpi "eq", %a, %b : f32 // expected-error {{operand #0 must be integer-like}} } // ----- @@ -226,7 +226,7 @@ func @func_with_ops(f32, f32) { // Result type must be boolean like. func @func_with_ops(i32, i32) { ^bb0(%a : i32, %b : i32): - %r = "std.cmpi"(%a, %b) {predicate: 0} : (i32, i32) -> i32 // expected-error {{op requires a bool result type}} + %r = "std.cmpi"(%a, %b) {predicate: 0} : (i32, i32) -> i32 // expected-error {{op result #0 must be bool-like}} } // ----- @@ -259,7 +259,7 @@ func @func_with_ops(i32, i32, i32) { func @func_with_ops(i32, i32, i32) { ^bb0(%cond : i32, %t : i32, %f : i32): - // expected-error@+1 {{elemental type i1}} + // expected-error@+1 {{op operand #0 must be bool-like}} %r = "std.select"(%cond, %t, %f) : (i32, i32, i32) -> i32 } @@ -275,7 +275,7 @@ func @func_with_ops(i1, i32, i64) { func @func_with_ops(i1, vector<42xi32>, vector<42xi32>) { ^bb0(%cond : i1, %t : vector<42xi32>, %f : vector<42xi32>): - // expected-error@+1 {{requires the condition to have the same shape as arguments}} + // expected-error@+1 {{requires the same shape for all operands and results}} %r = "std.select"(%cond, %t, %f) : (i1, vector<42xi32>, vector<42xi32>) -> vector<42xi32> } @@ -283,7 +283,7 @@ func @func_with_ops(i1, vector<42xi32>, vector<42xi32>) { func @func_with_ops(i1, tensor<42xi32>, tensor) { ^bb0(%cond : i1, %t : tensor<42xi32>, %f : tensor): - // expected-error@+1 {{'true' and 'false' arguments to be of the same type}} + // expected-error@+1 {{ op requires the same shape for all operands and results}} %r = "std.select"(%cond, %t, %f) : (i1, tensor<42xi32>, tensor) -> tensor<42xi32> } @@ -291,7 +291,7 @@ func @func_with_ops(i1, tensor<42xi32>, tensor) { func @func_with_ops(tensor, tensor<42xi32>, tensor<42xi32>) { ^bb0(%cond : tensor, %t : tensor<42xi32>, %f : tensor<42xi32>): - // expected-error@+1 {{requires the condition to have the same shape as arguments}} + // expected-error@+1 {{requires the same shape for all operands and results}} %r = "std.select"(%cond, %t, %f) : (tensor, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32> } @@ -566,13 +566,13 @@ func @cmpf_generic_no_predicate_attr(%a : f32, %b : f32) { // ----- func @cmpf_wrong_type(%a : i32, %b : i32) { - %r = cmpf "oeq", %a, %b : i32 // expected-error {{op requires a float type}} + %r = cmpf "oeq", %a, %b : i32 // expected-error {{operand #0 must be floating-point-like}} } // ----- func @cmpf_generic_wrong_result_type(%a : f32, %b : f32) { - // expected-error@+1 {{op requires a bool result type}} + // expected-error@+1 {{result #0 must be bool-like}} %r = "std.cmpf"(%a, %b) {predicate: 0} : (f32, f32) -> f32 } -- 2.7.4