From b6ebeccf00c00484aedf5119bb2fc1a58f58e633 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Fri, 16 Dec 2022 08:05:34 +0100 Subject: [PATCH] [mlir][llvm] Fastmath flags import from LLVM IR. This revision adds support to import fastmath flags from LLVMIR. It implement the import using a listener attached to the builder. The listener gets notified if an operation is created and then checks if there are fastmath flags to import from LLVM IR to the MLIR. The listener based approach allows us to perform the import without changing the mlirBuilders used to create the imported operations. An alternative solution, could be to update the builders so that they return the created operation using FailureOr instead of LogicalResult. However, this solution implies an LLVM IR instruction always maps to exatly one MLIR operation. While mostly true, there are already exceptions to this such as the PHI instruciton. Additionally, an mlirBuilder based solution also further complicates the builder implementations, which led to the listener based solution. Depends on D139405 Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D139620 --- .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 56 ++++++++++++---------- mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 27 +++++++---- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 22 ++++++--- mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp | 40 +++++++++++++++- mlir/test/Target/LLVMIR/Import/fastmath.ll | 56 ++++++++++++++++++++++ 5 files changed, 158 insertions(+), 43 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/Import/fastmath.ll diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index 13f4502..0e478ce 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -13,52 +13,55 @@ include "mlir/Interfaces/InferTypeOpInterface.td" // "intr." to avoid potential name clashes. class LLVM_UnaryIntrOpBase traits = [], - dag addAttrs = (ins)> : + list traits = [], bit requiresFastmath = 0> : LLVM_OneResultIntrOp { - dag args = (ins LLVM_ScalarOrVectorOf:$in); - let arguments = !con(args, addAttrs); + !listconcat([Pure, SameOperandsAndResultType], traits), + requiresFastmath> { + dag commonArgs = (ins LLVM_ScalarOrVectorOf:$in); let assemblyFormat = "`(` operands `)` custom(attr-dict) `:` " "functional-type(operands, results)"; } class LLVM_UnaryIntrOpI traits = []> : - LLVM_UnaryIntrOpBase; + LLVM_UnaryIntrOpBase { + let arguments = commonArgs; +} class LLVM_UnaryIntrOpF traits = []> : - LLVM_UnaryIntrOpBase], - traits), - (ins DefaultValuedAttr:$fastmathFlags)>; + LLVM_UnaryIntrOpBase { + dag fmfArg = ( + ins DefaultValuedAttr:$fastmathFlags); + let arguments = !con(commonArgs, fmfArg); +} class LLVM_BinarySameArgsIntrOpBase traits = [], - dag addAttrs = (ins)> : + list traits = [], bit requiresFastmath = 0> : LLVM_OneResultIntrOp { - dag args = (ins LLVM_ScalarOrVectorOf:$a, - LLVM_ScalarOrVectorOf:$b); - let arguments = !con(args, addAttrs); + !listconcat([Pure, SameOperandsAndResultType], traits), + requiresFastmath> { + dag commonArgs = (ins LLVM_ScalarOrVectorOf:$a, + LLVM_ScalarOrVectorOf:$b); let assemblyFormat = "`(` operands `)` custom(attr-dict) `:` " "functional-type(operands, results)"; } class LLVM_BinarySameArgsIntrOpI traits = []> : - LLVM_BinarySameArgsIntrOpBase; + LLVM_BinarySameArgsIntrOpBase { + let arguments = commonArgs; +} class LLVM_BinarySameArgsIntrOpF traits = []> : - LLVM_BinarySameArgsIntrOpBase], - traits), - (ins DefaultValuedAttr:$fastmathFlags)>; + LLVM_BinarySameArgsIntrOpBase { + dag fmfArg = ( + ins DefaultValuedAttr:$fastmathFlags); + let arguments = !con(commonArgs, fmfArg); +} class LLVM_TernarySameArgsIntrOpF traits = []> : LLVM_OneResultIntrOp, - Pure, SameOperandsAndResultType], traits)> { + !listconcat([Pure, SameOperandsAndResultType], traits), + /*requiresFastmath=*/1> { let arguments = (ins LLVM_ScalarOrVectorOf:$a, LLVM_ScalarOrVectorOf:$b, LLVM_ScalarOrVectorOf:$c, @@ -106,7 +109,8 @@ def LLVM_FTruncOp : LLVM_UnaryIntrOpF<"trunc">; def LLVM_SqrtOp : LLVM_UnaryIntrOpF<"sqrt">; def LLVM_PowOp : LLVM_BinarySameArgsIntrOpF<"pow">; def LLVM_PowIOp : LLVM_OneResultIntrOp<"powi", [], [0,1], - [DeclareOpInterfaceMethods, Pure]> { + [DeclareOpInterfaceMethods, Pure], + /*requiresFastmath=*/1> { let arguments = (ins LLVM_ScalarOrVectorOf:$val, AnySignlessInteger:$power, diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index 0d89e77..087ccbf 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -345,8 +345,13 @@ def LLVM_IntrPatterns { class LLVM_IntrOpBase overloadedResults, list overloadedOperands, list traits, int numResults, - bit requiresAccessGroup = 0, bit requiresAliasScope = 0> - : LLVM_OpBase, + bit requiresAccessGroup = 0, bit requiresAliasScope = 0, + bit requiresFastmath = 0> + : LLVM_OpBase], + []), + traits)>, Results { string resultPattern = !if(!gt(numResults, 1), LLVM_IntrPatterns.structResult, @@ -378,9 +383,11 @@ class LLVM_IntrOpBase resultTypes = }] # !if(!gt(numResults, 0), "{$_resultType};", "{};") # [{ - Operation *op = $_builder.create<$_qualCppClassName>( + auto op = $_builder.create<$_qualCppClassName>( $_location, resultTypes, *mlirOperands); - }] # !if(!gt(numResults, 0), "$res = op->getResult(0);", "(void)op;"); + }] # !if(!gt(requiresFastmath, 0), + "setFastmathFlagsAttr(inst, op);", "") + # !if(!gt(numResults, 0), "$res = op;", "(void)op;"); } // Base class for LLVM intrinsic operations, should not be used directly. Places @@ -388,10 +395,11 @@ class LLVM_IntrOpBase overloadedResults, list overloadedOperands, list traits, int numResults, bit requiresAccessGroup = 0, - bit requiresAliasScope = 0> + bit requiresAliasScope = 0, bit requiresFastmath = 0> : LLVM_IntrOpBase; + numResults, requiresAccessGroup, requiresAliasScope, + requiresFastmath>; // Base class for LLVM intrinsic operations returning no results. Places the // intrinsic into the LLVM dialect and prefixes its name with "intr.". @@ -419,8 +427,11 @@ class LLVM_ZeroResultIntrOp overloadedOperands = [], // empty otherwise. class LLVM_OneResultIntrOp overloadedResults = [], list overloadedOperands = [], - list traits = []> - : LLVM_IntrOp; + list traits = [], + bit requiresFastmath = 0> + : LLVM_IntrOp; def LLVM_OneResultOpBuilder : OpBuilder<(ins "Type":$resultType, "ValueRange":$operands, diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 4abd193..9ee69bdd 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -44,14 +44,14 @@ class LLVM_ArithmeticOpBase($_location, $lhs, $rhs); - }]; } class LLVM_IntArithmeticOp traits = []> : LLVM_ArithmeticOpBase { let arguments = commonArgs; + string mlirBuilder = [{ + $res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); + }]; } class LLVM_FloatArithmeticOp traits = []> : @@ -60,6 +60,11 @@ class LLVM_FloatArithmeticOp:$fastmathFlags); let arguments = !con(commonArgs, fmfArg); + string mlirBuilder = [{ + auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); + setFastmathFlagsAttr(inst, op); + $res = op; + }]; } // Class for arithmetic unary operations. @@ -76,8 +81,10 @@ class LLVM_UnaryFloatArithmeticOp($_location, $operand); - }]; + auto op = $_builder.create<$_qualCppClassName>($_location, $operand); + setFastmathFlagsAttr(inst, op); + $res = op; + }]; } // Integer binary operations. @@ -146,11 +153,12 @@ def LLVM_FCmpOp : LLVM_ArithmeticCmpOp<"fcmp", [ string llvmBuilder = [{ $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; - // FIXME: Import fastmath flags. string mlirBuilder = [{ auto *fCmpInst = cast(inst); - $res = $_builder.create<$_qualCppClassName>( + auto op = $_builder.create<$_qualCppClassName>( $_location, getFCmpPredicate(fCmpInst->getPredicate()), $lhs, $rhs); + setFastmathFlagsAttr(inst, op); + $res = op; }]; // Set the $predicate index to -1 to indicate there is no matching operand // and decrement the following indices. diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index 12d6238..0bd00aa 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -38,6 +38,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/Type.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Support/Error.h" @@ -326,8 +327,13 @@ getTopologicallySortedBlocks(llvm::Function *func) { return blocks; } -// Handles importing globals and functions from an LLVM module. namespace { +/// Module import implementation class that provides methods to import globals +/// and functions from an LLVM module into an MLIR module. It holds mappings +/// between the original and translated globals, basic blocks, and values used +/// during the translation. Additionally, it keeps track of the current constant +/// insertion point since LLVM immediate values translate to MLIR operations +/// that are introduced at the beginning of the region. class Importer { public: Importer(MLIRContext *context, ModuleOp module) @@ -421,6 +427,10 @@ private: constantInsertionOp = nullptr; } + /// Sets the fastmath flags attribute for the imported operation `op` given + /// the original instruction `inst`. Asserts if the operation does not + /// implement the fastmath interface. + void setFastmathFlagsAttr(llvm::Instruction *inst, Operation *op) const; /// Returns personality of `func` as a FlatSymbolRefAttr. FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *func); /// Imports `bb` into `block`, which must be initially empty. @@ -487,6 +497,31 @@ private: }; } // namespace +void Importer::setFastmathFlagsAttr(llvm::Instruction *inst, + Operation *op) const { + auto iface = cast(op); + + // Even if the imported operation implements the fastmath interface, the + // original instruction may not have fastmath flags set. Exit if an + // instruction, such as a non floating-point function call, does not have + // fastmath flags. + if (!isa(inst)) + return; + llvm::FastMathFlags flags = inst->getFastMathFlags(); + + // Set the fastmath bits flag-by-flag. + FastmathFlags value = {}; + value = bitEnumSet(value, FastmathFlags::nnan, flags.noNaNs()); + value = bitEnumSet(value, FastmathFlags::ninf, flags.noInfs()); + value = bitEnumSet(value, FastmathFlags::nsz, flags.noSignedZeros()); + value = bitEnumSet(value, FastmathFlags::arcp, flags.allowReciprocal()); + value = bitEnumSet(value, FastmathFlags::contract, flags.allowContract()); + value = bitEnumSet(value, FastmathFlags::afn, flags.approxFunc()); + value = bitEnumSet(value, FastmathFlags::reassoc, flags.allowReassoc()); + FastmathFlagsAttr attr = FastmathFlagsAttr::get(builder.getContext(), value); + iface->setAttr(iface.getFastmathAttrName(), attr); +} + // We only need integers, floats, doubles, and vectors and tensors thereof for // attributes. Scalar and vector types are converted to the standard // equivalents. Array types are converted to ranked tensors; nested array types @@ -1032,6 +1067,7 @@ LogicalResult Importer::convertOperation(OpBuilder &odsBuilder, } else { callOp = builder.create(loc, types, operands); } + setFastmathFlagsAttr(inst, callOp); if (!callInst->getType()->isVoidTy()) mapValue(inst, callOp.getResult()); return success(); @@ -1116,7 +1152,7 @@ LogicalResult Importer::convertOperation(OpBuilder &odsBuilder, LogicalResult Importer::processInstruction(llvm::Instruction *inst) { // FIXME: Support uses of SubtargetData. - // FIXME: Add support for fast-math flags and call / operand attributes. + // FIXME: Add support for call / operand attributes. // FIXME: Add support for the indirectbr, cleanupret, catchret, catchswitch, // callbr, vaarg, landingpad, catchpad, cleanuppad instructions. diff --git a/mlir/test/Target/LLVMIR/Import/fastmath.ll b/mlir/test/Target/LLVMIR/Import/fastmath.ll new file mode 100644 index 0000000..c546775 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/fastmath.ll @@ -0,0 +1,56 @@ +; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s + +; CHECK-LABEL: @fastmath_inst +define void @fastmath_inst(float %arg1, float %arg2) { + ; CHECK: llvm.fadd %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath} : f32 + %1 = fadd nnan ninf float %arg1, %arg2 + ; CHECK: llvm.fsub %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath} : f32 + %2 = fsub nsz float %arg1, %arg2 + ; CHECK: llvm.fmul %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath} : f32 + %3 = fmul arcp contract float %arg1, %arg2 + ; CHECK: llvm.fdiv %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath} : f32 + %4 = fdiv afn reassoc float %arg1, %arg2 + ; CHECK: llvm.fneg %{{.*}} {fastmathFlags = #llvm.fastmath} : f32 + %5 = fneg fast float %arg1 + ret void +} + +; // ----- + +; CHECK-LABEL: @fastmath_fcmp +define void @fastmath_fcmp(float %arg1, float %arg2) { + ; CHECK: llvm.fcmp "oge" %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath} : f32 + %1 = fcmp nsz oge float %arg1, %arg2 + ret void +} + +; // ----- + +declare float @fn(float) + +; CHECK-LABEL: @fastmath_call +define void @fastmath_call(float %arg1) { + ; CHECK: llvm.call @fn(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %1 = call ninf float @fn(float %arg1) + ret void +} + +; // ----- + +declare float @llvm.exp.f32(float) +declare float @llvm.powi.f32.i32(float, i32) +declare float @llvm.pow.f32(float, float) +declare float @llvm.fmuladd.f32(float, float, float) + +; CHECK-LABEL: @fastmath_intr +define void @fastmath_intr(float %arg1, i32 %arg2) { + ; CHECK: llvm.intr.exp(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %1 = call nnan ninf float @llvm.exp.f32(float %arg1) + ; CHECK: llvm.intr.powi(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32, i32) -> f32 + %2 = call fast float @llvm.powi.f32.i32(float %arg1, i32 %arg2) + ; CHECK: llvm.intr.pow(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32, f32) -> f32 + %3 = call fast float @llvm.pow.f32(float %arg1, float %arg1) + ; CHECK: llvm.intr.fmuladd(%{{.*}}, %{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32, f32, f32) -> f32 + %4 = call fast float @llvm.fmuladd.f32(float %arg1, float %arg1, float %arg1) + ret void +} -- 2.7.4