From: Jeremy Furtek Date: Mon, 17 Oct 2022 21:02:09 +0000 (-0700) Subject: [mlir][LLVMIR] Update LLVMIR fastmath to use EnumAttr tblgen classes X-Git-Tag: upstream/17.0.6~30323 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=dd38f899803465dd2765d1601b3989df3bd53863;p=platform%2Fupstream%2Fllvm.git [mlir][LLVMIR] Update LLVMIR fastmath to use EnumAttr tblgen classes This diff updates the `fastmath` attribute in the LLVMIR dialect to use `tblgen` classes that were developed after the initial LLVMIR `fastmath` implementation. Using the `EnumAttr` `tblgen` classes brings the LLVMIR `fastmath` attribute in line with other dialects, and eliminates some of the custom printing and parsing code in the LLVMIR dialect. Subsequent commits will further reduce the custom processing code for the LLVMIR `fastmath` attribute by unifying printing/parsing functionality between the LLVMIR and `arith` `fastmath` attributes. (The actual attributes will remain separate, but the printing and parsing will be made generic, and will be usable by other dialects/attributes.) Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D135289 --- diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt index abdb0b7..56b8e2d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -1,10 +1,5 @@ add_subdirectory(Transforms) -set(LLVM_TARGET_DEFINITIONS LLVMAttrDefs.td) -mlir_tablegen(LLVMOpsAttrDefs.h.inc -gen-attrdef-decls) -mlir_tablegen(LLVMOpsAttrDefs.cpp.inc -gen-attrdef-defs) -add_public_tablegen_target(MLIRLLVMAttrsIncGen) - set(LLVM_TARGET_DEFINITIONS LLVMOps.td) mlir_tablegen(LLVMOps.h.inc -gen-op-decls) mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs) @@ -12,6 +7,10 @@ mlir_tablegen(LLVMOpsDialect.h.inc -gen-dialect-decls) mlir_tablegen(LLVMOpsDialect.cpp.inc -gen-dialect-defs) mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls) mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(LLVMOpsAttrDefs.h.inc -gen-attrdef-decls + -attrdefs-dialect=llvm) +mlir_tablegen(LLVMOpsAttrDefs.cpp.inc -gen-attrdef-defs + -attrdefs-dialect=llvm) add_public_tablegen_target(MLIRLLVMOpsIncGen) set(LLVM_TARGET_DEFINITIONS LLVMIntrinsicOps.td) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td index 2de2f7a..95d4a90 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -15,17 +15,6 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td" // All of the attributes will extend this class. class LLVM_Attr : AttrDef; -// The "FastMath" flags associated with floating point LLVM instructions. -def FastmathFlagsAttr : LLVM_Attr<"FMF"> { - let mnemonic = "fastmath"; - - // List of type parameters. - let parameters = (ins - "FastmathFlags":$flags - ); - let hasCustomAssemblyFormat = 1; -} - // Attribute definition for the LLVM Linkage enum. def LinkageAttr : LLVM_Attr<"Linkage"> { let mnemonic = "linkage"; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index b25d151..b4429fc6 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -13,7 +13,9 @@ #ifndef LLVMIR_OPS #define LLVMIR_OPS +include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/IR/EnumAttr.td" include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" @@ -21,6 +23,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +def FMFnone : I32BitEnumAttrCaseNone<"none">; def FMFnnan : I32BitEnumAttrCaseBit<"nnan", 0>; def FMFninf : I32BitEnumAttrCaseBit<"ninf", 1>; def FMFnsz : I32BitEnumAttrCaseBit<"nsz", 2>; @@ -34,22 +37,18 @@ def FMFfast : I32BitEnumAttrCaseGroup<"fast", def FastmathFlags : I32BitEnumAttr< "FastmathFlags", "LLVM fastmath flags", - [FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast - ]> { + [FMFnone, FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, + FMFreassoc, FMFfast]> { let separator = ", "; let cppNamespace = "::mlir::LLVM"; + let genSpecializedAttr = 0; let printBitEnumPrimaryGroups = 1; } -def LLVM_FMFAttr : DialectAttr< - LLVM_Dialect, - CPred<"$_self.isa<::mlir::LLVM::FMFAttr>()">, - "LLVM fastmath flags"> { - let storageType = "::mlir::LLVM::FMFAttr"; - let returnType = "::mlir::LLVM::FastmathFlags"; - let convertFromStorage = "$_self.getFlags()"; - let constBuilderCall = - "::mlir::LLVM::FMFAttr::get($_builder.getContext(), $0)"; +// The "FastMath" flags associated with floating point LLVM instructions. +def LLVM_FastmathFlagsAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; } def LOptDisableUnroll : I32EnumAttrCase<"disable_unroll", 1>; @@ -229,7 +228,8 @@ class LLVM_FloatArithmeticOp traits = []> : LLVM_ArithmeticOpBase], traits)> { - dag fmfArg = (ins DefaultValuedAttr:$fastmathFlags); + dag fmfArg = ( + ins DefaultValuedAttr:$fastmathFlags); let arguments = !con(commonArgs, fmfArg); } @@ -239,7 +239,9 @@ class LLVM_UnaryFloatArithmeticOp], traits)>, LLVM_Builder<"$res = builder.Create" # instName # "($operand);"> { - let arguments = (ins type:$operand, DefaultValuedAttr:$fastmathFlags); + let arguments = ( + ins type:$operand, + DefaultValuedAttr:$fastmathFlags); let results = (outs type:$res); let builders = [LLVM_OneResultOpBuilder]; let assemblyFormat = "$operand custom(attr-dict) `:` type($res)"; @@ -354,7 +356,8 @@ def LLVM_FCmpOp : LLVM_ArithmeticCmpOp<"fcmp", [ let arguments = (ins FCmpPredicate:$predicate, LLVM_ScalarOrVectorOf:$lhs, LLVM_ScalarOrVectorOf:$rhs, - DefaultValuedAttr:$fastmathFlags); + DefaultValuedAttr:$fastmathFlags); let builders = [ OpBuilder<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)> ]; @@ -747,7 +750,8 @@ def LLVM_CallOp : LLVM_Op<"call", let arguments = (ins OptionalAttr:$callee, Variadic, - DefaultValuedAttr:$fastmathFlags); + DefaultValuedAttr:$fastmathFlags); let results = (outs Optional:$result); let builders = [ diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp index a07c434..f37d47d 100644 --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -72,7 +72,7 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern { Value real = complexStruct.real(rewriter, op.getLoc()); Value imag = complexStruct.imaginary(rewriter, op.getLoc()); - auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); + auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); Value sqNorm = rewriter.create( loc, rewriter.create(loc, real, real, fmf), rewriter.create(loc, imag, imag, fmf), fmf); @@ -180,7 +180,7 @@ struct AddOpConversion : public ConvertOpToLLVMPattern { auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. - auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); + auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = @@ -208,7 +208,7 @@ struct DivOpConversion : public ConvertOpToLLVMPattern { auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. - auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); + auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); Value rhsRe = arg.rhs.real(); Value rhsIm = arg.rhs.imag(); Value lhsRe = arg.lhs.real(); @@ -253,7 +253,7 @@ struct MulOpConversion : public ConvertOpToLLVMPattern { auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. - auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); + auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); Value rhsRe = arg.rhs.real(); Value rhsIm = arg.rhs.imag(); Value lhsRe = arg.lhs.real(); @@ -290,7 +290,7 @@ struct SubOpConversion : public ConvertOpToLLVMPattern { auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to substract complex numbers. - auto fmf = LLVM::FMFAttr::get(op.getContext(), {}); + auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 3e9c235..0ed3294 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -54,7 +54,8 @@ static auto processFMFAttr(ArrayRef attrs) { SmallVector filteredAttrs( llvm::make_filter_range(attrs, [&](NamedAttribute attr) { if (attr.getName() == "fastmathFlags") { - auto defAttr = FMFAttr::get(attr.getValue().getContext(), {}); + auto defAttr = + FastmathFlagsAttr::get(attr.getValue().getContext(), {}); return defAttr != attr.getValue(); } return true; @@ -2563,7 +2564,7 @@ OpFoldResult LLVM::GEPOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// void LLVMDialect::initialize() { - addAttributes(); + addAttributes(); // clang-format off addTypeshasTrait(); } -void FMFAttr::print(AsmPrinter &printer) const { - printer << "<"; - printer << stringifyFastmathFlags(this->getFlags()); - printer << ">"; -} - -Attribute FMFAttr::parse(AsmParser &parser, Type type) { - if (failed(parser.parseLess())) - return {}; - - FastmathFlags flags = {}; - if (failed(parser.parseOptionalGreater())) { - auto parseFlags = [&]() -> ParseResult { - StringRef elemName; - if (failed(parser.parseKeyword(&elemName))) - return failure(); - - auto elem = symbolizeFastmathFlags(elemName); - if (!elem) - return parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") - << elemName; - - flags = flags | *elem; - return success(); - }; - if (failed(parser.parseCommaSeparatedList(parseFlags)) || - parser.parseGreater()) - return {}; - } - - return FMFAttr::get(parser.getContext(), flags); -} - void LinkageAttr::print(AsmPrinter &printer) const { printer << "<"; if (static_cast(getLinkage()) <= getMaxEnumValForLinkage()) diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index ad26fd7..5f565bb 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -477,12 +477,12 @@ func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: vector<2 x f %7 = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath} : (i32) -> !llvm.struct<(i32, f64, i32)> // CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32 - %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<>} : f32 + %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 // CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 %9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 // CHECK: {{.*}} = llvm.fneg %arg0 : f32 - %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : f32 + %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath} : f32 return } diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index f7b0211..5153874 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1666,7 +1666,7 @@ llvm.func @fastmathFlags(%arg0: f32) { // CHECK: {{.*}} = call afn float @fastmathFlagsFunc({{.*}}) // CHECK: {{.*}} = call reassoc float @fastmathFlagsFunc({{.*}}) // CHECK: {{.*}} = call fast float @fastmathFlagsFunc({{.*}}) - %8 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<>} : (f32) -> (f32) + %8 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (f32) -> (f32) %9 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (f32) -> (f32) %10 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (f32) -> (f32) %11 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (f32) -> (f32)