[flang] Rework .mod file writing for subprogram interfaces
authorTim Keith <tkeith@nvidia.com>
Fri, 29 Mar 2019 22:04:17 +0000 (15:04 -0700)
committerTim Keith <tkeith@nvidia.com>
Fri, 29 Mar 2019 22:04:17 +0000 (15:04 -0700)
A subprogram interface in a `.mod` file requires all of the symbols
needed to declare the function return value and dummy arguments.
Some of those were missing.

`SubprogramSymbolCollector` recursively discovers all such symbols,
including symbols used in type parameters, array bounds, character
lengths, parent types.

Common blocks require special handling: If any of the symbols that
are need appear in a common block, we have to include that common block
and all other symbols in it. To make that easier to figure out, add the
`commonBlock` property to `ObjectEntityDetails` to map the entity to
the common block it is in, if any.

Original-commit: flang-compiler/f18@08709f8e88c3b37bdf8461f34340dc345dfbb085
Reviewed-on: https://github.com/flang-compiler/f18/pull/368
Tree-same-pre-rewrite: false

flang/lib/semantics/mod-file.cc
flang/lib/semantics/resolve-names.cc
flang/lib/semantics/symbol.h
flang/test/semantics/CMakeLists.txt
flang/test/semantics/modfile04.f90
flang/test/semantics/modfile21.f90
flang/test/semantics/modfile23.f90 [new file with mode: 0644]

index ffa522a..1cac339 100644 (file)
@@ -16,6 +16,7 @@
 #include "scope.h"
 #include "semantics.h"
 #include "symbol.h"
+#include "../evaluate/traversal.h"
 #include "../parser/parsing.h"
 #include <algorithm>
 #include <cerrno>
