[flang][openacc] Lower reduction for compute constructs
authorRazvan Lupusoru <rlupusoru@nvidia.com>
Wed, 7 Jun 2023 18:13:44 +0000 (11:13 -0700)
committerRazvan Lupusoru <rlupusoru@nvidia.com>
Wed, 7 Jun 2023 20:44:25 +0000 (13:44 -0700)
Parallel and serial constructs support reduction clause. Extend
recent D151564 loop reduction clause support to also include these
compute constructs.

Reviewed By: clementval, vzakhari

Differential Revision: https://reviews.llvm.org/D151955

flang/lib/Lower/OpenACC.cpp
flang/test/Lower/OpenACC/acc-kernels-loop.f90
flang/test/Lower/OpenACC/acc-loop.f90
flang/test/Lower/OpenACC/acc-parallel-loop.f90
flang/test/Lower/OpenACC/acc-parallel.f90
flang/test/Lower/OpenACC/acc-serial-loop.f90
flang/test/Lower/OpenACC/acc-serial.f90
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

index c59be17f2c6e138cd954c76c3075dbfa80ae1482..abdf320cdfd985999266c5c7b299fe7eafddae56 100644 (file)
@@ -997,7 +997,7 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
 
   llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
       firstprivateOperands;
-  llvm::SmallVector<mlir::Attribute> privatizations;
+  llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;
 
   // Async, wait and self clause have optional values but can be present with
   // no value as well. When there is no value, the op has an attribute to
@@ -1151,8 +1151,11 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
                        &clause.u)) {
       genObjectList(firstprivateClause->v, converter, semanticsContext, stmtCtx,
                     firstprivateOperands);
-    } else if (std::get_if<Fortran::parser::AccClause::Reduction>(&clause.u)) {
-      TODO(clauseLocation, "compute construct reduction clause lowering");
+    } else if (const auto *reductionClause =
+                   std::get_if<Fortran::parser::AccClause::Reduction>(
+                       &clause.u)) {
+      genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
+                    reductionOperands, reductionRecipes);
     }
   }
 
@@ -1194,6 +1197,9 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
     if (!privatizations.empty())
       computeOp.setPrivatizationsAttr(
           mlir::ArrayAttr::get(builder.getContext(), privatizations));
+    if (!reductionRecipes.empty())
+      computeOp.setReductionRecipesAttr(
+          mlir::ArrayAttr::get(builder.getContext(), reductionRecipes));
   }
 
   auto insPt = builder.saveInsertionPoint();
index 6aad08b13499c3a2e1b7d2b6e2390751a27a8354..33c3a8c447cf4cc3406bcabbb7cdc3c545641e81 100644 (file)
@@ -16,6 +16,8 @@ subroutine acc_kernels_loop
   real, dimension(n) :: a, b, c
   real, dimension(n, n) :: d, e
   real, pointer :: f, g
+  integer :: reduction_i
+  real :: reduction_r
 
   integer :: gangNum = 8
   integer :: gangStatic = 8
@@ -709,6 +711,20 @@ subroutine acc_kernels_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
+! CHECK-NEXT: }{{$}}
+
+  !$acc kernels loop reduction(+:reduction_r) reduction(*:reduction_i)
+  do i = 1, n
+    reduction_r = reduction_r + a(i)
+    reduction_i = 1
+  end do
+
+! CHECK:      acc.kernels {
+! CHECK:        acc.loop reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
+! CHECK:          fir.do_loop
+! CHECK:          acc.yield
+! CHECK-NEXT:   }{{$}}
+! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
 end subroutine
index ec8eb0f73b74e7379dfa0b3ec9fa804e06c4eb44..5b84763e32d7bd6759e7b2ca653ffa32f9f1cf5e 100644 (file)
@@ -17,6 +17,8 @@ program acc_loop
   integer :: gangStatic = 8
   integer :: vectorLength = 128
   integer, parameter :: tileSize = 2
+  integer :: reduction_i
+  real :: reduction_r
 
 
   !$acc loop
@@ -270,4 +272,15 @@ program acc_loop
 !CHECK:        acc.yield
 !CHECK-NEXT: }{{$}}
 
+  !$acc loop reduction(+:reduction_r) reduction(*:reduction_i)
+  do i = 1, n
+    reduction_r = reduction_r + a(i)
+    reduction_i = 1
+  end do
+
+! CHECK:      acc.loop reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
+! CHECK:        fir.do_loop
+! CHECK:        acc.yield
+! CHECK-NEXT: }{{$}}
+
 end program
