[flang] Handle lowering arguments in subroutine and function
authorValentin Clement <clementval@gmail.com>
Wed, 16 Feb 2022 19:27:23 +0000 (20:27 +0100)
committerValentin Clement <clementval@gmail.com>
Wed, 16 Feb 2022 19:28:07 +0000 (20:28 +0100)
This patch adds infrsatrcutrue to be able to lower
arguments in functions and subroutines.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D119957

Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
Co-authored-by: Jean Perier <jperier@nvidia.com>
flang/include/flang/Lower/CallInterface.h
flang/lib/Lower/Bridge.cpp
flang/lib/Lower/CallInterface.cpp
flang/lib/Lower/ConvertVariable.cpp
flang/test/Lower/arguments.f90 [new file with mode: 0644]

index a8f08ac..896fde8 100644 (file)
@@ -85,6 +85,26 @@ class CallInterface {
   friend CallInterfaceImpl<T>;
 
 public:
+  /// Enum the different ways an entity can be passed-by
+  enum class PassEntityBy {
+    BaseAddress,
+    BoxChar,
+    // passing a read-only descriptor
+    Box,
+    // passing a writable descriptor
+    MutableBox,
+    AddressAndLength,
+    /// Value means passed by value at the mlir level, it is not necessarily
+    /// implied by Fortran Value attribute.
+    Value,
+    /// ValueAttribute means dummy has the the Fortran VALUE attribute.
+    BaseAddressValueAttribute,
+    CharBoxValueAttribute, // BoxChar with VALUE
+    // Passing a character procedure as a <procedure address, result length>
+    // tuple.
+    CharProcTuple
+  };
+
   /// Different properties of an entity that can be passed/returned.
   /// One-to-One mapping with PassEntityBy but for
   /// PassEntityBy::AddressAndLength that has two properties.
@@ -105,8 +125,10 @@ public:
   /// FirPlaceHolder are place holders for the mlir inputs and outputs that are
   /// created during the first pass before the mlir::FuncOp is created.
   struct FirPlaceHolder {
-    FirPlaceHolder(mlir::Type t, int passedPosition, Property p)
-        : type{t}, passedEntityPosition{passedPosition}, property{p} {}
+    FirPlaceHolder(mlir::Type t, int passedPosition, Property p,
+                   llvm::ArrayRef<mlir::NamedAttribute> attrs)
+        : type{t}, passedEntityPosition{passedPosition}, property{p},
+          attributes{attrs.begin(), attrs.end()} {}
     /// Type for this input/output
     mlir::Type type;
     /// Position of related passedEntity in passedArguments.
@@ -116,8 +138,41 @@ public:
     /// Indicate property of the entity passedEntityPosition that must be passed
     /// through this argument.
     Property property;
+    /// MLIR attributes for this argument
+    llvm::SmallVector<mlir::NamedAttribute> attributes;
   };
 
+  /// PassedEntity is what is provided back to the CallInterface user.
+  /// It describe how the entity is plugged in the interface
+  struct PassedEntity {
+    /// Is the dummy argument optional ?
+    bool isOptional() const;
+    /// Can the argument be modified by the callee ?
+    bool mayBeModifiedByCall() const;
+    /// Can the argument be read by the callee ?
+    bool mayBeReadByCall() const;
+    /// How entity is passed by.
+    PassEntityBy passBy;
+    /// What is the entity (SymbolRef for callee/ActualArgument* for caller)
+    /// What is the related mlir::FuncOp argument(s) (mlir::Value for callee /
+    /// index for the caller).
+    FortranEntity entity;
+    FirValue firArgument;
+    FirValue firLength; /* only for AddressAndLength */
+
+    /// Pointer to the argument characteristics. Nullptr for results.
+    const Fortran::evaluate::characteristics::DummyArgument *characteristics =
+        nullptr;
+  };
+
+  /// Return a container of Symbol/ActualArgument* and how they must
+  /// be plugged with the mlir::FuncOp.
+  llvm::ArrayRef<PassedEntity> getPassedArguments() const {
+    return passedArguments;
+  }
+  /// In case the result must be passed by the caller, indicate how.
+  /// nullopt if the result is not passed by the caller.
+  std::optional<PassedEntity> getPassedResult() const { return passedResult; }
   /// Returns the mlir function type
   mlir::FunctionType genFunctionType();
 
@@ -134,9 +189,16 @@ protected:
   /// Entry point to be called by child ctor to analyze the signature and
   /// create/find the mlir::FuncOp. Child needs to be initialized first.
   void declare();
+  /// Second pass entry point, once the mlir::FuncOp is created.
+  /// Nothing is done if it was already called.
+  void mapPassedEntities();
+  void mapBackInputToPassedEntity(const FirPlaceHolder &, FirValue);
 
   llvm::SmallVector<FirPlaceHolder> outputs;
+  llvm::SmallVector<FirPlaceHolder> inputs;
   mlir::FuncOp func;
+  llvm::SmallVector<PassedEntity> passedArguments;
+  std::optional<PassedEntity> passedResult;
 
   Fortran::lower::AbstractConverter &converter;
   /// Store characteristic once created, it is required for further information
@@ -165,6 +227,10 @@ public:
   Fortran::evaluate::characteristics::Procedure characterize() const;
   bool isMainProgram() const;
 
+  Fortran::lower::pft::FunctionLikeUnit &getCallDescription() const {
+    return funit;
+  }
+
   /// On the callee side it does not matter whether the procedure is
   /// called through pointers or not.
   bool isIndirectCall() const { return false; }
index 6e7f56c..cfb326c 100644 (file)
@@ -227,6 +227,59 @@ public:
     localSymbols.clear();
   }
 
