[flang] Add lowering for host association
authorValentin Clement <clementval@gmail.com>
Mon, 7 Mar 2022 18:55:48 +0000 (19:55 +0100)
committerValentin Clement <clementval@gmail.com>
Mon, 7 Mar 2022 18:57:02 +0000 (19:57 +0100)
This patches adds the code to handle host association for
inner subroutines and functions.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D121134

Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
Co-authored-by: Jean Perier <jperier@nvidia.com>
Co-authored-by: V Donaldson <vdonaldson@nvidia.com>
16 files changed:
flang/include/flang/Lower/AbstractConverter.h
flang/include/flang/Lower/CallInterface.h
flang/include/flang/Optimizer/Builder/BoxValue.h
flang/include/flang/Optimizer/Builder/LowLevelIntrinsics.h [new file with mode: 0644]
flang/include/flang/Optimizer/Dialect/FIRType.h
flang/lib/Lower/Bridge.cpp
flang/lib/Lower/CMakeLists.txt
flang/lib/Lower/CallInterface.cpp
flang/lib/Lower/ConvertExpr.cpp
flang/lib/Lower/ConvertVariable.cpp
flang/lib/Lower/HostAssociations.cpp [new file with mode: 0644]
flang/lib/Optimizer/Builder/BoxValue.cpp
flang/lib/Optimizer/Builder/CMakeLists.txt
flang/lib/Optimizer/Builder/LowLevelIntrinsics.cpp [new file with mode: 0644]
flang/lib/Optimizer/Dialect/FIRType.cpp
flang/test/Lower/host-associated.f90 [new file with mode: 0644]

index 657c584..a62ce31 100644 (file)
@@ -126,6 +126,10 @@ public:
   /// which is itself a reference. Use bindTuple() to set this value.
   virtual mlir::Value hostAssocTupleValue() = 0;
 
+  /// Record a binding for the ssa-value of the host assoications tuple for this
+  /// function.
+  virtual void bindHostAssocTuple(mlir::Value val) = 0;
+
   //===--------------------------------------------------------------------===//
   // Types
   //===--------------------------------------------------------------------===//
index c39cbb8..e2dba7c 100644 (file)
@@ -43,6 +43,8 @@ class Location;
 
 namespace Fortran::lower {
 class AbstractConverter;
+class SymMap;
+class HostAssociations;
 namespace pft {
 struct FunctionLikeUnit;
 }
@@ -83,8 +85,8 @@ class CallInterfaceImpl;
 /// can be either a Symbol or an ActualArgument.
 /// It works in two passes: a first pass over the characteristics that decides
 /// how the interface must be. Then, the funcOp is created for it. Then a simple
-/// pass over fir arguments finalizes the interface information that must be
-/// passed back to the user (and may require having the funcOp). All these
+/// pass over fir arguments finalize the interface information that must be
+/// passed back to the user (and may require having the funcOp). All this
 /// passes are driven from the CallInterface constructor.
 template <typename T>
 class CallInterface {
@@ -110,7 +112,6 @@ public:
     // tuple.
     CharProcTuple
   };
-
   /// Different properties of an entity that can be passed/returned.
   /// One-to-One mapping with PassEntityBy but for
   /// PassEntityBy::AddressAndLength that has two properties.
@@ -138,7 +139,7 @@ public:
     /// Type for this input/output
     mlir::Type type;
     /// Position of related passedEntity in passedArguments.
-    /// (passedEntity is the passedResult this value is resultEntityPosition).
+    /// (passedEntity is the passedResult this value is resultEntityPosition.
     int passedEntityPosition;
     static constexpr int resultEntityPosition = -1;
     /// Indicate property of the entity passedEntityPosition that must be passed
@@ -370,10 +371,44 @@ public:
   /// argument symbols.
   mlir::FuncOp addEntryBlockAndMapArguments();
 
+  bool hasHostAssociated() const;
+  mlir::Type getHostAssociatedTy() const;
+  mlir::Value getHostAssociatedTuple() const;
+
 private:
   Fortran::lower::pft::FunctionLikeUnit &funit;
 };
 
+/// Translate a procedure characteristics to an mlir::FunctionType signature.
+mlir::FunctionType
+translateSignature(const Fortran::evaluate::ProcedureDesignator &,
+                   Fortran::lower::AbstractConverter &);
+
+/// Declare or find the mlir::FuncOp named \p name. If the mlir::FuncOp does
+/// not exist yet, declare it with the signature translated from the
+/// ProcedureDesignator argument.
+/// Due to Fortran implicit function typing rules, the returned FuncOp is not
+/// guaranteed to have the signature from ProcedureDesignator if the FuncOp was
+/// already declared.
+mlir::FuncOp
+getOrDeclareFunction(llvm::StringRef name,
+                     const Fortran::evaluate::ProcedureDesignator &,
+                     Fortran::lower::AbstractConverter &);
+
+/// Return the type of an argument that is a dummy procedure. This may be an
+/// mlir::FunctionType, but it can also be a more elaborate type based on the
+/// function type (like a tuple<function type, length type> for character
+/// functions).
+mlir::Type getDummyProcedureType(const Fortran::semantics::Symbol &dummyProc,
+                                 Fortran::lower::AbstractConverter &);
+
+/// Is it required to pass \p proc as a tuple<function address, result length> ?
+// This is required to convey  the length of character functions passed as dummy
+// procedures.
+bool mustPassLengthWithDummyProcedure(
+    const Fortran::evaluate::ProcedureDesignator &proc,
+    Fortran::lower::AbstractConverter &);
+
 } // namespace Fortran::lower
 
 #endif // FORTRAN_LOWER_FIRBUILDER_H
index a1ff8be..134b177 100644 (file)
@@ -346,7 +346,11 @@ public:
   bool isAllocatable() const {
     return getBoxTy().getEleTy().isa<fir::HeapType>();
   }
-  /// Does this entity have any non deferred length parameters ?
+  // Replace the fir.ref<fir.box>, keeping any non-deferred parameters.
+  MutableBoxValue clone(mlir::Value newBox) const {
+    return {newBox, lenParams, mutableProperties};
+  }
+  /// Does this entity has any non deferred length parameters ?
   bool hasNonDeferredLenParams() const { return !lenParams.empty(); }
   /// Return the non deferred length parameters.
   llvm::ArrayRef<mlir::Value> nonDeferredLenParams() const { return lenParams; }
@@ -354,7 +358,7 @@ public:
                                        const MutableBoxValue &);
   LLVM_DUMP_METHOD void dump() const { llvm::errs() << *this; }
 
-  /// Set of variables is used instead of a descriptor to hold the entity
+  /// Set of variable is used instead of a descriptor to hold the entity
   /// properties instead of a fir.ref<fir.box<>>.
   bool isDescribedByVariables() const { return !mutableProperties.isEmpty(); }
 
diff --git a/flang/include/flang/Optimizer/Builder/LowLevelIntrinsics.h b/flang/include/flang/Optimizer/Builder/LowLevelIntrinsics.h
new file mode 100644 (file)
index 0000000..7ef5ff1
--- /dev/null
@@ -0,0 +1,33 @@
+//===-- LowLevelIntrinsics.h ------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FLANG_OPTIMIZER_BUILDER_LOWLEVELINTRINSICS_H
+#define FLANG_OPTIMIZER_BUILDER_LOWLEVELINTRINSICS_H
+
+namespace mlir {
+class FuncOp;
+}
+namespace fir {
+class FirOpBuilder;
+}
+
+namespace fir::factory {
+
+/// Get the `llvm.stacksave` intrinsic.
+mlir::FuncOp getLlvmStackSave(FirOpBuilder &builder);
+
+/// Get the `llvm.stackrestore` intrinsic.
+mlir::FuncOp getLlvmStackRestore(FirOpBuilder &builder);
+
+} // namespace fir::factory
+
+#endif // FLANG_OPTIMIZER_BUILDER_LOWLEVELINTRINSICS_H
index 516d8e5..a0db083 100644 (file)
@@ -191,6 +191,9 @@ inline bool isRecordWithTypeParameters(mlir::Type ty) {
   return false;
 }
 
+/// Is this tuple type holding a character function and its result length ?
+bool isCharacterProcedureTuple(mlir::Type type, bool acceptRawFunc = true);
+
 /// Apply the components specified by `path` to `rootTy` to determine the type
 /// of the resulting component element. `rootTy` should be an aggregate type.
 /// Returns null on error.
index 78206c0..1eeab69 100644 (file)
@@ -117,11 +117,57 @@ public:
     }
     funit.setActiveEntry(0);
 
+    // Compute the set of host associated entities from the nested functions.
+    llvm::SetVector<const Fortran::semantics::Symbol *> escapeHost;
+    for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
+      collectHostAssociatedVariables(f, escapeHost);
+    funit.setHostAssociatedSymbols(escapeHost);
+
     // Declare internal procedures
     for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
       declareFunction(f);
   }
 
+  /// Collects the canonical list of all host associated symbols. These bindings
+  /// must be aggregated into a tuple which can then be added to each of the
+  /// internal procedure declarations and passed at each call site.
+  void collectHostAssociatedVariables(
+      Fortran::lower::pft::FunctionLikeUnit &funit,
+      llvm::SetVector<const Fortran::semantics::Symbol *> &escapees) {
+    const Fortran::semantics::Scope *internalScope =
+        funit.getSubprogramSymbol().scope();
+    assert(internalScope && "internal procedures symbol must create a scope");
+    auto addToListIfEscapee = [&](const Fortran::semantics::Symbol &sym) {
+      const Fortran::semantics::Symbol &ultimate = sym.GetUltimate();
+      const auto *namelistDetails =
+          ultimate.detailsIf<Fortran::semantics::NamelistDetails>();
+      if (ultimate.has<Fortran::semantics::ObjectEntityDetails>() ||
+          Fortran::semantics::IsProcedurePointer(ultimate) ||
+          Fortran::semantics::IsDummy(sym) || namelistDetails) {
+        const Fortran::semantics::Scope &ultimateScope = ultimate.owner();
+        if (ultimateScope.kind() ==
+                Fortran::semantics::Scope::Kind::MainProgram ||
+            ultimateScope.kind() == Fortran::semantics::Scope::Kind::Subprogram)
+          if (ultimateScope != *internalScope &&
+              ultimateScope.Contains(*internalScope)) {
+            if (namelistDetails) {
+              // So far, namelist symbols are processed on the fly in IO and
+              // the related namelist data structure is not added to the symbol
+              // map, so it cannot be passed to the internal procedures.
+              // Instead, all the symbols of the host namelist used in the
+              // internal procedure must be considered as host associated so
+              // that IO lowering can find them when needed.
+              for (const auto &namelistObject : namelistDetails->objects())
+                escapees.insert(&*namelistObject);
+            } else {
+              escapees.insert(&ultimate);
+            }
+          }
+      }
+    };
+    Fortran::lower::pft::visitAllSymbols(funit, addToListIfEscapee);
+  }
+
   //===--------------------------------------------------------------------===//
   // AbstractConverter overrides
   //===--------------------------------------------------------------------===//
@@ -342,9 +388,9 @@ public:
         if (arg.entity.has_value()) {
           addSymbol(arg.entity->get(), arg.firArgument);
         } else {
-          // assert(funit.parentHasHostAssoc());
-          // funit.parentHostAssoc().internalProcedureBindings(*this,
-          //                                                   localSymbols);
+          assert(funit.parentHasHostAssoc());
+          funit.parentHostAssoc().internalProcedureBindings(*this,
+                                                            localSymbols);
         }
       }
     };
@@ -394,22 +440,105 @@ public:
 
     mapDummiesAndResults(funit, callee);
 
