[flang][hlfir] Lower WHERE to HLFIR
authorJean Perier <jperier@nvidia.com>
Tue, 9 May 2023 07:21:09 +0000 (09:21 +0200)
committerJean Perier <jperier@nvidia.com>
Tue, 9 May 2023 07:21:27 +0000 (09:21 +0200)
Lower WHERE to the newly added hlfir.where and hlfir.elsewhere
operations.

Differential Revision: https://reviews.llvm.org/D149950

flang/lib/Lower/Bridge.cpp
flang/test/Lower/HLFIR/where.f90 [new file with mode: 0644]

index fe86fe8..acf3768 100644 (file)
@@ -3154,7 +3154,7 @@ private:
     // Gather some information about the assignment that will impact how it is
     // lowered.
     const bool isWholeAllocatableAssignment =
-        !userDefinedAssignment &&
+        !userDefinedAssignment && !isInsideHlfirWhere() &&
         Fortran::lower::isWholeAllocatable(assign.lhs);
     std::optional<Fortran::evaluate::DynamicType> lhsType =
         assign.lhs.GetType();
@@ -3243,8 +3243,6 @@ private:
   void genAssignment(const Fortran::evaluate::Assignment &assign) {
     mlir::Location loc = toLocation();
     if (lowerToHighLevelFIR()) {
-      if (!implicitIterSpace.empty())
-        TODO(loc, "HLFIR assignment inside WHERE");
       std::visit(
           Fortran::common::visitors{
               [&](const Fortran::evaluate::Assignment::Intrinsic &) {
@@ -3452,23 +3450,47 @@ private:
       Fortran::lower::createArrayMergeStores(*this, explicitIterSpace);
   }
 
-  bool isInsideHlfirForallOrWhere() const {
+  // Is the insertion point of the builder directly or indirectly set
+  // inside any operation of type "Op"?
+  template <typename... Op>
+  bool isInsideOp() const {
     mlir::Block *block = builder->getInsertionBlock();
     mlir::Operation *op = block ? block->getParentOp() : nullptr;
     while (op) {
-      if (mlir::isa<hlfir::ForallOp, hlfir::WhereOp>(op))
+      if (mlir::isa<Op...>(op))
         return true;
       op = op->getParentOp();
     }
     return false;
   }
+  bool isInsideHlfirForallOrWhere() const {
+    return isInsideOp<hlfir::ForallOp, hlfir::WhereOp>();
+  }
+  bool isInsideHlfirWhere() const { return isInsideOp<hlfir::WhereOp>(); }
 
   void genFIR(const Fortran::parser::WhereConstruct &c) {
-    implicitIterSpace.growStack();
+    mlir::Location loc = getCurrentLocation();
+    hlfir::WhereOp whereOp;
+
+    if (!lowerToHighLevelFIR()) {
+      implicitIterSpace.growStack();
+    } else {
+      whereOp = builder->create<hlfir::WhereOp>(loc);
+      builder->createBlock(&whereOp.getMaskRegion());
+    }
+
+    // Lower the where mask. For HLFIR, this is done in the hlfir.where mask
+    // region.
     genNestedStatement(
         std::get<
             Fortran::parser::Statement<Fortran::parser::WhereConstructStmt>>(
             c.t));
+
+    // Lower WHERE body. For HLFIR, this is done in the hlfir.where body
+    // region.
+    if (whereOp)
+      builder->createBlock(&whereOp.getBody());
+
     for (const auto &body :
          std::get<std::list<Fortran::parser::WhereBodyConstruct>>(c.t))
       genFIR(body);
@@ -3484,6 +3506,13 @@ private:
     genNestedStatement(
         std::get<Fortran::parser::Statement<Fortran::parser::EndWhereStmt>>(
             c.t));
+
+    if (whereOp) {
+      // For HLFIR, create fir.end terminator in the last hlfir.elsewhere, or
+      // in the hlfir.where if it had no elsewhere.
+      builder->create<fir::FirEndOp>(loc);
+      builder->setInsertionPointAfter(whereOp);
+    }
   }
   void genFIR(const Fortran::parser::WhereBodyConstruct &body) {
     std::visit(
@@ -3499,24 +3528,61 @@ private:
         },
         body.u);
   }
+
+  /// Lower a Where or Elsewhere mask into an hlfir mask region.
+  void lowerWhereMaskToHlfir(mlir::Location loc,
+                             const Fortran::semantics::SomeExpr *maskExpr) {
+    assert(maskExpr && "mask semantic analysis failed");
+    Fortran::lower::StatementContext maskContext;
+    hlfir::Entity mask = Fortran::lower::convertExprToHLFIR(
+        loc, *this, *maskExpr, localSymbols, maskContext);
+    mask = hlfir::loadTrivialScalar(loc, *builder, mask);
+    auto yieldOp = builder->create<hlfir::YieldOp>(loc, mask);
+    genCleanUpInRegionIfAny(loc, *builder, yieldOp.getCleanup(), maskContext);
+  }
   void genFIR(const Fortran::parser::WhereConstructStmt &stmt) {
-    implicitIterSpace.append(Fortran::semantics::GetExpr(
-        std::get<Fortran::parser::LogicalExpr>(stmt.t)));
+    const Fortran::semantics::SomeExpr *maskExpr = Fortran::semantics::GetExpr(
+        std::get<Fortran::parser::LogicalExpr>(stmt.t));
+    if (lowerToHighLevelFIR())
+      lowerWhereMaskToHlfir(getCurrentLocation(), maskExpr);
+    else
+      implicitIterSpace.append(maskExpr);
   }
   void genFIR(const Fortran::parser::WhereConstruct::MaskedElsewhere &ew) {
+    mlir::Location loc = getCurrentLocation();
+    hlfir::ElseWhereOp elsewhereOp;
+    if (lowerToHighLevelFIR()) {
+      elsewhereOp = builder->create<hlfir::ElseWhereOp>(loc);
+      // Lower mask in the mask region.
+      builder->createBlock(&elsewhereOp.getMaskRegion());
+    }
     genNestedStatement(
         std::get<
             Fortran::parser::Statement<Fortran::parser::MaskedElsewhereStmt>>(
             ew.t));
+
+    // For HLFIR, lower the body in the hlfir.elsewhere body region.
+    if (elsewhereOp)
+      builder->createBlock(&elsewhereOp.getBody());
+
     for (const auto &body :
          std::get<std::list<Fortran::parser::WhereBodyConstruct>>(ew.t))
       genFIR(body);
   }
   void genFIR(const Fortran::parser::MaskedElsewhereStmt &stmt) {
-    implicitIterSpace.append(Fortran::semantics::GetExpr(
-        std::get<Fortran::parser::LogicalExpr>(stmt.t)));
+    const auto *maskExpr = Fortran::semantics::GetExpr(
+        std::get<Fortran::parser::LogicalExpr>(stmt.t));
+    if (lowerToHighLevelFIR())
+      lowerWhereMaskToHlfir(getCurrentLocation(), maskExpr);
+    else
+      implicitIterSpace.append(maskExpr);
   }
   void genFIR(const Fortran::parser::WhereConstruct::Elsewhere &ew) {
+    if (lowerToHighLevelFIR()) {
+      auto elsewhereOp =
+          builder->create<hlfir::ElseWhereOp>(getCurrentLocation());
+      builder->createBlock(&elsewhereOp.getBody());
+    }
     genNestedStatement(
         std::get<Fortran::parser::Statement<Fortran::parser::ElsewhereStmt>>(
             ew.t));
@@ -3525,18 +3591,32 @@ private:
       genFIR(body);
   }
   void genFIR(const Fortran::parser::ElsewhereStmt &stmt) {
-    implicitIterSpace.append(nullptr);
+    if (!lowerToHighLevelFIR())
+      implicitIterSpace.append(nullptr);
   }
   void genFIR(const Fortran::parser::EndWhereStmt &) {
-    implicitIterSpace.shrinkStack();
+    if (!lowerToHighLevelFIR())
+      implicitIterSpace.shrinkStack();
   }
 
   void genFIR(const Fortran::parser::WhereStmt &stmt) {
     Fortran::lower::StatementContext stmtCtx;
     const auto &assign = std::get<Fortran::parser::AssignmentStmt>(stmt.t);
+    const auto *mask = Fortran::semantics::GetExpr(
+        std::get<Fortran::parser::LogicalExpr>(stmt.t));
+    if (lowerToHighLevelFIR()) {
+      mlir::Location loc = getCurrentLocation();
+      auto whereOp = builder->create<hlfir::WhereOp>(loc);
+      builder->createBlock(&whereOp.getMaskRegion());
+      lowerWhereMaskToHlfir(loc, mask);
+      builder->createBlock(&whereOp.getBody());
+      genAssignment(*assign.typedAssignment->v);
+      builder->create<fir::FirEndOp>(loc);
+      builder->setInsertionPointAfter(whereOp);
+      return;
+    }
     implicitIterSpace.growStack();
-    implicitIterSpace.append(Fortran::semantics::GetExpr(
-        std::get<Fortran::parser::LogicalExpr>(stmt.t)));
+    implicitIterSpace.append(mask);
     genAssignment(*assign.typedAssignment->v);
     implicitIterSpace.shrinkStack();
   }
diff --git a/flang/test/Lower/HLFIR/where.f90 b/flang/test/Lower/HLFIR/where.f90
new file mode 100644 (file)
index 0000000..88e49c9
--- /dev/null
@@ -0,0 +1,170 @@
+! Test lowering of WHERE construct and statements to HLFIR.
+! RUN: bbc --hlfir -emit-fir -o - %s | FileCheck %s
+
+module where_defs
+  logical :: mask(10)
+  real :: x(10), y(10)
+  real, allocatable :: a(:), b(:)
+  interface
+    function return_temporary_mask()
+      logical, allocatable :: return_temporary_mask(:)
+    end function
+    function return_temporary_array()
+      real, allocatable :: return_temporary_array(:)
+    end function
+  end interface
+end module
+
+subroutine simple_where()
+  use where_defs, only: mask, x, y
+  where (mask) x = y
+end subroutine
+! CHECK-LABEL:   func.func @_QPsimple_where() {
+! CHECK:  %[[VAL_3:.*]]:2 = hlfir.declare {{.*}}Emask
+! CHECK:  %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK:  %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK:  hlfir.where {
+! CHECK:    hlfir.yield %[[VAL_3]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
+! CHECK:  } do {
+! CHECK:    hlfir.region_assign {
+! CHECK:      hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK:    } to {
+! CHECK:      hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK:    }
+! CHECK:  }
+! CHECK:  return
+! CHECK:}
+
+subroutine where_construct()
+  use where_defs
+  where (mask)
+    x = y
+    a = b
+  end where
+end subroutine
+! CHECK-LABEL:   func.func @_QPwhere_construct() {
+! CHECK:  %[[VAL_1:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMwhere_defsEa"}
+! CHECK:  %[[VAL_3:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMwhere_defsEb"}
+! CHECK:  %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Emask
+! CHECK:  %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK:  %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK:  hlfir.where {
+! CHECK:    hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
+! CHECK:  } do {
+! CHECK:    hlfir.region_assign {
+! CHECK:      hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK:    } to {
+! CHECK:      hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK:    }
+! CHECK:    hlfir.region_assign {
+! CHECK:      %[[VAL_16:.*]] = fir.load %[[VAL_3]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+! CHECK:      hlfir.yield %[[VAL_16]] : !fir.box<!fir.heap<!fir.array<?xf32>>>
+! CHECK:    } to {
+! CHECK:      %[[VAL_17:.*]] = fir.load %[[VAL_1]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+! CHECK:      hlfir.yield %[[VAL_17]] : !fir.box<!fir.heap<!fir.array<?xf32>>>
+! CHECK:    }
+! CHECK:  }
+! CHECK:  return
+! CHECK:}
+
+subroutine where_cleanup()
+  use where_defs, only: x, return_temporary_mask, return_temporary_array
+  where (return_temporary_mask()) x = return_temporary_array()
+end subroutine
+! CHECK-LABEL:   func.func @_QPwhere_cleanup() {
+! CHECK:  %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = ".result"}
+! CHECK:  %[[VAL_1:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>> {bindc_name = ".result"}
+! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK:  hlfir.where {
+! CHECK:    %[[VAL_6:.*]] = fir.call @_QPreturn_temporary_mask() fastmath<contract> : () -> !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>
+! CHECK:    fir.save_result %[[VAL_6]] to %[[VAL_1]] : !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
+! CHECK:    %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = ".tmp.func_result"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>)
+! CHECK:    %[[VAL_8:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
+! CHECK:    hlfir.yield %[[VAL_8]] : !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>> cleanup {
+! CHECK:        fir.freemem
+! CHECK:    }
+! CHECK:  } do {
+! CHECK:    hlfir.region_assign {
+! CHECK:      %[[VAL_14:.*]] = fir.call @_QPreturn_temporary_array() fastmath<contract> : () -> !fir.box<!fir.heap<!fir.array<?xf32>>>
+! CHECK:      fir.save_result %[[VAL_14]] to %[[VAL_0]] : !fir.box<!fir.heap<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+! CHECK:      %[[VAL_15:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = ".tmp.func_result"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+! CHECK:      %[[VAL_16:.*]] = fir.load %[[VAL_15]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+! CHECK:      hlfir.yield %[[VAL_16]] : !fir.box<!fir.heap<!fir.array<?xf32>>> cleanup {
+! CHECK:          fir.freemem
+! CHECK:      }
+! CHECK:    } to {
+! CHECK:      hlfir.yield %[[VAL_5]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK:    }
+! CHECK:  }
+
+subroutine simple_elsewhere()
+  use where_defs
+  where (mask)
+    x = y
+  elsewhere
+    y = x
+  end where
+end subroutine
+! CHECK-LABEL:   func.func @_QPsimple_elsewhere() {
+! CHECK:  %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Emask
+! CHECK:  %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK:  %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK:  hlfir.where {
+! CHECK:    hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
+! CHECK:  } do {
+! CHECK:    hlfir.region_assign {
+! CHECK:      hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK:    } to {
+! CHECK:      hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK:    }
+! CHECK:    hlfir.elsewhere do {
+! CHECK:      hlfir.region_assign {
+! CHECK:        hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK:      } to {
+! CHECK:        hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK:      }
+! CHECK:    }
+! CHECK:  }
+
+subroutine elsewhere_2(mask2)
+  use where_defs, only : mask, x, y
+  logical :: mask2(:)
+  where (mask)
+    x = y
+  elsewhere(mask2)
+    y = x
+  elsewhere
+    x = foo()
+  end where
+end subroutine
+! CHECK-LABEL:   func.func @_QPelsewhere_2(
+! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Emask
+! CHECK:  %[[VAL_6:.*]]:2 = hlfir.declare {{.*}}Emask2
+! CHECK:  %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK:  %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK:  hlfir.where {
+! CHECK:    hlfir.yield %[[VAL_5]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
+! CHECK:  } do {
+! CHECK:    hlfir.region_assign {
+! CHECK:      hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK:    } to {
+! CHECK:      hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK:    }
+! CHECK:    hlfir.elsewhere mask {
+! CHECK:      hlfir.yield %[[VAL_6]]#0 : !fir.box<!fir.array<?x!fir.logical<4>>>
+! CHECK:    } do {
+! CHECK:      hlfir.region_assign {
+! CHECK:        hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK:      } to {
+! CHECK:        hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK:      }
+! CHECK:      hlfir.elsewhere do {
+! CHECK:        hlfir.region_assign {
+! CHECK:          %[[VAL_16:.*]] = fir.call @_QPfoo() fastmath<contract> : () -> f32
+! CHECK:          hlfir.yield %[[VAL_16]] : f32
+! CHECK:        } to {
+! CHECK:          hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
+! CHECK:        }
+! CHECK:      }
+! CHECK:    }
+! CHECK:  }