+  /// Map mlir function block arguments to the corresponding Fortran dummy
+  /// variables. When the result is passed as a hidden argument, the Fortran
+  /// result is also mapped. The symbol map is used to hold this mapping.
+  void mapDummiesAndResults(Fortran::lower::pft::FunctionLikeUnit &funit,
+                            const Fortran::lower::CalleeInterface &callee) {
+    assert(builder && "require a builder object at this point");
+    using PassBy = Fortran::lower::CalleeInterface::PassEntityBy;
+    auto mapPassedEntity = [&](const auto arg) -> void {
+      if (arg.passBy == PassBy::AddressAndLength) {
+        // // TODO: now that fir call has some attributes regarding character
+        // // return, PassBy::AddressAndLength should be retired.
+        // mlir::Location loc = toLocation();
+        // fir::factory::CharacterExprHelper charHelp{*builder, loc};
+        // mlir::Value box =
+        //     charHelp.createEmboxChar(arg.firArgument, arg.firLength);
+        // addSymbol(arg.entity->get(), box);
+      } else {
+        if (arg.entity.has_value()) {
+          addSymbol(arg.entity->get(), arg.firArgument);
+        } else {
+          // assert(funit.parentHasHostAssoc());
+          // funit.parentHostAssoc().internalProcedureBindings(*this,
+          //                                                   localSymbols);
+        }
+      }
+    };
+    for (const Fortran::lower::CalleeInterface::PassedEntity &arg :
+         callee.getPassedArguments())
+      mapPassedEntity(arg);
+
+    // Allocate local skeleton instances of dummies from other entry points.
+    // Most of these locals will not survive into final generated code, but
+    // some will.  It is illegal to reference them at run time if they do.
+    for (const Fortran::semantics::Symbol *arg :
+         funit.nonUniversalDummyArguments) {
+      if (lookupSymbol(*arg))
+        continue;
+      mlir::Type type = genType(*arg);
+      // TODO: Account for VALUE arguments (and possibly other variants).
+      type = builder->getRefType(type);
+      addSymbol(*arg, builder->create<fir::UndefOp>(toLocation(), type));
+    }
+    if (std::optional<Fortran::lower::CalleeInterface::PassedEntity>
+            passedResult = callee.getPassedResult()) {
+      mapPassedEntity(*passedResult);
+      // FIXME: need to make sure things are OK here. addSymbol may not be OK
+      if (funit.primaryResult &&
+          passedResult->entity->get() != *funit.primaryResult)
+        addSymbol(*funit.primaryResult,
+                  getSymbolAddress(passedResult->entity->get()));
+    }
+  }
+
   /// Instantiate variable \p var and add it to the symbol map.
   /// See ConvertVariable.cpp.
   void instantiateVar(const Fortran::lower::pft::Variable &var) {
@@ -243,6 +296,8 @@ public:
     assert(builder && "FirOpBuilder did not instantiate");
     builder->setInsertionPointToStart(&func.front());
 
+    mapDummiesAndResults(funit, callee);
+
     for (const Fortran::lower::pft::Variable &var :
          funit.getOrderedSymbolTable()) {
       const Fortran::semantics::Symbol &sym = var.getSymbol();
@@ -319,6 +374,17 @@ private:
     return {};
   }
 
+  /// Add the symbol to the local map and return `true`. If the symbol is
+  /// already in the map and \p forced is `false`, the map is not updated.
+  /// Instead the value `false` is returned.
+  bool addSymbol(const Fortran::semantics::SymbolRef sym, mlir::Value val,
+                 bool forced = false) {
+    if (!forced && lookupSymbol(sym))
+      return false;
+    localSymbols.addSymbol(sym, val, forced);
+    return true;
+  }
+
   void genFIRBranch(mlir::Block *targetBlock) {
     assert(targetBlock && "missing unconditional target block");
     builder->create<cf::BranchOp>(toLocation(), targetBlock);
index 8bf110c..93c8f02 100644 (file)
@@ -77,6 +77,7 @@ mlir::FuncOp Fortran::lower::CalleeInterface::addEntryBlockAndMapArguments() {
   // On the callee side, directly map the mlir::value argument of
   // the function block to the Fortran symbols.
   func.addEntryBlock();
+  mapPassedEntities();
   return func;
 }
 
@@ -122,10 +123,58 @@ void Fortran::lower::CallInterface<T>::declare() {
       func = fir::FirOpBuilder::createFunction(loc, module, name, ty);
       if (const Fortran::semantics::Symbol *sym = side().getProcedureSymbol())
         addSymbolAttribute(func, *sym, converter.getMLIRContext());
+      for (const auto &placeHolder : llvm::enumerate(inputs))
+        if (!placeHolder.value().attributes.empty())
+          func.setArgAttrs(placeHolder.index(), placeHolder.value().attributes);
     }
   }
 }
 
