From 80dcc907a8a2de9c0e24d1d40625c54c828d508a Mon Sep 17 00:00:00 2001 From: Slava Zakharin Date: Tue, 9 Aug 2022 16:37:44 -0700 Subject: [PATCH] [NFC] Restructured SimplifyIntrinsicsPass::getOrCreateFunction. I would like to add DOT_PRODUCT support in this pass, so this restructuring is the first step to allow some code reuse inside getOrCreateFunction(). Differential Revision: https://reviews.llvm.org/D131530 --- .../Optimizer/Transforms/SimplifyIntrinsics.cpp | 129 +++++++++++++-------- 1 file changed, 83 insertions(+), 46 deletions(-) diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp index cc30694..ff8f4cf 100644 --- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp @@ -40,35 +40,41 @@ namespace { class SimplifyIntrinsicsPass : public fir::SimplifyIntrinsicsBase { + using FunctionTypeGeneratorTy = + std::function; + using FunctionBodyGeneratorTy = + std::function; + public: - mlir::func::FuncOp getOrCreateFunction(const mlir::Location &loc, - fir::FirOpBuilder &builder, - const mlir::Type &type, - const mlir::StringRef &basename); + /// Generate a new function implementing a simplified version + /// of a Fortran runtime function defined by \p basename name. + /// \p typeGenerator is a callback that generates the new function's type. + /// \p bodyGenerator is a callback that generates the new function's body. + /// The new function is created in the \p builder's Module. + mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder, + const mlir::StringRef &basename, + FunctionTypeGeneratorTy typeGenerator, + FunctionBodyGeneratorTy bodyGenerator); void runOnOperation() override; }; } // namespace -mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction( - const mlir::Location &loc, fir::FirOpBuilder &builder, - const mlir::Type &type, const mlir::StringRef &baseName) { - // In future, the idea is that instead of building the function inside - // this function, this does the base creation, and calls a callback - // function (e.g. a lambda function) that fills in the actual content. - // For now, check that it's the ONLY the SUM runtime call. - assert(baseName.startswith("_FortranASum")); - - std::string replacementName = mlir::Twine{baseName, "_simplified"}.str(); - mlir::ModuleOp module = builder.getModule(); - // If we already have a function, just return it. - mlir::func::FuncOp newFunc = - fir::FirOpBuilder::getNamedFunction(module, replacementName); - if (newFunc) - return newFunc; +/// Generate function type for the simplified version of FortranASum +/// operating on the given \p elementType. +static mlir::FunctionType genFortranASumType(fir::FirOpBuilder &builder, + const mlir::Type &elementType) { + mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); + return mlir::FunctionType::get(builder.getContext(), {boxType}, + {elementType}); +} - // Need to build the function! - // Basic idea: +/// Generate function body of the simplified version of FortranASum +/// with signature provided by \p funcOp. The caller is responsible +/// for saving/restoring the original insertion point of \p builder. +/// \p funcOp is expected to be empty on entry to this function. +static void genFortranASumBody(fir::FirOpBuilder &builder, + mlir::func::FuncOp &funcOp) { // function FortranASum_simplified(arr) // T, dimension(:) :: arr // T sum = 0 @@ -78,35 +84,25 @@ mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction( // end do // FortranASum_simplified = sum // end function FortranASum_simplified - mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); - mlir::FunctionType fType = - mlir::FunctionType::get(builder.getContext(), {boxType}, {type}); - newFunc = - fir::FirOpBuilder::createFunction(loc, module, replacementName, fType); - auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR; - auto linkage = - mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); - newFunc->setAttr("llvm.linkage", linkage); - - // Save the position of the original call. - mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint(); - builder.setInsertionPointToEnd(newFunc.addEntryBlock()); + auto loc = mlir::UnknownLoc::get(builder.getContext()); + mlir::Type elementType = funcOp.getResultTypes()[0]; + builder.setInsertionPointToEnd(funcOp.addEntryBlock()); mlir::IndexType idxTy = builder.getIndexType(); - mlir::Value zero = type.isa() - ? builder.createRealConstant(loc, type, 0.0) - : builder.createIntegerConstant(loc, type, 0); - mlir::Value sum = builder.create(loc, type); + mlir::Value zero = elementType.isa() + ? builder.createRealConstant(loc, elementType, 0.0) + : builder.createIntegerConstant(loc, elementType, 0); + mlir::Value sum = builder.create(loc, elementType); builder.create(loc, zero, sum); - mlir::Block::BlockArgListType args = newFunc.front().getArguments(); + mlir::Block::BlockArgListType args = funcOp.front().getArguments(); mlir::Value arg = args[0]; mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()}; - mlir::Type arrTy = fir::SequenceType::get(flatShape, type); + mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType); mlir::Type boxArrTy = fir::BoxType::get(arrTy); mlir::Value array = builder.create(loc, boxArrTy, arg); auto dims = @@ -123,7 +119,7 @@ mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction( mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint(); builder.setInsertionPointToStart(loop.getBody()); - mlir::Type eleRefTy = builder.getRefType(type); + mlir::Type eleRefTy = builder.getRefType(elementType); mlir::Value index = loop.getInductionVar(); mlir::Value addr = builder.create(loc, eleRefTy, array, index); @@ -131,9 +127,9 @@ mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction( mlir::Value sumVal = builder.create(loc, sum); mlir::Value res; - if (type.isa()) + if (elementType.isa()) res = builder.create(loc, elem, sumVal); - else if (type.isa()) + else if (elementType.isa()) res = builder.create(loc, elem, sumVal); else TODO(loc, "Unsupported type"); @@ -144,6 +140,44 @@ mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction( mlir::Value resultVal = builder.create(loc, sum); builder.create(loc, resultVal); +} + +mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction( + fir::FirOpBuilder &builder, const mlir::StringRef &baseName, + FunctionTypeGeneratorTy typeGenerator, + FunctionBodyGeneratorTy bodyGenerator) { + // WARNING: if the function generated here changes its signature + // or behavior (the body code), we should probably embed some + // versioning information into its name, otherwise libraries + // statically linked with older versions of Flang may stop + // working with object files created with newer Flang. + // We can also avoid this by using internal linkage, but + // this may increase the size of final executable/shared library. + std::string replacementName = mlir::Twine{baseName, "_simplified"}.str(); + mlir::ModuleOp module = builder.getModule(); + // If we already have a function, just return it. + mlir::func::FuncOp newFunc = + fir::FirOpBuilder::getNamedFunction(module, replacementName); + mlir::FunctionType fType = typeGenerator(builder); + if (newFunc) { + assert(newFunc.getFunctionType() == fType && + "type mismatch for simplified function"); + return newFunc; + } + + // Need to build the function! + auto loc = mlir::UnknownLoc::get(builder.getContext()); + newFunc = + fir::FirOpBuilder::createFunction(loc, module, replacementName, fType); + auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR; + auto linkage = + mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); + newFunc->setAttr("llvm.linkage", linkage); + + // Save the position of the original call. + mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint(); + + bodyGenerator(builder, newFunc); // Now back to where we were adding code earlier... builder.restoreInsertionPoint(insertPt); @@ -218,8 +252,11 @@ void SimplifyIntrinsicsPass::runOnOperation() { } else { return; } - mlir::func::FuncOp newFunc = - getOrCreateFunction(loc, builder, type, funcName); + auto typeGenerator = [&type](fir::FirOpBuilder &builder) { + return genFortranASumType(builder, type); + }; + mlir::func::FuncOp newFunc = getOrCreateFunction( + builder, funcName, typeGenerator, genFortranASumBody); auto newCall = builder.create( loc, newFunc, mlir::ValueRange{args[0]}); call->replaceAllUsesWith(newCall.getResults()); -- 2.7.4