index 38df6228acc83a285e77cac06842739cd37bd933..b295a905bfd85b4507cb403ce86dc37b9c4a8a4a 100644 (file)
@@ -23,6 +23,8 @@ subroutine acc_parallel_loop
   real, dimension(n) :: a, b, c
   real, dimension(n, n) :: d, e
   real, pointer :: f, g
+  integer :: reduction_i
+  real :: reduction_r
 
   integer :: gangNum = 8
   integer :: gangStatic = 8
@@ -729,6 +731,20 @@ subroutine acc_parallel_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
+! CHECK-NEXT: }{{$}}
+
+  !$acc parallel loop reduction(+:reduction_r) reduction(*:reduction_i)
+  do i = 1, n
+    reduction_r = reduction_r + a(i)
+    reduction_i = 1
+  end do
+
+! CHECK:      acc.parallel reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
+! CHECK:        acc.loop reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
+! CHECK:          fir.do_loop
+! CHECK:          acc.yield
+! CHECK-NEXT:   }{{$}}
+! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
 end subroutine acc_parallel_loop
index d1c9d80c1fbb61ea448b7752663253ec68197a5a..acfab91f467107a9266c2f978d77c736ea1ab530 100644 (file)
@@ -21,6 +21,8 @@ subroutine acc_parallel
   logical :: ifCondition = .TRUE.
   real, dimension(10, 10) :: a, b, c
   real, pointer :: d, e
+  integer :: reduction_i
+  real :: reduction_r
 
 !CHECK: %[[A:.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Ea"}
 !CHECK: %[[B:.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Eb"}
@@ -302,4 +304,11 @@ subroutine acc_parallel
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
+!$acc parallel reduction(+:reduction_r) reduction(*:reduction_i)
+!$acc end parallel
+
+! CHECK:      acc.parallel reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
+! CHECK:        acc.yield
+! CHECK-NEXT: }{{$}}
+
 end subroutine acc_parallel
index 2e26da8bb2c63bd7cc745753a5e5e03c76acebb4..bf83af8bf55fd63a6809b05dd4b8bafec0f367c3 100644 (file)
@@ -23,6 +23,8 @@ subroutine acc_serial_loop
   real, dimension(n) :: a, b, c
   real, dimension(n, n) :: d, e
   real, pointer :: f, g
+  integer :: reduction_i
+  real :: reduction_r
 
   integer :: gangNum = 8
   integer :: gangStatic = 8
@@ -645,6 +647,20 @@ subroutine acc_serial_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
+! CHECK-NEXT: }{{$}}
+
+  !$acc serial loop reduction(+:reduction_r) reduction(*:reduction_i)
+  do i = 1, n
+    reduction_r = reduction_r + a(i)
+    reduction_i = 1
+  end do
+
+! CHECK:      acc.serial reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
+! CHECK:        acc.loop reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
+! CHECK:          fir.do_loop
+! CHECK:          acc.yield
+! CHECK-NEXT:   }{{$}}
+! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
 end subroutine acc_serial_loop
index d10a3ab7a0c4fe200a0b72c06de76e870a4ffb6b..4d17d58c241001e3dc8459cf6d7e162ecee2cabc 100644 (file)
@@ -21,6 +21,8 @@ subroutine acc_serial
   logical :: ifCondition = .TRUE.
   real, dimension(10, 10) :: a, b, c
   real, pointer :: d, e
+  integer :: reduction_i
+  real :: reduction_r
 
 ! CHECK: %[[A:.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Ea"}
 ! CHECK: %[[B:.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Eb"}
@@ -245,4 +247,11 @@ subroutine acc_serial
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
+!$acc serial reduction(+:reduction_r) reduction(*:reduction_i)
+!$acc end serial
+
+! CHECK:      acc.serial reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
+! CHECK:        acc.yield
+! CHECK-NEXT: }{{$}}
+
 end subroutine
index da5a2856aec21e7861b45d6a320224bfb4ba9bd8..b2998b736f991d6bd96f3e2eef07d2ba30e83c76 100644 (file)
@@ -558,7 +558,7 @@ LogicalResult acc::ParallelOp::verify() {
     return failure();
   if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
           *this, getReductionRecipes(), getReductionOperands(), "reduction",
-          "reductions")))
+          "reductions", false)))
     return failure();
   return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
 }
@@ -586,7 +586,7 @@ LogicalResult acc::SerialOp::verify() {
     return failure();
   if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
           *this, getReductionRecipes(), getReductionOperands(), "reduction",
-          "reductions")))
+          "reductions", false)))
     return failure();
   return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
 }