This diff updates the `fastmath` attribute in the LLVMIR dialect to use `tblgen`
classes that were developed after the initial LLVMIR `fastmath` implementation.
Using the `EnumAttr` `tblgen` classes brings the LLVMIR `fastmath` attribute in
line with other dialects, and eliminates some of the custom printing and parsing
code in the LLVMIR dialect.
Subsequent commits will further reduce the custom processing code for the LLVMIR
`fastmath` attribute by unifying printing/parsing functionality between the
LLVMIR and `arith` `fastmath` attributes. (The actual attributes will remain
separate, but the printing and parsing will be made generic, and will be usable
by other dialects/attributes.)
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D135289
add_subdirectory(Transforms)
-set(LLVM_TARGET_DEFINITIONS LLVMAttrDefs.td)
-mlir_tablegen(LLVMOpsAttrDefs.h.inc -gen-attrdef-decls)
-mlir_tablegen(LLVMOpsAttrDefs.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRLLVMAttrsIncGen)
-
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMOps.h.inc -gen-op-decls)
mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs)
mlir_tablegen(LLVMOpsDialect.cpp.inc -gen-dialect-defs)
mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(LLVMOpsAttrDefs.h.inc -gen-attrdef-decls
+ -attrdefs-dialect=llvm)
+mlir_tablegen(LLVMOpsAttrDefs.cpp.inc -gen-attrdef-defs
+ -attrdefs-dialect=llvm)
add_public_tablegen_target(MLIRLLVMOpsIncGen)
set(LLVM_TARGET_DEFINITIONS LLVMIntrinsicOps.td)
// All of the attributes will extend this class.
class LLVM_Attr<string name> : AttrDef<LLVM_Dialect, name>;
-// The "FastMath" flags associated with floating point LLVM instructions.
-def FastmathFlagsAttr : LLVM_Attr<"FMF"> {
- let mnemonic = "fastmath";
-
- // List of type parameters.
- let parameters = (ins
- "FastmathFlags":$flags
- );
- let hasCustomAssemblyFormat = 1;
-}
-
// Attribute definition for the LLVM Linkage enum.
def LinkageAttr : LLVM_Attr<"Linkage"> {
let mnemonic = "linkage";
#ifndef LLVMIR_OPS
#define LLVMIR_OPS
+include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/IR/EnumAttr.td"
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+def FMFnone : I32BitEnumAttrCaseNone<"none">;
def FMFnnan : I32BitEnumAttrCaseBit<"nnan", 0>;
def FMFninf : I32BitEnumAttrCaseBit<"ninf", 1>;
def FMFnsz : I32BitEnumAttrCaseBit<"nsz", 2>;
def FastmathFlags : I32BitEnumAttr<
"FastmathFlags",
"LLVM fastmath flags",
- [FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast
- ]> {
+ [FMFnone, FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn,
+ FMFreassoc, FMFfast]> {
let separator = ", ";
let cppNamespace = "::mlir::LLVM";
+ let genSpecializedAttr = 0;
let printBitEnumPrimaryGroups = 1;
}
-def LLVM_FMFAttr : DialectAttr<
- LLVM_Dialect,
- CPred<"$_self.isa<::mlir::LLVM::FMFAttr>()">,
- "LLVM fastmath flags"> {
- let storageType = "::mlir::LLVM::FMFAttr";
- let returnType = "::mlir::LLVM::FastmathFlags";
- let convertFromStorage = "$_self.getFlags()";
- let constBuilderCall =
- "::mlir::LLVM::FMFAttr::get($_builder.getContext(), $0)";
+// The "FastMath" flags associated with floating point LLVM instructions.
+def LLVM_FastmathFlagsAttr :
+ EnumAttr<LLVM_Dialect, FastmathFlags, "fastmath"> {
+ let assemblyFormat = "`<` $value `>`";
}
def LOptDisableUnroll : I32EnumAttrCase<"disable_unroll", 1>;
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
!listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>], traits)> {
- dag fmfArg = (ins DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
+ dag fmfArg = (
+ ins DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
let arguments = !con(commonArgs, fmfArg);
}
LLVM_Op<mnemonic,
!listconcat([Pure, SameOperandsAndResultType, DeclareOpInterfaceMethods<FastmathFlagsInterface>], traits)>,
LLVM_Builder<"$res = builder.Create" # instName # "($operand);"> {
- let arguments = (ins type:$operand, DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
+ let arguments = (
+ ins type:$operand,
+ DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
let results = (outs type:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "$operand custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
let arguments = (ins FCmpPredicate:$predicate,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$lhs,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$rhs,
- DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
+ DefaultValuedAttr<LLVM_FastmathFlagsAttr,
+ "{}">:$fastmathFlags);
let builders = [
OpBuilder<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)>
];
let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>,
- DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
+ DefaultValuedAttr<LLVM_FastmathFlagsAttr,
+ "{}">:$fastmathFlags);
let results = (outs Optional<LLVM_Type>:$result);
let builders = [
Value real = complexStruct.real(rewriter, op.getLoc());
Value imag = complexStruct.imaginary(rewriter, op.getLoc());
- auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
+ auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
Value sqNorm = rewriter.create<LLVM::FAddOp>(
loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to add complex numbers.
- auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
+ auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
Value real =
rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
Value imag =
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to add complex numbers.
- auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
+ auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
Value rhsRe = arg.rhs.real();
Value rhsIm = arg.rhs.imag();
Value lhsRe = arg.lhs.real();
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to add complex numbers.
- auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
+ auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
Value rhsRe = arg.rhs.real();
Value rhsIm = arg.rhs.imag();
Value lhsRe = arg.lhs.real();
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to substract complex numbers.
- auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
+ auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
Value real =
rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
Value imag =
SmallVector<NamedAttribute, 8> filteredAttrs(
llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
if (attr.getName() == "fastmathFlags") {
- auto defAttr = FMFAttr::get(attr.getValue().getContext(), {});
+ auto defAttr =
+ FastmathFlagsAttr::get(attr.getValue().getContext(), {});
return defAttr != attr.getValue();
}
return true;
//===----------------------------------------------------------------------===//
void LLVMDialect::initialize() {
- addAttributes<FMFAttr, LinkageAttr, CConvAttr, LoopOptionsAttr>();
+ addAttributes<FastmathFlagsAttr, LinkageAttr, CConvAttr, LoopOptionsAttr>();
// clang-format off
addTypes<LLVMVoidType,
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}
-void FMFAttr::print(AsmPrinter &printer) const {
- printer << "<";
- printer << stringifyFastmathFlags(this->getFlags());
- printer << ">";
-}
-
-Attribute FMFAttr::parse(AsmParser &parser, Type type) {
- if (failed(parser.parseLess()))
- return {};
-
- FastmathFlags flags = {};
- if (failed(parser.parseOptionalGreater())) {
- auto parseFlags = [&]() -> ParseResult {
- StringRef elemName;
- if (failed(parser.parseKeyword(&elemName)))
- return failure();
-
- auto elem = symbolizeFastmathFlags(elemName);
- if (!elem)
- return parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ")
- << elemName;
-
- flags = flags | *elem;
- return success();
- };
- if (failed(parser.parseCommaSeparatedList(parseFlags)) ||
- parser.parseGreater())
- return {};
- }
-
- return FMFAttr::get(parser.getContext(), flags);
-}
-
void LinkageAttr::print(AsmPrinter &printer) const {
printer << "<";
if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage())
%7 = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (i32) -> !llvm.struct<(i32, f64, i32)>
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32
- %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<>} : f32
+ %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<none>} : f32
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
%9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan,ninf>} : f32
// CHECK: {{.*}} = llvm.fneg %arg0 : f32
- %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : f32
+ %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<none>} : f32
return
}
// CHECK: {{.*}} = call afn float @fastmathFlagsFunc({{.*}})
// CHECK: {{.*}} = call reassoc float @fastmathFlagsFunc({{.*}})
// CHECK: {{.*}} = call fast float @fastmathFlagsFunc({{.*}})
- %8 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<>} : (f32) -> (f32)
+ %8 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<none>} : (f32) -> (f32)
%9 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<nnan>} : (f32) -> (f32)
%10 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<ninf>} : (f32) -> (f32)
%11 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<nsz>} : (f32) -> (f32)