From 4203b062fbf70c6394bd02e1645bc18c607b3826 Mon Sep 17 00:00:00 2001 From: Jean Perier Date: Fri, 16 Dec 2022 12:59:12 +0100 Subject: [PATCH] [flang] Lower procedure ref to user defined elemental procedures (part 1) Lower procedure ref to user defined elemental procedure when: - there are no arguments that may be dynamically optional - for functions, the result has no length parameters - the reference can be unordered - there are not character by value arguments This uses the recently added hlfir.elemental operation and tools. The "core" of the argument preparation is shared between elemental and non elemental calls (genUserCalls is code moved without any functional changes) Differential Revision: https://reviews.llvm.org/D140118 --- flang/include/flang/Optimizer/Builder/HLFIRTools.h | 10 ++ flang/lib/Lower/CallInterface.cpp | 9 +- flang/lib/Lower/ConvertCall.cpp | 149 ++++++++++++++++++--- flang/lib/Optimizer/Builder/HLFIRTools.cpp | 20 +++ .../Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp | 24 +--- .../Lower/HLFIR/elemental-user-procedure-ref.f90 | 92 +++++++++++++ 6 files changed, 259 insertions(+), 45 deletions(-) create mode 100644 flang/test/Lower/HLFIR/elemental-user-procedure-ref.f90 diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h index 08cd7d1..0cdac83 100644 --- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h +++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h @@ -227,6 +227,12 @@ genBounds(mlir::Location loc, fir::FirOpBuilder &builder, Entity entity); mlir::Value genShape(mlir::Location loc, fir::FirOpBuilder &builder, Entity entity); +/// Generate a vector of extents with index type from a fir.shape +/// of fir.shape_shift value. +llvm::SmallVector getIndexExtents(mlir::Location loc, + fir::FirOpBuilder &builder, + mlir::Value shape); + /// Read length parameters into result if this entity has any. void genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder, Entity entity, @@ -260,6 +266,10 @@ hlfir::ElementalOp genElementalOp(mlir::Location loc, std::pair> genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder, mlir::ValueRange extents); +inline std::pair> +genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder, mlir::Value shape) { + return genLoopNest(loc, builder, getIndexExtents(loc, builder, shape)); +} /// Inline the body of an hlfir.elemental at the current insertion point /// given a list of one based indices. This generates the computation diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp index 2fae18b..8f0c0db 100644 --- a/flang/lib/Lower/CallInterface.cpp +++ b/flang/lib/Lower/CallInterface.cpp @@ -1110,7 +1110,14 @@ bool Fortran::lower::CallInterface::PassedEntity::mayBeModifiedByCall() const { if (!characteristics) return true; - return characteristics->GetIntent() != Fortran::common::Intent::In; + if (characteristics->GetIntent() == Fortran::common::Intent::In) + return false; + const auto *dummy = + std::get_if( + &characteristics->u); + return !dummy || + !dummy->attrs.test( + Fortran::evaluate::characteristics::DummyDataObject::Attr::Value); } template bool Fortran::lower::CallInterface::PassedEntity::mayBeReadByCall() const { diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp index 8a11bfe..a2852b3 100644 --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -429,6 +429,15 @@ isStatementFunctionCall(const Fortran::evaluate::ProcedureRef &procRef) { namespace { class CallBuilder { +private: + struct PreparedActualArgument { + hlfir::Entity actual; + bool handleDynamicOptional; + }; + using PreparedActualArguments = + llvm::SmallVector>; + using PassBy = Fortran::lower::CallerInterface::PassEntityBy; + public: CallBuilder(mlir::Location loc, Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symMap, @@ -439,20 +448,18 @@ public: gen(const Fortran::evaluate::ProcedureRef &procRef, llvm::Optional resultType) { mlir::Location loc = getLoc(); - fir::FirOpBuilder &builder = getBuilder(); - if (isElementalProcWithArrayArgs(procRef)) - TODO(loc, "lowering elemental call to HLFIR"); - if (auto *specific = procRef.proc().GetSpecificIntrinsic()) + if (auto *specific = procRef.proc().GetSpecificIntrinsic()) { + if (isElementalProcWithArrayArgs(procRef)) + TODO(loc, "lowering elemental intrinsic call to HLFIR"); return genIntrinsicRef(procRef, resultType, *specific); + } if (isStatementFunctionCall(procRef)) TODO(loc, "lowering Statement function call to HLFIR"); Fortran::lower::CallerInterface caller(procRef, converter); - using PassBy = Fortran::lower::CallerInterface::PassEntityBy; mlir::FunctionType callSiteType = caller.genFunctionType(); - llvm::SmallVector> - loweredActuals; + PreparedActualArguments loweredActuals; // Lower the actual arguments for (const Fortran::lower::CallInterface< Fortran::lower::CallerInterface>::PassedEntity &arg : @@ -461,41 +468,62 @@ public: const auto *expr = actual->UnwrapExpr(); if (!expr) TODO(loc, "assumed type actual argument"); - loweredActuals.emplace_back(Fortran::lower::convertExprToHLFIR( - loc, getConverter(), *expr, getSymMap(), getStmtCtx())); + + const bool handleDynamicOptional = + arg.isOptional() && Fortran::evaluate::MayBePassedAsAbsentOptional( + *expr, getConverter().getFoldingContext()); + auto loweredActual = Fortran::lower::convertExprToHLFIR( + loc, getConverter(), *expr, getSymMap(), getStmtCtx()); + loweredActuals.emplace_back( + PreparedActualArgument{loweredActual, handleDynamicOptional}); } else { // Optional dummy argument for which there is no actual argument. loweredActuals.emplace_back(std::nullopt); } + if (isElementalProcWithArrayArgs(procRef)) { + bool isImpure = false; + if (const Fortran::semantics::Symbol *procSym = + procRef.proc().GetSymbol()) + isImpure = !Fortran::semantics::IsPureProcedure(*procSym); + return genElementalUserCall(loweredActuals, caller, resultType, + callSiteType, isImpure); + } + return genUserCall(loweredActuals, caller, resultType, callSiteType); + } +private: + llvm::Optional + genUserCall(PreparedActualArguments &loweredActuals, + Fortran::lower::CallerInterface &caller, + llvm::Optional resultType, + mlir::FunctionType callSiteType) { + mlir::Location loc = getLoc(); + fir::FirOpBuilder &builder = getBuilder(); llvm::SmallVector exprAssociations; - for (auto [actual, arg] : + for (auto [preparedActual, arg] : llvm::zip(loweredActuals, caller.getPassedArguments())) { mlir::Type argTy = callSiteType.getInput(arg.firArgument); - if (!actual) { + if (!preparedActual) { // Optional dummy argument for which there is no actual argument. caller.placeInput(arg, builder.create(loc, argTy)); continue; } - + hlfir::Entity actual = preparedActual->actual; const auto *expr = arg.entity->UnwrapExpr(); if (!expr) TODO(loc, "assumed type actual argument"); - const bool actualMayBeDynamicallyAbsent = - arg.isOptional() && Fortran::evaluate::MayBePassedAsAbsentOptional( - *expr, getConverter().getFoldingContext()); - if (actualMayBeDynamicallyAbsent) + if (preparedActual->handleDynamicOptional) TODO(loc, "passing optional arguments in HLFIR"); const bool isSimplyContiguous = - actual->isScalar() || Fortran::evaluate::IsSimplyContiguous( - *expr, getConverter().getFoldingContext()); + actual.isScalar() || Fortran::evaluate::IsSimplyContiguous( + *expr, getConverter().getFoldingContext()); switch (arg.passBy) { case PassBy::Value: { // True pass-by-value semantics. - auto value = hlfir::loadTrivialScalar(loc, builder, *actual); + auto value = hlfir::loadTrivialScalar(loc, builder, actual); if (!value.isValue()) TODO(loc, "Passing CPTR an CFUNCTPTR VALUE in HLFIR"); caller.placeInput(arg, builder.createConvert(loc, argTy, value)); @@ -506,7 +534,7 @@ public: } break; case PassBy::BaseAddress: case PassBy::BoxChar: { - hlfir::Entity entity = *actual; + hlfir::Entity entity = actual; if (entity.isVariable()) { entity = hlfir::derefPointersAndAllocatables(loc, builder, entity); // Copy-in non contiguous variable @@ -556,11 +584,88 @@ public: builder.create(loc, associate); if (!fir::getBase(result)) return std::nullopt; // subroutine call. - return extendedValueToHlfirEntity(result, ".tmp.func_result"); // TODO: "move" non pointer results into hlfir.expr. + return extendedValueToHlfirEntity(result, ".tmp.func_result"); + } + + llvm::Optional + genElementalUserCall(PreparedActualArguments &loweredActuals, + Fortran::lower::CallerInterface &caller, + llvm::Optional resultType, + mlir::FunctionType callSiteType, bool isImpure) { + mlir::Location loc = getLoc(); + fir::FirOpBuilder &builder = getBuilder(); + assert(loweredActuals.size() == caller.getPassedArguments().size()); + unsigned numArgs = loweredActuals.size(); + // Step 1: dereference pointers/allocatables and compute elemental shape. + mlir::Value shape; + // 10.1.4 p5. Impure elemental procedures must be called in element order. + bool mustBeOrdered = isImpure; + for (unsigned i = 0; i < numArgs; ++i) { + const auto &arg = caller.getPassedArguments()[i]; + auto &preparedActual = loweredActuals[i]; + if (preparedActual) { + hlfir::Entity &actual = preparedActual->actual; + // Elemental procedure dummy arguments cannot be pointer/allocatables + // (C15100), so it is safe to dereference any pointer or allocatable + // actual argument now instead of doing this inside the elemental + // region. + actual = hlfir::derefPointersAndAllocatables(loc, builder, actual); + // Better to load scalars outside of the loop when possible. + if (!preparedActual->handleDynamicOptional && + (arg.passBy == PassBy::Value || + arg.passBy == PassBy::BaseAddressValueAttribute)) + actual = hlfir::loadTrivialScalar(loc, builder, actual); + // TODO: merge shape instead of using the first one. + if (!shape && actual.isArray()) { + if (preparedActual->handleDynamicOptional) + TODO(loc, "deal with optional with shapes in HLFIR elemental call"); + shape = hlfir::genShape(loc, builder, actual); + } + // 15.8.3 p1. Elemental procedure with intent(out)/intent(inout) + // arguments must be called in element order. + if (arg.mayBeModifiedByCall()) + mustBeOrdered = true; + } + } + assert(shape && + "elemental array calls must have at least one array arguments"); + if (mustBeOrdered) + TODO(loc, "ordered elemental calls in HLFIR"); + if (!resultType) { + // Subroutine case. Generate call inside loop nest. + auto [innerLoop, oneBasedIndices] = + hlfir::genLoopNest(loc, builder, shape); + auto insPt = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(innerLoop.getBody()); + for (auto &preparedActual : loweredActuals) + if (preparedActual) + preparedActual->actual = hlfir::getElementAt( + loc, builder, preparedActual->actual, oneBasedIndices); + genUserCall(loweredActuals, caller, resultType, callSiteType); + builder.restoreInsertionPoint(insPt); + return std::nullopt; + } + // Function case: generate call inside hlfir.elemental + mlir::Type elementType = hlfir::getFortranElementType(*resultType); + // Get result length parameters. + llvm::SmallVector typeParams; + if (elementType.isa() || + fir::isRecordWithTypeParameters(elementType)) + TODO(loc, "compute elemental function result length parameters in HLFIR"); + auto genKernel = [&](mlir::Location l, fir::FirOpBuilder &b, + mlir::ValueRange oneBasedIndices) -> hlfir::Entity { + for (auto &preparedActual : loweredActuals) + if (preparedActual) + preparedActual->actual = hlfir::getElementAt( + l, b, preparedActual->actual, oneBasedIndices); + return *genUserCall(loweredActuals, caller, resultType, callSiteType); + }; + // TODO: deal with hlfir.elemental result destruction. + return hlfir::EntityWithAttributes{hlfir::genElementalOp( + loc, builder, elementType, shape, typeParams, genKernel)}; } -private: hlfir::EntityWithAttributes genIntrinsicRef(const Fortran::evaluate::ProcedureRef &procRef, llvm::Optional resultType, diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp index cdb78bd..d096ca9 100644 --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -396,6 +396,26 @@ mlir::Value hlfir::genShape(mlir::Location loc, fir::FirOpBuilder &builder, return builder.create(loc, extents); } +llvm::SmallVector +hlfir::getIndexExtents(mlir::Location loc, fir::FirOpBuilder &builder, + mlir::Value shape) { + llvm::SmallVector extents; + if (auto s = shape.getDefiningOp()) { + auto e = s.getExtents(); + extents.insert(extents.end(), e.begin(), e.end()); + } else if (auto s = shape.getDefiningOp()) { + auto e = s.getExtents(); + extents.insert(extents.end(), e.begin(), e.end()); + } else { + // TODO: add fir.get_extent ops on fir.shape<> ops. + TODO(loc, "get extents from fir.shape without fir::ShapeOp parent op"); + } + mlir::Type indexType = builder.getIndexType(); + for (auto &extent : extents) + extent = builder.createConvert(loc, indexType, extent); + return extents; +} + void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder, Entity entity, llvm::SmallVectorImpl &result) { diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp index 722a269..a6b4492 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -95,26 +95,6 @@ static mlir::Value getBufferizedExprMustFreeFlag(mlir::Value bufferizedExpr) { TODO(bufferizedExpr.getLoc(), "general extract storage case"); } -static llvm::SmallVector -getIndexExtents(mlir::Location loc, fir::FirOpBuilder &builder, - mlir::Value shape) { - llvm::SmallVector extents; - if (auto s = shape.getDefiningOp()) { - auto e = s.getExtents(); - extents.insert(extents.end(), e.begin(), e.end()); - } else if (auto s = shape.getDefiningOp()) { - auto e = s.getExtents(); - extents.insert(extents.end(), e.begin(), e.end()); - } else { - // TODO: add fir.get_extent ops on fir.shape<> ops. - TODO(loc, "get extents from fir.shape without fir::ShapeOp parent op"); - } - mlir::Type indexType = builder.getIndexType(); - for (auto &extent : extents) - extent = builder.createConvert(loc, indexType, extent); - return extents; -} - static std::pair createTempFromMold(mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity mold) { @@ -128,7 +108,7 @@ createTempFromMold(mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type sequenceType = hlfir::getFortranElementOrSequenceType(mold.getType()); shape = hlfir::genShape(loc, builder, mold); - auto extents = getIndexExtents(loc, builder, shape); + auto extents = hlfir::getIndexExtents(loc, builder, shape); alloc = builder.createHeapTemporary(loc, sequenceType, tmpName, extents, lenParams); isHeapAlloc = builder.createBool(loc, true); @@ -369,7 +349,7 @@ struct ElementalOpConversion builder.setListener(&listener); mlir::Value shape = adaptor.getShape(); - auto extents = getIndexExtents(loc, builder, shape); + auto extents = hlfir::getIndexExtents(loc, builder, shape); auto [temp, cleanup] = createArrayTemp(loc, builder, elemental.getType(), shape, extents, adaptor.getTypeparams()); diff --git a/flang/test/Lower/HLFIR/elemental-user-procedure-ref.f90 b/flang/test/Lower/HLFIR/elemental-user-procedure-ref.f90 new file mode 100644 index 0000000..9eb8f2e --- /dev/null +++ b/flang/test/Lower/HLFIR/elemental-user-procedure-ref.f90 @@ -0,0 +1,92 @@ +! Test lowering of user defined elemental procedure reference to HLFIR +! RUN: bbc -emit-fir -hlfir -o - %s 2>&1 | FileCheck %s + +subroutine by_addr(x, y) + integer :: x + real :: y(100) + interface + real elemental function elem(a, b) + integer, intent(in) :: a + real, intent(in) :: b + end function + end interface + call baz(elem(x, y)) +end subroutine +! CHECK-LABEL: func.func @_QPby_addr( +! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0:.*]] {{.*}}x +! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_1:.*]](%[[VAL_4:[^)]*]]) {{.*}}y +! CHECK: %[[VAL_6:.*]] = hlfir.elemental %[[VAL_4]] : (!fir.shape<1>) -> !hlfir.expr<100xf32> { +! CHECK: ^bb0(%[[VAL_7:.*]]: index): +! CHECK: %[[VAL_8:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_7]]) : (!fir.ref>, index) -> !fir.ref +! CHECK: %[[VAL_9:.*]] = fir.call @_QPelem(%[[VAL_2]]#1, %[[VAL_8]]) fastmath : (!fir.ref, !fir.ref) -> f32 +! CHECK: hlfir.yield_element %[[VAL_9]] : f32 +! CHECK: } + +subroutine by_value(x, y) + integer :: x + real :: y(10, 20) + interface + real elemental function elem_val(a, b) + integer, value :: a + real, value :: b + end function + end interface + call baz(elem_val(x, y)) +end subroutine +! CHECK-LABEL: func.func @_QPby_value( +! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0:.*]] {{.*}}x +! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1:.*]](%[[VAL_5:[^)]*]]) {{.*}}y +! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_2]]#0 : !fir.ref +! CHECK: %[[VAL_8:.*]] = hlfir.elemental %[[VAL_5]] : (!fir.shape<2>) -> !hlfir.expr<10x20xf32> { +! CHECK: ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: index): +! CHECK: %[[VAL_11:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_9]], %[[VAL_10]]) : (!fir.ref>, index, index) -> !fir.ref +! CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_11]] : !fir.ref +! CHECK: %[[VAL_13:.*]] = fir.call @_QPelem_val(%[[VAL_7]], %[[VAL_12]]) fastmath : (i32, f32) -> f32 +! CHECK: hlfir.yield_element %[[VAL_13]] : f32 +! CHECK: } + +subroutine by_boxaddr(x, y) + character(*) :: x + character(*) :: y(100) + interface + real elemental function char_elem(a, b) + character(*), intent(in) :: a + character(*), intent(in) :: b + end function + end interface + call baz2(char_elem(x, y)) +end subroutine +! CHECK-LABEL: func.func @_QPby_boxaddr( +! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2:.*]]#0 typeparams %[[VAL_2]]#1 {{.*}}x +! CHECK: %[[VAL_6:.*]] = arith.constant 100 : index +! CHECK: %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_5:.*]](%[[VAL_7:.*]]) typeparams %[[VAL_4:.*]]#1 {{.*}}y +! CHECK: %[[VAL_9:.*]] = hlfir.elemental %[[VAL_7]] : (!fir.shape<1>) -> !hlfir.expr<100xf32> { +! CHECK: ^bb0(%[[VAL_10:.*]]: index): +! CHECK: %[[VAL_11:.*]] = hlfir.designate %[[VAL_8]]#0 (%[[VAL_10]]) typeparams %[[VAL_4]]#1 : (!fir.box>>, index, index) -> !fir.boxchar<1> +! CHECK: %[[VAL_12:.*]] = fir.call @_QPchar_elem(%[[VAL_3]]#0, %[[VAL_11]]) fastmath : (!fir.boxchar<1>, !fir.boxchar<1>) -> f32 +! CHECK: hlfir.yield_element %[[VAL_12]] : f32 +! CHECK: } + +subroutine sub(x, y) + integer :: x + real :: y(10, 20) + interface + elemental subroutine elem_sub(a, b) + integer, intent(in) :: a + real, intent(in) :: b + end subroutine + end interface + call elem_sub(x, y) +end subroutine +! CHECK-LABEL: func.func @_QPsub( +! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0:.*]] {{.*}}x +! CHECK: %[[VAL_3:.*]] = arith.constant 10 : index +! CHECK: %[[VAL_4:.*]] = arith.constant 20 : index +! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1:.*]](%[[VAL_5:[^)]*]]) {{.*}}y +! CHECK: %[[VAL_7:.*]] = arith.constant 1 : index +! CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_7]] { +! CHECK: fir.do_loop %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_3]] step %[[VAL_7]] { +! CHECK: %[[VAL_10:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_9]], %[[VAL_8]]) : (!fir.ref>, index, index) -> !fir.ref +! CHECK: fir.call @_QPelem_sub(%[[VAL_2]]#1, %[[VAL_10]]) fastmath : (!fir.ref, !fir.ref) -> () +! CHECK: } +! CHECK: } -- 2.7.4