[flang][openacc] Initial reduction clause lowering
authorValentin Clement <clementval@gmail.com>
Thu, 1 Jun 2023 13:14:42 +0000 (22:14 +0900)
committerValentin Clement <clementval@gmail.com>
Thu, 1 Jun 2023 13:15:28 +0000 (22:15 +0900)
Add initial support to lower reduction clause to its representation in MLIR.

This patch adds support for addition of integer and real scalar types. Other
operators and types will be added with follow up patches.

Reviewed By: razvanlupusoru

Differential Revision: https://reviews.llvm.org/D151564

flang/include/flang/Lower/OpenACC.h
flang/lib/Lower/OpenACC.cpp
flang/test/Lower/OpenACC/acc-reduction.f90 [new file with mode: 0644]
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

index cd876ea..7546e84 100644 (file)
@@ -13,6 +13,8 @@
 #ifndef FORTRAN_LOWER_OPENACC_H
 #define FORTRAN_LOWER_OPENACC_H
 
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+
 namespace llvm {
 class StringRef;
 }
@@ -21,9 +23,6 @@ namespace mlir {
 class Location;
 class Type;
 class OpBuilder;
-namespace acc {
-class PrivateRecipeOp;
-}
 } // namespace mlir
 
 namespace Fortran {
@@ -57,6 +56,12 @@ mlir::acc::PrivateRecipeOp createOrGetPrivateRecipe(mlir::OpBuilder &,
                                                     llvm::StringRef,
                                                     mlir::Location, mlir::Type);
 
+/// Get a acc.reduction.recipe op for the given type or create it if it does not
+/// exist yet.
+mlir::acc::ReductionRecipeOp
+createOrGetReductionRecipe(mlir::OpBuilder &, llvm::StringRef, mlir::Location,
+                           mlir::Type, mlir::acc::ReductionOperator);
+
 } // namespace lower
 } // namespace Fortran
 
index 306b799..f332987 100644 (file)
@@ -22,7 +22,6 @@
 #include "flang/Parser/parse-tree.h"
 #include "flang/Semantics/expression.h"
 #include "flang/Semantics/tools.h"
-#include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "llvm/Frontend/OpenACC/ACC.h.inc"
 
 // Special value for * passed in device_type or gang clauses.
@@ -526,6 +525,132 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
   }
 }
 