+    // Note: not storing Variable references because getOrderedSymbolTable
+    // below returns a temporary.
+    llvm::SmallVector<Fortran::lower::pft::Variable> deferredFuncResultList;
+
+    // Backup actual argument for entry character results
+    // with different lengths. It needs to be added to the non
+    // primary results symbol before mapSymbolAttributes is called.
+    Fortran::lower::SymbolBox resultArg;
+    if (std::optional<Fortran::lower::CalleeInterface::PassedEntity>
+            passedResult = callee.getPassedResult())
+      resultArg = lookupSymbol(passedResult->entity->get());
+
     Fortran::lower::AggregateStoreMap storeMap;
+    // The front-end is currently not adding module variables referenced
+    // in a module procedure as host associated. As a result we need to
+    // instantiate all module variables here if this is a module procedure.
+    // It is likely that the front-end behavior should change here.
+    // This also applies to internal procedures inside module procedures.
+    if (auto *module = Fortran::lower::pft::getAncestor<
+            Fortran::lower::pft::ModuleLikeUnit>(funit))
+      for (const Fortran::lower::pft::Variable &var :
+           module->getOrderedSymbolTable())
+        instantiateVar(var, storeMap);
+
+    mlir::Value primaryFuncResultStorage;
     for (const Fortran::lower::pft::Variable &var :
          funit.getOrderedSymbolTable()) {
+      // Always instantiate aggregate storage blocks.
+      if (var.isAggregateStore()) {
+        instantiateVar(var, storeMap);
+        continue;
+      }
       const Fortran::semantics::Symbol &sym = var.getSymbol();
+      if (funit.parentHasHostAssoc()) {
+        // Never instantitate host associated variables, as they are already
+        // instantiated from an argument tuple. Instead, just bind the symbol to
+        // the reference to the host variable, which must be in the map.
+        const Fortran::semantics::Symbol &ultimate = sym.GetUltimate();
+        if (funit.parentHostAssoc().isAssociated(ultimate)) {
+          Fortran::lower::SymbolBox hostBox =
+              localSymbols.lookupSymbol(ultimate);
+          assert(hostBox && "host association is not in map");
+          localSymbols.addSymbol(sym, hostBox.toExtendedValue());
+          continue;
+        }
+      }
       if (!sym.IsFuncResult() || !funit.primaryResult) {
         instantiateVar(var, storeMap);
       } else if (&sym == funit.primaryResult) {
         instantiateVar(var, storeMap);
+        primaryFuncResultStorage = getSymbolAddress(sym);
+      } else {
+        deferredFuncResultList.push_back(var);
       }
     }
 
+    // If this is a host procedure with host associations, then create the tuple
+    // of pointers for passing to the internal procedures.
+    if (!funit.getHostAssoc().empty())
+      funit.getHostAssoc().hostProcedureBindings(*this, localSymbols);
+
+    /// TODO: should use same mechanism as equivalence?
+    /// One blocking point is character entry returns that need special handling
+    /// since they are not locally allocated but come as argument. CHARACTER(*)
+    /// is not something that fit wells with equivalence lowering.
+    for (const Fortran::lower::pft::Variable &altResult :
+         deferredFuncResultList) {
+      if (std::optional<Fortran::lower::CalleeInterface::PassedEntity>
+              passedResult = callee.getPassedResult())
+        addSymbol(altResult.getSymbol(), resultArg.getAddr());
+      Fortran::lower::StatementContext stmtCtx;
+      Fortran::lower::mapSymbolAttributes(*this, altResult, localSymbols,
+                                          stmtCtx, primaryFuncResultStorage);
+    }
+
     // Create most function blocks in advance.
     createEmptyGlobalBlocks(funit.evaluationList);
 
     // Reinstate entry block as the current insertion point.
     builder->setInsertionPointToEnd(&func.front());
+
+    if (callee.hasAlternateReturns()) {
+      // Create a local temp to hold the alternate return index.
+      // Give it an integer index type and the subroutine name (for dumps).
+      // Attach it to the subroutine symbol in the localSymbols map.
+      // Initialize it to zero, the "fallthrough" alternate return value.
+      const Fortran::semantics::Symbol &symbol = funit.getSubprogramSymbol();
+      mlir::Location loc = toLocation();
+      mlir::Type idxTy = builder->getIndexType();
+      mlir::Value altResult =
+          builder->createTemporary(loc, idxTy, toStringRef(symbol.name()));
+      addSymbol(symbol, altResult);
+      mlir::Value zero = builder->createIntegerConstant(loc, idxTy, 0);
+      builder->create<fir::StoreOp>(loc, zero, altResult);
+    }
+
+    if (Fortran::lower::pft::Evaluation *alternateEntryEval =
+            funit.getEntryEval())
+      genFIRBranch(alternateEntryEval->lexicalSuccessor->block);
   }
 
   /// Create global blocks for the current function.  This eliminates the
@@ -432,7 +561,11 @@ public:
         if (eval.lowerAsUnstructured()) {
           createEmptyGlobalBlocks(eval.getNestedEvaluations());
         } else if (eval.hasNestedEvaluations()) {
-          TODO(toLocation(), "Constructs with nested evaluations");
+          // A structured construct that is a target starts a new block.
+          Fortran::lower::pft::Evaluation &constructStmt =
+              eval.getFirstNestedEvaluation();
+          if (constructStmt.isNewBlock)
+            constructStmt.block = builder->createBlock(region);
         }
       }
     }
@@ -440,6 +573,14 @@ public:
 
   /// Lower a procedure (nest).
   void lowerFunc(Fortran::lower::pft::FunctionLikeUnit &funit) {
+    if (!funit.isMainProgram()) {
+      const Fortran::semantics::Symbol &procSymbol =
+          funit.getSubprogramSymbol();
+      if (procSymbol.owner().IsSubmodule()) {
+        TODO(toLocation(), "support submodules");
+        return;
+      }
+    }
     setCurrentPosition(funit.getStartingSourceLoc());
     for (int entryIndex = 0, last = funit.entryPointList.size();
          entryIndex < last; ++entryIndex) {
@@ -491,6 +632,12 @@ public:
 
   mlir::Value hostAssocTupleValue() override final { return hostAssocTuple; }
 
+  /// Record a binding for the ssa-value of the tuple for this function.
+  void bindHostAssocTuple(mlir::Value val) override final {
+    assert(!hostAssocTuple && val);
+    hostAssocTuple = val;
+  }
+
 private:
   FirConverter() = delete;
   FirConverter(const FirConverter &) = delete;
@@ -500,6 +647,12 @@ private:
   // Helper member functions
   //===--------------------------------------------------------------------===//
 
+  mlir::Value createFIRExpr(mlir::Location loc,
+                            const Fortran::lower::SomeExpr *expr,
+                            Fortran::lower::StatementContext &stmtCtx) {
+    return fir::getBase(genExprValue(*expr, stmtCtx, &loc));
+  }
+
   /// Find the symbol in the local map or return null.
   Fortran::lower::SymbolBox
   lookupSymbol(const Fortran::semantics::Symbol &sym) {
@@ -548,6 +701,39 @@ private:
     builder->create<cf::BranchOp>(toLocation(), targetBlock);
   }
 
+  void genFIRConditionalBranch(mlir::Value cond, mlir::Block *trueTarget,
+                               mlir::Block *falseTarget) {
+    assert(trueTarget && "missing conditional branch true block");
+    assert(falseTarget && "missing conditional branch false block");
+    mlir::Location loc = toLocation();
+    mlir::Value bcc = builder->createConvert(loc, builder->getI1Type(), cond);
+    builder->create<mlir::cf::CondBranchOp>(loc, bcc, trueTarget, llvm::None,
+                                            falseTarget, llvm::None);
+  }
+  void genFIRConditionalBranch(mlir::Value cond,
+                               Fortran::lower::pft::Evaluation *trueTarget,
+                               Fortran::lower::pft::Evaluation *falseTarget) {
+    genFIRConditionalBranch(cond, trueTarget->block, falseTarget->block);
+  }
+  void genFIRConditionalBranch(const Fortran::parser::ScalarLogicalExpr &expr,
+                               mlir::Block *trueTarget,
+                               mlir::Block *falseTarget) {
+    Fortran::lower::StatementContext stmtCtx;
+    mlir::Value cond =
+        createFIRExpr(toLocation(), Fortran::semantics::GetExpr(expr), stmtCtx);
+    stmtCtx.finalize();
+    genFIRConditionalBranch(cond, trueTarget, falseTarget);
+  }
+  void genFIRConditionalBranch(const Fortran::parser::ScalarLogicalExpr &expr,
+                               Fortran::lower::pft::Evaluation *trueTarget,
+                               Fortran::lower::pft::Evaluation *falseTarget) {
+    Fortran::lower::StatementContext stmtCtx;
+    mlir::Value cond =
+        createFIRExpr(toLocation(), Fortran::semantics::GetExpr(expr), stmtCtx);
+    stmtCtx.finalize();
+    genFIRConditionalBranch(cond, trueTarget->block, falseTarget->block);
+  }
+
   //===--------------------------------------------------------------------===//
   // Termination of symbolically referenced execution units
   //===--------------------------------------------------------------------===//
@@ -608,6 +794,29 @@ private:
     }
   }
 
