From 54c88fc9dfa5854a5891cf3d68d3d2c4a4ba0f25 Mon Sep 17 00:00:00 2001 From: Jean Perier Date: Tue, 9 May 2023 09:21:09 +0200 Subject: [PATCH] [flang][hlfir] Lower WHERE to HLFIR Lower WHERE to the newly added hlfir.where and hlfir.elsewhere operations. Differential Revision: https://reviews.llvm.org/D149950 --- flang/lib/Lower/Bridge.cpp | 108 +++++++++++++++++++++---- flang/test/Lower/HLFIR/where.f90 | 170 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 264 insertions(+), 14 deletions(-) create mode 100644 flang/test/Lower/HLFIR/where.f90 diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index fe86fe8..acf3768 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -3154,7 +3154,7 @@ private: // Gather some information about the assignment that will impact how it is // lowered. const bool isWholeAllocatableAssignment = - !userDefinedAssignment && + !userDefinedAssignment && !isInsideHlfirWhere() && Fortran::lower::isWholeAllocatable(assign.lhs); std::optional lhsType = assign.lhs.GetType(); @@ -3243,8 +3243,6 @@ private: void genAssignment(const Fortran::evaluate::Assignment &assign) { mlir::Location loc = toLocation(); if (lowerToHighLevelFIR()) { - if (!implicitIterSpace.empty()) - TODO(loc, "HLFIR assignment inside WHERE"); std::visit( Fortran::common::visitors{ [&](const Fortran::evaluate::Assignment::Intrinsic &) { @@ -3452,23 +3450,47 @@ private: Fortran::lower::createArrayMergeStores(*this, explicitIterSpace); } - bool isInsideHlfirForallOrWhere() const { + // Is the insertion point of the builder directly or indirectly set + // inside any operation of type "Op"? + template + bool isInsideOp() const { mlir::Block *block = builder->getInsertionBlock(); mlir::Operation *op = block ? block->getParentOp() : nullptr; while (op) { - if (mlir::isa(op)) + if (mlir::isa(op)) return true; op = op->getParentOp(); } return false; } + bool isInsideHlfirForallOrWhere() const { + return isInsideOp(); + } + bool isInsideHlfirWhere() const { return isInsideOp(); } void genFIR(const Fortran::parser::WhereConstruct &c) { - implicitIterSpace.growStack(); + mlir::Location loc = getCurrentLocation(); + hlfir::WhereOp whereOp; + + if (!lowerToHighLevelFIR()) { + implicitIterSpace.growStack(); + } else { + whereOp = builder->create(loc); + builder->createBlock(&whereOp.getMaskRegion()); + } + + // Lower the where mask. For HLFIR, this is done in the hlfir.where mask + // region. genNestedStatement( std::get< Fortran::parser::Statement>( c.t)); + + // Lower WHERE body. For HLFIR, this is done in the hlfir.where body + // region. + if (whereOp) + builder->createBlock(&whereOp.getBody()); + for (const auto &body : std::get>(c.t)) genFIR(body); @@ -3484,6 +3506,13 @@ private: genNestedStatement( std::get>( c.t)); + + if (whereOp) { + // For HLFIR, create fir.end terminator in the last hlfir.elsewhere, or + // in the hlfir.where if it had no elsewhere. + builder->create(loc); + builder->setInsertionPointAfter(whereOp); + } } void genFIR(const Fortran::parser::WhereBodyConstruct &body) { std::visit( @@ -3499,24 +3528,61 @@ private: }, body.u); } + + /// Lower a Where or Elsewhere mask into an hlfir mask region. + void lowerWhereMaskToHlfir(mlir::Location loc, + const Fortran::semantics::SomeExpr *maskExpr) { + assert(maskExpr && "mask semantic analysis failed"); + Fortran::lower::StatementContext maskContext; + hlfir::Entity mask = Fortran::lower::convertExprToHLFIR( + loc, *this, *maskExpr, localSymbols, maskContext); + mask = hlfir::loadTrivialScalar(loc, *builder, mask); + auto yieldOp = builder->create(loc, mask); + genCleanUpInRegionIfAny(loc, *builder, yieldOp.getCleanup(), maskContext); + } void genFIR(const Fortran::parser::WhereConstructStmt &stmt) { - implicitIterSpace.append(Fortran::semantics::GetExpr( - std::get(stmt.t))); + const Fortran::semantics::SomeExpr *maskExpr = Fortran::semantics::GetExpr( + std::get(stmt.t)); + if (lowerToHighLevelFIR()) + lowerWhereMaskToHlfir(getCurrentLocation(), maskExpr); + else + implicitIterSpace.append(maskExpr); } void genFIR(const Fortran::parser::WhereConstruct::MaskedElsewhere &ew) { + mlir::Location loc = getCurrentLocation(); + hlfir::ElseWhereOp elsewhereOp; + if (lowerToHighLevelFIR()) { + elsewhereOp = builder->create(loc); + // Lower mask in the mask region. + builder->createBlock(&elsewhereOp.getMaskRegion()); + } genNestedStatement( std::get< Fortran::parser::Statement>( ew.t)); + + // For HLFIR, lower the body in the hlfir.elsewhere body region. + if (elsewhereOp) + builder->createBlock(&elsewhereOp.getBody()); + for (const auto &body : std::get>(ew.t)) genFIR(body); } void genFIR(const Fortran::parser::MaskedElsewhereStmt &stmt) { - implicitIterSpace.append(Fortran::semantics::GetExpr( - std::get(stmt.t))); + const auto *maskExpr = Fortran::semantics::GetExpr( + std::get(stmt.t)); + if (lowerToHighLevelFIR()) + lowerWhereMaskToHlfir(getCurrentLocation(), maskExpr); + else + implicitIterSpace.append(maskExpr); } void genFIR(const Fortran::parser::WhereConstruct::Elsewhere &ew) { + if (lowerToHighLevelFIR()) { + auto elsewhereOp = + builder->create(getCurrentLocation()); + builder->createBlock(&elsewhereOp.getBody()); + } genNestedStatement( std::get>( ew.t)); @@ -3525,18 +3591,32 @@ private: genFIR(body); } void genFIR(const Fortran::parser::ElsewhereStmt &stmt) { - implicitIterSpace.append(nullptr); + if (!lowerToHighLevelFIR()) + implicitIterSpace.append(nullptr); } void genFIR(const Fortran::parser::EndWhereStmt &) { - implicitIterSpace.shrinkStack(); + if (!lowerToHighLevelFIR()) + implicitIterSpace.shrinkStack(); } void genFIR(const Fortran::parser::WhereStmt &stmt) { Fortran::lower::StatementContext stmtCtx; const auto &assign = std::get(stmt.t); + const auto *mask = Fortran::semantics::GetExpr( + std::get(stmt.t)); + if (lowerToHighLevelFIR()) { + mlir::Location loc = getCurrentLocation(); + auto whereOp = builder->create(loc); + builder->createBlock(&whereOp.getMaskRegion()); + lowerWhereMaskToHlfir(loc, mask); + builder->createBlock(&whereOp.getBody()); + genAssignment(*assign.typedAssignment->v); + builder->create(loc); + builder->setInsertionPointAfter(whereOp); + return; + } implicitIterSpace.growStack(); - implicitIterSpace.append(Fortran::semantics::GetExpr( - std::get(stmt.t))); + implicitIterSpace.append(mask); genAssignment(*assign.typedAssignment->v); implicitIterSpace.shrinkStack(); } diff --git a/flang/test/Lower/HLFIR/where.f90 b/flang/test/Lower/HLFIR/where.f90 new file mode 100644 index 0000000..88e49c9 --- /dev/null +++ b/flang/test/Lower/HLFIR/where.f90 @@ -0,0 +1,170 @@ +! Test lowering of WHERE construct and statements to HLFIR. +! RUN: bbc --hlfir -emit-fir -o - %s | FileCheck %s + +module where_defs + logical :: mask(10) + real :: x(10), y(10) + real, allocatable :: a(:), b(:) + interface + function return_temporary_mask() + logical, allocatable :: return_temporary_mask(:) + end function + function return_temporary_array() + real, allocatable :: return_temporary_array(:) + end function + end interface +end module + +subroutine simple_where() + use where_defs, only: mask, x, y + where (mask) x = y +end subroutine +! CHECK-LABEL: func.func @_QPsimple_where() { +! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare {{.*}}Emask +! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Ex +! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ey +! CHECK: hlfir.where { +! CHECK: hlfir.yield %[[VAL_3]]#0 : !fir.ref>> +! CHECK: } do { +! CHECK: hlfir.region_assign { +! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref> +! CHECK: } to { +! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref> +! CHECK: } +! CHECK: } +! CHECK: return +! CHECK:} + +subroutine where_construct() + use where_defs + where (mask) + x = y + a = b + end where +end subroutine +! CHECK-LABEL: func.func @_QPwhere_construct() { +! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs, uniq_name = "_QMwhere_defsEa"} +! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs, uniq_name = "_QMwhere_defsEb"} +! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Emask +! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex +! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey +! CHECK: hlfir.where { +! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref>> +! CHECK: } do { +! CHECK: hlfir.region_assign { +! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref> +! CHECK: } to { +! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref> +! CHECK: } +! CHECK: hlfir.region_assign { +! CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_3]]#0 : !fir.ref>>> +! CHECK: hlfir.yield %[[VAL_16]] : !fir.box>> +! CHECK: } to { +! CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_1]]#0 : !fir.ref>>> +! CHECK: hlfir.yield %[[VAL_17]] : !fir.box>> +! CHECK: } +! CHECK: } +! CHECK: return +! CHECK:} + +subroutine where_cleanup() + use where_defs, only: x, return_temporary_mask, return_temporary_array + where (return_temporary_mask()) x = return_temporary_array() +end subroutine +! CHECK-LABEL: func.func @_QPwhere_cleanup() { +! CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box>> {bindc_name = ".result"} +! CHECK: %[[VAL_1:.*]] = fir.alloca !fir.box>>> {bindc_name = ".result"} +! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Ex +! CHECK: hlfir.where { +! CHECK: %[[VAL_6:.*]] = fir.call @_QPreturn_temporary_mask() fastmath : () -> !fir.box>>> +! CHECK: fir.save_result %[[VAL_6]] to %[[VAL_1]] : !fir.box>>>, !fir.ref>>>> +! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = ".tmp.func_result"} : (!fir.ref>>>>) -> (!fir.ref>>>>, !fir.ref>>>>) +! CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref>>>> +! CHECK: hlfir.yield %[[VAL_8]] : !fir.box>>> cleanup { +! CHECK: fir.freemem +! CHECK: } +! CHECK: } do { +! CHECK: hlfir.region_assign { +! CHECK: %[[VAL_14:.*]] = fir.call @_QPreturn_temporary_array() fastmath : () -> !fir.box>> +! CHECK: fir.save_result %[[VAL_14]] to %[[VAL_0]] : !fir.box>>, !fir.ref>>> +! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = ".tmp.func_result"} : (!fir.ref>>>) -> (!fir.ref>>>, !fir.ref>>>) +! CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_15]]#0 : !fir.ref>>> +! CHECK: hlfir.yield %[[VAL_16]] : !fir.box>> cleanup { +! CHECK: fir.freemem +! CHECK: } +! CHECK: } to { +! CHECK: hlfir.yield %[[VAL_5]]#0 : !fir.ref> +! CHECK: } +! CHECK: } + +subroutine simple_elsewhere() + use where_defs + where (mask) + x = y + elsewhere + y = x + end where +end subroutine +! CHECK-LABEL: func.func @_QPsimple_elsewhere() { +! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Emask +! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex +! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey +! CHECK: hlfir.where { +! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref>> +! CHECK: } do { +! CHECK: hlfir.region_assign { +! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref> +! CHECK: } to { +! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref> +! CHECK: } +! CHECK: hlfir.elsewhere do { +! CHECK: hlfir.region_assign { +! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref> +! CHECK: } to { +! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref> +! CHECK: } +! CHECK: } +! CHECK: } + +subroutine elsewhere_2(mask2) + use where_defs, only : mask, x, y + logical :: mask2(:) + where (mask) + x = y + elsewhere(mask2) + y = x + elsewhere + x = foo() + end where +end subroutine +! CHECK-LABEL: func.func @_QPelsewhere_2( +! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Emask +! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare {{.*}}Emask2 +! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex +! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey +! CHECK: hlfir.where { +! CHECK: hlfir.yield %[[VAL_5]]#0 : !fir.ref>> +! CHECK: } do { +! CHECK: hlfir.region_assign { +! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref> +! CHECK: } to { +! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref> +! CHECK: } +! CHECK: hlfir.elsewhere mask { +! CHECK: hlfir.yield %[[VAL_6]]#0 : !fir.box>> +! CHECK: } do { +! CHECK: hlfir.region_assign { +! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref> +! CHECK: } to { +! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref> +! CHECK: } +! CHECK: hlfir.elsewhere do { +! CHECK: hlfir.region_assign { +! CHECK: %[[VAL_16:.*]] = fir.call @_QPfoo() fastmath : () -> f32 +! CHECK: hlfir.yield %[[VAL_16]] : f32 +! CHECK: } to { +! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref> +! CHECK: } +! CHECK: } +! CHECK: } +! CHECK: } -- 2.7.4