/// and small in size.
//===----------------------------------------------------------------------===//
-#include "PassDetail.h"
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Support/FIRContext.h"
#include "flang/Optimizer/Transforms/Passes.h"
+#include "flang/Runtime/entry-names.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
+namespace fir {
+#define GEN_PASS_DEF_SIMPLIFYINTRINSICS
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
#define DEBUG_TYPE "flang-simplify-intrinsics"
namespace {
class SimplifyIntrinsicsPass
- : public fir::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
+ : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
using FunctionTypeGeneratorTy =
- std::function<mlir::FunctionType(fir::FirOpBuilder &)>;
+ llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
using FunctionBodyGeneratorTy =
- std::function<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
+ llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
+ using GenReductionBodyTy = llvm::function_ref<void(
+ fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank)>;
public:
/// Generate a new function implementing a simplified version
FunctionBodyGeneratorTy bodyGenerator);
void runOnOperation() override;
void getDependentDialects(mlir::DialectRegistry ®istry) const override;
+
+private:
+ /// Helper function to replace a reduction type of call with its
+ /// simplified form. The actual function is generated using a callback
+ /// function.
+ /// \p call is the call to be replaced
+ /// \p kindMap is used to create FIROpBuilder
+ /// \p genBodyFunc is the callback that builds the replacement function
+ void simplifyReduction(fir::CallOp call, const fir::KindMapping &kindMap,
+ GenReductionBodyTy genBodyFunc);
};
} // namespace
-/// Generate function type for the simplified version of FortranASum and
+/// Generate function type for the simplified version of RTNAME(Sum) and
/// similar functions with a fir.box<none> type returning \p elementType.
static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
const mlir::Type &elementType) {
{elementType});
}
-using BodyOpGeneratorTy =
- std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location,
- const mlir::Type &, mlir::Value, mlir::Value)>;
-using InitValGeneratorTy = std::function<mlir::Value(
+using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
+ fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
+ mlir::Value)>;
+using InitValGeneratorTy = llvm::function_ref<mlir::Value(
fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>;
/// Generate the reduction loop into \p funcOp.
/// the reduction value
/// \p genBody is called to fill in the actual reduciton operation
/// for example add for SUM, MAX for MAXVAL, etc.
+/// \p rank is the rank of the input argument.
static void genReductionLoop(fir::FirOpBuilder &builder,
mlir::func::FuncOp &funcOp,
InitValGeneratorTy initVal,
- BodyOpGeneratorTy genBody) {
+ BodyOpGeneratorTy genBody, unsigned rank) {
auto loc = mlir::UnknownLoc::get(builder.getContext());
mlir::Type elementType = funcOp.getResultTypes()[0];
builder.setInsertionPointToEnd(funcOp.addEntryBlock());
mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
- fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
+ fir::SequenceType::Shape flatShape(rank,
+ fir::SequenceType::getUnknownExtent());
mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
mlir::Type boxArrTy = fir::BoxType::get(arrTy);
mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
- auto dims =
- builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, zeroIdx);
- mlir::Value len = dims.getResult(1);
+ mlir::Value init = initVal(builder, loc, elementType);
+
+ llvm::SmallVector<mlir::Value, 15> bounds;
+
+ assert(rank > 0 && "rank cannot be zero");
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
- mlir::Value step = one;
- // We use C indexing here, so len-1 as loopcount
- mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
- mlir::Value init = initVal(builder, loc, elementType);
- auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
- /*unordered=*/false,
- /*finalCountValue=*/false, init);
- mlir::Value reductionVal = loop.getRegionIterArgs()[0];
+ // Compute all the upper bounds before the loop nest.
+ // It is not strictly necessary for performance, since the loop nest
+ // does not have any store operations and any LICM optimization
+ // should be able to optimize the redundancy.
+ for (unsigned i = 0; i < rank; ++i) {
+ mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
+ auto dims =
+ builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
+ mlir::Value len = dims.getResult(1);
+ // We use C indexing here, so len-1 as loopcount
+ mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
+ bounds.push_back(loopCount);
+ }
- // Begin loop code
- mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
- builder.setInsertionPointToStart(loop.getBody());
+ // Create a loop nest consisting of DoLoopOp operations.
+ // Collect the loops' induction variables into indices array,
+ // which will be used in the innermost loop to load the input
+ // array's element.
+ // The loops are generated such that the innermost loop processes
+ // the 0 dimension.
+ llvm::SmallVector<mlir::Value, 15> indices;
+ for (unsigned i = rank; 0 < i; --i) {
+ mlir::Value step = one;
+ mlir::Value loopCount = bounds[i - 1];
+ auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
+ /*unordered=*/false,
+ /*finalCountValue=*/false, init);
+ init = loop.getRegionIterArgs()[0];
+ indices.push_back(loop.getInductionVar());
+ // Set insertion point to the loop body so that the next loop
+ // is inserted inside the current one.
+ builder.setInsertionPointToStart(loop.getBody());
+ }
+
+ // Reverse the indices such that they are ordered as:
+ // <dim-0-idx, dim-1-idx, ...>
+ std::reverse(indices.begin(), indices.end());
+ // We are in the innermost loop: generate the reduction body.
mlir::Type eleRefTy = builder.getRefType(elementType);
- mlir::Value index = loop.getInductionVar();
mlir::Value addr =
- builder.create<fir::CoordinateOp>(loc, eleRefTy, array, index);
+ builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
- reductionVal = genBody(builder, loc, elementType, elem, reductionVal);
-
- builder.create<fir::ResultOp>(loc, reductionVal);
- // End of loop.
- builder.restoreInsertionPoint(loopEndPt);
+ mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
+
+ // Unwind the loop nest and insert ResultOp on each level
+ // to return the updated value of the reduction to the enclosing
+ // loops.
+ for (unsigned i = 0; i < rank; ++i) {
+ auto result = builder.create<fir::ResultOp>(loc, reductionVal);
+ // Proceed to the outer loop.
+ auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
+ reductionVal = loop.getResult(0);
+ // Set insertion point after the loop operation that we have
+ // just processed.
+ builder.setInsertionPointAfter(loop.getOperation());
+ }
- mlir::Value resultVal = loop.getResult(0);
- builder.create<mlir::func::ReturnOp>(loc, resultVal);
+ // End of loop nest. The insertion point is after the outermost loop.
+ // Return the reduction value from the function.
+ builder.create<mlir::func::ReturnOp>(loc, reductionVal);
}
-/// Generate function body of the simplified version of FortranASum
+/// Generate function body of the simplified version of RTNAME(Sum)
/// with signature provided by \p funcOp. The caller is responsible
/// for saving/restoring the original insertion point of \p builder.
/// \p funcOp is expected to be empty on entry to this function.
-static void genFortranASumBody(fir::FirOpBuilder &builder,
- mlir::func::FuncOp &funcOp) {
- // function FortranASum<T>_simplified(arr)
+/// \p rank specifies the rank of the input argument.
+static void genRuntimeSumBody(fir::FirOpBuilder &builder,
+ mlir::func::FuncOp &funcOp, unsigned rank) {
+ // function RTNAME(Sum)<T>x<rank>_simplified(arr)
// T, dimension(:) :: arr
// T sum = 0
// integer iter
// do iter = 0, extent(arr)
// sum = sum + arr[iter]
// end do
- // FortranASum<T>_simplified = sum
- // end function FortranASum<T>_simplified
+ // RTNAME(Sum)<T>x<rank>_simplified = sum
+ // end function RTNAME(Sum)<T>x<rank>_simplified
auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType) {
- return elementType.isa<mlir::FloatType>()
- ? builder.createRealConstant(loc, elementType,
- llvm::APFloat(0.0))
- : builder.createIntegerConstant(loc, elementType, 0);
+ if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ return builder.createRealConstant(loc, elementType,
+ llvm::APFloat::getZero(sem));
+ }
+ return builder.createIntegerConstant(loc, elementType, 0);
};
auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
return {};
};
- genReductionLoop(builder, funcOp, zero, genBodyOp);
+ genReductionLoop(builder, funcOp, zero, genBodyOp, rank);
}
-static void genFortranAMaxvalBody(fir::FirOpBuilder &builder,
- mlir::func::FuncOp &funcOp) {
+static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
+ mlir::func::FuncOp &funcOp, unsigned rank) {
auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType) {
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
llvm_unreachable("unsupported type");
return {};
};
- genReductionLoop(builder, funcOp, init, genBodyOp);
+ genReductionLoop(builder, funcOp, init, genBodyOp, rank);
}
-/// Generate function type for the simplified version of FortranADotProduct
+/// Generate function type for the simplified version of RTNAME(DotProduct)
/// operating on the given \p elementType.
-static mlir::FunctionType genFortranADotType(fir::FirOpBuilder &builder,
- const mlir::Type &elementType) {
+static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder,
+ const mlir::Type &elementType) {
mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
{elementType});
}
-/// Generate function body of the simplified version of FortranADotProduct
+/// Generate function body of the simplified version of RTNAME(DotProduct)
/// with signature provided by \p funcOp. The caller is responsible
/// for saving/restoring the original insertion point of \p builder.
/// \p funcOp is expected to be empty on entry to this function.
/// \p arg1ElementTy and \p arg2ElementTy specify elements types
/// of the underlying array objects - they are used to generate proper
/// element accesses.
-static void genFortranADotBody(fir::FirOpBuilder &builder,
- mlir::func::FuncOp &funcOp,
- mlir::Type arg1ElementTy,
- mlir::Type arg2ElementTy) {
- // function FortranADotProduct<T>_simplified(arr1, arr2)
+static void genRuntimeDotBody(fir::FirOpBuilder &builder,
+ mlir::func::FuncOp &funcOp,
+ mlir::Type arg1ElementTy,
+ mlir::Type arg2ElementTy) {
+ // function RTNAME(DotProduct)<T>_simplified(arr1, arr2)
// T, dimension(:) :: arr1, arr2
// T product = 0
// integer iter
// do iter = 0, extent(arr1)
// product = product + arr1[iter] * arr2[iter]
// end do
- // FortranADotProduct<T>_simplified = product
- // end function FortranADotProduct<T>_simplified
+ // RTNAME(ADotProduct)<T>_simplified = product
+ // end function RTNAME(DotProduct)<T>_simplified
auto loc = mlir::UnknownLoc::get(builder.getContext());
mlir::Type resultElementType = funcOp.getResultTypes()[0];
builder.setInsertionPointToEnd(funcOp.addEntryBlock());
return false;
}
-static mlir::Value findShape(mlir::Value val) {
+static mlir::Value findBoxDef(mlir::Value val) {
if (auto op = expectConvertOp(val)) {
assert(op->getOperands().size() != 0);
if (auto box = mlir::dyn_cast_or_null<fir::EmboxOp>(
op->getOperand(0).getDefiningOp()))
- return box.getShape();
+ return box.getResult();
+ if (auto box = mlir::dyn_cast_or_null<fir::ReboxOp>(
+ op->getOperand(0).getDefiningOp()))
+ return box.getResult();
}
return {};
}
static unsigned getDimCount(mlir::Value val) {
- if (mlir::Value shapeVal = findShape(val)) {
- mlir::Type resType = shapeVal.getDefiningOp()->getResultTypes()[0];
- return fir::getRankOfShapeType(resType);
- }
+ // In order to find the dimensions count, we look for EmboxOp/ReboxOp
+ // and take the count from its *result* type. Note that in case
+ // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
+ // have different types.
+ // Actually, we can take the box type from the operand of
+ // the first ConvertOp that has non-opaque box type that we meet
+ // going through the ConvertOp chain.
+ if (mlir::Value emboxVal = findBoxDef(val))
+ if (auto boxTy = emboxVal.getType().dyn_cast<fir::BoxType>())
+ if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>())
+ return seqTy.getDimension();
return 0;
}
} while (true);
}
+void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
+ const fir::KindMapping &kindMap,
+ GenReductionBodyTy genBodyFunc) {
+ mlir::SymbolRefAttr callee = call.getCalleeAttr();
+ mlir::Operation::operand_range args = call.getArgs();
+ // args[1] and args[2] are source filename and line number, ignored.
+ const mlir::Value &dim = args[3];
+ const mlir::Value &mask = args[4];
+ // dim is zero when it is absent, which is an implementation
+ // detail in the runtime library.
+ bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
+ unsigned rank = getDimCount(args[0]);
+ if (dimAndMaskAbsent && rank > 0) {
+ mlir::Location loc = call.getLoc();
+ fir::FirOpBuilder builder(call, kindMap);
+
+ // Support only floating point and integer results now.
+ mlir::Type resultType = call.getResult(0).getType();
+ if (!resultType.isa<mlir::FloatType>() &&
+ !resultType.isa<mlir::IntegerType>())
+ return;
+
+ auto argType = getArgElementType(args[0]);
+ if (!argType)
+ return;
+ assert(*argType == resultType &&
+ "Argument/result types mismatch in reduction");
+
+ auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
+ return genNoneBoxType(builder, resultType);
+ };
+ auto bodyGenerator = [&rank, &genBodyFunc](fir::FirOpBuilder &builder,
+ mlir::func::FuncOp &funcOp) {
+ genBodyFunc(builder, funcOp, rank);
+ };
+ // Mangle the function name with the rank value as "x<rank>".
+ std::string funcName =
+ (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
+ mlir::Twine{rank})
+ .str();
+ mlir::func::FuncOp newFunc =
+ getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
+ auto newCall =
+ builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
+ call->replaceAllUsesWith(newCall.getResults());
+ call->dropAllReferences();
+ call->erase();
+ }
+}
+
void SimplifyIntrinsicsPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
mlir::ModuleOp module = getOperation();
// RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
// int dim, const Descriptor *mask)
//
- if (funcName.startswith("_FortranASum")) {
- mlir::Operation::operand_range args = call.getArgs();
- // args[1] and args[2] are source filename and line number, ignored.
- const mlir::Value &dim = args[3];
- const mlir::Value &mask = args[4];
- // dim is zero when it is absent, which is an implementation
- // detail in the runtime library.
- bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
- unsigned rank = getDimCount(args[0]);
- if (dimAndMaskAbsent && rank == 1) {
- mlir::Location loc = call.getLoc();
- mlir::Type type;
- fir::FirOpBuilder builder(op, kindMap);
- if (funcName.endswith("Integer4")) {
- type = mlir::IntegerType::get(builder.getContext(), 32);
- } else if (funcName.endswith("Real8")) {
- type = mlir::FloatType::getF64(builder.getContext());
- } else {
- return;
- }
- auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
- return genNoneBoxType(builder, type);
- };
- mlir::func::FuncOp newFunc = getOrCreateFunction(
- builder, funcName, typeGenerator, genFortranASumBody);
- auto newCall = builder.create<fir::CallOp>(
- loc, newFunc, mlir::ValueRange{args[0]});
- call->replaceAllUsesWith(newCall.getResults());
- call->dropAllReferences();
- call->erase();
- }
-
+ if (funcName.startswith(RTNAME_STRING(Sum))) {
+ simplifyReduction(call, kindMap, genRuntimeSumBody);
return;
}
- if (funcName.startswith("_FortranADotProduct")) {
+ if (funcName.startswith(RTNAME_STRING(DotProduct))) {
LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
llvm::dbgs() << "\n");
return;
auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
- return genFortranADotType(builder, type);
+ return genRuntimeDotType(builder, type);
};
auto bodyGenerator = [&arg1Type,
&arg2Type](fir::FirOpBuilder &builder,
mlir::func::FuncOp &funcOp) {
- genFortranADotBody(builder, funcOp, *arg1Type, *arg2Type);
+ genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type);
};
// Suffix the function name with the element types
llvm::dbgs() << "\n");
return;
}
- if (funcName.startswith("_FortranAMaxval")) {
- mlir::Operation::operand_range args = call.getArgs();
- const mlir::Value &dim = args[3];
- const mlir::Value &mask = args[4];
- // dim is zero when it is absent, which is an implementation
- // detail in the runtime library.
- bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
- unsigned rank = getDimCount(args[0]);
- if (dimAndMaskAbsent && rank == 1) {
- mlir::Location loc = call.getLoc();
- mlir::Type type;
- fir::FirOpBuilder builder(op, kindMap);
- if (funcName.endswith("Integer4")) {
- type = mlir::IntegerType::get(builder.getContext(), 32);
- } else if (funcName.endswith("Real8")) {
- type = mlir::FloatType::getF64(builder.getContext());
- } else {
- return;
- }
- auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
- return genNoneBoxType(builder, type);
- };
- mlir::func::FuncOp newFunc = getOrCreateFunction(
- builder, funcName, typeGenerator, genFortranAMaxvalBody);
- auto newCall = builder.create<fir::CallOp>(
- loc, newFunc, mlir::ValueRange{args[0]});
- call->replaceAllUsesWith(newCall.getResults());
- call->dropAllReferences();
- call->erase();
- return;
- }
+ if (funcName.startswith(RTNAME_STRING(Maxval))) {
+ simplifyReduction(call, kindMap, genRuntimeMaxvalBody);
+ return;
}
}
}