+  //
+  // Statements that have control-flow semantics
+  //
+
+  /// Generate an If[Then]Stmt condition or its negation.
+  template <typename A>
+  mlir::Value genIfCondition(const A *stmt, bool negate = false) {
+    mlir::Location loc = toLocation();
+    Fortran::lower::StatementContext stmtCtx;
+    mlir::Value condExpr = createFIRExpr(
+        loc,
+        Fortran::semantics::GetExpr(
+            std::get<Fortran::parser::ScalarLogicalExpr>(stmt->t)),
+        stmtCtx);
+    stmtCtx.finalize();
+    mlir::Value cond =
+        builder->createConvert(loc, builder->getI1Type(), condExpr);
+    if (negate)
+      cond = builder->create<mlir::arith::XOrIOp>(
+          loc, cond, builder->createIntegerConstant(loc, cond.getType(), 1));
+    return cond;
+  }
+
   [[maybe_unused]] static bool
   isFuncResultDesignator(const Fortran::lower::SomeExpr &expr) {
     const Fortran::semantics::Symbol *sym =
@@ -769,7 +978,59 @@ private:
   }
 
   void genFIR(const Fortran::parser::IfConstruct &) {
-    TODO(toLocation(), "IfConstruct lowering");
+    mlir::Location loc = toLocation();
+    Fortran::lower::pft::Evaluation &eval = getEval();
+    if (eval.lowerAsStructured()) {
+      // Structured fir.if nest.
+      fir::IfOp topIfOp, currentIfOp;
+      for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) {
+        auto genIfOp = [&](mlir::Value cond) {
+          auto ifOp = builder->create<fir::IfOp>(loc, cond, /*withElse=*/true);
+          builder->setInsertionPointToStart(&ifOp.getThenRegion().front());
+          return ifOp;
+        };
+        if (auto *s = e.getIf<Fortran::parser::IfThenStmt>()) {
+          topIfOp = currentIfOp = genIfOp(genIfCondition(s, e.negateCondition));
+        } else if (auto *s = e.getIf<Fortran::parser::IfStmt>()) {
+          topIfOp = currentIfOp = genIfOp(genIfCondition(s, e.negateCondition));
+        } else if (auto *s = e.getIf<Fortran::parser::ElseIfStmt>()) {
+          builder->setInsertionPointToStart(
+              &currentIfOp.getElseRegion().front());
+          currentIfOp = genIfOp(genIfCondition(s));
+        } else if (e.isA<Fortran::parser::ElseStmt>()) {
+          builder->setInsertionPointToStart(
+              &currentIfOp.getElseRegion().front());
+        } else if (e.isA<Fortran::parser::EndIfStmt>()) {
+          builder->setInsertionPointAfter(topIfOp);
+        } else {
+          genFIR(e, /*unstructuredContext=*/false);
+        }
+      }
+      return;
+    }
+
+    // Unstructured branch sequence.
+    for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) {
+      auto genIfBranch = [&](mlir::Value cond) {
+        if (e.lexicalSuccessor == e.controlSuccessor) // empty block -> exit
+          genFIRConditionalBranch(cond, e.parentConstruct->constructExit,
+                                  e.controlSuccessor);
+        else // non-empty block
+          genFIRConditionalBranch(cond, e.lexicalSuccessor, e.controlSuccessor);
+      };
+      if (auto *s = e.getIf<Fortran::parser::IfThenStmt>()) {
+        maybeStartBlock(e.block);
+        genIfBranch(genIfCondition(s, e.negateCondition));
+      } else if (auto *s = e.getIf<Fortran::parser::IfStmt>()) {
+        maybeStartBlock(e.block);
+        genIfBranch(genIfCondition(s, e.negateCondition));
+      } else if (auto *s = e.getIf<Fortran::parser::ElseIfStmt>()) {
+        startBlock(e.block);
+        genIfBranch(genIfCondition(s));
+      } else {
+        genFIR(e);
+      }
+    }
   }
 
   void genFIR(const Fortran::parser::CaseConstruct &) {
index 297cc9b..6503c8a 100644 (file)
@@ -8,10 +8,11 @@ add_flang_library(FortranLower
   ConvertExpr.cpp
   ConvertType.cpp
   ConvertVariable.cpp
-  IntrinsicCall.cpp
-  IO.cpp
   ComponentPath.cpp
   DumpEvaluateExpr.cpp
+  HostAssociations.cpp
+  IntrinsicCall.cpp
+  IO.cpp
   IterationSpace.cpp
   Mangler.cpp
   OpenACC.cpp
index da00d7f..d1cd6f5 100644 (file)
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/Mangler.h"
 #include "flang/Lower/PFTBuilder.h"
+#include "flang/Lower/StatementContext.h"
 #include "flang/Lower/Support/Utils.h"
 #include "flang/Lower/Todo.h"
+#include "flang/Optimizer/Builder/Character.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
@@ -30,6 +32,26 @@ static std::string getMangledName(const Fortran::semantics::Symbol &symbol) {
   return bindName ? *bindName : Fortran::lower::mangle::mangleName(symbol);
 }
 
+/// Return the type of a dummy procedure given its characteristic (if it has
+/// one).
+mlir::Type getProcedureDesignatorType(
+    const Fortran::evaluate::characteristics::Procedure *,
+    Fortran::lower::AbstractConverter &converter) {
+  // TODO: Get actual function type of the dummy procedure, at least when an
+  // interface is given. The result type should be available even if the arity
+  // and type of the arguments is not.
+  llvm::SmallVector<mlir::Type> resultTys;
+  llvm::SmallVector<mlir::Type> inputTys;
+  // In general, that is a nice to have but we cannot guarantee to find the
+  // function type that will match the one of the calls, we may not even know
+  // how many arguments the dummy procedure accepts (e.g. if a procedure
+  // pointer is only transiting through the current procedure without being
+  // called), so a function type cast must always be inserted.
+  auto *context = &converter.getMLIRContext();
+  auto untypedFunc = mlir::FunctionType::get(context, inputTys, resultTys);
+  return fir::BoxProcType::get(context, untypedFunc);
+}
+
 //===----------------------------------------------------------------------===//
 // Caller side interface implementation
 //===----------------------------------------------------------------------===//
@@ -193,11 +215,7 @@ void Fortran::lower::CallerInterface::walkResultLengths(
             dynamicType.GetCharLength())
       visitor(toEvExpr(*length));
   } else if (dynamicType.category() == common::TypeCategory::Derived) {
-    const Fortran::semantics::DerivedTypeSpec &derivedTypeSpec =
-        dynamicType.GetDerivedTypeSpec();
-    if (Fortran::semantics::CountLenParameters(derivedTypeSpec) > 0)
-      TODO(converter.getCurrentLocation(),
-           "function result with derived type length parameters");
+    TODO(converter.getCurrentLocation(), "walkResultLengths derived type");
   }
 }
 
@@ -336,8 +354,22 @@ mlir::FuncOp Fortran::lower::CalleeInterface::addEntryBlockAndMapArguments() {
   return func;
 }
 
+bool Fortran::lower::CalleeInterface::hasHostAssociated() const {
+  return funit.parentHasHostAssoc();
+}
+
+mlir::Type Fortran::lower::CalleeInterface::getHostAssociatedTy() const {
+  assert(hasHostAssociated());
+  return funit.parentHostAssoc().getArgumentType(converter);
+}
+
+mlir::Value Fortran::lower::CalleeInterface::getHostAssociatedTuple() const {
+  assert(hasHostAssociated() || !funit.getHostAssoc().empty());
+  return converter.hostAssocTupleValue();
+}
+
 //===----------------------------------------------------------------------===//
-// CallInterface implementation: this part is common to both callee and caller
+// CallInterface implementation: this part is common to both caller and caller
 // sides.
 //===----------------------------------------------------------------------===//
 
@@ -455,10 +487,20 @@ getResultEntity(Fortran::lower::pft::FunctionLikeUnit &funit) {
       .result();
 }
 
-//===----------------------------------------------------------------------===//
-// CallInterface implementation: this part is common to both caller and caller
-// sides.
-//===----------------------------------------------------------------------===//
+/// Bypass helpers to manipulate entities since they are not any symbol/actual
+/// argument to associate. See SignatureBuilder below.
+using FakeEntity = bool;
+using FakeEntities = llvm::SmallVector<FakeEntity>;
+static FakeEntities
+getEntityContainer(const Fortran::evaluate::characteristics::Procedure &proc) {
+  FakeEntities enities(proc.dummyArguments.size());
+  return enities;
+}
+static const FakeEntity &getDataObjectEntity(const FakeEntity &e) { return e; }
+static FakeEntity
+getResultEntity(const Fortran::evaluate::characteristics::Procedure &proc) {
+  return false;
+}
 
 /// This is the actual part that defines the FIR interface based on the
 /// characteristic. It directly mutates the CallInterface members.
@@ -552,6 +594,51 @@ public:
     }
   }
 
+  void appendHostAssocTupleArg(mlir::Type tupTy) {
+    MLIRContext *ctxt = tupTy.getContext();
+    addFirOperand(tupTy, nextPassedArgPosition(), Property::BaseAddress,
+                  {mlir::NamedAttribute{
+                      mlir::StringAttr::get(ctxt, fir::getHostAssocAttrName()),
+                      mlir::UnitAttr::get(ctxt)}});
+    interface.passedArguments.emplace_back(
+        PassedEntity{PassEntityBy::BaseAddress, std::nullopt,
+                     interface.side().getHostAssociatedTuple(), emptyValue()});
+  }
+
+  static llvm::Optional<Fortran::evaluate::DynamicType> getResultDynamicType(
+      const Fortran::evaluate::characteristics::Procedure &procedure) {
+    if (const std::optional<Fortran::evaluate::characteristics::FunctionResult>
+            &result = procedure.functionResult)
+      if (const auto *resultTypeAndShape = result->GetTypeAndShape())
+        return resultTypeAndShape->type();
+    return llvm::None;
+  }
+
+  static bool mustPassLengthWithDummyProcedure(
+      const Fortran::evaluate::characteristics::Procedure &procedure) {
+    // When passing a character function designator `bar` as dummy procedure to
+    // `foo` (e.g. `foo(bar)`), pass the result length of `bar` to `foo` so that
+    // `bar` can be called inside `foo` even if its length is assumed there.
+    // From an ABI perspective, the extra length argument must be handled
+    // exactly as if passing a character object. Using an argument of
+    // fir.boxchar type gives the expected behavior: after codegen, the
+    // fir.boxchar lengths are added after all the arguments as extra value
+    // arguments (the extra arguments order is the order of the fir.boxchar).
+
+    // This ABI is compatible with ifort, nag, nvfortran, and xlf, but not
+    // gfortran. Gfortran does not pass the length and is therefore unable to
+    // handle later call to `bar` in `foo` where the length would be assumed. If
+    // the result is an array, nag and ifort and xlf still pass the length, but
+    // not nvfortran (and gfortran). It is not clear it is possible to call an
+    // array function with assumed length (f18 forbides defining such
+    // interfaces). Hence, passing the length is most likely useless, but stick
+    // with ifort/nag/xlf interface here.
+    if (llvm::Optional<Fortran::evaluate::DynamicType> type =
+            getResultDynamicType(procedure))
+      return type->category() == Fortran::common::TypeCategory::Character;
+    return false;
+  }
+
 private:
   void handleImplicitResult(
       const Fortran::evaluate::characteristics::FunctionResult &result) {
@@ -567,8 +654,13 @@ private:
       handleImplicitCharacterResult(dynamicType);
     } else if (dynamicType.category() ==
                Fortran::common::TypeCategory::Derived) {
-      TODO(interface.converter.getCurrentLocation(),
-           "implicit result derived type");
+      // Derived result need to be allocated by the caller and the result value
+      // must be saved. Derived type in implicit interface cannot have length
+      // parameters.
+      setSaveResult();
+      mlir::Type mlirType = translateDynamicType(dynamicType);
+      addFirResult(mlirType, FirPlaceHolder::resultEntityPosition,
+                   Property::Value);
     } else {
       // All result other than characters/derived are simply returned by value
       // in implicit interfaces
@@ -578,7 +670,6 @@ private:
                    Property::Value);
     }
   }
-
   void
   handleImplicitCharacterResult(const Fortran::evaluate::DynamicType &type) {
     int resultPosition = FirPlaceHolder::resultEntityPosition;
@@ -597,62 +688,6 @@ private:
     addFirResult(boxCharTy, resultPosition, Property::BoxChar);
   }
 
-  void handleExplicitResult(
-      const Fortran::evaluate::characteristics::FunctionResult &result) {
-    using Attr = Fortran::evaluate::characteristics::FunctionResult::Attr;
-
-    if (result.IsProcedurePointer())
-      TODO(interface.converter.getCurrentLocation(),
-           "procedure pointer results");
-    const Fortran::evaluate::characteristics::TypeAndShape *typeAndShape =
-        result.GetTypeAndShape();
-    assert(typeAndShape && "expect type for non proc pointer result");
-    mlir::Type mlirType = translateDynamicType(typeAndShape->type());
-    fir::SequenceType::Shape bounds = getBounds(typeAndShape->shape());
-    if (!bounds.empty())
-      mlirType = fir::SequenceType::get(bounds, mlirType);
-    if (result.attrs.test(Attr::Allocatable))
-      mlirType = fir::BoxType::get(fir::HeapType::get(mlirType));
-    if (result.attrs.test(Attr::Pointer))
-      mlirType = fir::BoxType::get(fir::PointerType::get(mlirType));
-
-    if (fir::isa_char(mlirType)) {
-      // Character scalar results must be passed as arguments in lowering so
-      // that an assumed length character function callee can access the result
-      // length. A function with a result requiring an explicit interface does
-      // not have to be compatible with assumed length function, but most
-      // compilers supports it.
-      handleImplicitCharacterResult(typeAndShape->type());
-      return;
-    }
-
-    addFirResult(mlirType, FirPlaceHolder::resultEntityPosition,
-                 Property::Value);
-    // Explicit results require the caller to allocate the storage and save the
-    // function result in the storage with a fir.save_result.
-    setSaveResult();
-  }
-
-  fir::SequenceType::Shape getBounds(const Fortran::evaluate::Shape &shape) {
-    fir::SequenceType::Shape bounds;
-    for (const std::optional<Fortran::evaluate::ExtentExpr> &extent : shape) {
-      fir::SequenceType::Extent bound = fir::SequenceType::getUnknownExtent();
-      if (std::optional<std::int64_t> i = toInt64(extent))
-        bound = *i;
-      bounds.emplace_back(bound);
-    }
-    return bounds;
-  }
-  std::optional<std::int64_t>
-  toInt64(std::optional<
-          Fortran::evaluate::Expr<Fortran::evaluate::SubscriptInteger>>
-              expr) {
-    if (expr)
-      return Fortran::evaluate::ToInt64(Fortran::evaluate::Fold(
-          getConverter().getFoldingContext(), toEvExpr(*expr)));
-    return std::nullopt;
-  }
-
   /// Return a vector with an attribute with the name of the argument if this
   /// is a callee interface and the name is available. Otherwise, just return
   /// an empty vector.