+/// Return the corresponding enum value for the mlir::acc::ReductionOperator
+/// from the parser representation.
+static mlir::acc::ReductionOperator
+getReductionOperator(const Fortran::parser::AccReductionOperator &op) {
+  switch (op.v) {
+  case Fortran::parser::AccReductionOperator::Operator::Plus:
+    return mlir::acc::ReductionOperator::AccAdd;
+  case Fortran::parser::AccReductionOperator::Operator::Multiply:
+    return mlir::acc::ReductionOperator::AccMul;
+  case Fortran::parser::AccReductionOperator::Operator::Max:
+    return mlir::acc::ReductionOperator::AccMax;
+  case Fortran::parser::AccReductionOperator::Operator::Min:
+    return mlir::acc::ReductionOperator::AccMin;
+  case Fortran::parser::AccReductionOperator::Operator::Iand:
+    return mlir::acc::ReductionOperator::AccIand;
+  case Fortran::parser::AccReductionOperator::Operator::Ior:
+    return mlir::acc::ReductionOperator::AccIor;
+  case Fortran::parser::AccReductionOperator::Operator::Ieor:
+    return mlir::acc::ReductionOperator::AccXor;
+  case Fortran::parser::AccReductionOperator::Operator::And:
+    return mlir::acc::ReductionOperator::AccLand;
+  case Fortran::parser::AccReductionOperator::Operator::Or:
+    return mlir::acc::ReductionOperator::AccLor;
+  case Fortran::parser::AccReductionOperator::Operator::Eqv:
+    return mlir::acc::ReductionOperator::AccEqv;
+  case Fortran::parser::AccReductionOperator::Operator::Neqv:
+    return mlir::acc::ReductionOperator::AccNeqv;
+  }
+  llvm_unreachable("unexpected reduction operator");
+}
+
+static mlir::Value genReductionInitValue(mlir::OpBuilder &builder,
+                                         mlir::Location loc, mlir::Type ty,
+                                         mlir::acc::ReductionOperator op) {
+  if (op != mlir::acc::ReductionOperator::AccAdd)
+    TODO(loc, "reduction operator");
+
+  unsigned initValue = 0;
+
+  if (ty.isIntOrIndex())
+    return builder.create<mlir::arith::ConstantOp>(
+        loc, ty, builder.getIntegerAttr(ty, initValue));
+  if (mlir::isa<mlir::FloatType>(ty))
+    return builder.create<mlir::arith::ConstantOp>(
+        loc, ty, builder.getFloatAttr(ty, initValue));
+  TODO(loc, "reduction type");
+}
+
+static mlir::Value genCombiner(mlir::OpBuilder &builder, mlir::Location loc,
+                               mlir::acc::ReductionOperator op, mlir::Type ty,
+                               mlir::Value value1, mlir::Value value2) {
+  if (op == mlir::acc::ReductionOperator::AccAdd) {
+    if (ty.isIntOrIndex())
+      return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
+    if (mlir::isa<mlir::FloatType>(ty))
+      return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
+    TODO(loc, "reduction add type");
+  }
+  TODO(loc, "reduction operator");
+}
+
+mlir::acc::ReductionRecipeOp Fortran::lower::createOrGetReductionRecipe(
+    mlir::OpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc,
+    mlir::Type ty, mlir::acc::ReductionOperator op) {
+  mlir::ModuleOp mod =
+      builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
+  if (auto recipe = mod.lookupSymbol<mlir::acc::ReductionRecipeOp>(recipeName))
+    return recipe;
+
+  auto crtPos = builder.saveInsertionPoint();
+  mlir::OpBuilder modBuilder(mod.getBodyRegion());
+  auto recipe =
+      modBuilder.create<mlir::acc::ReductionRecipeOp>(loc, recipeName, ty, op);
+  builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(),
+                      {ty}, {loc});
+  builder.setInsertionPointToEnd(&recipe.getInitRegion().back());
+  mlir::Value initValue = genReductionInitValue(builder, loc, ty, op);
+  builder.create<mlir::acc::YieldOp>(loc, initValue);
+
+  builder.createBlock(&recipe.getCombinerRegion(),
+                      recipe.getCombinerRegion().end(), {ty, ty}, {loc, loc});
+  builder.setInsertionPointToEnd(&recipe.getCombinerRegion().back());
+  mlir::Value v1 = recipe.getCombinerRegion().front().getArgument(0);
+  mlir::Value v2 = recipe.getCombinerRegion().front().getArgument(1);
+  mlir::Value combinedValue = genCombiner(builder, loc, op, ty, v1, v2);
+  builder.create<mlir::acc::YieldOp>(loc, combinedValue);
+  builder.restoreInsertionPoint(crtPos);
+  return recipe;
+}
+
+static void
+genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
+              Fortran::lower::AbstractConverter &converter,
+              Fortran::semantics::SemanticsContext &semanticsContext,
+              Fortran::lower::StatementContext &stmtCtx,
+              llvm::SmallVectorImpl<mlir::Value> &reductionOperands,
+              llvm::SmallVector<mlir::Attribute> &reductionRecipes) {
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+  const auto &objects = std::get<Fortran::parser::AccObjectList>(objectList.t);
+  const auto &op =
+      std::get<Fortran::parser::AccReductionOperator>(objectList.t);
+  mlir::acc::ReductionOperator mlirOp = getReductionOperator(op);
+  for (const auto &accObject : objects.v) {
+    llvm::SmallVector<mlir::Value> bounds;
+    std::stringstream asFortran;
+    mlir::Location operandLocation = genOperandLocation(converter, accObject);
+    mlir::Value baseAddr = gatherDataOperandAddrAndBounds(
+        converter, builder, semanticsContext, stmtCtx, accObject,
+        operandLocation, asFortran, bounds);
+
+    if (!fir::isa_trivial(fir::unwrapRefType(baseAddr.getType())))
+      TODO(operandLocation, "reduction with unsupported type");
+
+    mlir::Type ty = fir::unwrapRefType(baseAddr.getType());
+    std::string recipeName = fir::getTypeAsString(
+        ty, converter.getKindMap(),
+        ("reduction_" + stringifyReductionOperator(mlirOp)).str());
+    mlir::acc::ReductionRecipeOp recipe =
+        Fortran::lower::createOrGetReductionRecipe(builder, recipeName,
+                                                   operandLocation, ty, mlirOp);
+    reductionRecipes.push_back(mlir::SymbolRefAttr::get(
+        builder.getContext(), recipe.getSymName().str()));
+    reductionOperands.push_back(baseAddr);
+  }
+}
+
 static void
 addOperands(llvm::SmallVectorImpl<mlir::Value> &operands,
             llvm::SmallVectorImpl<int32_t> &operandSegments,
@@ -666,7 +791,7 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   mlir::Value gangStatic;
   llvm::SmallVector<mlir::Value, 2> tileOperands, privateOperands,
       reductionOperands;
-  llvm::SmallVector<mlir::Attribute> privatizations;
+  llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;
   bool hasGang = false, hasVector = false, hasWorker = false;
 
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
@@ -735,10 +860,11 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
                        &clause.u)) {
       genPrivatizations(privateClause->v, converter, semanticsContext, stmtCtx,
                         privateOperands, privatizations);
-    } else if (std::get_if<Fortran::parser::AccClause::Reduction>(&clause.u)) {
-      // Reduction clause is left out for the moment as the clause will probably
-      // end up having its own operation.
-      TODO(clauseLocation, "OpenACC compute construct reduction lowering");
+    } else if (const auto *reductionClause =
+                   std::get_if<Fortran::parser::AccClause::Reduction>(
+                       &clause.u)) {
+      genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
+                    reductionOperands, reductionRecipes);
     }
   }
 
