/// Extends the MLIR OpBuilder to provide methods for building common FIR
/// patterns.
-class FirOpBuilder : public mlir::OpBuilder {
+class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
public:
explicit FirOpBuilder(mlir::Operation *op, const fir::KindMapping &kindMap)
- : OpBuilder{op}, kindMap{kindMap} {}
+ : OpBuilder{op, /*listener=*/this}, kindMap{kindMap} {}
explicit FirOpBuilder(mlir::OpBuilder &builder,
const fir::KindMapping &kindMap)
- : OpBuilder{builder}, kindMap{kindMap} {}
+ : OpBuilder{builder}, kindMap{kindMap} {
+ setListener(this);
+ }
+
+ // The listener self-reference has to be updated in case of copy-construction.
+ FirOpBuilder(const FirOpBuilder &other)
+ : OpBuilder{other}, kindMap{other.kindMap}, fastMathFlags{
+ other.fastMathFlags} {
+ setListener(this);
+ }
/// Get the current Region of the insertion point.
mlir::Region &getRegion() { return *getBlock()->getParent(); }
mlir::Value ub, mlir::Value step,
mlir::Type type);
+ /// Set default FastMathFlags value for all operations
+ /// supporting mlir::arith::FastMathAttr that will be created
+ /// by this builder.
+ void setFastMathFlags(mlir::arith::FastMathFlags flags) {
+ fastMathFlags = flags;
+ }
+
/// Dump the current function. (debug)
LLVM_DUMP_METHOD void dumpFunc();
private:
+ /// Set attributes (e.g. FastMathAttr) to \p op operation
+ /// based on the current attributes setting.
+ void setCommonAttributes(mlir::Operation *op) const;
+
+ /// FirOpBuilder hook for creating new operation.
+ void notifyOperationInserted(mlir::Operation *op) override {
+ setCommonAttributes(op);
+ }
+
const KindMapping &kindMap;
+
+ /// FastMathFlags that need to be set for operations that support
+ /// mlir::arith::FastMathAttr.
+ mlir::arith::FastMathFlags fastMathFlags{};
};
} // namespace fir
return create<mlir::arith::SelectOp>(loc, cmp, div, zero);
}
+void fir::FirOpBuilder::setCommonAttributes(mlir::Operation *op) const {
+ auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
+ if (!fmi)
+ return;
+ // TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged.
+ // For now set the attribute by the name.
+ llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName();
+ if (fastMathFlags != mlir::arith::FastMathFlags::none)
+ op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get(
+ op->getContext(), fastMathFlags));
+}
+
//===--------------------------------------------------------------------===//
// ExtendedValue inquiry helper implementation
//===--------------------------------------------------------------------===//
EXPECT_TRUE(fir::isDerivedWithLenParameters(array));
}
}
+
+TEST_F(FIRBuilderTest, genArithFastMath) {
+ auto builder = getBuilder();
+ auto ctx = builder.getContext();
+ auto loc = builder.getUnknownLoc();
+
+ auto realTy = mlir::FloatType::getF32(ctx);
+ auto arg = builder.create<fir::UndefOp>(loc, realTy);
+
+ // Test that FastMathFlags is 'none' by default.
+ mlir::Operation *op1 = builder.create<mlir::arith::AddFOp>(loc, arg, arg);
+ auto op1_fmi =
+ mlir::dyn_cast_or_null<mlir::arith::ArithFastMathInterface>(op1);
+ EXPECT_TRUE(op1_fmi);
+ auto op1_fmf = op1_fmi.getFastMathFlagsAttr().getValue();
+ EXPECT_EQ(op1_fmf, arith::FastMathFlags::none);
+
+ // Test that the builder is copied properly.
+ fir::FirOpBuilder builder_copy(builder);
+
+ arith::FastMathFlags FMF1 =
+ arith::FastMathFlags::contract | arith::FastMathFlags::reassoc;
+ builder.setFastMathFlags(FMF1);
+ arith::FastMathFlags FMF2 =
+ arith::FastMathFlags::nnan | arith::FastMathFlags::ninf;
+ builder_copy.setFastMathFlags(FMF2);
+
+ // Modifying FastMathFlags for the copy must not affect the original builder.
+ mlir::Operation *op2 = builder.create<mlir::arith::AddFOp>(loc, arg, arg);
+ auto op2_fmi =
+ mlir::dyn_cast_or_null<mlir::arith::ArithFastMathInterface>(op2);
+ EXPECT_TRUE(op2_fmi);
+ auto op2_fmf = op2_fmi.getFastMathFlagsAttr().getValue();
+ EXPECT_EQ(op2_fmf, FMF1);
+
+ // Modifying FastMathFlags for the original builder must not affect the copy.
+ mlir::Operation *op3 =
+ builder_copy.create<mlir::arith::AddFOp>(loc, arg, arg);
+ auto op3_fmi =
+ mlir::dyn_cast_or_null<mlir::arith::ArithFastMathInterface>(op3);
+ EXPECT_TRUE(op3_fmi);
+ auto op3_fmf = op3_fmi.getFastMathFlagsAttr().getValue();
+ EXPECT_EQ(op3_fmf, FMF2);
+
+ // Test that the builder copy inherits FastMathFlags from the original.
+ fir::FirOpBuilder builder_copy2(builder);
+
+ mlir::Operation *op4 =
+ builder_copy2.create<mlir::arith::AddFOp>(loc, arg, arg);
+ auto op4_fmi =
+ mlir::dyn_cast_or_null<mlir::arith::ArithFastMathInterface>(op4);
+ EXPECT_TRUE(op4_fmi);
+ auto op4_fmf = op4_fmi.getFastMathFlagsAttr().getValue();
+ EXPECT_EQ(op4_fmf, FMF1);
+}