[mlir][LLVMIR] Update LLVMIR fastmath to use EnumAttr tblgen classes
authorJeremy Furtek <jfurtek@nvidia.com>
Mon, 17 Oct 2022 21:02:09 +0000 (14:02 -0700)
committerSlava Zakharin <szakharin@nvidia.com>
Mon, 17 Oct 2022 22:03:47 +0000 (15:03 -0700)
This diff updates the `fastmath` attribute in the LLVMIR dialect to use `tblgen`
classes that were developed after the initial LLVMIR `fastmath` implementation.
Using the `EnumAttr` `tblgen` classes brings the LLVMIR `fastmath` attribute in
line with other dialects, and eliminates some of the custom printing and parsing
code in the LLVMIR dialect.

Subsequent commits will further reduce the custom processing code for the LLVMIR
`fastmath` attribute by unifying printing/parsing functionality between the
LLVMIR and `arith` `fastmath` attributes. (The actual attributes will remain
separate, but the printing and parsing will be made generic, and will be usable
by other dialects/attributes.)

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D135289

mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/LLVMIR/roundtrip.mlir
mlir/test/Target/LLVMIR/llvmir.mlir

index abdb0b7..56b8e2d 100644 (file)
@@ -1,10 +1,5 @@
 add_subdirectory(Transforms)
 
-set(LLVM_TARGET_DEFINITIONS LLVMAttrDefs.td)
-mlir_tablegen(LLVMOpsAttrDefs.h.inc -gen-attrdef-decls)
-mlir_tablegen(LLVMOpsAttrDefs.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRLLVMAttrsIncGen)
-
 set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
 mlir_tablegen(LLVMOps.h.inc -gen-op-decls)
 mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs)
@@ -12,6 +7,10 @@ mlir_tablegen(LLVMOpsDialect.h.inc -gen-dialect-decls)
 mlir_tablegen(LLVMOpsDialect.cpp.inc -gen-dialect-defs)
 mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls)
 mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(LLVMOpsAttrDefs.h.inc -gen-attrdef-decls
+              -attrdefs-dialect=llvm)
+mlir_tablegen(LLVMOpsAttrDefs.cpp.inc -gen-attrdef-defs
+              -attrdefs-dialect=llvm)
 add_public_tablegen_target(MLIRLLVMOpsIncGen)
 
 set(LLVM_TARGET_DEFINITIONS LLVMIntrinsicOps.td)
index 2de2f7a..95d4a90 100644 (file)
@@ -15,17 +15,6 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 // All of the attributes will extend this class.
 class LLVM_Attr<string name> : AttrDef<LLVM_Dialect, name>;
 
