#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/expression.h"
#include "flang/Semantics/tools.h"
-#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "llvm/Frontend/OpenACC/ACC.h.inc"
// Special value for * passed in device_type or gang clauses.
}
}
+/// Return the corresponding enum value for the mlir::acc::ReductionOperator
+/// from the parser representation.
+static mlir::acc::ReductionOperator
+getReductionOperator(const Fortran::parser::AccReductionOperator &op) {
+ switch (op.v) {
+ case Fortran::parser::AccReductionOperator::Operator::Plus:
+ return mlir::acc::ReductionOperator::AccAdd;
+ case Fortran::parser::AccReductionOperator::Operator::Multiply:
+ return mlir::acc::ReductionOperator::AccMul;
+ case Fortran::parser::AccReductionOperator::Operator::Max:
+ return mlir::acc::ReductionOperator::AccMax;
+ case Fortran::parser::AccReductionOperator::Operator::Min:
+ return mlir::acc::ReductionOperator::AccMin;
+ case Fortran::parser::AccReductionOperator::Operator::Iand:
+ return mlir::acc::ReductionOperator::AccIand;
+ case Fortran::parser::AccReductionOperator::Operator::Ior:
+ return mlir::acc::ReductionOperator::AccIor;
+ case Fortran::parser::AccReductionOperator::Operator::Ieor:
+ return mlir::acc::ReductionOperator::AccXor;
+ case Fortran::parser::AccReductionOperator::Operator::And:
+ return mlir::acc::ReductionOperator::AccLand;
+ case Fortran::parser::AccReductionOperator::Operator::Or:
+ return mlir::acc::ReductionOperator::AccLor;
+ case Fortran::parser::AccReductionOperator::Operator::Eqv:
+ return mlir::acc::ReductionOperator::AccEqv;
+ case Fortran::parser::AccReductionOperator::Operator::Neqv:
+ return mlir::acc::ReductionOperator::AccNeqv;
+ }
+ llvm_unreachable("unexpected reduction operator");
+}
+
+static mlir::Value genReductionInitValue(mlir::OpBuilder &builder,
+ mlir::Location loc, mlir::Type ty,
+ mlir::acc::ReductionOperator op) {
+ if (op != mlir::acc::ReductionOperator::AccAdd)
+ TODO(loc, "reduction operator");
+
+ unsigned initValue = 0;
+
+ if (ty.isIntOrIndex())
+ return builder.create<mlir::arith::ConstantOp>(
+ loc, ty, builder.getIntegerAttr(ty, initValue));
+ if (mlir::isa<mlir::FloatType>(ty))
+ return builder.create<mlir::arith::ConstantOp>(
+ loc, ty, builder.getFloatAttr(ty, initValue));
+ TODO(loc, "reduction type");
+}
+
+static mlir::Value genCombiner(mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::acc::ReductionOperator op, mlir::Type ty,
+ mlir::Value value1, mlir::Value value2) {
+ if (op == mlir::acc::ReductionOperator::AccAdd) {
+ if (ty.isIntOrIndex())
+ return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
+ if (mlir::isa<mlir::FloatType>(ty))
+ return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
+ TODO(loc, "reduction add type");
+ }
+ TODO(loc, "reduction operator");
+}
+
+mlir::acc::ReductionRecipeOp Fortran::lower::createOrGetReductionRecipe(
+ mlir::OpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc,
+ mlir::Type ty, mlir::acc::ReductionOperator op) {
+ mlir::ModuleOp mod =
+ builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
+ if (auto recipe = mod.lookupSymbol<mlir::acc::ReductionRecipeOp>(recipeName))
+ return recipe;
+
+ auto crtPos = builder.saveInsertionPoint();
+ mlir::OpBuilder modBuilder(mod.getBodyRegion());
+ auto recipe =
+ modBuilder.create<mlir::acc::ReductionRecipeOp>(loc, recipeName, ty, op);
+ builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(),
+ {ty}, {loc});
+ builder.setInsertionPointToEnd(&recipe.getInitRegion().back());
+ mlir::Value initValue = genReductionInitValue(builder, loc, ty, op);
+ builder.create<mlir::acc::YieldOp>(loc, initValue);
+
+ builder.createBlock(&recipe.getCombinerRegion(),
+ recipe.getCombinerRegion().end(), {ty, ty}, {loc, loc});
+ builder.setInsertionPointToEnd(&recipe.getCombinerRegion().back());
+ mlir::Value v1 = recipe.getCombinerRegion().front().getArgument(0);
+ mlir::Value v2 = recipe.getCombinerRegion().front().getArgument(1);
+ mlir::Value combinedValue = genCombiner(builder, loc, op, ty, v1, v2);
+ builder.create<mlir::acc::YieldOp>(loc, combinedValue);
+ builder.restoreInsertionPoint(crtPos);
+ return recipe;
+}
+
+static void
+genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semanticsContext,
+ Fortran::lower::StatementContext &stmtCtx,
+ llvm::SmallVectorImpl<mlir::Value> &reductionOperands,
+ llvm::SmallVector<mlir::Attribute> &reductionRecipes) {
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ const auto &objects = std::get<Fortran::parser::AccObjectList>(objectList.t);
+ const auto &op =
+ std::get<Fortran::parser::AccReductionOperator>(objectList.t);
+ mlir::acc::ReductionOperator mlirOp = getReductionOperator(op);
+ for (const auto &accObject : objects.v) {
+ llvm::SmallVector<mlir::Value> bounds;
+ std::stringstream asFortran;
+ mlir::Location operandLocation = genOperandLocation(converter, accObject);
+ mlir::Value baseAddr = gatherDataOperandAddrAndBounds(
+ converter, builder, semanticsContext, stmtCtx, accObject,
+ operandLocation, asFortran, bounds);
+
+ if (!fir::isa_trivial(fir::unwrapRefType(baseAddr.getType())))
+ TODO(operandLocation, "reduction with unsupported type");
+
+ mlir::Type ty = fir::unwrapRefType(baseAddr.getType());
+ std::string recipeName = fir::getTypeAsString(
+ ty, converter.getKindMap(),
+ ("reduction_" + stringifyReductionOperator(mlirOp)).str());
+ mlir::acc::ReductionRecipeOp recipe =
+ Fortran::lower::createOrGetReductionRecipe(builder, recipeName,
+ operandLocation, ty, mlirOp);
+ reductionRecipes.push_back(mlir::SymbolRefAttr::get(
+ builder.getContext(), recipe.getSymName().str()));
+ reductionOperands.push_back(baseAddr);
+ }
+}
+
static void
addOperands(llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<int32_t> &operandSegments,
mlir::Value gangStatic;
llvm::SmallVector<mlir::Value, 2> tileOperands, privateOperands,
reductionOperands;
- llvm::SmallVector<mlir::Attribute> privatizations;
+ llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;
bool hasGang = false, hasVector = false, hasWorker = false;
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
&clause.u)) {
genPrivatizations(privateClause->v, converter, semanticsContext, stmtCtx,
privateOperands, privatizations);
- } else if (std::get_if<Fortran::parser::AccClause::Reduction>(&clause.u)) {
- // Reduction clause is left out for the moment as the clause will probably
- // end up having its own operation.
- TODO(clauseLocation, "OpenACC compute construct reduction lowering");
+ } else if (const auto *reductionClause =
+ std::get_if<Fortran::parser::AccClause::Reduction>(
+ &clause.u)) {
+ genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
+ reductionOperands, reductionRecipes);
}
}
loopOp.setPrivatizationsAttr(
mlir::ArrayAttr::get(builder.getContext(), privatizations));
+ if (!reductionRecipes.empty())
+ loopOp.setReductionRecipesAttr(
+ mlir::ArrayAttr::get(builder.getContext(), reductionRecipes));
+
// Lower clauses mapped to attributes
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
if (const auto *collapseClause =