[mlir][llvm] Fastmath flags import from LLVM IR.
authorTobias Gysi <tobias.gysi@nextsilicon.com>
Fri, 16 Dec 2022 07:05:34 +0000 (08:05 +0100)
committerTobias Gysi <tobias.gysi@nextsilicon.com>
Fri, 16 Dec 2022 07:07:10 +0000 (08:07 +0100)
This revision adds support to import fastmath flags from LLVMIR. It
implement the import using a listener attached to the builder. The
listener gets notified if an operation is created and then checks if
there are fastmath flags to import from LLVM IR to the MLIR. The
listener based approach allows us to perform the import without changing
the mlirBuilders used to create the imported operations.

An alternative solution, could be to update the builders so that they
return the created operation using FailureOr<Operation*> instead of
LogicalResult. However, this solution implies an LLVM IR instruction
always maps to exatly one MLIR operation. While mostly true, there are
already exceptions to this such as the PHI instruciton. Additionally, an
mlirBuilder based solution also further complicates the builder
implementations, which led to the listener based solution.

Depends on D139405

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D139620

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/test/Target/LLVMIR/Import/fastmath.ll [new file with mode: 0644]

index 13f4502..0e478ce 100644 (file)
@@ -13,52 +13,55 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
 // "intr." to avoid potential name clashes.
 
 class LLVM_UnaryIntrOpBase<string func, Type element,
-                           list<Trait> traits = [],
-                           dag addAttrs = (ins)> :
+                           list<Trait> traits = [], bit requiresFastmath = 0> :
     LLVM_OneResultIntrOp<func, [], [0],