@@ -674,6 +709,30 @@ private:
     return {};
   }
 
+  void handleImplicitDummy(
+      const DummyCharacteristics *characteristics,
+      const Fortran::evaluate::characteristics::DummyDataObject &obj,
+      const FortranEntity &entity) {
+    Fortran::evaluate::DynamicType dynamicType = obj.type.type();
+    if (dynamicType.category() == Fortran::common::TypeCategory::Character) {
+      mlir::Type boxCharTy =
+          fir::BoxCharType::get(&mlirContext, dynamicType.kind());
+      addFirOperand(boxCharTy, nextPassedArgPosition(), Property::BoxChar,
+                    dummyNameAttr(entity));
+      addPassedArg(PassEntityBy::BoxChar, entity, characteristics);
+    } else {
+      // non-PDT derived type allowed in implicit interface.
+      mlir::Type type = translateDynamicType(dynamicType);
+      fir::SequenceType::Shape bounds = getBounds(obj.type.shape());
+      if (!bounds.empty())
+        type = fir::SequenceType::get(bounds, type);
+      mlir::Type refType = fir::ReferenceType::get(type);
+      addFirOperand(refType, nextPassedArgPosition(), Property::BaseAddress,
+                    dummyNameAttr(entity));
+      addPassedArg(PassEntityBy::BaseAddress, entity, characteristics);
+    }
+  }
+
   // Define when an explicit argument must be passed in a fir.box.
   bool dummyRequiresBox(
       const Fortran::evaluate::characteristics::DummyDataObject &obj) {
@@ -701,7 +760,7 @@ private:
     // DERIVED
     if (cat == Fortran::common::TypeCategory::Derived) {
       TODO(interface.converter.getCurrentLocation(),
-           "[translateDynamicType] Derived");
+           "[translateDynamicType] Derived types");
     }
     // CHARACTER with compile time constant length.
     if (cat == Fortran::common::TypeCategory::Character)
