[flang] Lower procedure ref to user defined elemental procedures (part 1)
authorJean Perier <jperier@nvidia.com>
Fri, 16 Dec 2022 11:59:12 +0000 (12:59 +0100)
committerJean Perier <jperier@nvidia.com>
Fri, 16 Dec 2022 12:04:04 +0000 (13:04 +0100)
Lower procedure ref to user defined elemental procedure when:
- there are no arguments that may be dynamically optional
- for functions, the result has no length parameters
- the reference can be unordered
- there are not character by value arguments

This uses the recently added hlfir.elemental operation and tools.
The "core" of the argument preparation is shared between elemental
and non elemental calls (genUserCalls is code moved without any
functional changes)

Differential Revision: https://reviews.llvm.org/D140118

flang/include/flang/Optimizer/Builder/HLFIRTools.h
flang/lib/Lower/CallInterface.cpp
flang/lib/Lower/ConvertCall.cpp
flang/lib/Optimizer/Builder/HLFIRTools.cpp
flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
flang/test/Lower/HLFIR/elemental-user-procedure-ref.f90 [new file with mode: 0644]

index 08cd7d1..0cdac83 100644 (file)
@@ -227,6 +227,12 @@ genBounds(mlir::Location loc, fir::FirOpBuilder &builder, Entity entity);
 mlir::Value genShape(mlir::Location loc, fir::FirOpBuilder &builder,
                      Entity entity);
 
+/// Generate a vector of extents with index type from a fir.shape
+/// of fir.shape_shift value.
+llvm::SmallVector<mlir::Value> getIndexExtents(mlir::Location loc,
+                                               fir::FirOpBuilder &builder,
+                                               mlir::Value shape);
+
 /// Read length parameters into result if this entity has any.
 void genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
                          Entity entity,
@@ -260,6 +266,10 @@ hlfir::ElementalOp genElementalOp(mlir::Location loc,
 std::pair<fir::DoLoopOp, llvm::SmallVector<mlir::Value>>
 genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
             mlir::ValueRange extents);
+inline std::pair<fir::DoLoopOp, llvm::SmallVector<mlir::Value>>
+genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder, mlir::Value shape) {
+  return genLoopNest(loc, builder, getIndexExtents(loc, builder, shape));
+}
 
 /// Inline the body of an hlfir.elemental at the current insertion point
 /// given a list of one based indices. This generates the computation