-           !listconcat([Pure, SameOperandsAndResultType], traits)> {
-  dag args = (ins LLVM_ScalarOrVectorOf<element>:$in);
-  let arguments = !con(args, addAttrs);
+           !listconcat([Pure, SameOperandsAndResultType], traits),
+           requiresFastmath> {
+  dag commonArgs = (ins LLVM_ScalarOrVectorOf<element>:$in);
   let assemblyFormat = "`(` operands `)` custom<LLVMOpAttrs>(attr-dict) `:` "
       "functional-type(operands, results)";
 }
 
 class LLVM_UnaryIntrOpI<string func, list<Trait> traits = []> :
-    LLVM_UnaryIntrOpBase<func, AnySignlessInteger, traits>;
+    LLVM_UnaryIntrOpBase<func, AnySignlessInteger, traits> {
+  let arguments = commonArgs;
+}
 
 class LLVM_UnaryIntrOpF<string func, list<Trait> traits = []> :
-    LLVM_UnaryIntrOpBase<func, LLVM_AnyFloat,
-           !listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>],
-                       traits),
-           (ins DefaultValuedAttr<LLVM_FastmathFlagsAttr,
-                                  "{}">:$fastmathFlags)>;
+    LLVM_UnaryIntrOpBase<func, LLVM_AnyFloat, traits, /*requiresFastmath=*/1> {
+  dag fmfArg = (
+    ins DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
+  let arguments = !con(commonArgs, fmfArg);
+}
 
 class LLVM_BinarySameArgsIntrOpBase<string func, Type element,
-                                    list<Trait> traits = [],
-                                    dag addAttrs = (ins)> :
+              list<Trait> traits = [], bit requiresFastmath = 0> :
     LLVM_OneResultIntrOp<func, [], [0],
-           !listconcat([Pure, SameOperandsAndResultType], traits)> {
-  dag args = (ins LLVM_ScalarOrVectorOf<element>:$a,
-                  LLVM_ScalarOrVectorOf<element>:$b);
-  let arguments = !con(args, addAttrs);
+           !listconcat([Pure, SameOperandsAndResultType], traits),
+           requiresFastmath> {
+  dag commonArgs = (ins LLVM_ScalarOrVectorOf<element>:$a,
+                        LLVM_ScalarOrVectorOf<element>:$b);
   let assemblyFormat = "`(` operands `)` custom<LLVMOpAttrs>(attr-dict) `:` "
       "functional-type(operands, results)";
 }
 
 class LLVM_BinarySameArgsIntrOpI<string func, list<Trait> traits = []> :
-    LLVM_BinarySameArgsIntrOpBase<func, AnySignlessInteger, traits>;
+    LLVM_BinarySameArgsIntrOpBase<func, AnySignlessInteger, traits> {
+  let arguments = commonArgs;
+}
 
 class LLVM_BinarySameArgsIntrOpF<string func, list<Trait> traits = []> :
-    LLVM_BinarySameArgsIntrOpBase<func, LLVM_AnyFloat,
-           !listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>],
-                       traits),
-           (ins DefaultValuedAttr<LLVM_FastmathFlagsAttr,
-                                  "{}">:$fastmathFlags)>;
+    LLVM_BinarySameArgsIntrOpBase<func, LLVM_AnyFloat, traits,
+                                  /*requiresFastmath=*/1> {
+  dag fmfArg = (
+    ins DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
+  let arguments = !con(commonArgs, fmfArg);
+}
 
 class LLVM_TernarySameArgsIntrOpF<string func, list<Trait> traits = []> :
     LLVM_OneResultIntrOp<func, [], [0],
-           !listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>,
-                        Pure, SameOperandsAndResultType], traits)> {
+           !listconcat([Pure, SameOperandsAndResultType], traits),
+           /*requiresFastmath=*/1> {
   let arguments = (ins LLVM_ScalarOrVectorOf<AnyFloat>:$a,
                        LLVM_ScalarOrVectorOf<AnyFloat>:$b,
                        LLVM_ScalarOrVectorOf<AnyFloat>:$c,
@@ -106,7 +109,8 @@ def LLVM_FTruncOp : LLVM_UnaryIntrOpF<"trunc">;
 def LLVM_SqrtOp : LLVM_UnaryIntrOpF<"sqrt">;
 def LLVM_PowOp : LLVM_BinarySameArgsIntrOpF<"pow">;
 def LLVM_PowIOp : LLVM_OneResultIntrOp<"powi", [], [0,1],
-    [DeclareOpInterfaceMethods<FastmathFlagsInterface>, Pure]> {
+    [DeclareOpInterfaceMethods<FastmathFlagsInterface>, Pure],
+    /*requiresFastmath=*/1> {
   let arguments =
       (ins LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$val,
            AnySignlessInteger:$power,
index 0d89e77..087ccbf 100644 (file)
@@ -345,8 +345,13 @@ def LLVM_IntrPatterns {
 class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
                       list<int> overloadedResults, list<int> overloadedOperands,
                       list<Trait> traits, int numResults,
-                      bit requiresAccessGroup = 0, bit requiresAliasScope = 0>
-    : LLVM_OpBase<dialect, opName, traits>,
+                      bit requiresAccessGroup = 0, bit requiresAliasScope = 0,
+                      bit requiresFastmath = 0>
+    : LLVM_OpBase<dialect, opName, !listconcat(
+        !if(!gt(requiresFastmath, 0),
+            [DeclareOpInterfaceMethods<FastmathFlagsInterface>],
+            []),
+        traits)>,
       Results<!if(!gt(numResults, 0), (outs LLVM_Type:$res), (outs))> {
   string resultPattern = !if(!gt(numResults, 1),
                              LLVM_IntrPatterns.structResult,
@@ -378,9 +383,11 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
       return failure();
     SmallVector<Type> resultTypes =
     }] # !if(!gt(numResults, 0), "{$_resultType};", "{};") # [{
-    Operation *op = $_builder.create<$_qualCppClassName>(
+    auto op = $_builder.create<$_qualCppClassName>(
       $_location, resultTypes, *mlirOperands);
-    }] # !if(!gt(numResults, 0), "$res = op->getResult(0);", "(void)op;");
+    }] # !if(!gt(requiresFastmath, 0),
+      "setFastmathFlagsAttr(inst, op);", "")
+    # !if(!gt(numResults, 0), "$res = op;", "(void)op;");
 }
 
 // Base class for LLVM intrinsic operations, should not be used directly. Places
@@ -388,10 +395,11 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
 class LLVM_IntrOp<string mnem, list<int> overloadedResults,
                   list<int> overloadedOperands, list<Trait> traits,
                   int numResults, bit requiresAccessGroup = 0,
-                  bit requiresAliasScope = 0>
+                  bit requiresAliasScope = 0, bit requiresFastmath = 0>
     : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
                       overloadedResults, overloadedOperands, traits,
-                      numResults, requiresAccessGroup, requiresAliasScope>;
+                      numResults, requiresAccessGroup, requiresAliasScope,
+                      requiresFastmath>;
 
 // Base class for LLVM intrinsic operations returning no results. Places the
 // intrinsic into the LLVM dialect and prefixes its name with "intr.".
@@ -419,8 +427,11 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
 // empty otherwise.
 class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
                            list<int> overloadedOperands = [],
-                           list<Trait> traits = []>
-    : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1>;
+                           list<Trait> traits = [],
+                           bit requiresFastmath = 0>
+    : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
+                  /*requiresAccessGroup=*/0, /*requiresAliasScope=*/0,
+                  requiresFastmath>;
 
 def LLVM_OneResultOpBuilder :
   OpBuilder<(ins "Type":$resultType, "ValueRange":$operands,
index 4abd193..9ee69bd 100644 (file)
@@ -44,14 +44,14 @@ class LLVM_ArithmeticOpBase<Type type, string mnemonic,
   let builders = [LLVM_OneResultOpBuilder];
   let assemblyFormat = "$lhs `,` $rhs custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
   string llvmInstName = instName;
-  string mlirBuilder = [{
-    $res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
-  }];
 }
 class LLVM_IntArithmeticOp<string mnemonic, string instName,
                            list<Trait> traits = []> :
     LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName, traits> {
   let arguments = commonArgs;
+  string mlirBuilder = [{
+    $res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+  }];
 }
 class LLVM_FloatArithmeticOp<string mnemonic, string instName,
                              list<Trait> traits = []> :
@@ -60,6 +60,11 @@ class LLVM_FloatArithmeticOp<string mnemonic, string instName,
   dag fmfArg = (
     ins DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
   let arguments = !con(commonArgs, fmfArg);
+  string mlirBuilder = [{
+    auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+    setFastmathFlagsAttr(inst, op);
+    $res = op;
+  }];
 }
 
 // Class for arithmetic unary operations.
@@ -76,8 +81,10 @@ class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
   let assemblyFormat = "$operand custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
   string llvmInstName = instName;
   string mlirBuilder = [{
-    $res = $_builder.create<$_qualCppClassName>($_location, $operand);
-  }];
+    auto op = $_builder.create<$_qualCppClassName>($_location, $operand);
+    setFastmathFlagsAttr(inst, op);
+    $res = op;
+   }];
 }
 
 // Integer binary operations.