@@ -804,37 +863,92 @@ private:
 
   void handleImplicitDummy(
       const DummyCharacteristics *characteristics,
-      const Fortran::evaluate::characteristics::DummyDataObject &obj,
+      const Fortran::evaluate::characteristics::DummyProcedure &proc,
       const FortranEntity &entity) {
-    Fortran::evaluate::DynamicType dynamicType = obj.type.type();
-    if (dynamicType.category() == Fortran::common::TypeCategory::Character) {
-      mlir::Type boxCharTy =
-          fir::BoxCharType::get(&mlirContext, dynamicType.kind());
-      addFirOperand(boxCharTy, nextPassedArgPosition(), Property::BoxChar,
-                    dummyNameAttr(entity));
-      addPassedArg(PassEntityBy::BoxChar, entity, characteristics);
-    } else {
-      // non-PDT derived type allowed in implicit interface.
-      Fortran::common::TypeCategory cat = dynamicType.category();
-      mlir::Type type = getConverter().genType(cat, dynamicType.kind());
-      fir::SequenceType::Shape bounds = getBounds(obj.type.shape());
-      if (!bounds.empty())
-        type = fir::SequenceType::get(bounds, type);
-      mlir::Type refType = fir::ReferenceType::get(type);
-      addFirOperand(refType, nextPassedArgPosition(), Property::BaseAddress,
-                    dummyNameAttr(entity));
-      addPassedArg(PassEntityBy::BaseAddress, entity, characteristics);
+    if (proc.attrs.test(
+            Fortran::evaluate::characteristics::DummyProcedure::Attr::Pointer))
+      TODO(interface.converter.getCurrentLocation(),
+           "procedure pointer arguments");
+    // Otherwise, it is a dummy procedure.
+    const Fortran::evaluate::characteristics::Procedure &procedure =
+        proc.procedure.value();
+    mlir::Type funcType =
+        getProcedureDesignatorType(&procedure, interface.converter);
+    llvm::Optional<Fortran::evaluate::DynamicType> resultTy =
+        getResultDynamicType(procedure);
+    if (resultTy && mustPassLengthWithDummyProcedure(procedure)) {
+      // The result length of dummy procedures that are character functions must
+      // be passed so that the dummy procedure can be called if it has assumed
+      // length on the callee side.
+      mlir::Type tupleType =
+          fir::factory::getCharacterProcedureTupleType(funcType);
+      llvm::StringRef charProcAttr = fir::getCharacterProcedureDummyAttrName();
+      addFirOperand(tupleType, nextPassedArgPosition(), Property::CharProcTuple,
+                    {mlir::NamedAttribute{
+                        mlir::StringAttr::get(&mlirContext, charProcAttr),
+                        mlir::UnitAttr::get(&mlirContext)}});
+      addPassedArg(PassEntityBy::CharProcTuple, entity, characteristics);
+      return;
     }
+    addFirOperand(funcType, nextPassedArgPosition(), Property::BaseAddress);
+    addPassedArg(PassEntityBy::BaseAddress, entity, characteristics);
   }
 
-  void handleImplicitDummy(
-      const DummyCharacteristics *characteristics,
-      const Fortran::evaluate::characteristics::DummyProcedure &proc,
-      const FortranEntity &entity) {
-    TODO(interface.converter.getCurrentLocation(),
-         "handleImlicitDummy DummyProcedure");
+  void handleExplicitResult(
+      const Fortran::evaluate::characteristics::FunctionResult &result) {
+    using Attr = Fortran::evaluate::characteristics::FunctionResult::Attr;
+
+    if (result.IsProcedurePointer())
+      TODO(interface.converter.getCurrentLocation(),
+           "procedure pointer results");
+    const Fortran::evaluate::characteristics::TypeAndShape *typeAndShape =
+        result.GetTypeAndShape();
+    assert(typeAndShape && "expect type for non proc pointer result");
+    mlir::Type mlirType = translateDynamicType(typeAndShape->type());
+    fir::SequenceType::Shape bounds = getBounds(typeAndShape->shape());
+    if (!bounds.empty())
+      mlirType = fir::SequenceType::get(bounds, mlirType);
+    if (result.attrs.test(Attr::Allocatable))
+      mlirType = fir::BoxType::get(fir::HeapType::get(mlirType));
+    if (result.attrs.test(Attr::Pointer))
+      mlirType = fir::BoxType::get(fir::PointerType::get(mlirType));
+
+    if (fir::isa_char(mlirType)) {
+      // Character scalar results must be passed as arguments in lowering so
+      // that an assumed length character function callee can access the result
+      // length. A function with a result requiring an explicit interface does
+      // not have to be compatible with assumed length function, but most
+      // compilers supports it.
+      handleImplicitCharacterResult(typeAndShape->type());
+      return;
+    }
+
+    addFirResult(mlirType, FirPlaceHolder::resultEntityPosition,
+                 Property::Value);
+    // Explicit results require the caller to allocate the storage and save the
+    // function result in the storage with a fir.save_result.
+    setSaveResult();
   }
 
+  fir::SequenceType::Shape getBounds(const Fortran::evaluate::Shape &shape) {
+    fir::SequenceType::Shape bounds;
+    for (const std::optional<Fortran::evaluate::ExtentExpr> &extent : shape) {
+      fir::SequenceType::Extent bound = fir::SequenceType::getUnknownExtent();
+      if (std::optional<std::int64_t> i = toInt64(extent))
+        bound = *i;
+      bounds.emplace_back(bound);
+    }
+    return bounds;
+  }
+  std::optional<std::int64_t>
+  toInt64(std::optional<
+          Fortran::evaluate::Expr<Fortran::evaluate::SubscriptInteger>>
+              expr) {
+    if (expr)
+      return Fortran::evaluate::ToInt64(Fortran::evaluate::Fold(
+          getConverter().getFoldingContext(), toEvExpr(*expr)));
+    return std::nullopt;
+  }
   void
   addFirOperand(mlir::Type type, int entityPosition, Property p,
                 llvm::ArrayRef<mlir::NamedAttribute> attributes = llvm::None) {
@@ -850,7 +964,7 @@ private:
   void addPassedArg(PassEntityBy p, FortranEntity entity,
                     const DummyCharacteristics *characteristics) {
     interface.passedArguments.emplace_back(
-        PassedEntity{p, entity, {}, {}, characteristics});
+        PassedEntity{p, entity, emptyValue(), emptyValue(), characteristics});
   }
   void setPassedResult(PassEntityBy p, FortranEntity entity) {
     interface.passedResult =
@@ -903,6 +1017,13 @@ void Fortran::lower::CallInterface<T>::determineInterface(
     impl.buildImplicitInterface(procedure);
   else
     impl.buildExplicitInterface(procedure);
+  // We only expect the extra host asspciations argument from the callee side as
+  // the definition of internal procedures will be present, and we'll always
+  // have a FuncOp definition in the ModuleOp, when lowering.
+  if constexpr (std::is_same_v<T, Fortran::lower::CalleeInterface>) {
+    if (side().hasHostAssociated())
+      impl.appendHostAssocTupleArg(side().getHostAssociatedTy());
+  }
 }
 
 template <typename T>
@@ -917,5 +1038,169 @@ mlir::FunctionType Fortran::lower::CallInterface<T>::genFunctionType() {
                                  returnTys);
 }
 
+template <typename T>
+llvm::SmallVector<mlir::Type>
+Fortran::lower::CallInterface<T>::getResultType() const {
+  llvm::SmallVector<mlir::Type> types;
+  for (const FirPlaceHolder &out : outputs)
+    types.emplace_back(out.type);
+  return types;
+}
+
 template class Fortran::lower::CallInterface<Fortran::lower::CalleeInterface>;
 template class Fortran::lower::CallInterface<Fortran::lower::CallerInterface>;
+
+//===----------------------------------------------------------------------===//
+// Function Type Translation
+//===----------------------------------------------------------------------===//
+
+/// Build signature from characteristics when there is no Fortran entity to
+/// associate with the arguments (i.e, this is not a call site or a procedure
+/// declaration. This is needed when dealing with function pointers/dummy
+/// arguments.
+
+class SignatureBuilder;
+template <>
+struct Fortran::lower::PassedEntityTypes<SignatureBuilder> {
+  using FortranEntity = FakeEntity;
+  using FirValue = int;
+};
+
+/// SignatureBuilder is a CRTP implementation of CallInterface intended to
+/// help translating characteristics::Procedure to mlir::FunctionType using
+/// the CallInterface translation.
+class SignatureBuilder
+    : public Fortran::lower::CallInterface<SignatureBuilder> {
+public:
+  SignatureBuilder(const Fortran::evaluate::characteristics::Procedure &p,
+                   Fortran::lower::AbstractConverter &c, bool forceImplicit)
+      : CallInterface{c}, proc{p} {
+    bool isImplicit = forceImplicit || proc.CanBeCalledViaImplicitInterface();
+    determineInterface(isImplicit, proc);
+  }
+  /// Does the procedure characteristics being translated have alternate
+  /// returns ?
+  bool hasAlternateReturns() const {
+    for (const Fortran::evaluate::characteristics::DummyArgument &dummy :
+         proc.dummyArguments)
+      if (std::holds_alternative<
+              Fortran::evaluate::characteristics::AlternateReturn>(dummy.u))
+        return true;
+    return false;
+  };
+
+  /// This is only here to fulfill CRTP dependencies and should not be called.
+  std::string getMangledName() const {
+    llvm_unreachable("trying to get name from SignatureBuilder");
+  }
+
+  /// This is only here to fulfill CRTP dependencies and should not be called.
+  mlir::Location getCalleeLocation() const {
+    llvm_unreachable("trying to get callee location from SignatureBuilder");
+  }
+
+  /// This is only here to fulfill CRTP dependencies and should not be called.
+  const Fortran::semantics::Symbol *getProcedureSymbol() const {
+    llvm_unreachable("trying to get callee symbol from SignatureBuilder");
+  };
+
+  Fortran::evaluate::characteristics::Procedure characterize() const {
+    return proc;
+  }
+  /// SignatureBuilder cannot be used on main program.
+  static constexpr bool isMainProgram() { return false; }
+
+  /// Return the characteristics::Procedure that is being translated to
+  /// mlir::FunctionType.
+  const Fortran::evaluate::characteristics::Procedure &
+  getCallDescription() const {
+    return proc;
+  }
+
+  /// This is not the description of an indirect call.
+  static constexpr bool isIndirectCall() { return false; }
+
+  /// Return the translated signature.
+  mlir::FunctionType getFunctionType() { return genFunctionType(); }
+
+  // Copy of base implementation.
+  static constexpr bool hasHostAssociated() { return false; }
+  mlir::Type getHostAssociatedTy() const {
+    llvm_unreachable("getting host associated type in SignatureBuilder");
+  }
+
+private:
+  const Fortran::evaluate::characteristics::Procedure &proc;
+};
+
+mlir::FunctionType Fortran::lower::translateSignature(
+    const Fortran::evaluate::ProcedureDesignator &proc,
+    Fortran::lower::AbstractConverter &converter) {
+  std::optional<Fortran::evaluate::characteristics::Procedure> characteristics =
+      Fortran::evaluate::characteristics::Procedure::Characterize(
+          proc, converter.getFoldingContext());
+  // Most unrestricted intrinsic characteristic has the Elemental attribute
+  // which triggers CanBeCalledViaImplicitInterface to return false. However,
+  // using implicit interface rules is just fine here.
+  bool forceImplicit = proc.GetSpecificIntrinsic();
+  return SignatureBuilder{characteristics.value(), converter, forceImplicit}
+      .getFunctionType();
+}
+
+mlir::FuncOp Fortran::lower::getOrDeclareFunction(
+    llvm::StringRef name, const Fortran::evaluate::ProcedureDesignator &proc,
+    Fortran::lower::AbstractConverter &converter) {
+  mlir::ModuleOp module = converter.getModuleOp();
+  mlir::FuncOp func = fir::FirOpBuilder::getNamedFunction(module, name);
+  if (func)
+    return func;
+
+  const Fortran::semantics::Symbol *symbol = proc.GetSymbol();
+  assert(symbol && "non user function in getOrDeclareFunction");
+  // getOrDeclareFunction is only used for functions not defined in the current
+  // program unit, so use the location of the procedure designator symbol, which
+  // is the first occurrence of the procedure in the program unit.
+  mlir::Location loc = converter.genLocation(symbol->name());
+  std::optional<Fortran::evaluate::characteristics::Procedure> characteristics =
+      Fortran::evaluate::characteristics::Procedure::Characterize(
+          proc, converter.getFoldingContext());
+  mlir::FunctionType ty = SignatureBuilder{characteristics.value(), converter,
+                                           /*forceImplicit=*/false}
+                              .getFunctionType();
+  mlir::FuncOp newFunc =
+      fir::FirOpBuilder::createFunction(loc, module, name, ty);
+  addSymbolAttribute(newFunc, *symbol, converter.getMLIRContext());
+  return newFunc;
+}
+
+// Is it required to pass a dummy procedure with \p characteristics as a tuple
+// containing the function address and the result length ?
+static bool mustPassLengthWithDummyProcedure(
+    const std::optional<Fortran::evaluate::characteristics::Procedure>
+        &characteristics) {
+  return characteristics &&
+         Fortran::lower::CallInterfaceImpl<SignatureBuilder>::
+             mustPassLengthWithDummyProcedure(*characteristics);
+}
+
+bool Fortran::lower::mustPassLengthWithDummyProcedure(
+    const Fortran::evaluate::ProcedureDesignator &procedure,
+    Fortran::lower::AbstractConverter &converter) {
+  std::optional<Fortran::evaluate::characteristics::Procedure> characteristics =
+      Fortran::evaluate::characteristics::Procedure::Characterize(
+          procedure, converter.getFoldingContext());
+  return ::mustPassLengthWithDummyProcedure(characteristics);
+}
+
+mlir::Type Fortran::lower::getDummyProcedureType(
+    const Fortran::semantics::Symbol &dummyProc,
+    Fortran::lower::AbstractConverter &converter) {
+  std::optional<Fortran::evaluate::characteristics::Procedure> iface =
+      Fortran::evaluate::characteristics::Procedure::Characterize(
+          dummyProc, converter.getFoldingContext());
+  mlir::Type procType = getProcedureDesignatorType(
+      iface.has_value() ? &*iface : nullptr, converter);
+  if (::mustPassLengthWithDummyProcedure(iface))
+    return fir::factory::getCharacterProcedureTupleType(procType);
+  return procType;
+}
index 73889cb..4962da9 100644 (file)
@@ -26,6 +26,7 @@
 #include "flang/Optimizer/Builder/Character.h"
 #include "flang/Optimizer/Builder/Complex.h"
 #include "flang/Optimizer/Builder/Factory.h"
+#include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
 #include "flang/Optimizer/Builder/MutableBox.h"
 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
 #include "flang/Semantics/expression.h"
@@ -1116,8 +1117,16 @@ public:
     // will be used only if there is no explicit length in the local interface).
     mlir::Value funcPointer;
     mlir::Value charFuncPointerLength;
-    if (caller.getIfIndirectCallSymbol()) {
-      TODO(loc, "genCallOpAndResult indirect call");
+    if (const Fortran::semantics::Symbol *sym =
+            caller.getIfIndirectCallSymbol()) {
+      funcPointer = symMap.lookupSymbol(*sym).getAddr();
+      if (!funcPointer)
+        fir::emitFatalError(loc, "failed to find indirect call symbol address");
+      if (fir::isCharacterProcedureTuple(funcPointer.getType(),
+                                         /*acceptRawFunc=*/false))
+        std::tie(funcPointer, charFuncPointerLength) =
+            fir::factory::extractCharacterProcedureTuple(builder, loc,
+                                                         funcPointer);
     }
 
     mlir::IndexType idxTy = builder.getIndexType();
@@ -1156,7 +1165,20 @@ public:
       }
 
       if (!extents.empty() || !lengths.empty()) {
-        TODO(loc, "genCallOpResult extents and length");
+        auto *bldr = &converter.getFirOpBuilder();
+        auto stackSaveFn = fir::factory::getLlvmStackSave(builder);
+        auto stackSaveSymbol = bldr->getSymbolRefAttr(stackSaveFn.getName());
+        mlir::Value sp =
+            bldr->create<fir::CallOp>(loc, stackSaveFn.getType().getResults(),
+                                      stackSaveSymbol, mlir::ValueRange{})
+                .getResult(0);
+        stmtCtx.attachCleanup([bldr, loc, sp]() {
+          auto stackRestoreFn = fir::factory::getLlvmStackRestore(*bldr);
+          auto stackRestoreSymbol =
+              bldr->getSymbolRefAttr(stackRestoreFn.getName());
+          bldr->create<fir::CallOp>(loc, stackRestoreFn.getType().getResults(),
+                                    stackRestoreSymbol, mlir::ValueRange{sp});
+        });
       }
       mlir::Value temp =
           builder.createTemporary(loc, type, ".result", extents, resultLengths);
@@ -1302,7 +1324,11 @@ public:
       allocatedResult->match(
           [&](const fir::MutableBoxValue &box) {
             if (box.isAllocatable()) {
-              TODO(loc, "allocatedResult for allocatable");
+              // 9.7.3.2 point 4. Finalize allocatables.
+              fir::FirOpBuilder *bldr = &converter.getFirOpBuilder();
+              stmtCtx.attachCleanup([bldr, loc, box]() {
+                fir::factory::genFinalization(*bldr, loc, box);
+              });
             }
           },
           [](const auto &) {});
index ba2d2e6..029ea16 100644 (file)
@@ -899,7 +899,40 @@ void Fortran::lower::mapSymbolAttributes(
       //===--------------------------------------------------------------===//
 
       [&](const Fortran::lower::details::ScalarDynamicChar &x) {
-        TODO(loc, "ScalarDynamicChar variable lowering");
+        // type is a CHARACTER, determine the LEN value
+        auto charLen = x.charLen();
+        if (replace) {
+          Fortran::lower::SymbolBox symBox = symMap.lookupSymbol(sym);
+          mlir::Value boxAddr = symBox.getAddr();
+          mlir::Value len;
+          mlir::Type addrTy = boxAddr.getType();
+          if (addrTy.isa<fir::BoxCharType>() || addrTy.isa<fir::BoxType>()) {
+            std::tie(boxAddr, len) = charHelp.createUnboxChar(symBox.getAddr());
+          } else {
+            // dummy from an other entry case: we cannot get a dynamic length
+            // for it, it's illegal for the user program to use it. However,
+            // since we are lowering all function unit statements regardless
+            // of whether the execution will reach them or not, we need to
+            // fill a value for the length here.
+            len = builder.createIntegerConstant(
+                loc, builder.getCharacterLengthType(), 1);
+          }
+          // Override LEN with an expression
+          if (charLen)
+            len = genExplicitCharLen(charLen);
+          symMap.addCharSymbol(sym, boxAddr, len, true);
+          return;
+        }
+        // local CHARACTER variable
+        mlir::Value len = genExplicitCharLen(charLen);
+        if (preAlloc) {
+          symMap.addCharSymbol(sym, preAlloc, len);
+          return;
+        }
+        llvm::SmallVector<mlir::Value> lengths = {len};
+        mlir::Value local =
+            createNewLocal(converter, loc, var, preAlloc, llvm::None, lengths);
+        symMap.addCharSymbol(sym, local, len);
       },
 
       //===--------------------------------------------------------------===//
diff --git a/flang/lib/Lower/HostAssociations.cpp b/flang/lib/Lower/HostAssociations.cpp
new file mode 100644 (file)
index 0000000..4c84884
--- /dev/null
@@ -0,0 +1,558 @@
+//===-- HostAssociations.cpp ----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Lower/HostAssociations.h"
+#include "flang/Evaluate/check-expression.h"
+#include "flang/Lower/AbstractConverter.h"
+#include "flang/Lower/Allocatable.h"
+#include "flang/Lower/BoxAnalyzer.h"
+#include "flang/Lower/CallInterface.h"
+#include "flang/Lower/ConvertType.h"
+#include "flang/Lower/PFTBuilder.h"
+#include "flang/Lower/SymbolMap.h"
+#include "flang/Lower/Todo.h"
+#include "flang/Optimizer/Builder/Character.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Support/FatalError.h"
+#include "flang/Semantics/tools.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "flang-host-assoc"
+
+// Host association inside internal procedures is implemented by allocating an
+// mlir tuple (a struct) inside the host containing the addresses and properties
+// of variables that are accessed by internal procedures. The address of this
+// tuple is passed as an argument by the host when calling internal procedures.
+// Internal procedures propagate a reference to this tuple when calling other
+// internal procedures of the host.
+//
+// This file defines how the type of the host tuple is built, how the tuple
+// value is created inside the host, and how the host associated variables are
+// instantiated inside the internal procedures from the tuple value. The
+// CapturedXXX classes define each of these three actions for a specific
+// kind of variables by providing a `getType`, a `instantiateHostTuple`, and a
+// `getFromTuple` method. These classes are structured as follow:
+//
+//   class CapturedKindOfVar : public CapturedSymbols<CapturedKindOfVar> {
+//     // Return the type of the tuple element for a host associated
+//     // variable given its symbol inside the host. This is called when
+//     // building function interfaces.
+//     static mlir::Type getType();
+//     // Build the tuple element value for a host associated variable given its
+//     // value inside the host. This is called when lowering the host body.
+//     static void instantiateHostTuple();
+//     // Instantiate a host variable inside an internal procedure given its
+//     // tuple element value. This is called when lowering internal procedure
+//     // bodies.
+//     static void getFromTuple();
+//   };
+//
+// If a new kind of variable requires ad-hoc handling, a new CapturedXXX class
+// should be added to handle it, and `walkCaptureCategories` should be updated
+// to dispatch this new kind of variable to this new class.
+
+/// Struct to be used as argument in walkCaptureCategories when building the
+/// tuple element type for a host associated variable.
+struct GetTypeInTuple {
+  /// walkCaptureCategories must return a type.
+  using Result = mlir::Type;
+};
+
+/// Struct to be used as argument in walkCaptureCategories when building the
+/// tuple element value for a host associated variable.
+struct InstantiateHostTuple {
+  /// walkCaptureCategories returns nothing.
+  using Result = void;
+  /// Value of the variable inside the host procedure.
+  fir::ExtendedValue hostValue;
+  /// Address of the tuple element of the variable.
+  mlir::Value addrInTuple;
+  mlir::Location loc;
+};
+
+/// Struct to be used as argument in walkCaptureCategories when instantiating a
+/// host associated variables from its tuple element value.
+struct GetFromTuple {
+  /// walkCaptureCategories returns nothing.
+  using Result = void;
+  /// Symbol map inside the internal procedure.
+  Fortran::lower::SymMap &symMap;
+  /// Value of the tuple element for the host associated variable.
+  mlir::Value valueInTuple;
+  mlir::Location loc;
+};
+
+/// Base class that must be inherited with CRTP by classes defining
+/// how host association is implemented for a type of symbol.
+/// It simply dispatches visit() calls to the implementations according
+/// to the argument type.
+template <typename SymbolCategory>
+class CapturedSymbols {
+public:
+  template <typename T>
+  static void visit(const T &, Fortran::lower::AbstractConverter &,
+                    const Fortran::semantics::Symbol &,
+                    const Fortran::lower::BoxAnalyzer &) {
+    static_assert(!std::is_same_v<T, T> &&
+                  "default visit must not be instantiated");
+  }
+  static mlir::Type visit(const GetTypeInTuple &,
+                          Fortran::lower::AbstractConverter &converter,
+                          const Fortran::semantics::Symbol &sym,
+                          const Fortran::lower::BoxAnalyzer &) {
+    return SymbolCategory::getType(converter, sym);
+  }
+  static void visit(const InstantiateHostTuple &args,
+                    Fortran::lower::AbstractConverter &converter,
+                    const Fortran::semantics::Symbol &sym,
+                    const Fortran::lower::BoxAnalyzer &) {
+    return SymbolCategory::instantiateHostTuple(args, converter, sym);
+  }
+  static void visit(const GetFromTuple &args,
+                    Fortran::lower::AbstractConverter &converter,
+                    const Fortran::semantics::Symbol &sym,
+                    const Fortran::lower::BoxAnalyzer &ba) {
+    return SymbolCategory::getFromTuple(args, converter, sym, ba);
+  }
+};
+
+/// Class defining simple scalars are captured in internal procedures.
+/// Simple scalars are non character intrinsic scalars. They are captured
+/// as `!fir.ref<T>`, for example `!fir.ref<i32>` for `INTEGER*4`.
+class CapturedSimpleScalars : public CapturedSymbols<CapturedSimpleScalars> {
+public:
+  static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
+                            const Fortran::semantics::Symbol &sym) {
+    return fir::ReferenceType::get(converter.genType(sym));
+  }
+
+  static void instantiateHostTuple(const InstantiateHostTuple &args,
+                                   Fortran::lower::AbstractConverter &converter,
+                                   const Fortran::semantics::Symbol &) {
+    fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+    mlir::Type typeInTuple = fir::dyn_cast_ptrEleTy(args.addrInTuple.getType());
+    assert(typeInTuple && "addrInTuple must be an address");
+    mlir::Value castBox = builder.createConvert(args.loc, typeInTuple,
+                                                fir::getBase(args.hostValue));
+    builder.create<fir::StoreOp>(args.loc, castBox, args.addrInTuple);
+  }
+
+  static void getFromTuple(const GetFromTuple &args,
+                           Fortran::lower::AbstractConverter &,
+                           const Fortran::semantics::Symbol &sym,
+                           const Fortran::lower::BoxAnalyzer &) {
+    args.symMap.addSymbol(sym, args.valueInTuple);
+  }
+};
+
+/// Class defining how dummy procedures and procedure pointers
+/// are captured in internal procedures.
+class CapturedProcedure : public CapturedSymbols<CapturedProcedure> {
+public:
+  static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
+                            const Fortran::semantics::Symbol &sym) {
+    if (Fortran::semantics::IsPointer(sym))
+      TODO(converter.getCurrentLocation(),
+           "capture procedure pointer in internal procedure");
+    return Fortran::lower::getDummyProcedureType(sym, converter);
+  }
+
+  static void instantiateHostTuple(const InstantiateHostTuple &args,
+                                   Fortran::lower::AbstractConverter &converter,
+                                   const Fortran::semantics::Symbol &) {
+    fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+    mlir::Type typeInTuple = fir::dyn_cast_ptrEleTy(args.addrInTuple.getType());
+    assert(typeInTuple && "addrInTuple must be an address");
+    mlir::Value castBox = builder.createConvert(args.loc, typeInTuple,
+                                                fir::getBase(args.hostValue));
+    builder.create<fir::StoreOp>(args.loc, castBox, args.addrInTuple);
+  }
+
+  static void getFromTuple(const GetFromTuple &args,
+                           Fortran::lower::AbstractConverter &,
+                           const Fortran::semantics::Symbol &sym,
+                           const Fortran::lower::BoxAnalyzer &) {
+    args.symMap.addSymbol(sym, args.valueInTuple);
+  }
+};
+
+/// Class defining how character scalars are captured in internal procedures.
+/// Character scalars are passed as !fir.boxchar<kind> in the tuple.
+class CapturedCharacterScalars
+    : public CapturedSymbols<CapturedCharacterScalars> {
+public:
+  // Note: so far, do not specialize constant length characters. They can be
+  // implemented by only passing the address. This could be done later in
+  // lowering or a CapturedStaticLenCharacterScalars class could be added here.
+
+  static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
+                            const Fortran::semantics::Symbol &sym) {
+    fir::KindTy kind =
+        converter.genType(sym).cast<fir::CharacterType>().getFKind();
+    return fir::BoxCharType::get(&converter.getMLIRContext(), kind);
+  }
+
+  static void instantiateHostTuple(const InstantiateHostTuple &args,
+                                   Fortran::lower::AbstractConverter &converter,
+                                   const Fortran::semantics::Symbol &) {
+    const fir::CharBoxValue *charBox = args.hostValue.getCharBox();
+    assert(charBox && "host value must be a fir::CharBoxValue");
+    fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+    mlir::Value boxchar = fir::factory::CharacterExprHelper(builder, args.loc)
+                              .createEmbox(*charBox);
+    builder.create<fir::StoreOp>(args.loc, boxchar, args.addrInTuple);
+  }
+
+  static void getFromTuple(const GetFromTuple &args,
+                           Fortran::lower::AbstractConverter &converter,
+                           const Fortran::semantics::Symbol &sym,
+                           const Fortran::lower::BoxAnalyzer &) {
+    fir::factory::CharacterExprHelper charHelp(converter.getFirOpBuilder(),
+                                               args.loc);
+    std::pair<mlir::Value, mlir::Value> unboxchar =
+        charHelp.createUnboxChar(args.valueInTuple);
+    args.symMap.addCharSymbol(sym, unboxchar.first, unboxchar.second);
+  }
+};
+
+/// Is \p sym a derived type entity with length parameters ?
+static bool
+isDerivedWithLengthParameters(const Fortran::semantics::Symbol &sym) {
+  if (const auto *declTy = sym.GetType())
+    if (const auto *derived = declTy->AsDerived())
+      return Fortran::semantics::CountLenParameters(*derived) != 0;
+  return false;
+}
+
+/// Class defining how allocatable and pointers entities are captured in
+/// internal procedures. Allocatable and pointers are simply captured by placing
+/// their !fir.ref<fir.box<>> address in the host tuple.
+class CapturedAllocatableAndPointer
+    : public CapturedSymbols<CapturedAllocatableAndPointer> {
+public:
+  static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
+                            const Fortran::semantics::Symbol &sym) {
+    return fir::ReferenceType::get(converter.genType(sym));
+  }
+  static void instantiateHostTuple(const InstantiateHostTuple &args,
+                                   Fortran::lower::AbstractConverter &converter,
+                                   const Fortran::semantics::Symbol &) {
+    assert(args.hostValue.getBoxOf<fir::MutableBoxValue>() &&
+           "host value must be a fir::MutableBoxValue");
+    fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+    mlir::Type typeInTuple = fir::dyn_cast_ptrEleTy(args.addrInTuple.getType());
+    assert(typeInTuple && "addrInTuple must be an address");
+    mlir::Value castBox = builder.createConvert(args.loc, typeInTuple,
+                                                fir::getBase(args.hostValue));
+    builder.create<fir::StoreOp>(args.loc, castBox, args.addrInTuple);
+  }
+  static void getFromTuple(const GetFromTuple &args,
+                           Fortran::lower::AbstractConverter &converter,
+                           const Fortran::semantics::Symbol &sym,
+                           const Fortran::lower::BoxAnalyzer &ba) {
+    fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+    mlir::Location loc = args.loc;
+    // Non deferred type parameters impact the semantics of some statements
+    // where allocatables/pointer can appear. For instance, assignment to a
+    // scalar character allocatable with has a different semantics in F2003 and
+    // later if the length is non deferred vs when it is deferred. So it is
+    // important to keep track of the non deferred parameters here.
+    llvm::SmallVector<mlir::Value> nonDeferredLenParams;
+    if (ba.isChar()) {
+      mlir::IndexType idxTy = builder.getIndexType();
+      if (llvm::Optional<int64_t> len = ba.getCharLenConst()) {
+        nonDeferredLenParams.push_back(
+            builder.createIntegerConstant(loc, idxTy, *len));
+      } else if (Fortran::semantics::IsAssumedLengthCharacter(sym) ||
+                 ba.getCharLenExpr()) {
+        // Read length from fir.box (explicit expr cannot safely be re-evaluated
+        // here).
+        auto readLength = [&]() {
+          fir::BoxValue boxLoad =
+              builder.create<fir::LoadOp>(loc, fir::getBase(args.valueInTuple))
+                  .getResult();
+          return fir::factory::readCharLen(builder, loc, boxLoad);
+        };
+        if (Fortran::semantics::IsOptional(sym)) {
+          // It is not safe to unconditionally read boxes of optionals in case
+          // they are absents. According to 15.5.2.12 3 (9), it is illegal to
+          // inquire the length of absent optional, even if non deferred, so
+          // it's fine to use undefOp in this case.
+          auto isPresent = builder.create<fir::IsPresentOp>(
+              loc, builder.getI1Type(), fir::getBase(args.valueInTuple));
+          mlir::Value len =
+              builder.genIfOp(loc, {idxTy}, isPresent, true)
+                  .genThen([&]() {
+                    builder.create<fir::ResultOp>(loc, readLength());
+                  })
+                  .genElse([&]() {
+                    auto undef = builder.create<fir::UndefOp>(loc, idxTy);
+                    builder.create<fir::ResultOp>(loc, undef.getResult());
+                  })
+                  .getResults()[0];
+          nonDeferredLenParams.push_back(len);
+        } else {
+          nonDeferredLenParams.push_back(readLength());
+        }
+      }
+    } else if (isDerivedWithLengthParameters(sym)) {
+      TODO(loc, "host associated derived type allocatable or pointer with "
+                "length parameters");
+    }
+    args.symMap.addSymbol(
+        sym, fir::MutableBoxValue(args.valueInTuple, nonDeferredLenParams, {}));
+  }
+};
+
+/// Class defining how arrays are captured inside internal procedures.
+/// Array are captured via a `fir.box<fir.array<T>>` descriptor that belongs to
+/// the host tuple. This allows capturing lower bounds, which can be done by
+/// providing a ShapeShiftOp argument to the EmboxOp.
+class CapturedArrays : public CapturedSymbols<CapturedArrays> {
+
+  // Note: Constant shape arrays are not specialized (their base address would
+  // be sufficient information inside the tuple). They could be specialized in
+  // a later FIR pass, or a CapturedStaticShapeArrays could be added to deal
+  // with them here.
+public:
+  static mlir::Type getType(Fortran::lower::AbstractConverter &converter,
+                            const Fortran::semantics::Symbol &sym) {
+    mlir::Type type = converter.genType(sym);
+    assert(type.isa<fir::SequenceType>() && "must be a sequence type");
+    return fir::BoxType::get(type);
+  }
+
+  static void instantiateHostTuple(const InstantiateHostTuple &args,
+                                   Fortran::lower::AbstractConverter &converter,
+                                   const Fortran::semantics::Symbol &sym) {
+    fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+    mlir::Location loc = args.loc;
+    fir::MutableBoxValue boxInTuple(args.addrInTuple, {}, {});
+    if (args.hostValue.getBoxOf<fir::BoxValue>() &&
+        Fortran::semantics::IsOptional(sym)) {
+      // The assumed shape optional case need some care because it is illegal to
+      // read the incoming box if it is absent (this would cause segfaults).
+      // Pointer association requires reading the target box, so it can only be
+      // done on present optional. For absent optionals, simply create a
+      // disassociated pointer (it is illegal to inquire about lower bounds or
+      // lengths of optional according to 15.5.2.12 3 (9) and 10.1.11 2 (7)b).
+      auto isPresent = builder.create<fir::IsPresentOp>(
+          loc, builder.getI1Type(), fir::getBase(args.hostValue));
+      builder.genIfThenElse(loc, isPresent)
+          .genThen([&]() {
+            fir::factory::associateMutableBox(builder, loc, boxInTuple,
+                                              args.hostValue,
+                                              /*lbounds=*/llvm::None);
+          })
+          .genElse([&]() {
+            fir::factory::disassociateMutableBox(builder, loc, boxInTuple);
+          })
+          .end();
+    } else {
+      fir::factory::associateMutableBox(builder, loc, boxInTuple,
+                                        args.hostValue, /*lbounds=*/llvm::None);
+    }
+  }
+
+  static void getFromTuple(const GetFromTuple &args,
+                           Fortran::lower::AbstractConverter &converter,
+                           const Fortran::semantics::Symbol &sym,
+                           const Fortran::lower::BoxAnalyzer &ba) {
+    fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+    mlir::Location loc = args.loc;
+    mlir::Value box = args.valueInTuple;
+    mlir::IndexType idxTy = builder.getIndexType();
+    llvm::SmallVector<mlir::Value> lbounds;
+    if (!ba.lboundIsAllOnes()) {
+      if (ba.isStaticArray()) {
+        for (std::int64_t lb : ba.staticLBound())
+          lbounds.emplace_back(builder.createIntegerConstant(loc, idxTy, lb));
+      } else {
+        // Cannot re-evaluate specification expressions here.
+        // Operands values may have changed. Get value from fir.box
+        const unsigned rank = sym.Rank();
+        for (unsigned dim = 0; dim < rank; ++dim) {
+          mlir::Value dimVal = builder.createIntegerConstant(loc, idxTy, dim);
+          auto dims = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy,
+                                                     box, dimVal);
+          lbounds.emplace_back(dims.getResult(0));
+        }
+      }
+    }
+
+    if (canReadCapturedBoxValue(converter, sym)) {
+      fir::BoxValue boxValue(box, lbounds, /*explicitParams=*/llvm::None);
+      args.symMap.addSymbol(sym,
+                            fir::factory::readBoxValue(builder, loc, boxValue));
+    } else {
+      // Keep variable as a fir.box.
+      // If this is an optional that is absent, the fir.box needs to be an
+      // AbsentOp result, otherwise it will not work properly with IsPresentOp
+      // (absent boxes are null descriptor addresses, not descriptors containing
+      // a null base address).
+      if (Fortran::semantics::IsOptional(sym)) {
+        auto boxTy = box.getType().cast<fir::BoxType>();
+        auto eleTy = boxTy.getEleTy();
+        if (!fir::isa_ref_type(eleTy))
+          eleTy = builder.getRefType(eleTy);
+        auto addr = builder.create<fir::BoxAddrOp>(loc, eleTy, box);
+        mlir::Value isPresent = builder.genIsNotNull(loc, addr);
+        auto absentBox = builder.create<fir::AbsentOp>(loc, boxTy);
+        box = builder.create<mlir::arith::SelectOp>(loc, isPresent, box,
+                                                    absentBox);
+      }
+      fir::BoxValue boxValue(box, lbounds, /*explicitParams=*/llvm::None);
+      args.symMap.addSymbol(sym, boxValue);
+    }
+  }
+
+private:
+  /// Can the fir.box from the host link be read into simpler values ?
+  /// Later, without the symbol information, it might not be possible
+  /// to tell if the fir::BoxValue from the host link is contiguous.
+  static bool
+  canReadCapturedBoxValue(Fortran::lower::AbstractConverter &converter,
+                          const Fortran::semantics::Symbol &sym) {
+    bool isScalarOrContiguous =
+        sym.Rank() == 0 || Fortran::evaluate::IsSimplyContiguous(
+                               Fortran::evaluate::AsGenericExpr(sym).value(),
+                               converter.getFoldingContext());
+    const Fortran::semantics::DeclTypeSpec *type = sym.GetType();
+    bool isPolymorphic = type && type->IsPolymorphic();
+    return isScalarOrContiguous && !isPolymorphic &&
+           !isDerivedWithLengthParameters(sym);
+  }
+};
+
+/// Dispatch \p visitor to the CapturedSymbols which is handling how host
+/// association is implemented for this kind of symbols. This ensures the same
+/// dispatch decision is taken when building the tuple type, when creating the
+/// tuple, and when instantiating host associated variables from it.
+template <typename T>
+typename T::Result
+walkCaptureCategories(T visitor, Fortran::lower::AbstractConverter &converter,
+                      const Fortran::semantics::Symbol &sym) {
+  if (isDerivedWithLengthParameters(sym))
+    // Should be boxed.
+    TODO(converter.genLocation(sym.name()),
+         "host associated derived type with length parameters");
+  Fortran::lower::BoxAnalyzer ba;
+  // Do not analyze procedures, they may be subroutines with no types that would
+  // crash the analysis.
+  if (Fortran::semantics::IsProcedure(sym))
+    return CapturedProcedure::visit(visitor, converter, sym, ba);
+  ba.analyze(sym);
+  if (Fortran::evaluate::IsAllocatableOrPointer(sym))
+    return CapturedAllocatableAndPointer::visit(visitor, converter, sym, ba);
+  if (ba.isArray())
+    return CapturedArrays::visit(visitor, converter, sym, ba);
+  if (ba.isChar())
+    return CapturedCharacterScalars::visit(visitor, converter, sym, ba);
+  assert(ba.isTrivial() && "must be trivial scalar");
+  return CapturedSimpleScalars::visit(visitor, converter, sym, ba);
+}
+
+// `t` should be the result of getArgumentType, which has a type of
+// `!fir.ref<tuple<...>>`.
+static mlir::TupleType unwrapTupleTy(mlir::Type t) {
+  return fir::dyn_cast_ptrEleTy(t).cast<mlir::TupleType>();
+}
+
+static mlir::Value genTupleCoor(fir::FirOpBuilder &builder, mlir::Location loc,
+                                mlir::Type varTy, mlir::Value tupleArg,
+                                mlir::Value offset) {
+  // fir.ref<fir.ref> and fir.ptr<fir.ref> are forbidden. Use
+  // fir.llvm_ptr if needed.
+  auto ty = varTy.isa<fir::ReferenceType>()
+                ? mlir::Type(fir::LLVMPointerType::get(varTy))
+                : mlir::Type(builder.getRefType(varTy));
+  return builder.create<fir::CoordinateOp>(loc, ty, tupleArg, offset);
+}
+
+void Fortran::lower::HostAssociations::hostProcedureBindings(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::lower::SymMap &symMap) {
+  if (symbols.empty())
+    return;
+
+  // Create the tuple variable.
+  mlir::TupleType tupTy = unwrapTupleTy(getArgumentType(converter));
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+  mlir::Location loc = converter.getCurrentLocation();
+  auto hostTuple = builder.create<fir::AllocaOp>(loc, tupTy);
+  mlir::IntegerType offTy = builder.getIntegerType(32);
+
+  // Walk the list of symbols and update the pointers in the tuple.
+  for (auto s : llvm::enumerate(symbols)) {
+    auto indexInTuple = s.index();
+    mlir::Value off = builder.createIntegerConstant(loc, offTy, indexInTuple);
+    mlir::Type varTy = tupTy.getType(indexInTuple);
+    mlir::Value eleOff = genTupleCoor(builder, loc, varTy, hostTuple, off);
+    InstantiateHostTuple instantiateHostTuple{
+        symMap.lookupSymbol(s.value()).toExtendedValue(), eleOff, loc};
+    walkCaptureCategories(instantiateHostTuple, converter, *s.value());
+  }
+
+  converter.bindHostAssocTuple(hostTuple);
+}
+
+void Fortran::lower::HostAssociations::internalProcedureBindings(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::lower::SymMap &symMap) {
+  if (symbols.empty())
+    return;
+
+  // Find the argument with the tuple type. The argument ought to be appended.
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+  mlir::Type argTy = getArgumentType(converter);
+  mlir::TupleType tupTy = unwrapTupleTy(argTy);
+  mlir::Location loc = converter.getCurrentLocation();
+  mlir::FuncOp func = builder.getFunction();
+  mlir::Value tupleArg;
+  for (auto [ty, arg] : llvm::reverse(
+           llvm::zip(func.getType().getInputs(), func.front().getArguments())))
+    if (ty == argTy) {
+      tupleArg = arg;
+      break;
+    }
+  if (!tupleArg)
+    fir::emitFatalError(loc, "no host association argument found");
+
+  converter.bindHostAssocTuple(tupleArg);
+
+  mlir::IntegerType offTy = builder.getIntegerType(32);
+
+  // Walk the list and add the bindings to the symbol table.
+  for (auto s : llvm::enumerate(symbols)) {
+    mlir::Value off = builder.createIntegerConstant(loc, offTy, s.index());
+    mlir::Type varTy = tupTy.getType(s.index());
+    mlir::Value eleOff = genTupleCoor(builder, loc, varTy, tupleArg, off);
+    mlir::Value valueInTuple = builder.create<fir::LoadOp>(loc, eleOff);
+    GetFromTuple getFromTuple{symMap, valueInTuple, loc};
+    walkCaptureCategories(getFromTuple, converter, *s.value());
+  }
+}
+
+mlir::Type Fortran::lower::HostAssociations::getArgumentType(
+    Fortran::lower::AbstractConverter &converter) {
+  if (symbols.empty())
+    return {};
+  if (argType)
+    return argType;
+
+  // Walk the list of Symbols and create their types. Wrap them in a reference
+  // to a tuple.
+  mlir::MLIRContext *ctxt = &converter.getMLIRContext();
+  llvm::SmallVector<mlir::Type> tupleTys;
+  for (const Fortran::semantics::Symbol *sym : symbols)
+    tupleTys.emplace_back(
+        walkCaptureCategories(GetTypeInTuple{}, converter, *sym));
+  argType = fir::ReferenceType::get(mlir::TupleType::get(ctxt, tupleTys));
+  return argType;
+}
index 0b713ec..3679507 100644 (file)
@@ -187,15 +187,13 @@ llvm::raw_ostream &fir::operator<<(llvm::raw_ostream &os,
 /// always be called, so it should not have any functional side effects,
 /// the const is here to enforce that.
 bool fir::MutableBoxValue::verify() const {
-  auto type = fir::dyn_cast_ptrEleTy(getAddr().getType());
+  mlir::Type type = fir::dyn_cast_ptrEleTy(getAddr().getType());
   if (!type)
     return false;
   auto box = type.dyn_cast<fir::BoxType>();
   if (!box)
     return false;
-  auto eleTy = box.getEleTy();
-  if (!eleTy.isa<fir::PointerType>() && !eleTy.isa<fir::HeapType>())
-    return false;
+  // A boxed value always takes a memory reference,
 
   auto nParams = lenParams.size();
   if (isCharacter()) {
index df2f526..d783961 100644 (file)
@@ -6,6 +6,7 @@ add_flang_library(FIRBuilder
   Complex.cpp
   DoLoopHelper.cpp
   FIRBuilder.cpp
+  LowLevelIntrinsics.cpp
   MutableBox.cpp
   Runtime/Assign.cpp
   Runtime/Character.cpp
diff --git a/flang/lib/Optimizer/Builder/LowLevelIntrinsics.cpp b/flang/lib/Optimizer/Builder/LowLevelIntrinsics.cpp
new file mode 100644 (file)
index 0000000..f95a4fd
--- /dev/null
@@ -0,0 +1,38 @@
+//===-- LowLevelIntrinsics.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+//
+// Low level intrinsic functions.
+//
+// These include LLVM intrinsic calls and standard C library calls.
+// Target-specific calls, such as OS functions, should be factored in other
+// file(s).
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+
+mlir::FuncOp fir::factory::getLlvmStackSave(fir::FirOpBuilder &builder) {
+  auto ptrTy = builder.getRefType(builder.getIntegerType(8));
+  auto funcTy =
+      mlir::FunctionType::get(builder.getContext(), llvm::None, {ptrTy});
+  return builder.addNamedFunction(builder.getUnknownLoc(), "llvm.stacksave",
+                                  funcTy);
+}
+
+mlir::FuncOp fir::factory::getLlvmStackRestore(fir::FirOpBuilder &builder) {
+  auto ptrTy = builder.getRefType(builder.getIntegerType(8));
+  auto funcTy =
+      mlir::FunctionType::get(builder.getContext(), {ptrTy}, llvm::None);
+  return builder.addNamedFunction(builder.getUnknownLoc(), "llvm.stackrestore",
+                                  funcTy);
+}
index 93b94f0..60234fc 100644 (file)
@@ -857,6 +857,14 @@ bool fir::VectorType::isValidElementType(mlir::Type t) {
   return isa_real(t) || isa_integer(t);
 }
 
+bool fir::isCharacterProcedureTuple(mlir::Type ty, bool acceptRawFunc) {
+  mlir::TupleType tuple = ty.dyn_cast<mlir::TupleType>();
+  return tuple && tuple.size() == 2 &&
+         (tuple.getType(0).isa<fir::BoxProcType>() ||
+          (acceptRawFunc && tuple.getType(0).isa<mlir::FunctionType>())) &&
+         fir::isa_integer(tuple.getType(1));
+}
+
 //===----------------------------------------------------------------------===//
 // FIROpsDialect
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/host-associated.f90 b/flang/test/Lower/host-associated.f90
new file mode 100644 (file)
index 0000000..17aeba1
--- /dev/null
@@ -0,0 +1,106 @@
+! Test internal procedure host association lowering.
+! RUN: bbc %s -o - -emit-fir | FileCheck %s
+
+! -----------------------------------------------------------------------------
+!     Test non character intrinsic scalars
+! -----------------------------------------------------------------------------
+
+!!! Test scalar (with implicit none)
+
+! CHECK-LABEL: func @_QPtest1(
+subroutine test1
+  implicit none
+  integer i
+  ! CHECK-DAG: %[[i:.*]] = fir.alloca i32 {{.*}}uniq_name = "_QFtest1Ei"
+  ! CHECK-DAG: %[[tup:.*]] = fir.alloca tuple<!fir.ref<i32>>
+  ! CHECK: %[[addr:.*]] = fir.coordinate_of %[[tup]], %c0
+  ! CHECK: fir.store %[[i]] to %[[addr]] : !fir.llvm_ptr<!fir.ref<i32>>
+  ! CHECK: fir.call @_QFtest1Ptest1_internal(%[[tup]]) : (!fir.ref<tuple<!fir.ref<i32>>>) -> ()
+  call test1_internal
+  print *, i
+contains
+  ! CHECK-LABEL: func @_QFtest1Ptest1_internal(
+  ! CHECK-SAME: %[[arg:[^:]*]]: !fir.ref<tuple<!fir.ref<i32>>> {fir.host_assoc}) {
+  ! CHECK: %[[iaddr:.*]] = fir.coordinate_of %[[arg]], %c0
+  ! CHECK: %[[i:.*]] = fir.load %[[iaddr]] : !fir.llvm_ptr<!fir.ref<i32>>
+  ! CHECK: %[[val:.*]] = fir.call @_QPifoo() : () -> i32
+  ! CHECK: fir.store %[[val]] to %[[i]] : !fir.ref<i32>
+  subroutine test1_internal
+    integer, external :: ifoo
+    i = ifoo()
+  end subroutine test1_internal
+end subroutine test1
+
+!!! Test scalar
+
+! CHECK-LABEL: func @_QPtest2() {
+subroutine test2
+  a = 1.0
+  b = 2.0
+  ! CHECK: %[[tup:.*]] = fir.alloca tuple<!fir.ref<f32>, !fir.ref<f32>>
+  ! CHECK: %[[a0:.*]] = fir.coordinate_of %[[tup]], %c0
+  ! CHECK: fir.store %{{.*}} to %[[a0]] : !fir.llvm_ptr<!fir.ref<f32>>
+  ! CHECK: %[[b0:.*]] = fir.coordinate_of %[[tup]], %c1
+  ! CHECK: fir.store %{{.*}} to %[[b0]] : !fir.llvm_ptr<!fir.ref<f32>>
+  ! CHECK: fir.call @_QFtest2Ptest2_internal(%[[tup]]) : (!fir.ref<tuple<!fir.ref<f32>, !fir.ref<f32>>>) -> ()
+  call test2_internal
+  print *, a, b
+contains
+  ! CHECK-LABEL: func @_QFtest2Ptest2_internal(
+  ! CHECK-SAME: %[[arg:[^:]*]]: !fir.ref<tuple<!fir.ref<f32>, !fir.ref<f32>>> {fir.host_assoc}) {
+  subroutine test2_internal
+    ! CHECK: %[[a:.*]] = fir.coordinate_of %[[arg]], %c0
+    ! CHECK: %[[aa:.*]] = fir.load %[[a]] : !fir.llvm_ptr<!fir.ref<f32>>
+    ! CHECK: %[[b:.*]] = fir.coordinate_of %[[arg]], %c1
+    ! CHECK: %{{.*}} = fir.load %[[b]] : !fir.llvm_ptr<!fir.ref<f32>>
+    ! CHECK: fir.alloca
+    ! CHECK: fir.load %[[aa]] : !fir.ref<f32>
+    c = a
+    a = b
+    b = c
+    call test2_inner
+  end subroutine test2_internal
+
+  ! CHECK-LABEL: func @_QFtest2Ptest2_inner(
+  ! CHECK-SAME: %[[arg:[^:]*]]: !fir.ref<tuple<!fir.ref<f32>, !fir.ref<f32>>> {fir.host_assoc}) {
+  subroutine test2_inner
+    ! CHECK: %[[a:.*]] = fir.coordinate_of %[[arg]], %c0
+    ! CHECK: %[[aa:.*]] = fir.load %[[a]] : !fir.llvm_ptr<!fir.ref<f32>>
+    ! CHECK: %[[b:.*]] = fir.coordinate_of %[[arg]], %c1
+    ! CHECK: %[[bb:.*]] = fir.load %[[b]] : !fir.llvm_ptr<!fir.ref<f32>>
+    ! CHECK-DAG: %[[bd:.*]] = fir.load %[[bb]] : !fir.ref<f32>
+    ! CHECK-DAG: %[[ad:.*]] = fir.load %[[aa]] : !fir.ref<f32>
+    ! CHECK: %{{.*}} = arith.cmpf ogt, %[[ad]], %[[bd]] : f32
+    if (a > b) then
+       b = b + 2.0
+    end if
+  end subroutine test2_inner
+end subroutine test2
+
+! -----------------------------------------------------------------------------
+!     Test non character scalars
+! -----------------------------------------------------------------------------
+
+! CHECK-LABEL: func @_QPtest6(
+! CHECK-SAME: %[[c:.*]]: !fir.boxchar<1>
+subroutine test6(c)
+  character(*) :: c
+  ! CHECK: %[[cunbox:.*]]:2 = fir.unboxchar %arg0 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
+  ! CHECK: %[[tup:.*]] = fir.alloca tuple<!fir.boxchar<1>>
+  ! CHECK: %[[coor:.*]] = fir.coordinate_of %[[tup]], %c0{{.*}} : (!fir.ref<tuple<!fir.boxchar<1>>>, i32) -> !fir.ref<!fir.boxchar<1>>
+  ! CHECK: %[[emboxchar:.*]] = fir.emboxchar %[[cunbox]]#0, %[[cunbox]]#1 : (!fir.ref<!fir.char<1,?>>, index) -> !fir.boxchar<1>
+  ! CHECK: fir.store %[[emboxchar]] to %[[coor]] : !fir.ref<!fir.boxchar<1>>
+  ! CHECK: fir.call @_QFtest6Ptest6_inner(%[[tup]]) : (!fir.ref<tuple<!fir.boxchar<1>>>) -> ()
+  call test6_inner
+  print *, c
+
+contains
+  ! CHECK-LABEL: func @_QFtest6Ptest6_inner(
+  ! CHECK-SAME: %[[tup:.*]]: !fir.ref<tuple<!fir.boxchar<1>>> {fir.host_assoc}) {
+  subroutine test6_inner
+    ! CHECK: %[[coor:.*]] = fir.coordinate_of %[[tup]], %c0{{.*}} : (!fir.ref<tuple<!fir.boxchar<1>>>, i32) -> !fir.ref<!fir.boxchar<1>>
+    ! CHECK: %[[load:.*]] = fir.load %[[coor]] : !fir.ref<!fir.boxchar<1>>
+    ! CHECK: fir.unboxchar %[[load]] : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
+    c = "Hi there"
+  end subroutine test6_inner
+end subroutine test6