index 2fae18b..8f0c0db 100644 (file)
@@ -1110,7 +1110,14 @@ bool Fortran::lower::CallInterface<T>::PassedEntity::mayBeModifiedByCall()
     const {
   if (!characteristics)
     return true;
-  return characteristics->GetIntent() != Fortran::common::Intent::In;
+  if (characteristics->GetIntent() == Fortran::common::Intent::In)
+    return false;
+  const auto *dummy =
+      std::get_if<Fortran::evaluate::characteristics::DummyDataObject>(
+          &characteristics->u);
+  return !dummy ||
+         !dummy->attrs.test(
+             Fortran::evaluate::characteristics::DummyDataObject::Attr::Value);
 }
 template <typename T>
 bool Fortran::lower::CallInterface<T>::PassedEntity::mayBeReadByCall() const {
index 8a11bfe..a2852b3 100644 (file)
@@ -429,6 +429,15 @@ isStatementFunctionCall(const Fortran::evaluate::ProcedureRef &procRef) {
 
 namespace {
 class CallBuilder {
+private:
+  struct PreparedActualArgument {
+    hlfir::Entity actual;
+    bool handleDynamicOptional;
+  };
+  using PreparedActualArguments =
+      llvm::SmallVector<llvm::Optional<PreparedActualArgument>>;
+  using PassBy = Fortran::lower::CallerInterface::PassEntityBy;
+
 public:
   CallBuilder(mlir::Location loc, Fortran::lower::AbstractConverter &converter,
               Fortran::lower::SymMap &symMap,
@@ -439,20 +448,18 @@ public:
   gen(const Fortran::evaluate::ProcedureRef &procRef,
       llvm::Optional<mlir::Type> resultType) {
     mlir::Location loc = getLoc();
-    fir::FirOpBuilder &builder = getBuilder();
-    if (isElementalProcWithArrayArgs(procRef))
-      TODO(loc, "lowering elemental call to HLFIR");
-    if (auto *specific = procRef.proc().GetSpecificIntrinsic())
+    if (auto *specific = procRef.proc().GetSpecificIntrinsic()) {
+      if (isElementalProcWithArrayArgs(procRef))
+        TODO(loc, "lowering elemental intrinsic call to HLFIR");
       return genIntrinsicRef(procRef, resultType, *specific);
+    }
     if (isStatementFunctionCall(procRef))
       TODO(loc, "lowering Statement function call to HLFIR");
 
     Fortran::lower::CallerInterface caller(procRef, converter);
-    using PassBy = Fortran::lower::CallerInterface::PassEntityBy;
     mlir::FunctionType callSiteType = caller.genFunctionType();
 
-    llvm::SmallVector<llvm::Optional<hlfir::EntityWithAttributes>>
-        loweredActuals;
+    PreparedActualArguments loweredActuals;
     // Lower the actual arguments
     for (const Fortran::lower::CallInterface<
              Fortran::lower::CallerInterface>::PassedEntity &arg :
@@ -461,41 +468,62 @@ public:
         const auto *expr = actual->UnwrapExpr();
         if (!expr)
           TODO(loc, "assumed type actual argument");
-        loweredActuals.emplace_back(Fortran::lower::convertExprToHLFIR(
-            loc, getConverter(), *expr, getSymMap(), getStmtCtx()));
+
+        const bool handleDynamicOptional =
+            arg.isOptional() && Fortran::evaluate::MayBePassedAsAbsentOptional(
+                                    *expr, getConverter().getFoldingContext());
+        auto loweredActual = Fortran::lower::convertExprToHLFIR(
+            loc, getConverter(), *expr, getSymMap(), getStmtCtx());
+        loweredActuals.emplace_back(
+            PreparedActualArgument{loweredActual, handleDynamicOptional});
       } else {
         // Optional dummy argument for which there is no actual argument.
         loweredActuals.emplace_back(std::nullopt);
       }
+    if (isElementalProcWithArrayArgs(procRef)) {
+      bool isImpure = false;
+      if (const Fortran::semantics::Symbol *procSym =
+              procRef.proc().GetSymbol())
+        isImpure = !Fortran::semantics::IsPureProcedure(*procSym);
+      return genElementalUserCall(loweredActuals, caller, resultType,
+                                  callSiteType, isImpure);
+    }
+    return genUserCall(loweredActuals, caller, resultType, callSiteType);
+  }
 
+private:
+  llvm::Optional<hlfir::EntityWithAttributes>
+  genUserCall(PreparedActualArguments &loweredActuals,
+              Fortran::lower::CallerInterface &caller,
+              llvm::Optional<mlir::Type> resultType,
+              mlir::FunctionType callSiteType) {
+    mlir::Location loc = getLoc();
+    fir::FirOpBuilder &builder = getBuilder();
     llvm::SmallVector<hlfir::AssociateOp> exprAssociations;
-    for (auto [actual, arg] :
+    for (auto [preparedActual, arg] :
          llvm::zip(loweredActuals, caller.getPassedArguments())) {
       mlir::Type argTy = callSiteType.getInput(arg.firArgument);
-      if (!actual) {
+      if (!preparedActual) {
         // Optional dummy argument for which there is no actual argument.
         caller.placeInput(arg, builder.create<fir::AbsentOp>(loc, argTy));
         continue;
       }
-
+      hlfir::Entity actual = preparedActual->actual;
       const auto *expr = arg.entity->UnwrapExpr();
       if (!expr)
         TODO(loc, "assumed type actual argument");
 
-      const bool actualMayBeDynamicallyAbsent =
-          arg.isOptional() && Fortran::evaluate::MayBePassedAsAbsentOptional(
-                                  *expr, getConverter().getFoldingContext());
-      if (actualMayBeDynamicallyAbsent)
+      if (preparedActual->handleDynamicOptional)
         TODO(loc, "passing optional arguments in HLFIR");
 
       const bool isSimplyContiguous =
-          actual->isScalar() || Fortran::evaluate::IsSimplyContiguous(
-                                    *expr, getConverter().getFoldingContext());
+          actual.isScalar() || Fortran::evaluate::IsSimplyContiguous(
+                                   *expr, getConverter().getFoldingContext());
 
       switch (arg.passBy) {
       case PassBy::Value: {
         // True pass-by-value semantics.
-        auto value = hlfir::loadTrivialScalar(loc, builder, *actual);
+        auto value = hlfir::loadTrivialScalar(loc, builder, actual);
         if (!value.isValue())
           TODO(loc, "Passing CPTR an CFUNCTPTR VALUE in HLFIR");
         caller.placeInput(arg, builder.createConvert(loc, argTy, value));
@@ -506,7 +534,7 @@ public:
       } break;
       case PassBy::BaseAddress:
       case PassBy::BoxChar: {
-        hlfir::Entity entity = *actual;
+        hlfir::Entity entity = actual;
         if (entity.isVariable()) {
           entity = hlfir::derefPointersAndAllocatables(loc, builder, entity);
           // Copy-in non contiguous variable
@@ -556,11 +584,88 @@ public:
       builder.create<hlfir::EndAssociateOp>(loc, associate);
     if (!fir::getBase(result))
       return std::nullopt; // subroutine call.
-    return extendedValueToHlfirEntity(result, ".tmp.func_result");
     // TODO: "move" non pointer results into hlfir.expr.
+    return extendedValueToHlfirEntity(result, ".tmp.func_result");
+  }
+
+  llvm::Optional<hlfir::EntityWithAttributes>
+  genElementalUserCall(PreparedActualArguments &loweredActuals,
+                       Fortran::lower::CallerInterface &caller,
+                       llvm::Optional<mlir::Type> resultType,
+                       mlir::FunctionType callSiteType, bool isImpure) {
+    mlir::Location loc = getLoc();
+    fir::FirOpBuilder &builder = getBuilder();
+    assert(loweredActuals.size() == caller.getPassedArguments().size());
+    unsigned numArgs = loweredActuals.size();
+    // Step 1: dereference pointers/allocatables and compute elemental shape.
+    mlir::Value shape;
+    // 10.1.4 p5. Impure elemental procedures must be called in element order.
+    bool mustBeOrdered = isImpure;
+    for (unsigned i = 0; i < numArgs; ++i) {
+      const auto &arg = caller.getPassedArguments()[i];
+      auto &preparedActual = loweredActuals[i];
+      if (preparedActual) {
+        hlfir::Entity &actual = preparedActual->actual;
+        // Elemental procedure dummy arguments cannot be pointer/allocatables
+        // (C15100), so it is safe to dereference any pointer or allocatable
+        // actual argument now instead of doing this inside the elemental
+        // region.
+        actual = hlfir::derefPointersAndAllocatables(loc, builder, actual);
+        // Better to load scalars outside of the loop when possible.
+        if (!preparedActual->handleDynamicOptional &&
+            (arg.passBy == PassBy::Value ||
+             arg.passBy == PassBy::BaseAddressValueAttribute))
+          actual = hlfir::loadTrivialScalar(loc, builder, actual);
+        // TODO: merge shape instead of using the first one.
+        if (!shape && actual.isArray()) {
+          if (preparedActual->handleDynamicOptional)
+            TODO(loc, "deal with optional with shapes in HLFIR elemental call");
+          shape = hlfir::genShape(loc, builder, actual);
+        }
+        // 15.8.3 p1. Elemental procedure with intent(out)/intent(inout)
+        // arguments must be called in element order.
+        if (arg.mayBeModifiedByCall())
+          mustBeOrdered = true;
+      }
+    }
+    assert(shape &&
+           "elemental array calls must have at least one array arguments");
+    if (mustBeOrdered)
+      TODO(loc, "ordered elemental calls in HLFIR");
+    if (!resultType) {
+      // Subroutine case. Generate call inside loop nest.
+      auto [innerLoop, oneBasedIndices] =
+          hlfir::genLoopNest(loc, builder, shape);
+      auto insPt = builder.saveInsertionPoint();
+      builder.setInsertionPointToStart(innerLoop.getBody());
+      for (auto &preparedActual : loweredActuals)
+        if (preparedActual)
+          preparedActual->actual = hlfir::getElementAt(
+              loc, builder, preparedActual->actual, oneBasedIndices);
+      genUserCall(loweredActuals, caller, resultType, callSiteType);
+      builder.restoreInsertionPoint(insPt);
+      return std::nullopt;
+    }
+    // Function case: generate call inside hlfir.elemental
+    mlir::Type elementType = hlfir::getFortranElementType(*resultType);
+    // Get result length parameters.
+    llvm::SmallVector<mlir::Value> typeParams;
+    if (elementType.isa<fir::CharacterType>() ||
+        fir::isRecordWithTypeParameters(elementType))
+      TODO(loc, "compute elemental function result length parameters in HLFIR");
+    auto genKernel = [&](mlir::Location l, fir::FirOpBuilder &b,
+                         mlir::ValueRange oneBasedIndices) -> hlfir::Entity {
+      for (auto &preparedActual : loweredActuals)
+        if (preparedActual)
+          preparedActual->actual = hlfir::getElementAt(
+              l, b, preparedActual->actual, oneBasedIndices);
+      return *genUserCall(loweredActuals, caller, resultType, callSiteType);
+    };
+    // TODO: deal with hlfir.elemental result destruction.
+    return hlfir::EntityWithAttributes{hlfir::genElementalOp(
+        loc, builder, elementType, shape, typeParams, genKernel)};
   }
 
-private:
   hlfir::EntityWithAttributes
   genIntrinsicRef(const Fortran::evaluate::ProcedureRef &procRef,
                   llvm::Optional<mlir::Type> resultType,
index cdb78bd..d096ca9 100644 (file)
@@ -396,6 +396,26 @@ mlir::Value hlfir::genShape(mlir::Location loc, fir::FirOpBuilder &builder,
   return builder.create<fir::ShapeOp>(loc, extents);
 }
 
+llvm::SmallVector<mlir::Value>
+hlfir::getIndexExtents(mlir::Location loc, fir::FirOpBuilder &builder,
+                       mlir::Value shape) {
+  llvm::SmallVector<mlir::Value> extents;
+  if (auto s = shape.getDefiningOp<fir::ShapeOp>()) {
+    auto e = s.getExtents();
+    extents.insert(extents.end(), e.begin(), e.end());
+  } else if (auto s = shape.getDefiningOp<fir::ShapeShiftOp>()) {
+    auto e = s.getExtents();
+    extents.insert(extents.end(), e.begin(), e.end());
+  } else {
+    // TODO: add fir.get_extent ops on fir.shape<> ops.
+    TODO(loc, "get extents from fir.shape without fir::ShapeOp parent op");
+  }
+  mlir::Type indexType = builder.getIndexType();
+  for (auto &extent : extents)
+    extent = builder.createConvert(loc, indexType, extent);
+  return extents;
+}
+
 void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
                                 Entity entity,
                                 llvm::SmallVectorImpl<mlir::Value> &result) {
index 722a269..a6b4492 100644 (file)
@@ -95,26 +95,6 @@ static mlir::Value getBufferizedExprMustFreeFlag(mlir::Value bufferizedExpr) {
   TODO(bufferizedExpr.getLoc(), "general extract storage case");
 }
 
-static llvm::SmallVector<mlir::Value>
-getIndexExtents(mlir::Location loc, fir::FirOpBuilder &builder,
-                mlir::Value shape) {
-  llvm::SmallVector<mlir::Value> extents;
-  if (auto s = shape.getDefiningOp<fir::ShapeOp>()) {
-    auto e = s.getExtents();
-    extents.insert(extents.end(), e.begin(), e.end());
-  } else if (auto s = shape.getDefiningOp<fir::ShapeShiftOp>()) {
-    auto e = s.getExtents();
-    extents.insert(extents.end(), e.begin(), e.end());
-  } else {
-    // TODO: add fir.get_extent ops on fir.shape<> ops.
-    TODO(loc, "get extents from fir.shape without fir::ShapeOp parent op");
-  }
-  mlir::Type indexType = builder.getIndexType();
-  for (auto &extent : extents)
-    extent = builder.createConvert(loc, indexType, extent);
-  return extents;
-}
-
 static std::pair<hlfir::Entity, mlir::Value>
 createTempFromMold(mlir::Location loc, fir::FirOpBuilder &builder,
                    hlfir::Entity mold) {
@@ -128,7 +108,7 @@ createTempFromMold(mlir::Location loc, fir::FirOpBuilder &builder,
     mlir::Type sequenceType =
         hlfir::getFortranElementOrSequenceType(mold.getType());
     shape = hlfir::genShape(loc, builder, mold);
-    auto extents = getIndexExtents(loc, builder, shape);
+    auto extents = hlfir::getIndexExtents(loc, builder, shape);
     alloc = builder.createHeapTemporary(loc, sequenceType, tmpName, extents,
                                         lenParams);
     isHeapAlloc = builder.createBool(loc, true);
@@ -369,7 +349,7 @@ struct ElementalOpConversion
     builder.setListener(&listener);
 
     mlir::Value shape = adaptor.getShape();
-    auto extents = getIndexExtents(loc, builder, shape);
+    auto extents = hlfir::getIndexExtents(loc, builder, shape);
     auto [temp, cleanup] =
         createArrayTemp(loc, builder, elemental.getType(), shape, extents,
                         adaptor.getTypeparams());
diff --git a/flang/test/Lower/HLFIR/elemental-user-procedure-ref.f90 b/flang/test/Lower/HLFIR/elemental-user-procedure-ref.f90
new file mode 100644 (file)
index 0000000..9eb8f2e
--- /dev/null
@@ -0,0 +1,92 @@
+! Test lowering of user defined elemental procedure reference to HLFIR
+! RUN: bbc -emit-fir -hlfir -o - %s 2>&1 | FileCheck %s
+
+subroutine by_addr(x, y)
+  integer :: x
+  real :: y(100)
+  interface
+    real elemental function elem(a, b)
+      integer, intent(in) :: a
+      real, intent(in) :: b
+    end function
+  end interface
+  call baz(elem(x, y))
+end subroutine
+! CHECK-LABEL: func.func @_QPby_addr(
+! CHECK:  %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0:.*]] {{.*}}x
+! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_1:.*]](%[[VAL_4:[^)]*]]) {{.*}}y
+! CHECK:  %[[VAL_6:.*]] = hlfir.elemental %[[VAL_4]] : (!fir.shape<1>) -> !hlfir.expr<100xf32> {
+! CHECK:  ^bb0(%[[VAL_7:.*]]: index):
+! CHECK:    %[[VAL_8:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_7]])  : (!fir.ref<!fir.array<100xf32>>, index) -> !fir.ref<f32>
+! CHECK:    %[[VAL_9:.*]] = fir.call @_QPelem(%[[VAL_2]]#1, %[[VAL_8]]) fastmath<contract> : (!fir.ref<i32>, !fir.ref<f32>) -> f32
+! CHECK:    hlfir.yield_element %[[VAL_9]] : f32
+! CHECK:  }
+
+subroutine by_value(x, y)
+  integer :: x
+  real :: y(10, 20)
+  interface
+    real elemental function elem_val(a, b)
+      integer, value :: a
+      real, value :: b
+    end function
+  end interface
+  call baz(elem_val(x, y))
+end subroutine
+! CHECK-LABEL: func.func @_QPby_value(
+! CHECK:  %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0:.*]] {{.*}}x
+! CHECK:  %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1:.*]](%[[VAL_5:[^)]*]]) {{.*}}y
+! CHECK:  %[[VAL_7:.*]] = fir.load %[[VAL_2]]#0 : !fir.ref<i32>
+! CHECK:  %[[VAL_8:.*]] = hlfir.elemental %[[VAL_5]] : (!fir.shape<2>) -> !hlfir.expr<10x20xf32> {
+! CHECK:  ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: index):
+! CHECK:    %[[VAL_11:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_9]], %[[VAL_10]])  : (!fir.ref<!fir.array<10x20xf32>>, index, index) -> !fir.ref<f32>
+! CHECK:    %[[VAL_12:.*]] = fir.load %[[VAL_11]] : !fir.ref<f32>
+! CHECK:    %[[VAL_13:.*]] = fir.call @_QPelem_val(%[[VAL_7]], %[[VAL_12]]) fastmath<contract> : (i32, f32) -> f32
+! CHECK:    hlfir.yield_element %[[VAL_13]] : f32
+! CHECK:  }
+
+subroutine by_boxaddr(x, y)
+  character(*) :: x
+  character(*) :: y(100)
+  interface
+    real elemental function char_elem(a, b)
+      character(*), intent(in) :: a
+      character(*), intent(in) :: b
+    end function
+  end interface
+  call baz2(char_elem(x, y))
+end subroutine
+! CHECK-LABEL: func.func @_QPby_boxaddr(
+! CHECK:  %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2:.*]]#0 typeparams %[[VAL_2]]#1 {{.*}}x
+! CHECK:  %[[VAL_6:.*]] = arith.constant 100 : index
+! CHECK:  %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_5:.*]](%[[VAL_7:.*]]) typeparams %[[VAL_4:.*]]#1 {{.*}}y
+! CHECK:  %[[VAL_9:.*]] = hlfir.elemental %[[VAL_7]] : (!fir.shape<1>) -> !hlfir.expr<100xf32> {
+! CHECK:  ^bb0(%[[VAL_10:.*]]: index):
+! CHECK:    %[[VAL_11:.*]] = hlfir.designate %[[VAL_8]]#0 (%[[VAL_10]])  typeparams %[[VAL_4]]#1 : (!fir.box<!fir.array<100x!fir.char<1,?>>>, index, index) -> !fir.boxchar<1>
+! CHECK:    %[[VAL_12:.*]] = fir.call @_QPchar_elem(%[[VAL_3]]#0, %[[VAL_11]]) fastmath<contract> : (!fir.boxchar<1>, !fir.boxchar<1>) -> f32
+! CHECK:    hlfir.yield_element %[[VAL_12]] : f32
+! CHECK:  }
+
+subroutine sub(x, y)
+  integer :: x
+  real :: y(10, 20)
+  interface
+    elemental subroutine elem_sub(a, b)
+      integer, intent(in) :: a
+      real, intent(in) :: b
+    end subroutine
+  end interface
+  call elem_sub(x, y)
+end subroutine
+! CHECK-LABEL: func.func @_QPsub(
+! CHECK:  %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0:.*]] {{.*}}x
+! CHECK:  %[[VAL_3:.*]] = arith.constant 10 : index
+! CHECK:  %[[VAL_4:.*]] = arith.constant 20 : index
+! CHECK:  %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1:.*]](%[[VAL_5:[^)]*]]) {{.*}}y
+! CHECK:  %[[VAL_7:.*]] = arith.constant 1 : index
+! CHECK:  fir.do_loop %[[VAL_8:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_7]] {
+! CHECK:    fir.do_loop %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_3]] step %[[VAL_7]] {
+! CHECK:      %[[VAL_10:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_9]], %[[VAL_8]])  : (!fir.ref<!fir.array<10x20xf32>>, index, index) -> !fir.ref<f32>
+! CHECK:      fir.call @_QPelem_sub(%[[VAL_2]]#1, %[[VAL_10]]) fastmath<contract> : (!fir.ref<i32>, !fir.ref<f32>) -> ()
+! CHECK:    }
+! CHECK:  }