From c1d58c2b0023cd41f0da128f5190fa887d8f6c69 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 7 Jan 2021 13:56:37 +0100 Subject: [PATCH] [mlir] Add fastmath flags support to some LLVM dialect ops Add fastmath enum, attributes to some llvm dialect ops, `FastmathFlagsInterface` op interface, and `translateModuleToLLVMIR` support. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D92485 --- mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt | 2 + mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h | 16 +++ mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 77 +++++++++--- .../mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td | 30 +++++ mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 3 +- .../Conversion/StandardToLLVM/StandardToLLVM.cpp | 13 +- mlir/lib/Dialect/LLVMIR/CMakeLists.txt | 1 + mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 135 ++++++++++++++++++++- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 27 +++++ mlir/test/Dialect/LLVMIR/roundtrip.mlir | 32 +++++ mlir/test/Target/llvmir.mlir | 44 +++++++ 11 files changed, 354 insertions(+), 26 deletions(-) create mode 100644 mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt index 6166f36..29cef3f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -10,6 +10,8 @@ add_public_tablegen_target(MLIRLLVMOpsIncGen) add_mlir_doc(LLVMOps -gen-op-doc LLVMOps Dialects/) +add_mlir_interface(LLVMOpsInterfaces) + set(LLVM_TARGET_DEFINITIONS LLVMOps.td) mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions) mlir_tablegen(LLVMConversionEnumsToLLVM.inc -gen-enum-to-llvmir-conversions) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index 630bad4..22ff151 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -29,6 +29,7 @@ #include "llvm/IR/Type.h" #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.h.inc" +#include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.h.inc" namespace llvm { class Type; @@ -46,8 +47,23 @@ class LLVMDialect; namespace detail { struct LLVMTypeStorage; struct LLVMDialectImpl; +struct BitmaskEnumStorage; } // namespace detail +/// An attribute that specifies LLVM instruction fastmath flags. +class FMFAttr : public Attribute::AttrBase { +public: + using Base::Base; + + static FMFAttr get(FastmathFlags flags, MLIRContext *context); + + FastmathFlags getFlags() const; + + void print(DialectAsmPrinter &p) const; + static Attribute parse(DialectAsmParser &parser); +}; + } // namespace LLVM } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 428ca67..53c4254 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -14,10 +14,39 @@ #define LLVMIR_OPS include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +def FMFnnan : BitEnumAttrCase<"nnan", 0x1>; +def FMFninf : BitEnumAttrCase<"ninf", 0x2>; +def FMFnsz : BitEnumAttrCase<"nsz", 0x4>; +def FMFarcp : BitEnumAttrCase<"arcp", 0x8>; +def FMFcontract : BitEnumAttrCase<"contract", 0x10>; +def FMFafn : BitEnumAttrCase<"afn", 0x20>; +def FMFreassoc : BitEnumAttrCase<"reassoc", 0x40>; +def FMFfast : BitEnumAttrCase<"fast", 0x80>; + +def FastmathFlags : BitEnumAttr< + "FastmathFlags", + "LLVM fastmath flags", + [FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast + ]> { + let cppNamespace = "::mlir::LLVM"; +} + +def LLVM_FMFAttr : DialectAttr< + LLVM_Dialect, + CPred<"$_self.isa<::mlir::LLVM::FMFAttr>()">, + "LLVM fastmath flags"> { + let storageType = "::mlir::LLVM::FMFAttr"; + let returnType = "::mlir::LLVM::FastmathFlags"; + let convertFromStorage = "$_self.getFlags()"; + let constBuilderCall = + "::mlir::LLVM::FMFAttr::get($0, $_builder.getContext())"; +} + class LLVM_Builder { string llvmBuilder = builder; } @@ -77,29 +106,35 @@ class LLVM_ArithmeticOpBase, LLVM_Builder<"$res = builder." # builderFunc # "($lhs, $rhs);"> { - let arguments = (ins LLVM_ScalarOrVectorOf:$lhs, - LLVM_ScalarOrVectorOf:$rhs); + dag commonArgs = (ins LLVM_ScalarOrVectorOf:$lhs, + LLVM_ScalarOrVectorOf:$rhs); let results = (outs LLVM_ScalarOrVectorOf:$res); let builders = [LLVM_OneResultOpBuilder]; - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($res)"; + let assemblyFormat = "$lhs `,` $rhs custom(attr-dict) `:` type($res)"; } class LLVM_IntArithmeticOp traits = []> : - LLVM_ArithmeticOpBase; + LLVM_ArithmeticOpBase { + let arguments = commonArgs; +} class LLVM_FloatArithmeticOp traits = []> : - LLVM_ArithmeticOpBase; + LLVM_ArithmeticOpBase], traits)> { + dag fmfArg = (ins DefaultValuedAttr:$fastmathFlags); + let arguments = !con(commonArgs, fmfArg); +} // Class for arithmetic unary operations. -class LLVM_UnaryArithmeticOp traits = []> : LLVM_Op, + !listconcat([NoSideEffect, SameOperandsAndResultType, DeclareOpInterfaceMethods], traits)>, LLVM_Builder<"$res = builder." # builderFunc # "($operand);"> { - let arguments = (ins type:$operand); + let arguments = (ins type:$operand, DefaultValuedAttr:$fastmathFlags); let results = (outs type:$res); let builders = [LLVM_OneResultOpBuilder]; - let assemblyFormat = "$operand attr-dict `:` type($res)"; + let assemblyFormat = "$operand custom(attr-dict) `:` type($res)"; } // Integer binary operations. @@ -185,20 +220,24 @@ def FCmpPredicate : I64EnumAttr< let cppNamespace = "::mlir::LLVM"; } -// Other integer operations. -def LLVM_FCmpOp : LLVM_Op<"fcmp", [NoSideEffect]> { +// Other floating-point operations. +def LLVM_FCmpOp : LLVM_Op<"fcmp", [ + NoSideEffect, DeclareOpInterfaceMethods]> { let arguments = (ins FCmpPredicate:$predicate, LLVM_ScalarOrVectorOf:$lhs, - LLVM_ScalarOrVectorOf:$rhs); + LLVM_ScalarOrVectorOf:$rhs, + DefaultValuedAttr:$fastmathFlags); let results = (outs LLVM_ScalarOrVectorOf:$res); let llvmBuilder = [{ $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; let builders = [ - OpBuilderDAG<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs), + OpBuilderDAG<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs, + CArg<"FastmathFlags", "{}">:$fmf), [{ build($_builder, $_state, LLVMIntegerType::get(lhs.getType().getContext(), 1), - $_builder.getI64IntegerAttr(static_cast(predicate)), lhs, rhs); + $_builder.getI64IntegerAttr(static_cast(predicate)), lhs, rhs, + ::mlir::LLVM::FMFAttr::get(fmf, $_builder.getContext())); }]>]; let parser = [{ return parseCmpOp(parser, result); }]; let printer = [{ printFCmpOp(p, *this); }]; @@ -210,8 +249,8 @@ def LLVM_FSubOp : LLVM_FloatArithmeticOp<"fsub", "CreateFSub">; def LLVM_FMulOp : LLVM_FloatArithmeticOp<"fmul", "CreateFMul">; def LLVM_FDivOp : LLVM_FloatArithmeticOp<"fdiv", "CreateFDiv">; def LLVM_FRemOp : LLVM_FloatArithmeticOp<"frem", "CreateFRem">; -def LLVM_FNegOp : LLVM_UnaryArithmeticOp, - "fneg", "CreateFNeg">; +def LLVM_FNegOp : LLVM_UnaryFloatArithmeticOp< + LLVM_ScalarOrVectorOf, "fneg", "CreateFNeg">; // Common code definition that is used to verify and set the alignment attribute // of LLVM ops that accept such an attribute. @@ -405,7 +444,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> { let printer = [{ printLandingpadOp(p, *this); }]; } -def LLVM_CallOp : LLVM_Op<"call"> { +def LLVM_CallOp : LLVM_Op<"call", + [DeclareOpInterfaceMethods]> { let summary = "Call to an LLVM function."; let description = [{ @@ -436,7 +476,8 @@ def LLVM_CallOp : LLVM_Op<"call"> { ``` }]; let arguments = (ins OptionalAttr:$callee, - Variadic); + Variadic, + DefaultValuedAttr:$fastmathFlags); let results = (outs Variadic); let builders = [ OpBuilderDAG<(ins "LLVMFuncOp":$func, "ValueRange":$operands, diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td new file mode 100644 index 0000000..d31ae81 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td @@ -0,0 +1,30 @@ +//===-- LLVMOpsInterfaces.td - LLVM op interfaces ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the LLVM IR interfaces definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_OPS_INTERFACES +#define LLVM_OPS_INTERFACES + +include "mlir/IR/OpBase.td" + +def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> { + let description = [{ + Access to op fastmath flags. + }]; + + let cppNamespace = "::mlir::LLVM"; + + let methods = [ + InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags", "fastmathFlags">, + ]; +} + +#endif // LLVM_OPS_INTERFACES diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 2ebb24b..78927fbc 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -828,7 +828,8 @@ public: rewriter.template replaceOpWithNewOp( operation, dstType, rewriter.getI64IntegerAttr(static_cast(predicate)), - operation.operand1(), operation.operand2()); + operation.operand1(), operation.operand2(), + LLVM::FMFAttr::get({}, operation.getContext())); return success(); } }; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 39680a2..5e27088 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1836,10 +1836,11 @@ struct AddCFOpLowering : public ConvertOpToLLVMPattern { auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. + auto fmf = LLVM::FMFAttr::get({}, op.getContext()); Value real = - rewriter.create(loc, arg.lhs.real(), arg.rhs.real()); + rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = - rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag()); + rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); @@ -1863,10 +1864,11 @@ struct SubCFOpLowering : public ConvertOpToLLVMPattern { auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to substract complex numbers. + auto fmf = LLVM::FMFAttr::get({}, op.getContext()); Value real = - rewriter.create(loc, arg.lhs.real(), arg.rhs.real()); + rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = - rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag()); + rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); @@ -3155,11 +3157,12 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { CmpFOpAdaptor transformed(operands); + auto fmf = LLVM::FMFAttr::get({}, cmpfOp.getContext()); rewriter.replaceOpWithNewOp( cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpfOp.getPredicate()))), - transformed.lhs(), transformed.rhs()); + transformed.lhs(), transformed.rhs(), fmf); return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index 91fb02d..c2f88d0 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRLLVMIR DEPENDS MLIRLLVMOpsIncGen + MLIRLLVMOpsInterfacesIncGen MLIROpenMPOpsIncGen intrinsics_gen diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 0a9b616..b7f7789 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -36,6 +36,51 @@ static constexpr const char kVolatileAttrName[] = "volatile_"; static constexpr const char kNonTemporalAttrName[] = "nontemporal"; #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" +#include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" + +namespace mlir { +namespace LLVM { +namespace detail { +struct BitmaskEnumStorage : public AttributeStorage { + using KeyTy = uint64_t; + + BitmaskEnumStorage(KeyTy val) : value(val) {} + + bool operator==(const KeyTy &key) const { return value == key; } + + static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + BitmaskEnumStorage(key); + } + + KeyTy value = 0; +}; +} // namespace detail +} // namespace LLVM +} // namespace mlir + +static auto processFMFAttr(ArrayRef attrs) { + SmallVector filteredAttrs( + llvm::make_filter_range(attrs, [&](NamedAttribute attr) { + if (attr.first == "fastmathFlags") { + auto defAttr = FMFAttr::get({}, attr.second.getContext()); + return defAttr != attr.second; + } + return true; + })); + return filteredAttrs; +} + +static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, + NamedAttrList &result) { + return parser.parseOptionalAttrDict(result); +} + +static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, + DictionaryAttr attrs) { + printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); +} //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::CmpOp. @@ -50,7 +95,7 @@ static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) { static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) { p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate()) << "\" " << op.getOperand(0) << ", " << op.getOperand(1); - p.printOptionalAttrDict(op.getAttrs(), {"predicate"}); + p.printOptionalAttrDict(processFMFAttr(op.getAttrs()), {"predicate"}); p << " : " << op.lhs().getType(); } @@ -771,7 +816,7 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) { auto args = op.getOperands().drop_front(isDirect ? 0 : 1); p << '(' << args << ')'; - p.printOptionalAttrDict(op.getAttrs(), {"callee"}); + p.printOptionalAttrDict(processFMFAttr(op.getAttrs()), {"callee"}); // Reconstruct the function MLIR function type from operand and result types. p << " : " @@ -2041,6 +2086,8 @@ static LogicalResult verify(FenceOp &op) { //===----------------------------------------------------------------------===// void LLVMDialect::initialize() { + addAttributes(); + // clang-format off addTypeshasTrait() && op->hasTrait(); } + +FMFAttr FMFAttr::get(FastmathFlags flags, MLIRContext *context) { + return Base::get(context, static_cast(flags)); +} + +FastmathFlags FMFAttr::getFlags() const { + return static_cast(getImpl()->value); +} + +static constexpr const FastmathFlags FastmathFlagsList[] = { + // clang-format off + FastmathFlags::nnan, + FastmathFlags::ninf, + FastmathFlags::nsz, + FastmathFlags::arcp, + FastmathFlags::contract, + FastmathFlags::afn, + FastmathFlags::reassoc, + FastmathFlags::fast, + // clang-format on +}; + +void FMFAttr::print(DialectAsmPrinter &printer) const { + printer << "fastmath<"; + auto flags = llvm::make_filter_range(FastmathFlagsList, [&](auto flag) { + return bitEnumContains(getFlags(), flag); + }); + llvm::interleaveComma(flags, printer, + [&](auto flag) { printer << stringifyEnum(flag); }); + printer << ">"; +} + +Attribute FMFAttr::parse(DialectAsmParser &parser) { + if (failed(parser.parseLess())) + return {}; + + FastmathFlags flags = {}; + if (failed(parser.parseOptionalGreater())) { + do { + StringRef elemName; + if (failed(parser.parseKeyword(&elemName))) + return {}; + + auto elem = symbolizeFastmathFlags(elemName); + if (!elem) { + parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") + << elemName; + return {}; + } + + flags = flags | *elem; + } while (succeeded(parser.parseOptionalComma())); + + if (failed(parser.parseGreater())) + return {}; + } + + return FMFAttr::get(flags, parser.getBuilder().getContext()); +} + +Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser, + Type type) const { + if (type) { + parser.emitError(parser.getNameLoc(), "unexpected type"); + return {}; + } + StringRef attrKind; + if (parser.parseKeyword(&attrKind)) + return {}; + + if (attrKind == "fastmath") + return FMFAttr::parse(parser); + + parser.emitError(parser.getNameLoc(), "Unknown attrribute type: ") + << attrKind; + return {}; +} + +void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { + if (auto fmf = attr.dyn_cast()) + fmf.print(os); + else + llvm_unreachable("Unknown attribute type"); +} diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 5ffb11e..7700867 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -666,6 +666,29 @@ ModuleTranslation::convertOmpOperation(Operation &opInst, }); } +static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) { + using llvmFMF = llvm::FastMathFlags; + using FuncT = void (llvmFMF::*)(bool); + const std::pair handlers[] = { + // clang-format off + {FastmathFlags::nnan, &llvmFMF::setNoNaNs}, + {FastmathFlags::ninf, &llvmFMF::setNoInfs}, + {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros}, + {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal}, + {FastmathFlags::contract, &llvmFMF::setAllowContract}, + {FastmathFlags::afn, &llvmFMF::setApproxFunc}, + {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc}, + {FastmathFlags::fast, &llvmFMF::setFast}, + // clang-format on + }; + llvm::FastMathFlags ret; + auto fmf = op.fastmathFlags(); + for (auto it : handlers) + if (bitEnumContains(fmf, it.first)) + (ret.*(it.second))(true); + return ret; +} + /// Given a single MLIR operation, create the corresponding LLVM IR operation /// using the `builder`. LLVM IR Builder does not have a generic interface so /// this has to be a long chain of `if`s calling different functions with a @@ -680,6 +703,10 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst, return position; }; + llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder); + if (auto fmf = dyn_cast(opInst)) + builder.setFastMathFlags(getFastmathFlags(fmf)); + #include "mlir/Dialect/LLVMIR/LLVMConversions.inc" // Emit function calls. If the "callee" attribute is present, this is a diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index fc9ff68..05d8381 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -387,3 +387,35 @@ llvm.func @useInlineAsm(%arg0: !llvm.i32) { llvm.return } + +// CHECK-LABEL: @fastmathFlags +func @fastmathFlags(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32) { +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float +// CHECK: {{.*}} = llvm.fsub %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float +// CHECK: {{.*}} = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float +// CHECK: {{.*}} = llvm.fdiv %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float +// CHECK: {{.*}} = llvm.frem %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + %0 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + %1 = llvm.fsub %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + %2 = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + %3 = llvm.fdiv %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + %4 = llvm.frem %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + +// CHECK: {{.*}} = llvm.fcmp "oeq" %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + %5 = llvm.fcmp "oeq" %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + +// CHECK: {{.*}} = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + %6 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + +// CHECK: {{.*}} = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath} : (!llvm.i32) -> !llvm.struct<(i32, double, i32)> + %7 = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath} : (!llvm.i32) -> !llvm.struct<(i32, double, i32)> + +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : !llvm.float + %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<>} : !llvm.float +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + %9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : !llvm.float + +// CHECK: {{.*}} = llvm.fneg %arg0 : !llvm.float + %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : !llvm.float + return +} diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir index 099b8c9..921c3e8 100644 --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1360,6 +1360,50 @@ llvm.func @useInlineAsm(%arg0: !llvm.i32) { // ----- +llvm.func @fastmathFlagsFunc(!llvm.float) -> !llvm.float + +// CHECK-LABEL: @fastmathFlags +llvm.func @fastmathFlags(%arg0: !llvm.float) { +// CHECK: {{.*}} = fadd nnan ninf float {{.*}}, {{.*}} +// CHECK: {{.*}} = fsub nnan ninf float {{.*}}, {{.*}} +// CHECK: {{.*}} = fmul nnan ninf float {{.*}}, {{.*}} +// CHECK: {{.*}} = fdiv nnan ninf float {{.*}}, {{.*}} +// CHECK: {{.*}} = frem nnan ninf float {{.*}}, {{.*}} + %0 = llvm.fadd %arg0, %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + %1 = llvm.fsub %arg0, %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + %2 = llvm.fmul %arg0, %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + %3 = llvm.fdiv %arg0, %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + %4 = llvm.frem %arg0, %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + +// CHECK: {{.*}} = fcmp nnan ninf oeq {{.*}}, {{.*}} + %5 = llvm.fcmp "oeq" %arg0, %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + +// CHECK: {{.*}} = fneg nnan ninf float {{.*}} + %6 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath} : !llvm.float + +// CHECK: {{.*}} = call float @fastmathFlagsFunc({{.*}}) +// CHECK: {{.*}} = call nnan float @fastmathFlagsFunc({{.*}}) +// CHECK: {{.*}} = call ninf float @fastmathFlagsFunc({{.*}}) +// CHECK: {{.*}} = call nsz float @fastmathFlagsFunc({{.*}}) +// CHECK: {{.*}} = call arcp float @fastmathFlagsFunc({{.*}}) +// CHECK: {{.*}} = call contract float @fastmathFlagsFunc({{.*}}) +// CHECK: {{.*}} = call afn float @fastmathFlagsFunc({{.*}}) +// CHECK: {{.*}} = call reassoc float @fastmathFlagsFunc({{.*}}) +// CHECK: {{.*}} = call fast float @fastmathFlagsFunc({{.*}}) + %8 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<>} : (!llvm.float) -> (!llvm.float) + %9 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (!llvm.float) -> (!llvm.float) + %10 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (!llvm.float) -> (!llvm.float) + %11 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (!llvm.float) -> (!llvm.float) + %12 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (!llvm.float) -> (!llvm.float) + %13 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (!llvm.float) -> (!llvm.float) + %14 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (!llvm.float) -> (!llvm.float) + %15 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (!llvm.float) -> (!llvm.float) + %16 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath} : (!llvm.float) -> (!llvm.float) + llvm.return +} + +// ----- + // CHECK-LABEL: @switch_args llvm.func @switch_args(%arg0: !llvm.i32) { %0 = llvm.mlir.constant(5 : i32) : !llvm.i32 -- 2.7.4