From 59ceb7dd9a02f9c6a1342b3c282d1ddfa028ce34 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Thu, 1 Jun 2023 22:14:42 +0900 Subject: [PATCH] [flang][openacc] Initial reduction clause lowering Add initial support to lower reduction clause to its representation in MLIR. This patch adds support for addition of integer and real scalar types. Other operators and types will be added with follow up patches. Reviewed By: razvanlupusoru Differential Revision: https://reviews.llvm.org/D151564 --- flang/include/flang/Lower/OpenACC.h | 11 ++- flang/lib/Lower/OpenACC.cpp | 142 +++++++++++++++++++++++++++-- flang/test/Lower/OpenACC/acc-reduction.f90 | 51 +++++++++++ mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 6 +- 4 files changed, 198 insertions(+), 12 deletions(-) create mode 100644 flang/test/Lower/OpenACC/acc-reduction.f90 diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h index cd876ea..7546e84 100644 --- a/flang/include/flang/Lower/OpenACC.h +++ b/flang/include/flang/Lower/OpenACC.h @@ -13,6 +13,8 @@ #ifndef FORTRAN_LOWER_OPENACC_H #define FORTRAN_LOWER_OPENACC_H +#include "mlir/Dialect/OpenACC/OpenACC.h" + namespace llvm { class StringRef; } @@ -21,9 +23,6 @@ namespace mlir { class Location; class Type; class OpBuilder; -namespace acc { -class PrivateRecipeOp; -} } // namespace mlir namespace Fortran { @@ -57,6 +56,12 @@ mlir::acc::PrivateRecipeOp createOrGetPrivateRecipe(mlir::OpBuilder &, llvm::StringRef, mlir::Location, mlir::Type); +/// Get a acc.reduction.recipe op for the given type or create it if it does not +/// exist yet. +mlir::acc::ReductionRecipeOp +createOrGetReductionRecipe(mlir::OpBuilder &, llvm::StringRef, mlir::Location, + mlir::Type, mlir::acc::ReductionOperator); + } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 306b799..f3329876 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -22,7 +22,6 @@ #include "flang/Parser/parse-tree.h" #include "flang/Semantics/expression.h" #include "flang/Semantics/tools.h" -#include "mlir/Dialect/OpenACC/OpenACC.h" #include "llvm/Frontend/OpenACC/ACC.h.inc" // Special value for * passed in device_type or gang clauses. @@ -526,6 +525,132 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList, } } +/// Return the corresponding enum value for the mlir::acc::ReductionOperator +/// from the parser representation. +static mlir::acc::ReductionOperator +getReductionOperator(const Fortran::parser::AccReductionOperator &op) { + switch (op.v) { + case Fortran::parser::AccReductionOperator::Operator::Plus: + return mlir::acc::ReductionOperator::AccAdd; + case Fortran::parser::AccReductionOperator::Operator::Multiply: + return mlir::acc::ReductionOperator::AccMul; + case Fortran::parser::AccReductionOperator::Operator::Max: + return mlir::acc::ReductionOperator::AccMax; + case Fortran::parser::AccReductionOperator::Operator::Min: + return mlir::acc::ReductionOperator::AccMin; + case Fortran::parser::AccReductionOperator::Operator::Iand: + return mlir::acc::ReductionOperator::AccIand; + case Fortran::parser::AccReductionOperator::Operator::Ior: + return mlir::acc::ReductionOperator::AccIor; + case Fortran::parser::AccReductionOperator::Operator::Ieor: + return mlir::acc::ReductionOperator::AccXor; + case Fortran::parser::AccReductionOperator::Operator::And: + return mlir::acc::ReductionOperator::AccLand; + case Fortran::parser::AccReductionOperator::Operator::Or: + return mlir::acc::ReductionOperator::AccLor; + case Fortran::parser::AccReductionOperator::Operator::Eqv: + return mlir::acc::ReductionOperator::AccEqv; + case Fortran::parser::AccReductionOperator::Operator::Neqv: + return mlir::acc::ReductionOperator::AccNeqv; + } + llvm_unreachable("unexpected reduction operator"); +} + +static mlir::Value genReductionInitValue(mlir::OpBuilder &builder, + mlir::Location loc, mlir::Type ty, + mlir::acc::ReductionOperator op) { + if (op != mlir::acc::ReductionOperator::AccAdd) + TODO(loc, "reduction operator"); + + unsigned initValue = 0; + + if (ty.isIntOrIndex()) + return builder.create( + loc, ty, builder.getIntegerAttr(ty, initValue)); + if (mlir::isa(ty)) + return builder.create( + loc, ty, builder.getFloatAttr(ty, initValue)); + TODO(loc, "reduction type"); +} + +static mlir::Value genCombiner(mlir::OpBuilder &builder, mlir::Location loc, + mlir::acc::ReductionOperator op, mlir::Type ty, + mlir::Value value1, mlir::Value value2) { + if (op == mlir::acc::ReductionOperator::AccAdd) { + if (ty.isIntOrIndex()) + return builder.create(loc, value1, value2); + if (mlir::isa(ty)) + return builder.create(loc, value1, value2); + TODO(loc, "reduction add type"); + } + TODO(loc, "reduction operator"); +} + +mlir::acc::ReductionRecipeOp Fortran::lower::createOrGetReductionRecipe( + mlir::OpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc, + mlir::Type ty, mlir::acc::ReductionOperator op) { + mlir::ModuleOp mod = + builder.getBlock()->getParent()->getParentOfType(); + if (auto recipe = mod.lookupSymbol(recipeName)) + return recipe; + + auto crtPos = builder.saveInsertionPoint(); + mlir::OpBuilder modBuilder(mod.getBodyRegion()); + auto recipe = + modBuilder.create(loc, recipeName, ty, op); + builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(), + {ty}, {loc}); + builder.setInsertionPointToEnd(&recipe.getInitRegion().back()); + mlir::Value initValue = genReductionInitValue(builder, loc, ty, op); + builder.create(loc, initValue); + + builder.createBlock(&recipe.getCombinerRegion(), + recipe.getCombinerRegion().end(), {ty, ty}, {loc, loc}); + builder.setInsertionPointToEnd(&recipe.getCombinerRegion().back()); + mlir::Value v1 = recipe.getCombinerRegion().front().getArgument(0); + mlir::Value v2 = recipe.getCombinerRegion().front().getArgument(1); + mlir::Value combinedValue = genCombiner(builder, loc, op, ty, v1, v2); + builder.create(loc, combinedValue); + builder.restoreInsertionPoint(crtPos); + return recipe; +} + +static void +genReductions(const Fortran::parser::AccObjectListWithReduction &objectList, + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::lower::StatementContext &stmtCtx, + llvm::SmallVectorImpl &reductionOperands, + llvm::SmallVector &reductionRecipes) { + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + const auto &objects = std::get(objectList.t); + const auto &op = + std::get(objectList.t); + mlir::acc::ReductionOperator mlirOp = getReductionOperator(op); + for (const auto &accObject : objects.v) { + llvm::SmallVector bounds; + std::stringstream asFortran; + mlir::Location operandLocation = genOperandLocation(converter, accObject); + mlir::Value baseAddr = gatherDataOperandAddrAndBounds( + converter, builder, semanticsContext, stmtCtx, accObject, + operandLocation, asFortran, bounds); + + if (!fir::isa_trivial(fir::unwrapRefType(baseAddr.getType()))) + TODO(operandLocation, "reduction with unsupported type"); + + mlir::Type ty = fir::unwrapRefType(baseAddr.getType()); + std::string recipeName = fir::getTypeAsString( + ty, converter.getKindMap(), + ("reduction_" + stringifyReductionOperator(mlirOp)).str()); + mlir::acc::ReductionRecipeOp recipe = + Fortran::lower::createOrGetReductionRecipe(builder, recipeName, + operandLocation, ty, mlirOp); + reductionRecipes.push_back(mlir::SymbolRefAttr::get( + builder.getContext(), recipe.getSymName().str())); + reductionOperands.push_back(baseAddr); + } +} + static void addOperands(llvm::SmallVectorImpl &operands, llvm::SmallVectorImpl &operandSegments, @@ -666,7 +791,7 @@ createLoopOp(Fortran::lower::AbstractConverter &converter, mlir::Value gangStatic; llvm::SmallVector tileOperands, privateOperands, reductionOperands; - llvm::SmallVector privatizations; + llvm::SmallVector privatizations, reductionRecipes; bool hasGang = false, hasVector = false, hasWorker = false; for (const Fortran::parser::AccClause &clause : accClauseList.v) { @@ -735,10 +860,11 @@ createLoopOp(Fortran::lower::AbstractConverter &converter, &clause.u)) { genPrivatizations(privateClause->v, converter, semanticsContext, stmtCtx, privateOperands, privatizations); - } else if (std::get_if(&clause.u)) { - // Reduction clause is left out for the moment as the clause will probably - // end up having its own operation. - TODO(clauseLocation, "OpenACC compute construct reduction lowering"); + } else if (const auto *reductionClause = + std::get_if( + &clause.u)) { + genReductions(reductionClause->v, converter, semanticsContext, stmtCtx, + reductionOperands, reductionRecipes); } } @@ -767,6 +893,10 @@ createLoopOp(Fortran::lower::AbstractConverter &converter, loopOp.setPrivatizationsAttr( mlir::ArrayAttr::get(builder.getContext(), privatizations)); + if (!reductionRecipes.empty()) + loopOp.setReductionRecipesAttr( + mlir::ArrayAttr::get(builder.getContext(), reductionRecipes)); + // Lower clauses mapped to attributes for (const Fortran::parser::AccClause &clause : accClauseList.v) { if (const auto *collapseClause = diff --git a/flang/test/Lower/OpenACC/acc-reduction.f90 b/flang/test/Lower/OpenACC/acc-reduction.f90 new file mode 100644 index 0000000..4c95b40 --- /dev/null +++ b/flang/test/Lower/OpenACC/acc-reduction.f90 @@ -0,0 +1,51 @@ +! This test checks lowering of OpenACC reduction clause. + +! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s + +! CHECK-LABEL: acc.reduction.recipe @reduction_add_f32 : f32 reduction_operator init { +! CHECK: ^bb0(%{{.*}}: f32): +! CHECK: %[[INIT:.*]] = arith.constant 0.000000e+00 : f32 +! CHECK: acc.yield %[[INIT]] : f32 +! CHECK: } combiner { +! CHECK: ^bb0(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32): +! CHECK: %[[COMBINED:.*]] = arith.addf %[[ARG0]], %[[ARG1]] {{.*}} : f32 +! CHECK: acc.yield %[[COMBINED]] : f32 +! CHECK: } + +! CHECK-LABEL: acc.reduction.recipe @reduction_add_i32 : i32 reduction_operator init { +! CHECK: ^bb0(%{{.*}}: i32): +! CHECK: %[[INIT:.*]] = arith.constant 0 : i32 +! CHECK: acc.yield %[[INIT]] : i32 +! CHECK: } combiner { +! CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32): +! CHECK: %[[COMBINED:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32 +! CHECK: acc.yield %[[COMBINED]] : i32 +! CHECK: } + +subroutine acc_reduction_add_int(a, b) + integer :: a(100) + integer :: i, b + + !$acc loop reduction(+:b) + do i = 1, 100 + b = b + a(i) + end do +end subroutine + +! CHECK-LABEL: func.func @_QPacc_reduction_add_int( +! CHECK-SAME: %{{.*}}: !fir.ref> {fir.bindc_name = "a"}, %[[B:.*]]: !fir.ref {fir.bindc_name = "b"}) +! CHECK: acc.loop reduction(@reduction_add_i32 -> %[[B]] : !fir.ref) { + +subroutine acc_reduction_add_float(a, b) + real :: a(100), b + integer :: i + + !$acc loop reduction(+:b) + do i = 1, 100 + b = b + a(i) + end do +end subroutine + +! CHECK-LABEL: func.func @_QPacc_reduction_add_float( +! CHECK-SAME: %{{.*}}: !fir.ref> {fir.bindc_name = "a"}, %[[B:.*]]: !fir.ref {fir.bindc_name = "b"}) +! CHECK: acc.loop reduction(@reduction_add_f32 -> %[[B]] : !fir.ref) diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index fab270a..da5a2856 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -498,7 +498,7 @@ template static LogicalResult checkSymOperandList(Operation *op, std::optional attributes, mlir::OperandRange operands, llvm::StringRef operandName, - llvm::StringRef symbolName) { + llvm::StringRef symbolName, bool checkOperandType = true) { if (!operands.empty()) { if (!attributes || attributes->size() != operands.size()) return op->emitOpError() @@ -527,7 +527,7 @@ checkSymOperandList(Operation *op, std::optional attributes, << "expected symbol reference " << symbolRef << " to point to a " << operandName << " declaration"; - if (decl.getType() && decl.getType() != varType) + if (checkOperandType && decl.getType() && decl.getType() != varType) return op->emitOpError() << "expected " << operandName << " (" << varType << ") to be the same type as " << operandName << " declaration (" << decl.getType() << ")"; @@ -751,7 +751,7 @@ LogicalResult acc::LoopOp::verify() { if (failed(checkSymOperandList( *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions"))) + "reductions", false))) return failure(); // Check non-empty body(). -- 2.7.4