+/// Once the signature has been analyzed and the mlir::FuncOp was built/found,
+/// map the fir inputs to Fortran entities (the symbols or expressions).
+template <typename T>
+void Fortran::lower::CallInterface<T>::mapPassedEntities() {
+  // map back fir inputs to passed entities
+  if constexpr (std::is_same_v<T, Fortran::lower::CalleeInterface>) {
+    assert(inputs.size() == func.front().getArguments().size() &&
+           "function previously created with different number of arguments");
+    for (auto [fst, snd] : llvm::zip(inputs, func.front().getArguments()))
+      mapBackInputToPassedEntity(fst, snd);
+  } else {
+    // On the caller side, map the index of the mlir argument position
+    // to Fortran ActualArguments.
+    int firPosition = 0;
+    for (const FirPlaceHolder &placeHolder : inputs)
+      mapBackInputToPassedEntity(placeHolder, firPosition++);
+  }
+}
+
+template <typename T>
+void Fortran::lower::CallInterface<T>::mapBackInputToPassedEntity(
+    const FirPlaceHolder &placeHolder, FirValue firValue) {
+  PassedEntity &passedEntity =
+      placeHolder.passedEntityPosition == FirPlaceHolder::resultEntityPosition
+          ? passedResult.value()
+          : passedArguments[placeHolder.passedEntityPosition];
+  if (placeHolder.property == Property::CharLength)
+    passedEntity.firLength = firValue;
+  else
+    passedEntity.firArgument = firValue;
+}
+
+static const std::vector<Fortran::semantics::Symbol *> &
+getEntityContainer(Fortran::lower::pft::FunctionLikeUnit &funit) {
+  return funit.getSubprogramSymbol()
+      .get<Fortran::semantics::SubprogramDetails>()
+      .dummyArgs();
+}
+
+static const Fortran::semantics::Symbol &
+getDataObjectEntity(const Fortran::semantics::Symbol *arg) {
+  assert(arg && "expect symbol for data object entity");
+  return *arg;
+}
+
 //===----------------------------------------------------------------------===//
 // CallInterface implementation: this part is common to both caller and caller
 // sides.