-// The "FastMath" flags associated with floating point LLVM instructions.
-def FastmathFlagsAttr : LLVM_Attr<"FMF"> {
-  let mnemonic = "fastmath";
-
-  // List of type parameters.
-  let parameters = (ins
-    "FastmathFlags":$flags
-  );
-  let hasCustomAssemblyFormat = 1;
-}
-
 // Attribute definition for the LLVM Linkage enum.
 def LinkageAttr : LLVM_Attr<"Linkage"> {
   let mnemonic = "linkage";
index b25d151..b4429fc 100644 (file)
@@ -13,7 +13,9 @@
 #ifndef LLVMIR_OPS
 #define LLVMIR_OPS
 
+include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/IR/EnumAttr.td"
 include "mlir/IR/FunctionInterfaces.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/CallInterfaces.td"
@@ -21,6 +23,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
+def FMFnone     : I32BitEnumAttrCaseNone<"none">;
 def FMFnnan     : I32BitEnumAttrCaseBit<"nnan", 0>;
 def FMFninf     : I32BitEnumAttrCaseBit<"ninf", 1>;
 def FMFnsz      : I32BitEnumAttrCaseBit<"nsz", 2>;
@@ -34,22 +37,18 @@ def FMFfast     : I32BitEnumAttrCaseGroup<"fast",
 def FastmathFlags : I32BitEnumAttr<
     "FastmathFlags",
     "LLVM fastmath flags",
-    [FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast
-    ]> {
+    [FMFnone, FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn,
+     FMFreassoc, FMFfast]> {
   let separator = ", ";
   let cppNamespace = "::mlir::LLVM";
+  let genSpecializedAttr = 0;
   let printBitEnumPrimaryGroups = 1;
 }
 
-def LLVM_FMFAttr : DialectAttr<
-    LLVM_Dialect,
-    CPred<"$_self.isa<::mlir::LLVM::FMFAttr>()">,
-    "LLVM fastmath flags"> {
-  let storageType = "::mlir::LLVM::FMFAttr";
-  let returnType = "::mlir::LLVM::FastmathFlags";
-  let convertFromStorage = "$_self.getFlags()";
-  let constBuilderCall =
-          "::mlir::LLVM::FMFAttr::get($_builder.getContext(), $0)";
+// The "FastMath" flags associated with floating point LLVM instructions.
+def LLVM_FastmathFlagsAttr :
+    EnumAttr<LLVM_Dialect, FastmathFlags, "fastmath"> {
+  let assemblyFormat = "`<` $value `>`";
 }
 
 def LOptDisableUnroll : I32EnumAttrCase<"disable_unroll", 1>;
@@ -229,7 +228,8 @@ class LLVM_FloatArithmeticOp<string mnemonic, string instName,
                              list<Trait> traits = []> :
     LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
     !listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>], traits)> {
-  dag fmfArg = (ins DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
+  dag fmfArg = (
+    ins DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
   let arguments = !con(commonArgs, fmfArg);
 }
 
@@ -239,7 +239,9 @@ class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
     LLVM_Op<mnemonic,
            !listconcat([Pure, SameOperandsAndResultType, DeclareOpInterfaceMethods<FastmathFlagsInterface>], traits)>,
     LLVM_Builder<"$res = builder.Create" # instName # "($operand);"> {
-  let arguments = (ins type:$operand, DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
+  let arguments = (
+    ins type:$operand,
+    DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
   let results = (outs type:$res);
   let builders = [LLVM_OneResultOpBuilder];
   let assemblyFormat = "$operand custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
@@ -354,7 +356,8 @@ def LLVM_FCmpOp : LLVM_ArithmeticCmpOp<"fcmp", [
   let arguments = (ins FCmpPredicate:$predicate,
                    LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$lhs,
                    LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$rhs,
-                   DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
+                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
+                                     "{}">:$fastmathFlags);
   let builders = [
     OpBuilder<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)>
   ];
@@ -747,7 +750,8 @@ def LLVM_CallOp : LLVM_Op<"call",
 
   let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
                        Variadic<LLVM_Type>,
-                       DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
+                       DefaultValuedAttr<LLVM_FastmathFlagsAttr,
+                                         "{}">:$fastmathFlags);
   let results = (outs Optional<LLVM_Type>:$result);
 
   let builders = [
index a07c434..f37d47d 100644 (file)
@@ -72,7 +72,7 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
     Value real = complexStruct.real(rewriter, op.getLoc());
     Value imag = complexStruct.imaginary(rewriter, op.getLoc());
 
-    auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
+    auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
     Value sqNorm = rewriter.create<LLVM::FAddOp>(
         loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
         rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
@@ -180,7 +180,7 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
 
     // Emit IR to add complex numbers.
-    auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
+    auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
     Value real =
         rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
     Value imag =
@@ -208,7 +208,7 @@ struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
 
     // Emit IR to add complex numbers.
-    auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
+    auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
     Value rhsRe = arg.rhs.real();
     Value rhsIm = arg.rhs.imag();
     Value lhsRe = arg.lhs.real();
@@ -253,7 +253,7 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
 
     // Emit IR to add complex numbers.
-    auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
+    auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
     Value rhsRe = arg.rhs.real();
     Value rhsIm = arg.rhs.imag();
     Value lhsRe = arg.lhs.real();
@@ -290,7 +290,7 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
 
     // Emit IR to substract complex numbers.
-    auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
+    auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
     Value real =
         rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
     Value imag =
index 3e9c235..0ed3294 100644 (file)
@@ -54,7 +54,8 @@ static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
   SmallVector<NamedAttribute, 8> filteredAttrs(
       llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
         if (attr.getName() == "fastmathFlags") {
-          auto defAttr = FMFAttr::get(attr.getValue().getContext(), {});
+          auto defAttr =
+              FastmathFlagsAttr::get(attr.getValue().getContext(), {});
           return defAttr != attr.getValue();
         }
         return true;
@@ -2563,7 +2564,7 @@ OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 void LLVMDialect::initialize() {
-  addAttributes<FMFAttr, LinkageAttr, CConvAttr, LoopOptionsAttr>();
+  addAttributes<FastmathFlagsAttr, LinkageAttr, CConvAttr, LoopOptionsAttr>();
 
   // clang-format off
   addTypes<LLVMVoidType,
@@ -2809,39 +2810,6 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
          op->hasTrait<OpTrait::IsIsolatedFromAbove>();
 }
 
-void FMFAttr::print(AsmPrinter &printer) const {
-  printer << "<";
-  printer << stringifyFastmathFlags(this->getFlags());
-  printer << ">";
-}
-
-Attribute FMFAttr::parse(AsmParser &parser, Type type) {
-  if (failed(parser.parseLess()))
-    return {};
-
-  FastmathFlags flags = {};
-  if (failed(parser.parseOptionalGreater())) {
-    auto parseFlags = [&]() -> ParseResult {
-      StringRef elemName;
-      if (failed(parser.parseKeyword(&elemName)))
-        return failure();
-
-      auto elem = symbolizeFastmathFlags(elemName);
-      if (!elem)
-        return parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ")
-               << elemName;
-
-      flags = flags | *elem;
-      return success();
-    };
-    if (failed(parser.parseCommaSeparatedList(parseFlags)) ||
-        parser.parseGreater())
-      return {};
-  }
-
-  return FMFAttr::get(parser.getContext(), flags);
-}
-
 void LinkageAttr::print(AsmPrinter &printer) const {
   printer << "<";
   if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage())
index ad26fd7..5f565bb 100644 (file)
@@ -477,12 +477,12 @@ func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: vector<2 x f
   %7 = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (i32) -> !llvm.struct<(i32, f64, i32)>
 
 // CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32
-  %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<>} : f32
+  %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<none>} : f32
 // CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
   %9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan,ninf>} : f32
 
 // CHECK: {{.*}} = llvm.fneg %arg0 : f32
-  %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : f32
+  %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<none>} : f32
   return
 }
 
index f7b0211..5153874 100644 (file)
@@ -1666,7 +1666,7 @@ llvm.func @fastmathFlags(%arg0: f32) {
 // CHECK: {{.*}} = call afn float @fastmathFlagsFunc({{.*}})
 // CHECK: {{.*}} = call reassoc float @fastmathFlagsFunc({{.*}})
 // CHECK: {{.*}} = call fast float @fastmathFlagsFunc({{.*}})
-  %8 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<>} : (f32) -> (f32)
+  %8 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<none>} : (f32) -> (f32)
   %9 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<nnan>} : (f32) -> (f32)
   %10 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<ninf>} : (f32) -> (f32)
   %11 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<nsz>} : (f32) -> (f32)