From cd73af92315ecf25ed47f4991806a054ddfca5ea Mon Sep 17 00:00:00 2001 From: Kiran Chandramohan Date: Tue, 8 Jun 2021 16:48:57 +0100 Subject: [PATCH] [MLIR] Remove LLVM_AnyInteger type constraint LLVM Dialect uses builtin-integer types. The existing LLVM_AnyInteger type constraint is a dupe of AnyInteger. This patch removes LLVM_AnyInteger and replaces all usage with AnyInteger. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D103839 --- mlir/include/mlir/Dialect/AMX/AMX.td | 40 ++++++++++++------------ mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 5 --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 42 +++++++++++++------------- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 5 ++- mlir/test/Dialect/LLVMIR/invalid.mlir | 4 +-- 5 files changed, 45 insertions(+), 51 deletions(-) diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td index 24052ed..85611af 100644 --- a/mlir/include/mlir/Dialect/AMX/AMX.td +++ b/mlir/include/mlir/Dialect/AMX/AMX.td @@ -239,7 +239,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"] // def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>, - Arguments<(ins LLVM_AnyInteger, LLVM_AnyInteger)>; + Arguments<(ins AnyInteger, AnyInteger)>; // // Tile memory operations. Parameters define the tile size, @@ -248,12 +248,12 @@ def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>, // def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>, - Arguments<(ins LLVM_AnyInteger, - LLVM_AnyInteger, LLVM_AnyPointer, LLVM_AnyInteger)>; + Arguments<(ins AnyInteger, + AnyInteger, LLVM_AnyPointer, AnyInteger)>; def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>, - Arguments<(ins LLVM_AnyInteger, - LLVM_AnyInteger, LLVM_AnyPointer, LLVM_AnyInteger, LLVM_Type)>; + Arguments<(ins AnyInteger, + AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>; // // Tile multiplication operations (series of dot products). Parameters @@ -263,32 +263,32 @@ def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>, // Dot product of bf16 tiles into f32 tile. def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>, - Arguments<(ins LLVM_AnyInteger, - LLVM_AnyInteger, - LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; + Arguments<(ins AnyInteger, + AnyInteger, + AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; // Dot product of i8 tiles into i32 tile (with sign/sign extension). def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>, - Arguments<(ins LLVM_AnyInteger, - LLVM_AnyInteger, - LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; + Arguments<(ins AnyInteger, + AnyInteger, + AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; // Dot product of i8 tiles into i32 tile (with sign/zero extension). def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>, - Arguments<(ins LLVM_AnyInteger, - LLVM_AnyInteger, - LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; + Arguments<(ins AnyInteger, + AnyInteger, + AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; // Dot product of i8 tiles into i32 tile (with zero/sign extension). def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>, - Arguments<(ins LLVM_AnyInteger, - LLVM_AnyInteger, - LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; + Arguments<(ins AnyInteger, + AnyInteger, + AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; // Dot product of i8 tiles into i32 tile (with zero/zero extension). def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>, - Arguments<(ins LLVM_AnyInteger, - LLVM_AnyInteger, - LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; + Arguments<(ins AnyInteger, + AnyInteger, + AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; #endif // AMX diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index 8c83dbc..716260f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -62,11 +62,6 @@ def LLVM_TokenType : Type< "LLVM token type">, BuildableType<"::mlir::LLVM::LLVMTokenType::get($_builder.getContext())">; -// Type constraint accepting LLVM integer types. -def LLVM_AnyInteger : Type< - CPred<"$_self.isa<::mlir::IntegerType>()">, - "LLVM integer type">; - // Type constraint accepting LLVM primitive types, i.e. all types except void // and function. def LLVM_PrimitiveType : Type< diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 64271d6..e1a32e6 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -129,7 +129,7 @@ class LLVM_ArithmeticOpBase traits = []> : - LLVM_ArithmeticOpBase { + LLVM_ArithmeticOpBase { let arguments = commonArgs; } class LLVM_FloatArithmeticOp { let arguments = (ins ICmpPredicate:$predicate, - AnyTypeOf<[LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf]>:$lhs, - AnyTypeOf<[LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf]>:$rhs); + AnyTypeOf<[LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf]>:$lhs, + AnyTypeOf<[LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf]>:$rhs); let results = (outs LLVM_ScalarOrVectorOf:$res); let llvmBuilder = [{ $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); @@ -290,7 +290,7 @@ class MemoryOpWithAlignmentAndAttributes : MemoryOpWithAlignmentBase { // Memory-related operations. def LLVM_AllocaOp : LLVM_Op<"alloca">, MemoryOpWithAlignmentBase { - let arguments = (ins LLVM_AnyInteger:$arraySize, + let arguments = (ins AnyInteger:$arraySize, OptionalAttr:$alignment); let results = (outs LLVM_AnyPointer:$res); string llvmBuilder = [{ @@ -318,7 +318,7 @@ def LLVM_GEPOp "$res = builder.CreateGEP(" " $base->getType()->getPointerElementType(), $base, $indices);"> { let arguments = (ins LLVM_ScalarOrVectorOf:$base, - Variadic>:$indices); + Variadic>:$indices); let results = (outs LLVM_ScalarOrVectorOf:$res); let builders = [LLVM_OneResultOpBuilder]; let assemblyFormat = [{ @@ -389,32 +389,32 @@ def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "CreateAddrSpaceCast", LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr", - LLVM_ScalarOrVectorOf, + LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt", LLVM_ScalarOrVectorOf, - LLVM_ScalarOrVectorOf>; + LLVM_ScalarOrVectorOf>; def LLVM_SExtOp : LLVM_CastOp<"sext", "CreateSExt", - LLVM_ScalarOrVectorOf, - LLVM_ScalarOrVectorOf>; + LLVM_ScalarOrVectorOf, + LLVM_ScalarOrVectorOf>; def LLVM_ZExtOp : LLVM_CastOp<"zext", "CreateZExt", - LLVM_ScalarOrVectorOf, - LLVM_ScalarOrVectorOf>; + LLVM_ScalarOrVectorOf, + LLVM_ScalarOrVectorOf>; def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc", - LLVM_ScalarOrVectorOf, - LLVM_ScalarOrVectorOf>; + LLVM_ScalarOrVectorOf, + LLVM_ScalarOrVectorOf>; def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "CreateSIToFP", - LLVM_ScalarOrVectorOf, + LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "CreateUIToFP", - LLVM_ScalarOrVectorOf, + LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "CreateFPToSI", LLVM_ScalarOrVectorOf, - LLVM_ScalarOrVectorOf>; + LLVM_ScalarOrVectorOf>; def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "CreateFPToUI", LLVM_ScalarOrVectorOf, - LLVM_ScalarOrVectorOf>; + LLVM_ScalarOrVectorOf>; def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt", LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; @@ -514,7 +514,7 @@ def LLVM_CallOp : LLVM_Op<"call", let printer = [{ printCallOp(p, *this); }]; } def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> { - let arguments = (ins LLVM_AnyVector:$vector, LLVM_AnyInteger:$position); + let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position); let results = (outs LLVM_Type:$res); string llvmBuilder = [{ $res = builder.CreateExtractElement($vector, $position); @@ -537,7 +537,7 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> { } def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> { let arguments = (ins LLVM_AnyVector:$vector, LLVM_PrimitiveType:$value, - LLVM_AnyInteger:$position); + AnyInteger:$position); let results = (outs LLVM_AnyVector:$res); string llvmBuilder = [{ $res = builder.CreateInsertElement($vector, $value, $position); @@ -1616,7 +1616,7 @@ def AtomicOrdering : I64EnumAttr< let cppNamespace = "::mlir::LLVM"; } -def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyInteger]>; +def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, AnyInteger]>; // FIXME: Need to add alignment attribute to MLIR atomicrmw operation. def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw"> { @@ -1634,7 +1634,7 @@ def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw"> { let verifier = "return ::verify(*this);"; } -def LLVM_AtomicCmpXchgType : AnyTypeOf<[LLVM_AnyInteger, LLVM_AnyPointer]>; +def LLVM_AtomicCmpXchgType : AnyTypeOf<[AnyInteger, LLVM_AnyPointer]>; def LLVM_AtomicCmpXchgResultType : Type().getBody().size() == 2">, diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 6c1f5c0..087d10d 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -28,9 +28,8 @@ def OpenMP_Dialect : Dialect { class OpenMP_Op traits = []> : Op; -// Type which can be constraint accepting standard integers, indices and -// LLVM integer types. -def IntLikeType : AnyTypeOf<[AnyInteger, Index, LLVM_AnyInteger]>; +// Type which can be constraint accepting standard integers and indices. +def IntLikeType : AnyTypeOf<[AnyInteger, Index]>; //===----------------------------------------------------------------------===// // 2.6 parallel Construct diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index d01a195..a28218b 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -539,7 +539,7 @@ func @nvvm_invalid_mma_7(%a0 : vector<2xf16>, %a1 : vector<2xf16>, // ----- func @atomicrmw_expected_ptr(%f32 : f32) { - // expected-error@+1 {{operand #0 must be LLVM pointer to floating point LLVM type or LLVM integer type}} + // expected-error@+1 {{operand #0 must be LLVM pointer to floating point LLVM type or integer}} %0 = "llvm.atomicrmw"(%f32, %f32) {bin_op=11, ordering=1} : (f32, f32) -> f32 llvm.return } @@ -587,7 +587,7 @@ func @atomicrmw_expected_int(%f32_ptr : !llvm.ptr, %f32 : f32) { // ----- func @cmpxchg_expected_ptr(%f32_ptr : !llvm.ptr, %f32 : f32) { - // expected-error@+1 {{op operand #0 must be LLVM pointer to LLVM integer type or LLVM pointer type}} + // expected-error@+1 {{op operand #0 must be LLVM pointer to integer or LLVM pointer type}} %0 = "llvm.cmpxchg"(%f32, %f32, %f32) {success_ordering=2,failure_ordering=2} : (f32, f32, f32) -> !llvm.struct<(f32, i1)> llvm.return } -- 2.7.4