[Flang] Set constructExit for Where and Forall constructs
authorJonathon Penix <jpenix@quicinc.com>
Fri, 15 Jul 2022 19:32:36 +0000 (12:32 -0700)
committerUsman Nadeem <mnadeem@quicinc.com>
Fri, 15 Jul 2022 19:34:28 +0000 (12:34 -0700)
Evaluations for the Where and Forall constructs previously did
not have their constructExit field fixed up. This could lead to
falling through to subsequent case blocks in select case
statements if either a Where or Forall construct was the final part
of one case block. Setting the constructExit field results in the
proper branching behavior.

Fixes issue: https://github.com/llvm/llvm-project/issues/56500

Differential Revision: https://reviews.llvm.org/D129879

Change-Id: Ia868df12084520a935f087524e118bcdf47f6d7a

flang/lib/Lower/PFTBuilder.cpp
flang/test/Lower/select-case-statement.f90

index 49ae531..16cfde1 100644 (file)
@@ -939,6 +939,7 @@ private:
             eval.constructExit = &eval.evaluationList->back();
           },
           [&](const parser::DoConstruct &) { setConstructExit(eval); },
+          [&](const parser::ForallConstruct &) { setConstructExit(eval); },
           [&](const parser::IfConstruct &) { setConstructExit(eval); },
           [&](const parser::SelectRankConstruct &) {
             setConstructExit(eval);
@@ -948,6 +949,7 @@ private:
             setConstructExit(eval);
             eval.isUnstructured = true;
           },
+          [&](const parser::WhereConstruct &) { setConstructExit(eval); },
 
           // Default - Common analysis for IO statements; otherwise nop.
           [&](const auto &stmt) {
index e188f88..0a666b9 100644 (file)
     print*, n
   end subroutine
 
+  ! CHECK-LABEL: func @_QPswhere
+  subroutine swhere(num)
+    implicit none
+
+    integer, intent(in) :: num
+    real, dimension(1) :: array
+
+    array = 0.0
+
+    select case (num)
+    ! CHECK: ^bb1:  // pred: ^bb0
+    case (1)
+      where (array >= 0.0)
+        array = 42
+      end where
+    ! CHECK: cf.br ^bb3
+    ! CHECK: ^bb2:  // pred: ^bb0
+    case default
+      array = -1
+    end select
+    ! CHECK: cf.br ^bb3
+    ! CHECK: ^bb3:  // 2 preds: ^bb1, ^bb2
+    print*, array(1)
+  end subroutine swhere
+
+  ! CHECK-LABEL: func @_QPsforall
+  subroutine sforall(num)
+    implicit none
+
+    integer, intent(in) :: num
+    real, dimension(1) :: array
+
+    array = 0.0
+
+    select case (num)
+    ! CHECK: ^bb1:  // pred: ^bb0
+    case (1)
+      where (array >= 0.0)
+        array = 42
+      end where
+    ! CHECK: cf.br ^bb3
+    ! CHECK: ^bb2:  // pred: ^bb0
+    case default
+      array = -1
+    end select
+    ! CHECK: cf.br ^bb3
+    ! CHECK: ^bb3:  // 2 preds: ^bb1, ^bb2
+    print*, array(1)
+  end subroutine sforall
+
   ! CHECK-LABEL: main
   program p
     integer sinteger, v(10)
     call scharacter2('22 ')   ! expected output:  9 -2
     call scharacter2('.  ')   ! expected output:  9 -2
     call scharacter2('   ')   ! expected output:  9 -2
+
+    print*
+    call swhere(1)            ! expected output: 42.
+    call sforall(1)           ! expected output: 42.
   end