From aec9e20a3e9a4f25a5b1e07816c95f970300d918 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 4 Sep 2020 10:00:09 +0200 Subject: [PATCH] [mlir] introduce type constraints for operands of LLVM dialect operations Historically, the operations in the MLIR's LLVM dialect only checked that the operand are of LLVM dialect type without more detailed constraints. This was due to LLVM dialect types wrapping LLVM IR types and having clunky verification methods. With the new first-class modeling, it is possible to define type constraints similarly to other dialects and use them to enforce some correctness rules in verifiers instead of having LLVM assert during translation to LLVM IR. This hardening discovered several issues where MLIR was producing LLVM dialect operations that cannot exist in LLVM IR. Depends On D85900 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D85901 --- mlir/include/mlir/Dialect/GPU/GPUOps.td | 3 +- mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 126 ++++++++++++++---- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 171 +++++++++++++++---------- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 2 - mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 3 +- mlir/test/Dialect/LLVMIR/invalid.mlir | 4 +- 6 files changed, 215 insertions(+), 94 deletions(-) diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td index 288031c..0ae6267 100644 --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -21,7 +21,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td" // Type constraint accepting standard integers, indices and wrapped LLVM integer // types. def IntLikeOrLLVMInt : TypeConstraint< - Or<[AnySignlessInteger.predicate, Index.predicate, LLVMInt.predicate]>, + Or<[AnySignlessInteger.predicate, Index.predicate, + LLVM_AnyInteger.predicate]>, "integer, index or LLVM dialect equivalent">; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index 1f0eb6a..10755a4 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -17,6 +17,10 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +//===----------------------------------------------------------------------===// +// LLVM Dialect. +//===----------------------------------------------------------------------===// + def LLVM_Dialect : Dialect { let name = "llvm"; let cppNamespace = "LLVM"; @@ -38,34 +42,108 @@ def LLVM_Dialect : Dialect { }]; } -// LLVM IR type wrapped in MLIR. +//===----------------------------------------------------------------------===// +// LLVM dialect type constraints. +//===----------------------------------------------------------------------===// + +// LLVM dialect type. def LLVM_Type : DialectType()">, "LLVM dialect type">; -// Type constraint accepting only wrapped LLVM integer types. -def LLVMInt : TypeConstraint< - And<[LLVM_Type.predicate, - CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>, - "LLVM dialect integer">; +// Type constraint accepting LLVM integer types. +def LLVM_AnyInteger : Type< + CPred<"$_self.isa<::mlir::LLVM::LLVMIntegerType>()">, + "LLVM integer type">; + +// Type constraints accepting LLVM integer type of a specific width. +class LLVM_IntBase : + Type().getBitWidth() == " + # width>]>, + "LLVM " # width # "-bit integer type">, + BuildableType< + "::mlir::LLVM::LLVMIntegerType::get($_builder.getContext(), " + # width # ")">; + +def LLVM_i1 : LLVM_IntBase<1>; +def LLVM_i8 : LLVM_IntBase<8>; +def LLVM_i32 : LLVM_IntBase<32>; -def LLVMIntBase : TypeConstraint< +// Type constraint accepting LLVM primitive types, i.e. all types except void +// and function. +def LLVM_PrimitiveType : Type< And<[LLVM_Type.predicate, - CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>, - "LLVM dialect integer">; - -// Integer type of a specific width. -class LLVMI - : Type().isIntegerTy(" # width # ")">]>, - "LLVM dialect " # width # "-bit integer">, - BuildableType< - "::mlir::LLVM::LLVMType::getIntNTy($_builder.getContext()," - # width # ")">; - -def LLVMI1 : LLVMI<1>; + CPred<"!$_self.isa<::mlir::LLVM::LLVMVoidType, " + "::mlir::LLVM::LLVMFunctionType>()">]>, + "primitive LLVM type">; + +// Type constraint accepting any LLVM floating point type. +def LLVM_AnyFloat : Type< + CPred<"$_self.isa<::mlir::LLVM::LLVMBFloatType, " + "::mlir::LLVM::LLVMHalfType, " + "::mlir::LLVM::LLVMFloatType, " + "::mlir::LLVM::LLVMDoubleType>()">, + "floating point LLVM type">; + +// Type constraint accepting any LLVM pointer type. +def LLVM_AnyPointer : Type()">, + "LLVM pointer type">; + +// Type constraint accepting LLVM pointer type with an additional constraint +// on the element type. +class LLVM_PointerTo : Type< + And<[LLVM_AnyPointer.predicate, + SubstLeaves< + "$_self", + "$_self.cast<::mlir::LLVM::LLVMPointerType>().getElementType()", + pointee.predicate>]>, + "LLVM pointer to " # pointee.description>; + +// Type constraint accepting any LLVM structure type. +def LLVM_AnyStruct : Type()">, + "LLVM structure type">; + +// Type constraint accepting opaque LLVM structure type. +def LLVM_OpaqueStruct : Type< + And<[LLVM_AnyStruct.predicate, + CPred<"$_self.cast<::mlir::LLVM::LLVMStructType>().isOpaque()">]>>; + +// Type constraint accepting any LLVM type that can be loaded or stored, i.e. a +// type that has size (not void, function or opaque struct type). +def LLVM_LoadableType : Type< + And<[LLVM_PrimitiveType.predicate, Neg]>, + "LLVM type with size">; + +// Type constraint accepting any LLVM aggregate type, i.e. structure or array. +def LLVM_AnyAggregate : Type< + CPred<"$_self.isa<::mlir::LLVM::LLVMStructType, " + "::mlir::LLVM::LLVMArrayType>()">, + "LLVM aggregate type">; + +// Type constraint accepting any LLVM non-aggregate type, i.e. not structure or +// array. +def LLVM_AnyNonAggregate : Type, + "LLVM non-aggregate type">; + +// Type constraint accepting any LLVM vector type. +def LLVM_AnyVector : Type()">, + "LLVM vector type">; + +// Type constraint accepting an LLVM vector type with an additional constraint +// on the vector element type. +class LLVM_VectorOf : Type< + And<[LLVM_AnyVector.predicate, + SubstLeaves< + "$_self", + "$_self.cast<::mlir::LLVM::LLVMVectorType>().getElementType()", + element.predicate>]>, + "LLVM vector of " # element.description>; + +// Type constraint accepting a constrained type, or a vector of such types. +class LLVM_ScalarOrVectorOf : + AnyTypeOf<[element, LLVM_VectorOf]>; // Base class for LLVM operations. Defines the interface to the llvm::IRBuilder // used to translate to LLVM IR proper. @@ -85,6 +163,10 @@ class LLVM_OpBase traits = []> : string llvmBuilder = ""; } +//===----------------------------------------------------------------------===// +// Base classes for LLVM dialect operations. +//===----------------------------------------------------------------------===// + // Base class for LLVM operations. All operations get an "llvm." prefix in // their name automatically. LLVM operations have either zero or one result, // this class is specialized below for both cases and should not be used diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index b1dd7b1..b5bf4ac 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -87,39 +87,50 @@ class LLVM_TerminatorOp traits = []> : LLVM_Op; // Class for arithmetic binary operations. -class LLVM_ArithmeticOp traits = []> : +class LLVM_ArithmeticOpBase traits = []> : LLVM_OneResultOp, - Arguments<(ins LLVM_Type:$lhs, LLVM_Type:$rhs)>, + Arguments<(ins LLVM_ScalarOrVectorOf:$lhs, + LLVM_ScalarOrVectorOf:$rhs)>, LLVM_Builder<"$res = builder." # builderFunc # "($lhs, $rhs);"> { - let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; + let parser = + [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; let printer = [{ mlir::impl::printOneResultOp(this->getOperation(), p); }]; } -class LLVM_UnaryArithmeticOp traits = []> : +class LLVM_IntArithmeticOp traits = []> : + LLVM_ArithmeticOpBase; +class LLVM_FloatArithmeticOp traits = []> : + LLVM_ArithmeticOpBase; + +// Class for arithmetic unary operations. +class LLVM_UnaryArithmeticOp traits = []> : LLVM_OneResultOp, - Arguments<(ins LLVM_Type:$operand)>, + Arguments<(ins type:$operand)>, LLVM_Builder<"$res = builder." # builderFunc # "($operand);"> { - let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; + let parser = + [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; let printer = [{ mlir::impl::printOneResultOp(this->getOperation(), p); }]; } // Integer binary operations. -def LLVM_AddOp : LLVM_ArithmeticOp<"add", "CreateAdd", [Commutative]>; -def LLVM_SubOp : LLVM_ArithmeticOp<"sub", "CreateSub">; -def LLVM_MulOp : LLVM_ArithmeticOp<"mul", "CreateMul", [Commutative]>; -def LLVM_UDivOp : LLVM_ArithmeticOp<"udiv", "CreateUDiv">; -def LLVM_SDivOp : LLVM_ArithmeticOp<"sdiv", "CreateSDiv">; -def LLVM_URemOp : LLVM_ArithmeticOp<"urem", "CreateURem">; -def LLVM_SRemOp : LLVM_ArithmeticOp<"srem", "CreateSRem">; -def LLVM_AndOp : LLVM_ArithmeticOp<"and", "CreateAnd">; -def LLVM_OrOp : LLVM_ArithmeticOp<"or", "CreateOr">; -def LLVM_XOrOp : LLVM_ArithmeticOp<"xor", "CreateXor">; -def LLVM_ShlOp : LLVM_ArithmeticOp<"shl", "CreateShl">; -def LLVM_LShrOp : LLVM_ArithmeticOp<"lshr", "CreateLShr">; -def LLVM_AShrOp : LLVM_ArithmeticOp<"ashr", "CreateAShr">; +def LLVM_AddOp : LLVM_IntArithmeticOp<"add", "CreateAdd", [Commutative]>; +def LLVM_SubOp : LLVM_IntArithmeticOp<"sub", "CreateSub">; +def LLVM_MulOp : LLVM_IntArithmeticOp<"mul", "CreateMul", [Commutative]>; +def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "CreateUDiv">; +def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "CreateSDiv">; +def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "CreateURem">; +def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "CreateSRem">; +def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "CreateAnd">; +def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "CreateOr">; +def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "CreateXor">; +def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "CreateShl">; +def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "CreateLShr">; +def LLVM_AShrOp : LLVM_IntArithmeticOp<"ashr", "CreateAShr">; // Predicate for integer comparisons. def ICmpPredicateEQ : I64EnumAttrCase<"eq", 0>; @@ -143,8 +154,9 @@ def ICmpPredicate : I64EnumAttr< // Other integer operations. def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>, - Arguments<(ins ICmpPredicate:$predicate, LLVM_Type:$lhs, - LLVM_Type:$rhs)> { + Arguments<(ins ICmpPredicate:$predicate, + LLVM_ScalarOrVectorOf:$lhs, + LLVM_ScalarOrVectorOf:$rhs)> { let llvmBuilder = [{ $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; @@ -189,8 +201,9 @@ def FCmpPredicate : I64EnumAttr< // Other integer operations. def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]>, - Arguments<(ins FCmpPredicate:$predicate, LLVM_Type:$lhs, - LLVM_Type:$rhs)> { + Arguments<(ins FCmpPredicate:$predicate, + LLVM_ScalarOrVectorOf:$lhs, + LLVM_ScalarOrVectorOf:$rhs)> { let llvmBuilder = [{ $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; @@ -205,12 +218,13 @@ def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]>, } // Floating point binary operations. -def LLVM_FAddOp : LLVM_ArithmeticOp<"fadd", "CreateFAdd">; -def LLVM_FSubOp : LLVM_ArithmeticOp<"fsub", "CreateFSub">; -def LLVM_FMulOp : LLVM_ArithmeticOp<"fmul", "CreateFMul">; -def LLVM_FDivOp : LLVM_ArithmeticOp<"fdiv", "CreateFDiv">; -def LLVM_FRemOp : LLVM_ArithmeticOp<"frem", "CreateFRem">; -def LLVM_FNegOp : LLVM_UnaryArithmeticOp<"fneg", "CreateFNeg">; +def LLVM_FAddOp : LLVM_FloatArithmeticOp<"fadd", "CreateFAdd">; +def LLVM_FSubOp : LLVM_FloatArithmeticOp<"fsub", "CreateFSub">; +def LLVM_FMulOp : LLVM_FloatArithmeticOp<"fmul", "CreateFMul">; +def LLVM_FDivOp : LLVM_FloatArithmeticOp<"fdiv", "CreateFDiv">; +def LLVM_FRemOp : LLVM_FloatArithmeticOp<"frem", "CreateFRem">; +def LLVM_FNegOp : LLVM_UnaryArithmeticOp, + "fneg", "CreateFNeg">; // Common code definition that is used to verify and set the alignment attribute // of LLVM ops that accept such an attribute. @@ -241,7 +255,8 @@ class MemoryOpWithAlignmentAndAttributes : MemoryOpWithAlignmentBase { def LLVM_AllocaOp : MemoryOpWithAlignmentBase, LLVM_OneResultOp<"alloca">, - Arguments<(ins LLVM_Type:$arraySize, OptionalAttr:$alignment)> { + Arguments<(ins LLVM_AnyInteger:$arraySize, + OptionalAttr:$alignment)> { string llvmBuilder = [{ auto *inst = builder.CreateAlloca( $_resultType->getPointerElementType(), $arraySize); @@ -259,8 +274,11 @@ def LLVM_AllocaOp : let parser = [{ return parseAllocaOp(parser, result); }]; let printer = [{ printAllocaOp(p, *this); }]; } + def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$base, Variadic:$indices)>, + Arguments<(ins LLVM_ScalarOrVectorOf:$base, + Variadic>:$indices)>, LLVM_Builder<"$res = builder.CreateGEP($base, $indices);"> { let assemblyFormat = [{ $base `[` $indices `]` attr-dict `:` functional-type(operands, results) @@ -269,7 +287,7 @@ def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>, def LLVM_LoadOp : MemoryOpWithAlignmentAndAttributes, LLVM_OneResultOp<"load">, - Arguments<(ins LLVM_Type:$addr, + Arguments<(ins LLVM_PointerTo:$addr, OptionalAttr:$alignment, UnitAttr:$volatile_, UnitAttr:$nontemporal)> { @@ -296,8 +314,8 @@ def LLVM_LoadOp : def LLVM_StoreOp : MemoryOpWithAlignmentAndAttributes, LLVM_ZeroResultOp<"store">, - Arguments<(ins LLVM_Type:$value, - LLVM_Type:$addr, + Arguments<(ins LLVM_LoadableType:$value, + LLVM_PointerTo:$addr, OptionalAttr:$alignment, UnitAttr:$volatile_, UnitAttr:$nontemporal)> { @@ -314,28 +332,41 @@ def LLVM_StoreOp : } // Casts. -class LLVM_CastOp traits = []> : LLVM_OneResultOp, - Arguments<(ins LLVM_Type:$arg)>, + Arguments<(ins type:$arg)>, LLVM_Builder<"$res = builder." # builderFunc # "($arg, $_resultType);"> { let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; } -def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast">; -def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "CreateAddrSpaceCast">; -def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr">; -def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt">; -def LLVM_SExtOp : LLVM_CastOp<"sext", "CreateSExt">; -def LLVM_ZExtOp : LLVM_CastOp<"zext", "CreateZExt">; -def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc">; -def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "CreateSIToFP">; -def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "CreateUIToFP">; -def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "CreateFPToSI">; -def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "CreateFPToUI">; -def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt">; -def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc">; +def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast", + LLVM_AnyNonAggregate>; +def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "CreateAddrSpaceCast", + LLVM_ScalarOrVectorOf>; +def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr", + LLVM_ScalarOrVectorOf>; +def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt", + LLVM_ScalarOrVectorOf>; +def LLVM_SExtOp : LLVM_CastOp<"sext", "CreateSExt", + LLVM_ScalarOrVectorOf>; +def LLVM_ZExtOp : LLVM_CastOp<"zext", "CreateZExt", + LLVM_ScalarOrVectorOf>; +def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc", + LLVM_ScalarOrVectorOf>; +def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "CreateSIToFP", + LLVM_ScalarOrVectorOf>; +def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "CreateUIToFP", + LLVM_ScalarOrVectorOf>; +def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "CreateFPToSI", + LLVM_ScalarOrVectorOf>; +def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "CreateFPToUI", + LLVM_ScalarOrVectorOf>; +def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt", + LLVM_ScalarOrVectorOf>; +def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc", + LLVM_ScalarOrVectorOf>; // Call-related operations. def LLVM_InvokeOp : LLVM_Op<"invoke", [ @@ -404,8 +435,8 @@ def LLVM_CallOp : LLVM_Op<"call">, let printer = [{ printCallOp(p, *this); }]; } def LLVM_ExtractElementOp : LLVM_OneResultOp<"extractelement", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$vector, - LLVM_Type:$position)> { + Arguments<(ins LLVM_AnyVector:$vector, + LLVM_AnyInteger:$position)> { string llvmBuilder = [{ $res = builder.CreateExtractElement($vector, $position); }]; @@ -416,8 +447,8 @@ def LLVM_ExtractElementOp : LLVM_OneResultOp<"extractelement", [NoSideEffect]>, let printer = [{ printExtractElementOp(p, *this); }]; } def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$container, - ArrayAttr:$position)> { + Arguments<(ins LLVM_AnyAggregate:$container, + ArrayAttr:$position)> { string llvmBuilder = [{ $res = builder.CreateExtractValue($container, extractPosition($position)); }]; @@ -425,8 +456,9 @@ def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]>, let printer = [{ printExtractValueOp(p, *this); }]; } def LLVM_InsertElementOp : LLVM_OneResultOp<"insertelement", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$vector, LLVM_Type:$value, - LLVM_Type:$position)> { + Arguments<(ins LLVM_AnyVector:$vector, + LLVM_PrimitiveType:$value, + LLVM_AnyInteger:$position)> { string llvmBuilder = [{ $res = builder.CreateInsertElement($vector, $value, $position); }]; @@ -434,8 +466,9 @@ def LLVM_InsertElementOp : LLVM_OneResultOp<"insertelement", [NoSideEffect]>, let printer = [{ printInsertElementOp(p, *this); }]; } def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$container, LLVM_Type:$value, - ArrayAttr:$position)> { + Arguments<(ins LLVM_AnyAggregate:$container, + LLVM_PrimitiveType:$value, + ArrayAttr:$position)> { string llvmBuilder = [{ $res = builder.CreateInsertValue($container, $value, extractPosition($position)); @@ -451,7 +484,7 @@ def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>, } def LLVM_ShuffleVectorOp : LLVM_OneResultOp<"shufflevector", [NoSideEffect]>, - Arguments<(ins LLVM_Type:$v1, LLVM_Type:$v2, ArrayAttr:$mask)> { + Arguments<(ins LLVM_AnyVector:$v1, LLVM_AnyVector:$v2, ArrayAttr:$mask)> { string llvmBuilder = [{ SmallVector position = extractPosition($mask); SmallVector mask(position.begin(), position.end()); @@ -478,8 +511,9 @@ def LLVM_ShuffleVectorOp def LLVM_SelectOp : LLVM_OneResultOp<"select", [NoSideEffect, AllTypesMatch<["trueValue", "falseValue", "res"]>]>, - Arguments<(ins LLVM_Type:$condition, LLVM_Type:$trueValue, - LLVM_Type:$falseValue)>, + Arguments<(ins LLVM_ScalarOrVectorOf:$condition, + LLVM_Type:$trueValue, + LLVM_Type:$falseValue)>, LLVM_Builder< "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> { let builders = [OpBuilder< @@ -508,7 +542,7 @@ def LLVM_BrOp : LLVM_TerminatorOp<"br", def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", [AttrSizedOperandSegments, DeclareOpInterfaceMethods, NoSideEffect]> { - let arguments = (ins LLVMI1:$condition, + let arguments = (ins LLVM_i1:$condition, Variadic:$trueDestOperands, Variadic:$falseDestOperands, OptionalAttr:$branch_weights); @@ -1090,9 +1124,11 @@ def AtomicOrdering : I64EnumAttr< let cppNamespace = "::mlir::LLVM"; } +def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyInteger]>; + def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw">, - Arguments<(ins AtomicBinOp:$bin_op, LLVM_Type:$ptr, LLVM_Type:$val, - AtomicOrdering:$ordering)>, + Arguments<(ins AtomicBinOp:$bin_op, LLVM_PointerTo:$ptr, + LLVM_AtomicRMWType:$val, AtomicOrdering:$ordering)>, Results<(outs LLVM_Type:$res)> { let llvmBuilder = [{ $res = builder.CreateAtomicRMW(getLLVMAtomicBinOp($bin_op), $ptr, $val, @@ -1103,8 +1139,11 @@ def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw">, let verifier = "return ::verify(*this);"; } +def LLVM_AtomicCmpXchgType : AnyTypeOf<[LLVM_AnyInteger, LLVM_AnyPointer]>; + def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg">, - Arguments<(ins LLVM_Type:$ptr, LLVM_Type:$cmp, LLVM_Type:$val, + Arguments<(ins LLVM_PointerTo:$ptr, + LLVM_AtomicCmpXchgType:$cmp, LLVM_AtomicCmpXchgType:$val, AtomicOrdering:$success_ordering, AtomicOrdering:$failure_ordering)>, Results<(outs LLVM_Type:$res)> { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 96d8459..63bd10c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1533,8 +1533,6 @@ static ParseResult parseAtomicRMWOp(OpAsmParser &parser, static LogicalResult verify(AtomicRMWOp op) { auto ptrType = op.ptr().getType().cast(); - if (!ptrType.isPointerTy()) - return op.emitOpError("expected LLVM IR pointer type for operand #0"); auto valType = op.val().getType().cast(); if (valType != ptrType.getPointerElementTy()) return op.emitOpError("expected LLVM IR element type for operand #0 to " diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index e27650b..a89287b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -440,7 +440,8 @@ LogicalResult LLVMStructType::setBody(ArrayRef types, bool isPacked) { bool LLVMStructType::isPacked() { return getImpl()->isPacked(); } bool LLVMStructType::isIdentified() { return getImpl()->isIdentified(); } bool LLVMStructType::isOpaque() { - return getImpl()->isOpaque() || !getImpl()->isInitialized(); + return getImpl()->isIdentified() && + (getImpl()->isOpaque() || !getImpl()->isInitialized()); } bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); } StringRef LLVMStructType::getName() { return getImpl()->getIdentifier(); } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 1f8b160..c19795e 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -394,7 +394,7 @@ func @nvvm_invalid_mma_7(%a0 : !llvm.vec<2 x half>, %a1 : !llvm.vec<2 x half>, // CHECK-LABEL: @atomicrmw_expected_ptr func @atomicrmw_expected_ptr(%f32 : !llvm.float) { - // expected-error@+1 {{expected LLVM IR pointer type for operand #0}} + // expected-error@+1 {{operand #0 must be LLVM pointer to floating point LLVM type or LLVM integer type}} %0 = "llvm.atomicrmw"(%f32, %f32) {bin_op=11, ordering=1} : (!llvm.float, !llvm.float) -> !llvm.float llvm.return } @@ -448,7 +448,7 @@ func @atomicrmw_expected_int(%f32_ptr : !llvm.ptr, %f32 : !llvm.float) { // CHECK-LABEL: @cmpxchg_expected_ptr func @cmpxchg_expected_ptr(%f32_ptr : !llvm.ptr, %f32 : !llvm.float) { - // expected-error@+1 {{expected LLVM IR pointer type for operand #0}} + // expected-error@+1 {{op operand #0 must be LLVM pointer to LLVM integer type or LLVM pointer type}} %0 = "llvm.cmpxchg"(%f32, %f32, %f32) {success_ordering=2,failure_ordering=2} : (!llvm.float, !llvm.float, !llvm.float) -> !llvm.struct<(float, i1)> llvm.return } -- 2.7.4