class SimplifyIntrinsicsPass
: public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
using FunctionTypeGeneratorTy =
- std::function<mlir::FunctionType(fir::FirOpBuilder &)>;
+ llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
using FunctionBodyGeneratorTy =
- std::function<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
+ llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
+ using GenReductionBodyTy = llvm::function_ref<void(
+ fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp)>;
public:
/// Generate a new function implementing a simplified version
FunctionBodyGeneratorTy bodyGenerator);
void runOnOperation() override;
void getDependentDialects(mlir::DialectRegistry ®istry) const override;
+
+private:
+ /// Helper function to replace a reduction type of call with its
+ /// simplified form. The actual function is generated using a callback
+ /// function.
+ /// \p call is the call to be replaced
+ /// \p kindMap is used to create FIROpBuilder
+ /// \p genBodyFunc is the callback that builds the replacement function
+ void simplifyReduction(fir::CallOp call, const fir::KindMapping &kindMap,
+ GenReductionBodyTy genBodyFunc);
};
} // namespace
{elementType});
}
-using BodyOpGeneratorTy =
- std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location,
- const mlir::Type &, mlir::Value, mlir::Value)>;
-using InitValGeneratorTy = std::function<mlir::Value(
+using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
+ fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
+ mlir::Value)>;
+using InitValGeneratorTy = llvm::function_ref<mlir::Value(
fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>;
/// Generate the reduction loop into \p funcOp.
} while (true);
}
+void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
+ const fir::KindMapping &kindMap,
+ GenReductionBodyTy genBodyFunc) {
+ mlir::SymbolRefAttr callee = call.getCalleeAttr();
+ mlir::StringRef funcName = callee.getLeafReference().getValue();
+ mlir::Operation::operand_range args = call.getArgs();
+ // args[1] and args[2] are source filename and line number, ignored.
+ const mlir::Value &dim = args[3];
+ const mlir::Value &mask = args[4];
+ // dim is zero when it is absent, which is an implementation
+ // detail in the runtime library.
+ bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
+ unsigned rank = getDimCount(args[0]);
+ if (dimAndMaskAbsent && rank == 1) {
+ mlir::Location loc = call.getLoc();
+ mlir::Type type;
+ fir::FirOpBuilder builder(call, kindMap);
+ if (funcName.endswith("Integer4")) {
+ type = mlir::IntegerType::get(builder.getContext(), 32);
+ } else if (funcName.endswith("Real8")) {
+ type = mlir::FloatType::getF64(builder.getContext());
+ } else {
+ return;
+ }
+ auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
+ return genNoneBoxType(builder, type);
+ };
+ mlir::func::FuncOp newFunc =
+ getOrCreateFunction(builder, funcName, typeGenerator, genBodyFunc);
+ auto newCall =
+ builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
+ call->replaceAllUsesWith(newCall.getResults());
+ call->dropAllReferences();
+ call->erase();
+ }
+}
+
void SimplifyIntrinsicsPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
mlir::ModuleOp module = getOperation();
// int dim, const Descriptor *mask)
//
if (funcName.startswith("_FortranASum")) {
- mlir::Operation::operand_range args = call.getArgs();
- // args[1] and args[2] are source filename and line number, ignored.
- const mlir::Value &dim = args[3];
- const mlir::Value &mask = args[4];
- // dim is zero when it is absent, which is an implementation
- // detail in the runtime library.
- bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
- unsigned rank = getDimCount(args[0]);
- if (dimAndMaskAbsent && rank == 1) {
- mlir::Location loc = call.getLoc();
- mlir::Type type;
- fir::FirOpBuilder builder(op, kindMap);
- if (funcName.endswith("Integer4")) {
- type = mlir::IntegerType::get(builder.getContext(), 32);
- } else if (funcName.endswith("Real8")) {
- type = mlir::FloatType::getF64(builder.getContext());
- } else {
- return;
- }
- auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
- return genNoneBoxType(builder, type);
- };
- mlir::func::FuncOp newFunc = getOrCreateFunction(
- builder, funcName, typeGenerator, genFortranASumBody);
- auto newCall = builder.create<fir::CallOp>(
- loc, newFunc, mlir::ValueRange{args[0]});
- call->replaceAllUsesWith(newCall.getResults());
- call->dropAllReferences();
- call->erase();
- }
-
+ simplifyReduction(call, kindMap, genFortranASumBody);
return;
}
if (funcName.startswith("_FortranADotProduct")) {
return;
}
if (funcName.startswith("_FortranAMaxval")) {
- mlir::Operation::operand_range args = call.getArgs();
- const mlir::Value &dim = args[3];
- const mlir::Value &mask = args[4];
- // dim is zero when it is absent, which is an implementation
- // detail in the runtime library.
- bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
- unsigned rank = getDimCount(args[0]);
- if (dimAndMaskAbsent && rank == 1) {
- mlir::Location loc = call.getLoc();
- mlir::Type type;
- fir::FirOpBuilder builder(op, kindMap);
- if (funcName.endswith("Integer4")) {
- type = mlir::IntegerType::get(builder.getContext(), 32);
- } else if (funcName.endswith("Real8")) {
- type = mlir::FloatType::getF64(builder.getContext());
- } else {
- return;
- }
- auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
- return genNoneBoxType(builder, type);
- };
- mlir::func::FuncOp newFunc = getOrCreateFunction(
- builder, funcName, typeGenerator, genFortranAMaxvalBody);
- auto newCall = builder.create<fir::CallOp>(
- loc, newFunc, mlir::ValueRange{args[0]});
- call->replaceAllUsesWith(newCall.getResults());
- call->dropAllReferences();
- call->erase();
- return;
- }
+ simplifyReduction(call, kindMap, genFortranAMaxvalBody);
+ return;
}
}
}