// "intr." to avoid potential name clashes.
class LLVM_UnaryIntrOpBase<string func, Type element,
- list<Trait> traits = [],
- dag addAttrs = (ins)> :
+ list<Trait> traits = [], bit requiresFastmath = 0> :
LLVM_OneResultIntrOp<func, [], [0],
- !listconcat([Pure, SameOperandsAndResultType], traits)> {
- dag args = (ins LLVM_ScalarOrVectorOf<element>:$in);
- let arguments = !con(args, addAttrs);
+ !listconcat([Pure, SameOperandsAndResultType], traits),
+ requiresFastmath> {
+ dag commonArgs = (ins LLVM_ScalarOrVectorOf<element>:$in);
let assemblyFormat = "`(` operands `)` custom<LLVMOpAttrs>(attr-dict) `:` "
"functional-type(operands, results)";
}
class LLVM_UnaryIntrOpI<string func, list<Trait> traits = []> :
- LLVM_UnaryIntrOpBase<func, AnySignlessInteger, traits>;
+ LLVM_UnaryIntrOpBase<func, AnySignlessInteger, traits> {
+ let arguments = commonArgs;
+}
class LLVM_UnaryIntrOpF<string func, list<Trait> traits = []> :
- LLVM_UnaryIntrOpBase<func, LLVM_AnyFloat,
- !listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>],
- traits),
- (ins DefaultValuedAttr<LLVM_FastmathFlagsAttr,
- "{}">:$fastmathFlags)>;
+ LLVM_UnaryIntrOpBase<func, LLVM_AnyFloat, traits, /*requiresFastmath=*/1> {
+ dag fmfArg = (
+ ins DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
+ let arguments = !con(commonArgs, fmfArg);
+}
class LLVM_BinarySameArgsIntrOpBase<string func, Type element,
- list<Trait> traits = [],
- dag addAttrs = (ins)> :
+ list<Trait> traits = [], bit requiresFastmath = 0> :
LLVM_OneResultIntrOp<func, [], [0],
- !listconcat([Pure, SameOperandsAndResultType], traits)> {
- dag args = (ins LLVM_ScalarOrVectorOf<element>:$a,
- LLVM_ScalarOrVectorOf<element>:$b);
- let arguments = !con(args, addAttrs);
+ !listconcat([Pure, SameOperandsAndResultType], traits),
+ requiresFastmath> {
+ dag commonArgs = (ins LLVM_ScalarOrVectorOf<element>:$a,
+ LLVM_ScalarOrVectorOf<element>:$b);
let assemblyFormat = "`(` operands `)` custom<LLVMOpAttrs>(attr-dict) `:` "
"functional-type(operands, results)";
}
class LLVM_BinarySameArgsIntrOpI<string func, list<Trait> traits = []> :
- LLVM_BinarySameArgsIntrOpBase<func, AnySignlessInteger, traits>;
+ LLVM_BinarySameArgsIntrOpBase<func, AnySignlessInteger, traits> {
+ let arguments = commonArgs;
+}
class LLVM_BinarySameArgsIntrOpF<string func, list<Trait> traits = []> :
- LLVM_BinarySameArgsIntrOpBase<func, LLVM_AnyFloat,
- !listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>],
- traits),
- (ins DefaultValuedAttr<LLVM_FastmathFlagsAttr,
- "{}">:$fastmathFlags)>;
+ LLVM_BinarySameArgsIntrOpBase<func, LLVM_AnyFloat, traits,
+ /*requiresFastmath=*/1> {
+ dag fmfArg = (
+ ins DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
+ let arguments = !con(commonArgs, fmfArg);
+}
class LLVM_TernarySameArgsIntrOpF<string func, list<Trait> traits = []> :
LLVM_OneResultIntrOp<func, [], [0],
- !listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>,
- Pure, SameOperandsAndResultType], traits)> {
+ !listconcat([Pure, SameOperandsAndResultType], traits),
+ /*requiresFastmath=*/1> {
let arguments = (ins LLVM_ScalarOrVectorOf<AnyFloat>:$a,
LLVM_ScalarOrVectorOf<AnyFloat>:$b,
LLVM_ScalarOrVectorOf<AnyFloat>:$c,
def LLVM_SqrtOp : LLVM_UnaryIntrOpF<"sqrt">;
def LLVM_PowOp : LLVM_BinarySameArgsIntrOpF<"pow">;
def LLVM_PowIOp : LLVM_OneResultIntrOp<"powi", [], [0,1],
- [DeclareOpInterfaceMethods<FastmathFlagsInterface>, Pure]> {
+ [DeclareOpInterfaceMethods<FastmathFlagsInterface>, Pure],
+ /*requiresFastmath=*/1> {
let arguments =
(ins LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$val,
AnySignlessInteger:$power,
class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
list<int> overloadedResults, list<int> overloadedOperands,
list<Trait> traits, int numResults,
- bit requiresAccessGroup = 0, bit requiresAliasScope = 0>
- : LLVM_OpBase<dialect, opName, traits>,
+ bit requiresAccessGroup = 0, bit requiresAliasScope = 0,
+ bit requiresFastmath = 0>
+ : LLVM_OpBase<dialect, opName, !listconcat(
+ !if(!gt(requiresFastmath, 0),
+ [DeclareOpInterfaceMethods<FastmathFlagsInterface>],
+ []),
+ traits)>,
Results<!if(!gt(numResults, 0), (outs LLVM_Type:$res), (outs))> {
string resultPattern = !if(!gt(numResults, 1),
LLVM_IntrPatterns.structResult,
return failure();
SmallVector<Type> resultTypes =
}] # !if(!gt(numResults, 0), "{$_resultType};", "{};") # [{
- Operation *op = $_builder.create<$_qualCppClassName>(
+ auto op = $_builder.create<$_qualCppClassName>(
$_location, resultTypes, *mlirOperands);
- }] # !if(!gt(numResults, 0), "$res = op->getResult(0);", "(void)op;");
+ }] # !if(!gt(requiresFastmath, 0),
+ "setFastmathFlagsAttr(inst, op);", "")
+ # !if(!gt(numResults, 0), "$res = op;", "(void)op;");
}
// Base class for LLVM intrinsic operations, should not be used directly. Places
class LLVM_IntrOp<string mnem, list<int> overloadedResults,
list<int> overloadedOperands, list<Trait> traits,
int numResults, bit requiresAccessGroup = 0,
- bit requiresAliasScope = 0>
+ bit requiresAliasScope = 0, bit requiresFastmath = 0>
: LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
overloadedResults, overloadedOperands, traits,
- numResults, requiresAccessGroup, requiresAliasScope>;
+ numResults, requiresAccessGroup, requiresAliasScope,
+ requiresFastmath>;
// Base class for LLVM intrinsic operations returning no results. Places the
// intrinsic into the LLVM dialect and prefixes its name with "intr.".
// empty otherwise.
class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
list<int> overloadedOperands = [],
- list<Trait> traits = []>
- : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1>;
+ list<Trait> traits = [],
+ bit requiresFastmath = 0>
+ : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
+ /*requiresAccessGroup=*/0, /*requiresAliasScope=*/0,
+ requiresFastmath>;
def LLVM_OneResultOpBuilder :
OpBuilder<(ins "Type":$resultType, "ValueRange":$operands,
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "$lhs `,` $rhs custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
string llvmInstName = instName;
- string mlirBuilder = [{
- $res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
- }];
}
class LLVM_IntArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName, traits> {
let arguments = commonArgs;
+ string mlirBuilder = [{
+ $res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+ }];
}
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
dag fmfArg = (
ins DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
let arguments = !con(commonArgs, fmfArg);
+ string mlirBuilder = [{
+ auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+ setFastmathFlagsAttr(inst, op);
+ $res = op;
+ }];
}
// Class for arithmetic unary operations.
let assemblyFormat = "$operand custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
string llvmInstName = instName;
string mlirBuilder = [{
- $res = $_builder.create<$_qualCppClassName>($_location, $operand);
- }];
+ auto op = $_builder.create<$_qualCppClassName>($_location, $operand);
+ setFastmathFlagsAttr(inst, op);
+ $res = op;
+ }];
}
// Integer binary operations.
string llvmBuilder = [{
$res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
}];
- // FIXME: Import fastmath flags.
string mlirBuilder = [{
auto *fCmpInst = cast<llvm::FCmpInst>(inst);
- $res = $_builder.create<$_qualCppClassName>(
+ auto op = $_builder.create<$_qualCppClassName>(
$_location, getFCmpPredicate(fCmpInst->getPredicate()), $lhs, $rhs);
+ setFastmathFlagsAttr(inst, op);
+ $res = op;
}];
// Set the $predicate index to -1 to indicate there is no matching operand
// and decrement the following indices.
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/Operator.h"
#include "llvm/IR/Type.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/Error.h"
return blocks;
}
-// Handles importing globals and functions from an LLVM module.
namespace {
+/// Module import implementation class that provides methods to import globals
+/// and functions from an LLVM module into an MLIR module. It holds mappings
+/// between the original and translated globals, basic blocks, and values used
+/// during the translation. Additionally, it keeps track of the current constant
+/// insertion point since LLVM immediate values translate to MLIR operations
+/// that are introduced at the beginning of the region.
class Importer {
public:
Importer(MLIRContext *context, ModuleOp module)
constantInsertionOp = nullptr;
}
+ /// Sets the fastmath flags attribute for the imported operation `op` given
+ /// the original instruction `inst`. Asserts if the operation does not
+ /// implement the fastmath interface.
+ void setFastmathFlagsAttr(llvm::Instruction *inst, Operation *op) const;
/// Returns personality of `func` as a FlatSymbolRefAttr.
FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *func);
/// Imports `bb` into `block`, which must be initially empty.
};
} // namespace
+void Importer::setFastmathFlagsAttr(llvm::Instruction *inst,
+ Operation *op) const {
+ auto iface = cast<FastmathFlagsInterface>(op);
+
+ // Even if the imported operation implements the fastmath interface, the
+ // original instruction may not have fastmath flags set. Exit if an
+ // instruction, such as a non floating-point function call, does not have
+ // fastmath flags.
+ if (!isa<llvm::FPMathOperator>(inst))
+ return;
+ llvm::FastMathFlags flags = inst->getFastMathFlags();
+
+ // Set the fastmath bits flag-by-flag.
+ FastmathFlags value = {};
+ value = bitEnumSet(value, FastmathFlags::nnan, flags.noNaNs());
+ value = bitEnumSet(value, FastmathFlags::ninf, flags.noInfs());
+ value = bitEnumSet(value, FastmathFlags::nsz, flags.noSignedZeros());
+ value = bitEnumSet(value, FastmathFlags::arcp, flags.allowReciprocal());
+ value = bitEnumSet(value, FastmathFlags::contract, flags.allowContract());
+ value = bitEnumSet(value, FastmathFlags::afn, flags.approxFunc());
+ value = bitEnumSet(value, FastmathFlags::reassoc, flags.allowReassoc());
+ FastmathFlagsAttr attr = FastmathFlagsAttr::get(builder.getContext(), value);
+ iface->setAttr(iface.getFastmathAttrName(), attr);
+}
+
// We only need integers, floats, doubles, and vectors and tensors thereof for
// attributes. Scalar and vector types are converted to the standard
// equivalents. Array types are converted to ranked tensors; nested array types
} else {
callOp = builder.create<CallOp>(loc, types, operands);
}
+ setFastmathFlagsAttr(inst, callOp);
if (!callInst->getType()->isVoidTy())
mapValue(inst, callOp.getResult());
return success();
LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
// FIXME: Support uses of SubtargetData.
- // FIXME: Add support for fast-math flags and call / operand attributes.
+ // FIXME: Add support for call / operand attributes.
// FIXME: Add support for the indirectbr, cleanupret, catchret, catchswitch,
// callbr, vaarg, landingpad, catchpad, cleanuppad instructions.
--- /dev/null
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @fastmath_inst
+define void @fastmath_inst(float %arg1, float %arg2) {
+ ; CHECK: llvm.fadd %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
+ %1 = fadd nnan ninf float %arg1, %arg2
+ ; CHECK: llvm.fsub %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath<nsz>} : f32
+ %2 = fsub nsz float %arg1, %arg2
+ ; CHECK: llvm.fmul %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath<arcp, contract>} : f32
+ %3 = fmul arcp contract float %arg1, %arg2
+ ; CHECK: llvm.fdiv %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath<afn, reassoc>} : f32
+ %4 = fdiv afn reassoc float %arg1, %arg2
+ ; CHECK: llvm.fneg %{{.*}} {fastmathFlags = #llvm.fastmath<fast>} : f32
+ %5 = fneg fast float %arg1
+ ret void
+}
+
+; // -----
+
+; CHECK-LABEL: @fastmath_fcmp
+define void @fastmath_fcmp(float %arg1, float %arg2) {
+ ; CHECK: llvm.fcmp "oge" %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath<nsz>} : f32
+ %1 = fcmp nsz oge float %arg1, %arg2
+ ret void
+}
+
+; // -----
+
+declare float @fn(float)
+
+; CHECK-LABEL: @fastmath_call
+define void @fastmath_call(float %arg1) {
+ ; CHECK: llvm.call @fn(%{{.*}}) {fastmathFlags = #llvm.fastmath<ninf>} : (f32) -> f32
+ %1 = call ninf float @fn(float %arg1)
+ ret void
+}
+
+; // -----
+
+declare float @llvm.exp.f32(float)
+declare float @llvm.powi.f32.i32(float, i32)
+declare float @llvm.pow.f32(float, float)
+declare float @llvm.fmuladd.f32(float, float, float)
+
+; CHECK-LABEL: @fastmath_intr
+define void @fastmath_intr(float %arg1, i32 %arg2) {
+ ; CHECK: llvm.intr.exp(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, ninf>} : (f32) -> f32
+ %1 = call nnan ninf float @llvm.exp.f32(float %arg1)
+ ; CHECK: llvm.intr.powi(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32, i32) -> f32
+ %2 = call fast float @llvm.powi.f32.i32(float %arg1, i32 %arg2)
+ ; CHECK: llvm.intr.pow(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32, f32) -> f32
+ %3 = call fast float @llvm.pow.f32(float %arg1, float %arg1)
+ ; CHECK: llvm.intr.fmuladd(%{{.*}}, %{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32, f32, f32) -> f32
+ %4 = call fast float @llvm.fmuladd.f32(float %arg1, float %arg1, float %arg1)
+ ret void
+}