#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir
-#define RTNAME_STRINGIFY2(x) #x
-#define RTNAME_STRINGIFY(x) RTNAME_STRINGIFY2(x)
-#define RTNAME_STRING(x) RTNAME_STRINGIFY(RTNAME(x))
-
#define DEBUG_TYPE "flang-simplify-intrinsics"
namespace {
using FunctionBodyGeneratorTy =
llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
using GenReductionBodyTy = llvm::function_ref<void(
- fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp)>;
+ fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank)>;
public:
/// Generate a new function implementing a simplified version
/// the reduction value
/// \p genBody is called to fill in the actual reduciton operation
/// for example add for SUM, MAX for MAXVAL, etc.
+/// \p rank is the rank of the input argument.
static void genReductionLoop(fir::FirOpBuilder &builder,
mlir::func::FuncOp &funcOp,
InitValGeneratorTy initVal,
- BodyOpGeneratorTy genBody) {
+ BodyOpGeneratorTy genBody, unsigned rank) {
auto loc = mlir::UnknownLoc::get(builder.getContext());
mlir::Type elementType = funcOp.getResultTypes()[0];
builder.setInsertionPointToEnd(funcOp.addEntryBlock());
mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
- fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
+ fir::SequenceType::Shape flatShape(rank,
+ fir::SequenceType::getUnknownExtent());
mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
mlir::Type boxArrTy = fir::BoxType::get(arrTy);
mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
- auto dims =
- builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, zeroIdx);
- mlir::Value len = dims.getResult(1);
+ mlir::Value init = initVal(builder, loc, elementType);
+
+ llvm::SmallVector<mlir::Value, 15> bounds;
+
+ assert(rank > 0 && "rank cannot be zero");
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
- mlir::Value step = one;
- // We use C indexing here, so len-1 as loopcount
- mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
- mlir::Value init = initVal(builder, loc, elementType);
- auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
- /*unordered=*/false,
- /*finalCountValue=*/false, init);
- mlir::Value reductionVal = loop.getRegionIterArgs()[0];
+ // Compute all the upper bounds before the loop nest.
+ // It is not strictly necessary for performance, since the loop nest
+ // does not have any store operations and any LICM optimization
+ // should be able to optimize the redundancy.
+ for (unsigned i = 0; i < rank; ++i) {
+ mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
+ auto dims =
+ builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
+ mlir::Value len = dims.getResult(1);
+ // We use C indexing here, so len-1 as loopcount
+ mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
+ bounds.push_back(loopCount);
+ }
- // Begin loop code
- mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
- builder.setInsertionPointToStart(loop.getBody());
+ // Create a loop nest consisting of DoLoopOp operations.
+ // Collect the loops' induction variables into indices array,
+ // which will be used in the innermost loop to load the input
+ // array's element.
+ // The loops are generated such that the innermost loop processes
+ // the 0 dimension.
+ llvm::SmallVector<mlir::Value, 15> indices;
+ for (unsigned i = rank; 0 < i; --i) {
+ mlir::Value step = one;
+ mlir::Value loopCount = bounds[i - 1];
+ auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
+ /*unordered=*/false,
+ /*finalCountValue=*/false, init);
+ init = loop.getRegionIterArgs()[0];
+ indices.push_back(loop.getInductionVar());
+ // Set insertion point to the loop body so that the next loop
+ // is inserted inside the current one.
+ builder.setInsertionPointToStart(loop.getBody());
+ }
+
+ // Reverse the indices such that they are ordered as:
+ // <dim-0-idx, dim-1-idx, ...>
+ std::reverse(indices.begin(), indices.end());
+ // We are in the innermost loop: generate the reduction body.
mlir::Type eleRefTy = builder.getRefType(elementType);
- mlir::Value index = loop.getInductionVar();
mlir::Value addr =
- builder.create<fir::CoordinateOp>(loc, eleRefTy, array, index);
+ builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
- reductionVal = genBody(builder, loc, elementType, elem, reductionVal);
-
- builder.create<fir::ResultOp>(loc, reductionVal);
- // End of loop.
- builder.restoreInsertionPoint(loopEndPt);
+ mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
+
+ // Unwind the loop nest and insert ResultOp on each level
+ // to return the updated value of the reduction to the enclosing
+ // loops.
+ for (unsigned i = 0; i < rank; ++i) {
+ auto result = builder.create<fir::ResultOp>(loc, reductionVal);
+ // Proceed to the outer loop.
+ auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
+ reductionVal = loop.getResult(0);
+ // Set insertion point after the loop operation that we have
+ // just processed.
+ builder.setInsertionPointAfter(loop.getOperation());
+ }
- mlir::Value resultVal = loop.getResult(0);
- builder.create<mlir::func::ReturnOp>(loc, resultVal);
+ // End of loop nest. The insertion point is after the outermost loop.
+ // Return the reduction value from the function.
+ builder.create<mlir::func::ReturnOp>(loc, reductionVal);
}
/// Generate function body of the simplified version of RTNAME(Sum)
/// with signature provided by \p funcOp. The caller is responsible
/// for saving/restoring the original insertion point of \p builder.
/// \p funcOp is expected to be empty on entry to this function.
+/// \p rank specifies the rank of the input argument.
static void genRuntimeSumBody(fir::FirOpBuilder &builder,
- mlir::func::FuncOp &funcOp) {
- // function RTNAME(Sum)<T>_simplified(arr)
+ mlir::func::FuncOp &funcOp, unsigned rank) {
+ // function RTNAME(Sum)<T>x<rank>_simplified(arr)
// T, dimension(:) :: arr
// T sum = 0
// integer iter
// do iter = 0, extent(arr)
// sum = sum + arr[iter]
// end do
- // RTNAME(Sum)<T>_simplified = sum
- // end function RTNAME(Sum)<T>_simplified
+ // RTNAME(Sum)<T>x<rank>_simplified = sum
+ // end function RTNAME(Sum)<T>x<rank>_simplified
auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType) {
- return elementType.isa<mlir::FloatType>()
- ? builder.createRealConstant(loc, elementType,
- llvm::APFloat(0.0))
- : builder.createIntegerConstant(loc, elementType, 0);
+ if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ return builder.createRealConstant(loc, elementType,
+ llvm::APFloat::getZero(sem));
+ }
+ return builder.createIntegerConstant(loc, elementType, 0);
};
auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
return {};
};
- genReductionLoop(builder, funcOp, zero, genBodyOp);
+ genReductionLoop(builder, funcOp, zero, genBodyOp, rank);
}
static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
- mlir::func::FuncOp &funcOp) {
+ mlir::func::FuncOp &funcOp, unsigned rank) {
auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType) {
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
llvm_unreachable("unsupported type");
return {};
};
- genReductionLoop(builder, funcOp, init, genBodyOp);
+ genReductionLoop(builder, funcOp, init, genBodyOp, rank);
}
/// Generate function type for the simplified version of RTNAME(DotProduct)
return false;
}
-static mlir::Value findShape(mlir::Value val) {
+static mlir::Value findBoxDef(mlir::Value val) {
if (auto op = expectConvertOp(val)) {
assert(op->getOperands().size() != 0);
if (auto box = mlir::dyn_cast_or_null<fir::EmboxOp>(
op->getOperand(0).getDefiningOp()))
- return box.getShape();
+ return box.getResult();
+ if (auto box = mlir::dyn_cast_or_null<fir::ReboxOp>(
+ op->getOperand(0).getDefiningOp()))
+ return box.getResult();
}
return {};
}
static unsigned getDimCount(mlir::Value val) {
- if (mlir::Value shapeVal = findShape(val)) {
- mlir::Type resType = shapeVal.getDefiningOp()->getResultTypes()[0];
- return fir::getRankOfShapeType(resType);
- }
+ // In order to find the dimensions count, we look for EmboxOp/ReboxOp
+ // and take the count from its *result* type. Note that in case
+ // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
+ // have different types.
+ // Actually, we can take the box type from the operand of
+ // the first ConvertOp that has non-opaque box type that we meet
+ // going through the ConvertOp chain.
+ if (mlir::Value emboxVal = findBoxDef(val))
+ if (auto boxTy = emboxVal.getType().dyn_cast<fir::BoxType>())
+ if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>())
+ return seqTy.getDimension();
return 0;
}
const fir::KindMapping &kindMap,
GenReductionBodyTy genBodyFunc) {
mlir::SymbolRefAttr callee = call.getCalleeAttr();
- mlir::StringRef funcName = callee.getLeafReference().getValue();
mlir::Operation::operand_range args = call.getArgs();
// args[1] and args[2] are source filename and line number, ignored.
const mlir::Value &dim = args[3];
// detail in the runtime library.
bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
unsigned rank = getDimCount(args[0]);
- if (dimAndMaskAbsent && rank == 1) {
+ if (dimAndMaskAbsent && rank > 0) {
mlir::Location loc = call.getLoc();
- mlir::Type type;
fir::FirOpBuilder builder(call, kindMap);
- if (funcName.endswith("Integer4")) {
- type = mlir::IntegerType::get(builder.getContext(), 32);
- } else if (funcName.endswith("Real8")) {
- type = mlir::FloatType::getF64(builder.getContext());
- } else {
+
+ // Support only floating point and integer results now.
+ mlir::Type resultType = call.getResult(0).getType();
+ if (!resultType.isa<mlir::FloatType>() &&
+ !resultType.isa<mlir::IntegerType>())
return;
- }
- auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
- return genNoneBoxType(builder, type);
+
+ auto argType = getArgElementType(args[0]);
+ if (!argType)
+ return;
+ assert(*argType == resultType &&
+ "Argument/result types mismatch in reduction");
+
+ auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
+ return genNoneBoxType(builder, resultType);
+ };
+ auto bodyGenerator = [&rank, &genBodyFunc](fir::FirOpBuilder &builder,
+ mlir::func::FuncOp &funcOp) {
+ genBodyFunc(builder, funcOp, rank);
};
+ // Mangle the function name with the rank value as "x<rank>".
+ std::string funcName =
+ (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
+ mlir::Twine{rank})
+ .str();
mlir::func::FuncOp newFunc =
- getOrCreateFunction(builder, funcName, typeGenerator, genBodyFunc);
+ getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
auto newCall =
builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
call->replaceAllUsesWith(newCall.getResults());