@@ -146,11 +153,12 @@ def LLVM_FCmpOp : LLVM_ArithmeticCmpOp<"fcmp", [
   string llvmBuilder = [{
     $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
   }];
-  // FIXME: Import fastmath flags.
   string mlirBuilder = [{
     auto *fCmpInst = cast<llvm::FCmpInst>(inst);
-    $res = $_builder.create<$_qualCppClassName>(
+    auto op = $_builder.create<$_qualCppClassName>(
       $_location, getFCmpPredicate(fCmpInst->getPredicate()), $lhs, $rhs);
+    setFastmathFlagsAttr(inst, op);
+    $res = op;
   }];
   // Set the $predicate index to -1 to indicate there is no matching operand
   // and decrement the following indices.
index 12d6238..0bd00aa 100644 (file)
@@ -38,6 +38,7 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/Operator.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IRReader/IRReader.h"
 #include "llvm/Support/Error.h"
@@ -326,8 +327,13 @@ getTopologicallySortedBlocks(llvm::Function *func) {
   return blocks;
 }
 
-// Handles importing globals and functions from an LLVM module.
 namespace {
+/// Module import implementation class that provides methods to import globals
+/// and functions from an LLVM module into an MLIR module. It holds mappings
+/// between the original and translated globals, basic blocks, and values used
+/// during the translation. Additionally, it keeps track of the current constant
+/// insertion point since LLVM immediate values translate to MLIR operations
+/// that are introduced at the beginning of the region.
 class Importer {
 public:
   Importer(MLIRContext *context, ModuleOp module)
@@ -421,6 +427,10 @@ private:
     constantInsertionOp = nullptr;
   }
 
+  /// Sets the fastmath flags attribute for the imported operation `op` given
+  /// the original instruction `inst`. Asserts if the operation does not
+  /// implement the fastmath interface.
+  void setFastmathFlagsAttr(llvm::Instruction *inst, Operation *op) const;
   /// Returns personality of `func` as a FlatSymbolRefAttr.
   FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *func);
   /// Imports `bb` into `block`, which must be initially empty.
@@ -487,6 +497,31 @@ private:
 };
 } // namespace
 
+void Importer::setFastmathFlagsAttr(llvm::Instruction *inst,
+                                    Operation *op) const {
+  auto iface = cast<FastmathFlagsInterface>(op);
+
+  // Even if the imported operation implements the fastmath interface, the
+  // original instruction may not have fastmath flags set. Exit if an
+  // instruction, such as a non floating-point function call, does not have
+  // fastmath flags.
+  if (!isa<llvm::FPMathOperator>(inst))
+    return;
+  llvm::FastMathFlags flags = inst->getFastMathFlags();
+
+  // Set the fastmath bits flag-by-flag.
+  FastmathFlags value = {};
+  value = bitEnumSet(value, FastmathFlags::nnan, flags.noNaNs());
+  value = bitEnumSet(value, FastmathFlags::ninf, flags.noInfs());
+  value = bitEnumSet(value, FastmathFlags::nsz, flags.noSignedZeros());
+  value = bitEnumSet(value, FastmathFlags::arcp, flags.allowReciprocal());
+  value = bitEnumSet(value, FastmathFlags::contract, flags.allowContract());
+  value = bitEnumSet(value, FastmathFlags::afn, flags.approxFunc());
+  value = bitEnumSet(value, FastmathFlags::reassoc, flags.allowReassoc());
+  FastmathFlagsAttr attr = FastmathFlagsAttr::get(builder.getContext(), value);
+  iface->setAttr(iface.getFastmathAttrName(), attr);
+}
+
 // We only need integers, floats, doubles, and vectors and tensors thereof for
 // attributes. Scalar and vector types are converted to the standard
 // equivalents. Array types are converted to ranked tensors; nested array types