@@ -767,6 +893,10 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
     loopOp.setPrivatizationsAttr(
         mlir::ArrayAttr::get(builder.getContext(), privatizations));
 
+  if (!reductionRecipes.empty())
+    loopOp.setReductionRecipesAttr(
+        mlir::ArrayAttr::get(builder.getContext(), reductionRecipes));
+
   // Lower clauses mapped to attributes
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
     if (const auto *collapseClause =
diff --git a/flang/test/Lower/OpenACC/acc-reduction.f90 b/flang/test/Lower/OpenACC/acc-reduction.f90
new file mode 100644 (file)
index 0000000..4c95b40
--- /dev/null
@@ -0,0 +1,51 @@
+! This test checks lowering of OpenACC reduction clause.
+
+! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
+
+! CHECK-LABEL: acc.reduction.recipe @reduction_add_f32 : f32 reduction_operator <add> init {
+! CHECK: ^bb0(%{{.*}}: f32):
+! CHECK:   %[[INIT:.*]] = arith.constant 0.000000e+00 : f32
+! CHECK:   acc.yield %[[INIT]] : f32
+! CHECK: } combiner {
+! CHECK: ^bb0(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32):
+! CHECK:   %[[COMBINED:.*]] = arith.addf %[[ARG0]], %[[ARG1]] {{.*}} : f32
+! CHECK:   acc.yield %[[COMBINED]] : f32
+! CHECK: }
+
+! CHECK-LABEL: acc.reduction.recipe @reduction_add_i32 : i32 reduction_operator <add> init {
+! CHECK: ^bb0(%{{.*}}: i32):
+! CHECK:   %[[INIT:.*]] = arith.constant 0 : i32
+! CHECK:   acc.yield %[[INIT]] : i32
+! CHECK: } combiner {
+! CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
+! CHECK:   %[[COMBINED:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
+! CHECK:   acc.yield %[[COMBINED]] : i32
+! CHECK: }
+
+subroutine acc_reduction_add_int(a, b)
+  integer :: a(100)
+  integer :: i, b
+
+  !$acc loop reduction(+:b)
+  do i = 1, 100
+    b = b + a(i)
+  end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPacc_reduction_add_int(
+! CHECK-SAME:  %{{.*}}: !fir.ref<!fir.array<100xi32>> {fir.bindc_name = "a"}, %[[B:.*]]: !fir.ref<i32> {fir.bindc_name = "b"})
+! CHECK:       acc.loop reduction(@reduction_add_i32 -> %[[B]] : !fir.ref<i32>) {
+
+subroutine acc_reduction_add_float(a, b)
+  real :: a(100), b
+  integer :: i
+
+  !$acc loop reduction(+:b)
+  do i = 1, 100
+    b = b + a(i)
+  end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPacc_reduction_add_float(
+! CHECK-SAME:  %{{.*}}: !fir.ref<!fir.array<100xf32>> {fir.bindc_name = "a"}, %[[B:.*]]: !fir.ref<f32> {fir.bindc_name = "b"})
+! CHECK:       acc.loop reduction(@reduction_add_f32 -> %[[B]] : !fir.ref<f32>)
index fab270a..da5a285 100644 (file)
@@ -498,7 +498,7 @@ template <typename Op>
 static LogicalResult
 checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
                     mlir::OperandRange operands, llvm::StringRef operandName,
-                    llvm::StringRef symbolName) {
+                    llvm::StringRef symbolName, bool checkOperandType = true) {
   if (!operands.empty()) {
     if (!attributes || attributes->size() != operands.size())
       return op->emitOpError()
@@ -527,7 +527,7 @@ checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
              << "expected symbol reference " << symbolRef << " to point to a "
              << operandName << " declaration";
 
-    if (decl.getType() && decl.getType() != varType)
+    if (checkOperandType && decl.getType() && decl.getType() != varType)
       return op->emitOpError() << "expected " << operandName << " (" << varType
                                << ") to be the same type as " << operandName
                                << " declaration (" << decl.getType() << ")";
@@ -751,7 +751,7 @@ LogicalResult acc::LoopOp::verify() {
 
   if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
           *this, getReductionRecipes(), getReductionOperands(), "reduction",
-          "reductions")))
+          "reductions", false)))
     return failure();
 
   // Check non-empty body().