@@ -59,6 +60,47 @@ static bool FileContentsMatch(
 static std::string GetHeader(const std::string &);
 static std::size_t GetFileSize(const std::string &);
 
+// Collect symbols needed for a subprogram interface
+class SubprogramSymbolCollector {
+public:
+  using SymbolSet = std::set<const Symbol *>;
+
+  SubprogramSymbolCollector(const Symbol &symbol)
+    : symbol_{symbol}, scope_{*symbol.scope()} {}
+  SymbolList symbols() const { return need_; }
+  SymbolSet imports() const { return imports_; }
+  void Collect();
+
+private:
+  const Symbol &symbol_;
+  const Scope &scope_;
+  bool isInterface_{false};
+  SymbolList need_;  // symbols that are needed
+  SymbolSet needSet_;  // symbols already in need_
+  SymbolSet useSet_;  // use-associations that might be needed
+  SymbolSet imports_;  // imports from host that are needed
+
+  void DoSymbol(const Symbol &);
+  void DoType(const DeclTypeSpec *);
+  void DoBound(const Bound &);
+  void DoParamValue(const ParamValue &);
+
+  using SymbolVector = std::vector<const Symbol *>;
+  struct SymbolVisitor : public virtual evaluate::VisitorBase<SymbolVector> {
+    explicit SymbolVisitor(int) {}
+    void Handle(const Symbol *symbol) { result().push_back(symbol); }
+  };
+
+  template<typename T> void DoExpr(evaluate::Expr<T> expr) {
+    evaluate::Visitor<SymbolVector, SymbolVisitor> visitor{0};
+    for (const Symbol *symbol : visitor.Traverse(expr)) {
+      CHECK(symbol && "bad symbol from Traverse");
+      DoSymbol(*symbol);
+    }
+  }
+
+};
+
 bool ModFileWriter::WriteAll() {
   WriteAll(context_.globalScope());
   return !context_.AnyFatalError();
@@ -131,6 +173,7 @@ void ModFileWriter::PutSymbols(const Scope &scope) {
     PutSymbol(typeBindings, symbol);
   }
   if (auto str{typeBindings.str()}; !str.empty()) {
+    CHECK(scope.kind() == Scope::Kind::DerivedType);
     decls_ << "contains\n" << str;
   }
 }
@@ -250,7 +293,9 @@ void ModFileWriter::PutSubprogram(const Symbol &symbol) {
   PutLower(os, symbol) << '(';
   int n = 0;
   for (const auto &dummy : details.dummyArgs()) {
-    if (n++ > 0) os << ',';
+    if (n++ > 0) {
+      os << ',';
+    }
     PutLower(os, *dummy);
   }
   os << ')';
@@ -260,14 +305,23 @@ void ModFileWriter::PutSubprogram(const Symbol &symbol) {
     if (result.name() != symbol.name()) {
       PutLower(os << " result(", result) << ')';
     }
-    os << '\n';
-    PutEntity(os, details.result());
-  } else {
-    os << '\n';
   }
-  for (const auto &dummy : details.dummyArgs()) {
-    PutEntity(os, *dummy);
+  os << '\n';
+
+  // walk symbols, collect ones needed
+  ModFileWriter writer{context_};
+  std::stringstream typeBindings;
+  SubprogramSymbolCollector collector{symbol};
+  collector.Collect();
+  for (const Symbol *need : collector.symbols()) {
+    writer.PutSymbol(typeBindings, need);
   }
+  CHECK(typeBindings.str().empty());
+  os << writer.uses_.str();
+  for (const Symbol *import : collector.imports()) {
+    decls_ << "import::" << import->name().ToString() << "\n";
+  }
+  os << writer.decls_.str();
   os << "end\n";
   if (isInterface) {
     os << "end interface\n";
@@ -320,30 +374,39 @@ void ModFileWriter::PutUseExtraAttr(
 std::vector<const Symbol *> CollectSymbols(const Scope &scope) {
   std::set<const Symbol *> symbols;  // to prevent duplicates
   std::vector<const Symbol *> sorted;
-  sorted.reserve(scope.size());
+  std::vector<const Symbol *> namelist;
+  std::vector<const Symbol *> common;
+  sorted.reserve(scope.size() + scope.commonBlocks().size());
   for (const auto &pair : scope) {
     auto *symbol{pair.second};
     if (!symbol->test(Symbol::Flag::ParentComp)) {
       if (symbols.insert(symbol).second) {
-        sorted.push_back(symbol);
+        if (symbol->has<NamelistDetails>()) {
+          namelist.push_back(symbol);
+        } else {
+          sorted.push_back(symbol);
+        }
       }
     }
   }
   for (const auto &pair : scope.commonBlocks()) {
-    auto *symbol{pair.second};
+    const Symbol *symbol{pair.second};
+    SourceName name{pair.first};
     if (symbols.insert(symbol).second) {
-      sorted.push_back(symbol);
+      common.push_back(symbol);
     }
   }
-  std::sort(sorted.begin(), sorted.end(), [](const Symbol *x, const Symbol *y) {
-    bool xIsNml{x->has<NamelistDetails>()};
-    bool yIsNml{y->has<NamelistDetails>()};
-    if (xIsNml != yIsNml) {
-      return xIsNml < yIsNml;
-    } else {
-      return x->name().begin() < y->name().begin();
-    }
-  });
+  // sort normal symbols, then namelists, then common blocks:
+  auto compareByOrder = [](const Symbol *x, const Symbol *y) {
+    CHECK(x != nullptr);
+    return x->name().begin() < y->name().begin();
+  };
+  auto cursor{sorted.begin()};
+  std::sort(cursor, sorted.end(), compareByOrder);
+  cursor = sorted.insert(sorted.end(), namelist.begin(), namelist.end());
+  std::sort(cursor, sorted.end(), compareByOrder);
+  cursor = sorted.insert(sorted.end(), common.begin(), common.end());
+  std::sort(cursor, sorted.end(), compareByOrder);
   return sorted;
 }
 
@@ -713,4 +776,105 @@ static std::string ModFilePath(const std::string &dir, const SourceName &name,
   PutLower(path, name.ToString()) << extension;
   return path.str();
 }
+
+void SubprogramSymbolCollector::Collect() {
+  const auto &details{symbol_.get<SubprogramDetails>()};
+  isInterface_ = details.isInterface();
+  if (details.isFunction()) {
+    DoSymbol(details.result());
+  }
+  for (const Symbol *dummyArg : details.dummyArgs()) {
+    CHECK(dummyArg);
+    DoSymbol(*dummyArg);
+  }
+  for (const auto &pair : scope_) {
+    const Symbol *symbol{pair.second};
+    if (const auto *useDetails{symbol->detailsIf<UseDetails>()}) {
+      if (useSet_.count(&useDetails->symbol()) > 0) {
+        need_.push_back(symbol);
+      }
+    }
+  }
+}
+
+// Do symbols this one depends on; then add to need_
+void SubprogramSymbolCollector::DoSymbol(const Symbol &symbol) {
+  const auto &scope{symbol.owner()};
+  if (scope != scope_ && scope.kind() != Scope::Kind::DerivedType) {
+    if (scope != scope_.parent()) {
+      useSet_.insert(&symbol);
+    } else if (isInterface_) {
+      imports_.insert(&symbol);
+    }
+    return;
+  }
+  if (!needSet_.insert(&symbol).second) {
+    return;  // already done
+  }
+  std::visit(
+      common::visitors{
+          [this](const ObjectEntityDetails &details) {
+            for (const ShapeSpec &spec : details.shape()) {
+              DoBound(spec.lbound());
+              DoBound(spec.ubound());
+            }
+            if (const Symbol * commonBlock{details.commonBlock()}) {
+              DoSymbol(*commonBlock);
+            }
+          },
+          [this](const CommonBlockDetails &details) {
+            for (const Symbol *object : details.objects()) {
+              DoSymbol(*object);
+            }
+          },
+          [](const auto &) {},
+      },
+      symbol.details());
+  if (!symbol.has<UseDetails>()) {
+    DoType(symbol.GetType());
+  }
+  if (scope.kind() != Scope::Kind::DerivedType) {
+    need_.push_back(&symbol);
+  }
+}
+
+void SubprogramSymbolCollector::DoType(const DeclTypeSpec *type) {
+  if (!type) {
+    return;
+  }
+  switch (type->category()) {
+  case DeclTypeSpec::Numeric:
+  case DeclTypeSpec::Logical: break;  // nothing to do
+  case DeclTypeSpec::Character:
+    DoParamValue(type->characterTypeSpec().length());
+    break;
+  default:
+    if (const DerivedTypeSpec * derived{type->AsDerived()}) {
+      const auto &typeSymbol{derived->typeSymbol()};
+      if (const DerivedTypeSpec * extends{typeSymbol.GetParentTypeSpec()}) {
+        DoSymbol(extends->typeSymbol());
+      }
+      for (const auto pair : derived->parameters()) {
+        DoParamValue(pair.second);
+      }
+      for (const auto pair : *typeSymbol.scope()) {
+        const auto &comp{*pair.second};
+        DoSymbol(comp);
+      }
+      DoSymbol(typeSymbol);
+    }
+  }
+}
+
+void SubprogramSymbolCollector::DoBound(const Bound &bound) {
+  if (const MaybeSubscriptIntExpr & expr{bound.GetExplicit()}) {
+    DoExpr(*expr);
+  }
+}
+void SubprogramSymbolCollector::DoParamValue(const ParamValue &paramValue) {
+  if (const auto &expr{paramValue.GetExplicit()}) {
+    DoExpr(*expr);
+  }
+}
+
 }
index f6e50fd..c05c3ad 100644 (file)
@@ -3414,11 +3414,12 @@ void DeclarationVisitor::Post(const parser::CommonBlockObject &x) {
   const auto &name{std::get<parser::Name>(x.t)};
   auto &symbol{DeclareObjectEntity(name, Attrs{})};
   ClearArraySpec();
-  if (!symbol.has<ObjectEntityDetails>()) {
+  auto *details{symbol.detailsIf<ObjectEntityDetails>()};
+  if (!details) {
     return;  // error was reported
   }
   commonBlockInfo_.curr->get<CommonBlockDetails>().add_object(symbol);
-  if (!IsExplicit(symbol.get<ObjectEntityDetails>().shape())) {
+  if (!IsExplicit(details->shape())) {
     Say(name,
         "The shape of common block object '%s' must be explicit"_err_en_US);
     return;
@@ -3430,6 +3431,7 @@ void DeclarationVisitor::Post(const parser::CommonBlockObject &x) {
         "Previous occurrence of '%s' in a COMMON block"_en_US);
     return;
   }
+  details->set_commonBlock(*commonBlockInfo_.curr);
 }
 
 bool DeclarationVisitor::Pre(const parser::SaveStmt &x) {
index 1356605..f94fd9f 100644 (file)
@@ -151,6 +151,8 @@ public:
   ArraySpec &shape() { return shape_; }
   const ArraySpec &shape() const { return shape_; }
   void set_shape(const ArraySpec &shape);
+  const Symbol *commonBlock() const { return commonBlock_; }    
+  void set_commonBlock(const Symbol &commonBlock) { commonBlock_ = &commonBlock; }
   bool IsArray() const { return !shape_.empty(); }
   bool IsAssumedShape() const {
     return isDummy() && IsArray() && shape_.back().ubound().isDeferred() &&
@@ -172,6 +174,7 @@ public:
 private:
   MaybeExpr init_;
   ArraySpec shape_;
+  const Symbol *commonBlock_{nullptr};  // common block this object is in
   friend std::ostream &operator<<(std::ostream &, const ObjectEntityDetails &);
 };
 
index 2a4ae11..32c4ad6 100644 (file)
@@ -132,6 +132,7 @@ set(MODFILE_TESTS
   modfile20.f90
   modfile21.f90
   modfile22.f90
+  modfile23.f90
 )
 
 set(LABEL_TESTS
index 2b82e39..caab30f 100644 (file)
@@ -38,8 +38,14 @@ end
 
 module m2
 contains
-  type(t) function f3()
+  type(t) function f3(x)
     use m1
+    integer, parameter :: a = 2
+    type t2(b)
+      integer, kind :: b = a
+      integer :: y
+    end type
+    type(t2) :: x
   end
   function f4() result(x)
     implicit complex(x)
@@ -67,8 +73,14 @@ end
 !Expect: m2.mod
 !module m2
 !contains
-!function f3()
-!type(t)::f3
+!function f3(x)
+! use m1,only:t
+! type(t)::f3
+! type::t2(b)
+!  integer(4),kind::b=2_4
+!  integer(4)::y
+! end type
+! type(t2)::x
 !end
 !function f4() result(x)
 !complex(4)::x
index 04a5ce9..194de9b 100644 (file)
@@ -29,12 +29,7 @@ end
 !Expect: m.mod
 !module m
 !  logical(4)::b(1_8:4_8,1_8:4_8)
-!  common/cb2/a,b,c
-!  bind(c)::/cb2/
-!  common//t,w,u,v
 !  real(4)::t
-!  common/cb/x,y,z
-!  bind(c, name=1_"CB")::/cb/
 !  real(4)::x(2_8:10_8)
 !  real(4)::a
 !  real(4)::c
@@ -43,6 +38,11 @@ end
 !  complex(4)::w
 !  real(4)::u
 !  real(4)::v
-!  common/b/cb
 !  real(4)::cb
+!  common/cb2/a,b,c
+!  bind(c)::/cb2/
+!  common//t,w,u,v
+!  common/cb/x,y,z
+!  bind(c, name=1_"CB")::/cb/
+!  common/b/cb
 !end
diff --git a/flang/test/semantics/modfile23.f90 b/flang/test/semantics/modfile23.f90
new file mode 100644 (file)
index 0000000..584ca42
--- /dev/null
@@ -0,0 +1,184 @@
+! Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
+!
+! Licensed under the Apache License, Version 2.0 (the "License");
+! you may not use this file except in compliance with the License.
+! You may obtain a copy of the License at
+!
+!     http://www.apache.org/licenses/LICENSE-2.0
+!
+! Unless required by applicable law or agreed to in writing, software
+! distributed under the License is distributed on an "AS IS" BASIS,
+! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+! See the License for the specific language governing permissions and
+! limitations under the License.
+
+! Test that subprogram interfaces get all of the symbols that they need.
+
+module m1
+  integer(8) :: i
+  type t1
+    sequence
+  end type
+  type t2
+  end type
+end
+!Expect: m1.mod
+!module m1
+! integer(8)::i
+! type::t1
+!  sequence
+! end type
+! type::t2
+! end type
+!end
+
+module m2
+  integer(8) :: k
+contains
+  subroutine s(a, j)
+    use m1
+    integer(8) :: j
+    real :: a(i:j,1:k)  ! need i from m1
+  end
+end
+!Expect: m2.mod
+!module m2
+! integer(8)::k
+!contains
+! subroutine s(a,j)
+!  use m1,only:i
+!  integer(8)::j
+!  real(4)::a(i:j,1_8:k)
+! end
+!end
+
+module m3
+  implicit none
+contains
+  subroutine s(b, n)
+    type t2
+    end type
+    type t4(l)
+      integer, len :: l
+      type(t2) :: x  ! need t2
+    end type
+    integer :: n
+    type(t4(n)) :: b
+  end
+end module
+!Expect: m3.mod
+!module m3
+!contains
+! subroutine s(b,n)
+!  integer(4)::n
+!  type::t2
+!  end type
+!  type::t4(l)
+!   integer(4),len::l
+!   type(t2)::x
+!  end type
+!  type(t4(l=n))::b
+! end
+!end
+
+module m4
+contains
+  subroutine s1(a)
+    use m1
+    common /c/x,n  ! x is needed
+    integer(8) :: n
+    real :: a(n)
+    type(t1) :: x
+  end
+end
+!Expect: m4.mod
+!module m4
+!contains
+! subroutine s1(a)
+!  use m1,only:t1
+!  type(t1)::x
+!  common/c/x,n
+!  integer(8)::n
+!  real(4)::a(1_8:n)
+! end
+!end
+
+module m5
+  type t5
+  end type
+  interface
+    subroutine s(x1,x5)
+      use m1
+      import :: t5
+      type(t1) :: x1
+      type(t5) :: x5
+    end subroutine
+  end interface
+end
+!Expect: m5.mod
+!module m5
+! type::t5
+! end type
+! interface
+!  subroutine s(x1,x5)
+!   use m1,only:t1
+!   import::t5
+!   type(t1)::x1
+!   type(t5)::x5
+!  end
+! end interface
+!end
+
+module m6
+contains
+  subroutine s(x)
+    use m1
+    type, extends(t2) :: t6
+    end type
+    type, extends(t6) :: t7
+    end type
+    type(t7) :: x
+  end
+end
+!Expect: m6.mod
+!module m6
+!contains
+! subroutine s(x)
+!  use m1,only:t2
+!  type,extends(t2)::t6
+!  end type
+!  type,extends(t6)::t7
+!  end type
+!  type(t7)::x
+! end
+!end
+
+module m7
+  type :: t5(l)
+    integer, len :: l
+  end type
+contains
+  subroutine s1(x)
+    use m1
+    type(t5(i)) :: x
+  end subroutine
+  subroutine s2(x)
+    use m1
+    character(i) :: x
+  end subroutine
+end
+!Expect: m7.mod
+!module m7
+! type::t5(l)
+!  integer(4),len::l
+! end type
+!contains
+! subroutine s1(x)
+!  use m1,only:i
+!  type(t5(l=i))::x
+! end
+! subroutine s2(x)
+!  use m1,only:i
+!  character(i,1)::x
+! end
+!end