@@ -136,9 +185,14 @@ void Fortran::lower::CallInterface<T>::declare() {
 template <typename T>
 class Fortran::lower::CallInterfaceImpl {
   using CallInterface = Fortran::lower::CallInterface<T>;
+  using PassEntityBy = typename CallInterface::PassEntityBy;
+  using PassedEntity = typename CallInterface::PassedEntity;
+  using FortranEntity = typename CallInterface::FortranEntity;
   using FirPlaceHolder = typename CallInterface::FirPlaceHolder;
   using Property = typename CallInterface::Property;
   using TypeAndShape = Fortran::evaluate::characteristics::TypeAndShape;
+  using DummyCharacteristics =
+      Fortran::evaluate::characteristics::DummyArgument;
 
 public:
   CallInterfaceImpl(CallInterface &i)
@@ -153,6 +207,24 @@ public:
     else if (interface.side().hasAlternateReturns())
       addFirResult(mlir::IndexType::get(&mlirContext),
                    FirPlaceHolder::resultEntityPosition, Property::Value);
+    // Handle arguments
+    const auto &argumentEntities =
+        getEntityContainer(interface.side().getCallDescription());
+    for (auto pair : llvm::zip(procedure.dummyArguments, argumentEntities)) {
+      const Fortran::evaluate::characteristics::DummyArgument
+          &argCharacteristics = std::get<0>(pair);
+      std::visit(
+          Fortran::common::visitors{
+              [&](const auto &dummy) {
+                const auto &entity = getDataObjectEntity(std::get<1>(pair));
+                handleImplicitDummy(&argCharacteristics, dummy, entity);
+              },
+              [&](const Fortran::evaluate::characteristics::AlternateReturn &) {
+                // nothing to do
+              },
+          },
+          argCharacteristics.u);
+    }
   }
 
   void buildExplicitInterface(
@@ -248,9 +320,78 @@ private:
         getConverter().getFoldingContext(), std::move(expr)));
   }
 
