let cppNamespace = "::fir";
let useDefaultTypePrinterParser = 0;
let useDefaultAttributePrinterParser = 0;
+ let dependentDialects = [
+ // Arith dialect provides FastMathFlagsAttr
+ // supported by some FIR operations.
+ "arith::ArithDialect"
+ ];
}
#endif // FORTRAN_DIALECT_FIR_DIALECT
#ifndef FORTRAN_DIALECT_FIR_OPS
#define FORTRAN_DIALECT_FIR_OPS
+include "mlir/Dialect/Arith/IR/ArithBase.td"
+include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "flang/Optimizer/Dialect/FIRDialect.td"
include "flang/Optimizer/Dialect/FIRTypes.td"
include "flang/Optimizer/Dialect/FIRAttr.td"
// Procedure call operations
//===----------------------------------------------------------------------===//
-def fir_CallOp : fir_Op<"call", [CallOpInterface]> {
+def fir_CallOp : fir_Op<"call",
+ [CallOpInterface, DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
let summary = "call a procedure";
let description = [{
let arguments = (ins
OptionalAttr<SymbolRefAttr>:$callee,
- Variadic<AnyType>:$args
+ Variadic<AnyType>:$args,
+ DefaultValuedAttr<Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::none">:$fastmath
);
let results = (outs Variadic<AnyType>);
#include "flang/Optimizer/Support/InternalNames.h"
#include "flang/Optimizer/Support/TypeCode.h"
#include "flang/Semantics/runtime-type-info.h"
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
llvm::SmallVector<mlir::Type> resultTys;
for (auto r : call.getResults())
resultTys.push_back(convertType(r.getType()));
+ // Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
+ mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
+ attrConvert(call);
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
- call, resultTys, adaptor.getOperands(), call->getAttrs());
+ call, resultTys, adaptor.getOperands(), attrConvert.getAttrs());
return mlir::success();
}
};
else
p << getOperand(0);
p << '(' << (*this)->getOperands().drop_front(isDirect ? 0 : 1) << ')';
- p.printOptionalAttrDict((*this)->getAttrs(),
- {fir::CallOp::getCalleeAttrNameStr()});
+
+ // Print 'fastmath<...>' (if it has non-default value) before
+ // any other attributes.
+ mlir::arith::FastMathFlagsAttr fmfAttr = getFastmathAttr();
+ if (fmfAttr.getValue() != mlir::arith::FastMathFlags::none) {
+ p << ' ' << mlir::arith::FastMathFlagsAttr::getMnemonic();
+ p.printStrippedAttrOrType(fmfAttr);
+ }
+
+ p.printOptionalAttrDict(
+ (*this)->getAttrs(),
+ {fir::CallOp::getCalleeAttrNameStr(), getFastmathAttrName()});
auto resultTypes{getResultTypes()};
llvm::SmallVector<mlir::Type> argTypes(
llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1));
return mlir::failure();
mlir::Type type;
- if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) ||
- parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
+ if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren))
+ return mlir::failure();
+
+ // Parse 'fastmath<...>', if present.
+ mlir::arith::FastMathFlagsAttr fmfAttr;
+ llvm::StringRef fmfAttrName = getFastmathAttrName(result.name);
+ if (mlir::succeeded(parser.parseOptionalKeyword(fmfAttrName)))
+ if (parser.parseCustomAttributeWithFallback(fmfAttr, mlir::Type{},
+ fmfAttrName, attrs))
+ return mlir::failure();
+
+ if (parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
parser.parseType(type))
return mlir::failure();
--- /dev/null
+// RUN: fir-opt %s | fir-opt | FileCheck %s
+
+// CHECK-LABEL: @test_callop
+func.func @test_callop(%arg0 : f32) {
+ // CHECK: fir.call @callee() : () -> ()
+ fir.call @callee() fastmath<none> : () -> ()
+ // CHECK: fir.call @callee() : () -> ()
+ fir.call @callee() {fastmath = #arith.fastmath<none>} : () -> ()
+ // CHECK: fir.call @callee() fastmath<ninf,contract> : () -> ()
+ fir.call @callee() fastmath<ninf,contract> : () -> ()
+ // CHECK: fir.call @callee() fastmath<nnan,afn> : () -> ()
+ fir.call @callee() {fastmath = #arith.fastmath<nnan,afn>} : () -> ()
+ // CHECK: fir.call @callee() fastmath<fast> : () -> ()
+ fir.call @callee() fastmath<fast> : () -> ()
+ // CHECK: fir.call @callee() fastmath<fast> : () -> ()
+ fir.call @callee() {fastmath = #arith.fastmath<fast>} : () -> ()
+ return
+}
+
+func.func private @callee()