@@ -1032,6 +1067,7 @@ LogicalResult Importer::convertOperation(OpBuilder &odsBuilder,
     } else {
       callOp = builder.create<CallOp>(loc, types, operands);
     }
+    setFastmathFlagsAttr(inst, callOp);
     if (!callInst->getType()->isVoidTy())
       mapValue(inst, callOp.getResult());
     return success();
@@ -1116,7 +1152,7 @@ LogicalResult Importer::convertOperation(OpBuilder &odsBuilder,
 
 LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
   // FIXME: Support uses of SubtargetData.
-  // FIXME: Add support for fast-math flags and call / operand attributes.
+  // FIXME: Add support for call / operand attributes.
   // FIXME: Add support for the indirectbr, cleanupret, catchret, catchswitch,
   // callbr, vaarg, landingpad, catchpad, cleanuppad instructions.
 
diff --git a/mlir/test/Target/LLVMIR/Import/fastmath.ll b/mlir/test/Target/LLVMIR/Import/fastmath.ll
new file mode 100644 (file)
index 0000000..c546775
--- /dev/null
@@ -0,0 +1,56 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @fastmath_inst
+define void @fastmath_inst(float %arg1, float %arg2) {
+  ; CHECK: llvm.fadd %{{.*}}, %{{.*}}  {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
+  %1 = fadd nnan ninf float %arg1, %arg2
+  ; CHECK: llvm.fsub %{{.*}}, %{{.*}}  {fastmathFlags = #llvm.fastmath<nsz>} : f32
+  %2 = fsub nsz float %arg1, %arg2
+  ; CHECK: llvm.fmul %{{.*}}, %{{.*}}  {fastmathFlags = #llvm.fastmath<arcp, contract>} : f32
+  %3 = fmul arcp contract float %arg1, %arg2
+  ; CHECK: llvm.fdiv %{{.*}}, %{{.*}}  {fastmathFlags = #llvm.fastmath<afn, reassoc>} : f32
+  %4 = fdiv afn reassoc float %arg1, %arg2
+  ; CHECK: llvm.fneg %{{.*}}  {fastmathFlags = #llvm.fastmath<fast>} : f32
+  %5 = fneg fast float %arg1
+  ret void
+}
+
+; // -----
+
+; CHECK-LABEL: @fastmath_fcmp
+define void @fastmath_fcmp(float %arg1, float %arg2) {
+  ; CHECK:  llvm.fcmp "oge" %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath<nsz>} : f32
+  %1 = fcmp nsz oge float %arg1, %arg2
+  ret void
+}
+
+; // -----
+
+declare float @fn(float)
+
+; CHECK-LABEL: @fastmath_call
+define void @fastmath_call(float %arg1) {
+  ; CHECK:  llvm.call @fn(%{{.*}}) {fastmathFlags = #llvm.fastmath<ninf>} : (f32) -> f32
+  %1 = call ninf float @fn(float %arg1)
+  ret void
+}
+
+; // -----
+
+declare float @llvm.exp.f32(float)
+declare float @llvm.powi.f32.i32(float, i32)
+declare float @llvm.pow.f32(float, float)
+declare float @llvm.fmuladd.f32(float, float, float)
+
+; CHECK-LABEL: @fastmath_intr
+define void @fastmath_intr(float %arg1, i32 %arg2) {
+  ; CHECK: llvm.intr.exp(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, ninf>} : (f32) -> f32
+  %1 = call nnan ninf float @llvm.exp.f32(float %arg1)
+  ; CHECK: llvm.intr.powi(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32, i32) -> f32
+  %2 = call fast float @llvm.powi.f32.i32(float %arg1, i32 %arg2)
+  ; CHECK:  llvm.intr.pow(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32, f32) -> f32
+  %3 = call fast float @llvm.pow.f32(float %arg1, float %arg1)
+  ; CHECK: llvm.intr.fmuladd(%{{.*}}, %{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32, f32, f32) -> f32
+  %4 = call fast float @llvm.fmuladd.f32(float %arg1, float %arg1, float %arg1)
+  ret void
+}