-  void addFirResult(mlir::Type type, int entityPosition, Property p) {
-    interface.outputs.emplace_back(FirPlaceHolder{type, entityPosition, p});
+  /// Return a vector with an attribute with the name of the argument if this
+  /// is a callee interface and the name is available. Otherwise, just return
+  /// an empty vector.
+  llvm::SmallVector<mlir::NamedAttribute>
+  dummyNameAttr(const FortranEntity &entity) {
+    if constexpr (std::is_same_v<FortranEntity,
+                                 std::optional<Fortran::common::Reference<
+                                     const Fortran::semantics::Symbol>>>) {
+      if (entity.has_value()) {
+        const Fortran::semantics::Symbol *argument = &*entity.value();
+        // "fir.bindc_name" is used for arguments for the sake of consistency
+        // with other attributes carrying surface syntax names in FIR.
+        return {mlir::NamedAttribute(
+            mlir::StringAttr::get(&mlirContext, "fir.bindc_name"),
+            mlir::StringAttr::get(&mlirContext,
+                                  toStringRef(argument->name())))};
+      }
+    }
+    return {};
+  }
+
+  void handleImplicitDummy(
+      const DummyCharacteristics *characteristics,
+      const Fortran::evaluate::characteristics::DummyDataObject &obj,
+      const FortranEntity &entity) {
+    Fortran::evaluate::DynamicType dynamicType = obj.type.type();
+    if (dynamicType.category() == Fortran::common::TypeCategory::Character) {
+      mlir::Type boxCharTy =
+          fir::BoxCharType::get(&mlirContext, dynamicType.kind());
+      addFirOperand(boxCharTy, nextPassedArgPosition(), Property::BoxChar,
+                    dummyNameAttr(entity));
+      addPassedArg(PassEntityBy::BoxChar, entity, characteristics);
+    } else {
+      // non-PDT derived type allowed in implicit interface.
+      Fortran::common::TypeCategory cat = dynamicType.category();
+      mlir::Type type = getConverter().genType(cat, dynamicType.kind());
+      fir::SequenceType::Shape bounds = getBounds(obj.type.shape());
+      if (!bounds.empty())
+        type = fir::SequenceType::get(bounds, type);
+      mlir::Type refType = fir::ReferenceType::get(type);
+      addFirOperand(refType, nextPassedArgPosition(), Property::BaseAddress,
+                    dummyNameAttr(entity));
+      addPassedArg(PassEntityBy::BaseAddress, entity, characteristics);
+    }
+  }
+
+  void handleImplicitDummy(
+      const DummyCharacteristics *characteristics,
+      const Fortran::evaluate::characteristics::DummyProcedure &proc,
+      const FortranEntity &entity) {
+    TODO(interface.converter.getCurrentLocation(),
+         "handleImlicitDummy DummyProcedure");
+  }
+
+  void
+  addFirOperand(mlir::Type type, int entityPosition, Property p,
+                llvm::ArrayRef<mlir::NamedAttribute> attributes = llvm::None) {
+    interface.inputs.emplace_back(
+        FirPlaceHolder{type, entityPosition, p, attributes});
+  }
+  void
+  addFirResult(mlir::Type type, int entityPosition, Property p,
+               llvm::ArrayRef<mlir::NamedAttribute> attributes = llvm::None) {
+    interface.outputs.emplace_back(
+        FirPlaceHolder{type, entityPosition, p, attributes});
+  }
+  void addPassedArg(PassEntityBy p, FortranEntity entity,
+                    const DummyCharacteristics *characteristics) {
+    interface.passedArguments.emplace_back(
+        PassedEntity{p, entity, {}, {}, characteristics});
   }
+  int nextPassedArgPosition() { return interface.passedArguments.size(); }
 
   Fortran::lower::AbstractConverter &getConverter() {
     return interface.converter;
@@ -273,9 +414,13 @@ void Fortran::lower::CallInterface<T>::determineInterface(
 template <typename T>
 mlir::FunctionType Fortran::lower::CallInterface<T>::genFunctionType() {
   llvm::SmallVector<mlir::Type> returnTys;
+  llvm::SmallVector<mlir::Type> inputTys;
   for (const FirPlaceHolder &placeHolder : outputs)
     returnTys.emplace_back(placeHolder.type);
-  return mlir::FunctionType::get(&converter.getMLIRContext(), {}, returnTys);
+  for (const FirPlaceHolder &placeHolder : inputs)
+    inputTys.emplace_back(placeHolder.type);
+  return mlir::FunctionType::get(&converter.getMLIRContext(), inputTys,
+                                 returnTys);
 }
 
 template class Fortran::lower::CallInterface<Fortran::lower::CalleeInterface>;
index c207f60..bd34736 100644 (file)
@@ -66,9 +66,25 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
                              Fortran::lower::SymMap &symMap) {
   assert(!var.isAlias());
   const Fortran::semantics::Symbol &sym = var.getSymbol();
+  const bool isDummy = Fortran::semantics::IsDummy(sym);
+  const bool isResult = Fortran::semantics::IsFunctionResult(sym);
   if (symMap.lookupSymbol(sym))
     return;
+
   const mlir::Location loc = converter.genLocation(sym.name());
+  if (isDummy) {
+    // This is an argument.
+    if (!symMap.lookupSymbol(sym))
+      mlir::emitError(loc, "symbol \"")
+          << toStringRef(sym.name()) << "\" must already be in map";
+    return;
+  } else if (isResult) {
+    // Some Fortran results may be passed by argument (e.g. derived
+    // types)
+    if (symMap.lookupSymbol(sym))
+      return;
+  }
+  // Otherwise, it's a local variable or function result.
   mlir::Value local = createNewLocal(converter, loc, var, {});
   symMap.addSymbol(sym, local);
 }
diff --git a/flang/test/Lower/arguments.f90 b/flang/test/Lower/arguments.f90
new file mode 100644 (file)
index 0000000..e451510
--- /dev/null
@@ -0,0 +1,48 @@
+! RUN: bbc %s -o "-" -emit-fir | FileCheck %s
+
+subroutine sub1(a, b)
+  integer, intent(in) :: a
+  logical :: b
+end
+
+! Check that arguments are correctly set and no local allocation is happening.
+! CHECK-LABEL: func @_QPsub1(
+! CHECK-SAME:    %{{.*}}: !fir.ref<i32> {fir.bindc_name = "a"}, %{{.*}}: !fir.ref<!fir.logical<4>> {fir.bindc_name = "b"})
+! CHECK-NOT:     fir.alloc
+! CHECK:         return
+
+subroutine sub2(i)
+  integer :: i(2, 5)
+end
+
+! CHECK-LABEL: func @_QPsub2(
+! CHECK-SAME: %{{.*}}: !fir.ref<!fir.array<2x5xi32>>{{.*}})
+
+subroutine sub3(i)
+  real :: i(2)
+end
+
+! CHECK-LABEL: func @_QPsub3(
+! CHECK-SAME: %{{.*}}: !fir.ref<!fir.array<2xf32>>{{.*}})
+
+integer function fct1(a, b)
+  integer, intent(in) :: a
+  logical :: b
+end
+
+! CHECK-LABEL: func @_QPfct1(
+! CHECK-SAME:    %{{.*}}: !fir.ref<i32> {fir.bindc_name = "a"}, %{{.*}}: !fir.ref<!fir.logical<4>> {fir.bindc_name = "b"}) -> i32
+
+real function fct2(i)
+  integer :: i(2, 5)
+end
+
+! CHECK-LABEL: func @_QPfct2(
+! CHECK-SAME:    %{{.*}}: !fir.ref<!fir.array<2x5xi32>> {fir.bindc_name = "i"}) -> f32
+
+function fct3(i)
+  real :: i(2)
+end
+
+! CHECK-LABEL: func @_QPfct3(
+! CHECK-SAME:    %{{.*}}: !fir.ref<!fir.array<2xf32>> {fir.bindc_name = "i"}) -> f32