tree-optimization/100981 - fix SLP patterns involving reductions
authorRichard Biener <rguenther@suse.de>
Wed, 9 Jun 2021 12:48:35 +0000 (14:48 +0200)
committerRichard Biener <rguenther@suse.de>
Wed, 9 Jun 2021 14:33:18 +0000 (16:33 +0200)
The following fixes the SLP FMA patterns to preserve reduction
info and the reduction vectorization to consider internal function
call defs for the reduction stmt.

2021-06-09  Richard Biener  <rguenther@suse.de>

PR tree-optimization/100981
gcc/
* tree-vect-loop.c (vect_create_epilog_for_reduction): Use
gimple_get_lhs to also handle calls.
* tree-vect-slp-patterns.c (complex_pattern::build): Transfer
reduction info.

gcc/testsuite/
* gfortran.dg/vect/pr100981-1.f90: New testcase.

libgomp/
* testsuite/libgomp.fortran/pr100981-2.f90: New testcase.

gcc/testsuite/gfortran.dg/vect/pr100981-1.f90 [new file with mode: 0644]
gcc/tree-vect-loop.c
gcc/tree-vect-slp-patterns.c
libgomp/testsuite/libgomp.fortran/pr100981-2.f90 [new file with mode: 0644]

diff --git a/gcc/testsuite/gfortran.dg/vect/pr100981-1.f90 b/gcc/testsuite/gfortran.dg/vect/pr100981-1.f90
new file mode 100644 (file)
index 0000000..6f11121
--- /dev/null
@@ -0,0 +1,22 @@
+! { dg-do compile }
+! { dg-additional-options "-O3 -ftree-parallelize-loops=2 -fno-signed-zeros -fno-trapping-math" }
+! { dg-additional-options "-march=armv8.3-a" { target aarch64*-*-* } }
+
+complex function cdcdot(n, cx)
+  implicit none
+
+  integer :: n, i, kx
+  complex :: cx(*)
+  double precision :: dsdotr, dsdoti, dt1, dt3
+
+  kx = 1
+  do i = 1, n
+     dt1 = real(cx(kx))
+     dt3 = aimag(cx(kx))
+     dsdotr = dsdotr + dt1 * 2 - dt3 * 2
+     dsdoti = dsdoti + dt1 * 2 + dt3 * 2
+     kx = kx + 1
+  end do
+  cdcdot = cmplx(real(dsdotr), real(dsdoti))
+  return
+end function cdcdot
index ba36348..ee79808 100644 (file)
@@ -5247,7 +5247,7 @@ vect_create_epilog_for_reduction (loop_vec_info loop_vinfo,
       gcc_assert (STMT_VINFO_RELATED_STMT (orig_stmt_info) == stmt_info);
     }
   
-  scalar_dest = gimple_assign_lhs (orig_stmt_info->stmt);
+  scalar_dest = gimple_get_lhs (orig_stmt_info->stmt);
   scalar_type = TREE_TYPE (scalar_dest);
   scalar_results.create (group_size); 
   new_scalar_dest = vect_create_destination_var (scalar_dest, NULL);
index b25655c..2ed49cd 100644 (file)
@@ -544,6 +544,8 @@ complex_pattern::build (vec_info *vinfo)
     {
       /* Calculate the location of the statement in NODE to replace.  */
       stmt_info = SLP_TREE_REPRESENTATIVE (node);
+      stmt_vec_info reduc_def
+       = STMT_VINFO_REDUC_DEF (vect_orig_stmt (stmt_info));
       gimple* old_stmt = STMT_VINFO_STMT (stmt_info);
       tree lhs_old_stmt = gimple_get_lhs (old_stmt);
       tree type = TREE_TYPE (lhs_old_stmt);
@@ -568,9 +570,10 @@ complex_pattern::build (vec_info *vinfo)
        = vinfo->add_pattern_stmt (call_stmt, stmt_info);
 
       /* Make sure to mark the representative statement pure_slp and
-        relevant. */
+        relevant and transfer reduction info. */
       STMT_VINFO_RELEVANT (call_stmt_info) = vect_used_in_scope;
       STMT_SLP_TYPE (call_stmt_info) = pure_slp;
+      STMT_VINFO_REDUC_DEF (call_stmt_info) = reduc_def;
 
       gimple_set_bb (call_stmt, gimple_bb (stmt_info->stmt));
       STMT_VINFO_VECTYPE (call_stmt_info) = SLP_TREE_VECTYPE (node);
diff --git a/libgomp/testsuite/libgomp.fortran/pr100981-2.f90 b/libgomp/testsuite/libgomp.fortran/pr100981-2.f90
new file mode 100644 (file)
index 0000000..12836d4
--- /dev/null
@@ -0,0 +1,31 @@
+! { dg-do run }
+! { dg-additional-options "-O3 -ftree-parallelize-loops=2 -fno-signed-zeros -fno-trapping-math" }
+
+complex function cdcdot(n, cx)
+  implicit none
+
+  integer :: n, i, kx
+  complex :: cx(*)
+  double precision :: dsdotr, dsdoti, dt1, dt3
+
+  kx = 1
+  do i = 1, n
+     dt1 = real(cx(kx))
+     dt3 = aimag(cx(kx))
+     dsdotr = dsdotr + dt1 * 2 - dt3 * 2
+     dsdoti = dsdoti + dt1 * 2 + dt3 * 2
+     kx = kx + 1
+  end do
+  cdcdot = cmplx(real(dsdotr), real(dsdoti))
+  return
+end function cdcdot
+program test
+  implicit none
+  complex :: cx(100), ct, cdcdot
+  integer :: i
+  do i = 1, 100
+    cx(i) = cmplx(2*i, i)
+  end do
+  ct = cdcdot (100, cx)
+  if (ct.ne.cmplx(10100.0000,30300.0000)) call abort
+end