[flang] Detect circularly defined interfaces of procedures
authorPeter Steinfeld <psteinfeld@nvidia.com>
Mon, 22 Feb 2021 16:59:15 +0000 (08:59 -0800)
committerPeter Steinfeld <psteinfeld@nvidia.com>
Fri, 26 Feb 2021 22:44:35 +0000 (14:44 -0800)
It's possible to define a procedure whose interface depends on a procedure
which has an interface that depends on the original procedure.  Such a circular
definition was causing the compiler to fall into an infinite loop when
resolving the name of the second procedure.  It's also possible to create
circular dependency chains of more than two procedures.

I fixed this by adding the function HasCycle() to the class DeclarationVisitor
and calling it from DeclareProcEntity() to detect procedures with such
circularly defined interfaces.  I marked the associated symbols of such
procedures by calling SetError() on them.  When processing subsequent
procedures, I called HasError() before attempting to analyze their interfaces.
Unfortunately, this did not work.

With help from Tim, we determined that the SymbolSet used to track the
erroneous symbols was instantiated using a "<" operator which was
defined using the name of the procedure.  But the procedure name was
being changed by a call to ReplaceName() between the times that the
calls to SetError() and HasError() were made.  This caused HasError() to
incorrectly report that a symbol was not in the set of erroneous
symbols.  I fixed this by making SymbolSet be an ordered set, which does
not use the "<" operator.

I also added tests that will crash the compiler without this change.
And I fixed the formatting on an error message from a previous update.

Differential Revision: https://reviews.llvm.org/D97201

flang/include/flang/Semantics/symbol.h
flang/lib/Evaluate/characteristics.cpp
flang/lib/Semantics/resolve-names.cpp
flang/test/Semantics/resolve102.f90

index f04b05a..7e83837 100644 (file)
@@ -17,7 +17,7 @@
 #include <array>
 #include <list>
 #include <optional>
-#include <set>
+#include <unordered_set>
 #include <vector>
 
 namespace llvm {
@@ -38,6 +38,12 @@ using SymbolRef = common::Reference<const Symbol>;
 using SymbolVector = std::vector<SymbolRef>;
 using MutableSymbolRef = common::Reference<Symbol>;
 using MutableSymbolVector = std::vector<MutableSymbolRef>;
+struct SymbolHash {
+  std::size_t operator()(SymbolRef symRef) const {
+    return (std::size_t)(&symRef.get());
+  }
+};
+using SymbolSet = std::unordered_set<SymbolRef, SymbolHash>;
 
 // A module or submodule.
 class ModuleDetails {
@@ -594,9 +600,10 @@ public:
 
   bool operator==(const Symbol &that) const { return this == &that; }
   bool operator!=(const Symbol &that) const { return !(*this == that); }
+  // For maps using symbols as keys and sorting symbols.  Collate them by their
+  // position in the cooked character stream
   bool operator<(const Symbol &that) const {
-    // For sets of symbols: collate them by source location
-    return name_.begin() < that.name_.begin();
+    return sortName_ < that.sortName_;
   }
 
   int Rank() const {
@@ -653,6 +660,7 @@ public:
 private:
   const Scope *owner_;
   SourceName name_;
+  const char *sortName_; // used in the "<" operator for sorting symbols
   Attrs attrs_;
   Flags flags_;
   Scope *scope_{nullptr};
@@ -687,6 +695,7 @@ public:
     Symbol &symbol = Get();
     symbol.owner_ = &owner;
     symbol.name_ = name;
+    symbol.sortName_ = name.begin();
     symbol.attrs_ = attrs;
     symbol.details_ = std::move(details);
     return symbol;
@@ -765,7 +774,6 @@ inline bool operator<(SymbolRef x, SymbolRef y) { return *x < *y; }
 inline bool operator<(MutableSymbolRef x, MutableSymbolRef y) {
   return *x < *y;
 }
-using SymbolSet = std::set<SymbolRef>;
 
 } // namespace Fortran::semantics
 
index 1e83709..9b15e3e 100644 (file)
@@ -369,7 +369,7 @@ static std::optional<Procedure> CharacterizeProcedure(
     std::string procsList{GetSeenProcs(seenProcs)};
     context.messages().Say(symbol.name(),
         "Procedure '%s' is recursively defined.  Procedures in the cycle:"
-        " '%s'"_err_en_US,
+        " %s"_err_en_US,
         symbol.name(), procsList);
     return std::nullopt;
   }
index 7f14121..7ace9fc 100644 (file)
@@ -1003,6 +1003,7 @@ private:
     context().SetError(symbol);
     return symbol;
   }
+  bool HasCycle(const Symbol &, const ProcInterface &);
 };
 
 // Resolve construct entities and statement entities.
@@ -2132,7 +2133,7 @@ static bool NeedsType(const Symbol &symbol) {
 
 void ScopeHandler::ApplyImplicitRules(
     Symbol &symbol, bool allowForwardReference) {
-  if (!NeedsType(symbol)) {
+  if (context().HasError(symbol) || !NeedsType(symbol)) {
     return;
   }
   if (const DeclTypeSpec * type{GetImplicitType(symbol)}) {
@@ -2156,10 +2157,8 @@ void ScopeHandler::ApplyImplicitRules(
   if (allowForwardReference && ImplicitlyTypeForwardRef(symbol)) {
     return;
   }
-  if (!context().HasError(symbol)) {
-    Say(symbol.name(), "No explicit type declared for '%s'"_err_en_US);
-    context().SetError(symbol);
-  }
+  Say(symbol.name(), "No explicit type declared for '%s'"_err_en_US);
+  context().SetError(symbol);
 }
 
 // Extension: Allow forward references to scalar integer dummy arguments
@@ -3641,6 +3640,35 @@ Symbol &DeclarationVisitor::DeclareUnknownEntity(
   }
 }
 
+bool DeclarationVisitor::HasCycle(
+    const Symbol &procSymbol, const ProcInterface &interface) {
+  SymbolSet procsInCycle;
+  procsInCycle.insert(procSymbol);
+  const ProcInterface *thisInterface{&interface};
+  bool haveInterface{true};
+  while (haveInterface) {
+    haveInterface = false;
+    if (const Symbol * interfaceSymbol{thisInterface->symbol()}) {
+      if (procsInCycle.count(*interfaceSymbol) > 0) {
+        for (const auto procInCycle : procsInCycle) {
+          Say(procInCycle->name(),
+              "The interface for procedure '%s' is recursively "
+              "defined"_err_en_US,
+              procInCycle->name());
+          context().SetError(*procInCycle);
+        }
+        return true;
+      } else if (const auto *procDetails{
+                     interfaceSymbol->detailsIf<ProcEntityDetails>()}) {
+        haveInterface = true;
+        thisInterface = &procDetails->interface();
+        procsInCycle.insert(*interfaceSymbol);
+      }
+    }
+  }
+  return false;
+}
+
 Symbol &DeclarationVisitor::DeclareProcEntity(
     const parser::Name &name, Attrs attrs, const ProcInterface &interface) {
   Symbol &symbol{DeclareEntity<ProcEntityDetails>(name, attrs)};
@@ -3650,20 +3678,20 @@ Symbol &DeclarationVisitor::DeclareProcEntity(
           "The interface for procedure '%s' has already been "
           "declared"_err_en_US);
       context().SetError(symbol);
-    } else {
-      if (interface.type()) {
+    } else if (HasCycle(symbol, interface)) {
+      return symbol;
+    } else if (interface.type()) {
+      symbol.set(Symbol::Flag::Function);
+    } else if (interface.symbol()) {
+      if (interface.symbol()->test(Symbol::Flag::Function)) {
         symbol.set(Symbol::Flag::Function);
-      } else if (interface.symbol()) {
-        if (interface.symbol()->test(Symbol::Flag::Function)) {
-          symbol.set(Symbol::Flag::Function);
-        } else if (interface.symbol()->test(Symbol::Flag::Subroutine)) {
-          symbol.set(Symbol::Flag::Subroutine);
-        }
+      } else if (interface.symbol()->test(Symbol::Flag::Subroutine)) {
+        symbol.set(Symbol::Flag::Subroutine);
       }
-      details->set_interface(interface);
-      SetBindNameOn(symbol);
-      SetPassNameOn(symbol);
     }
+    details->set_interface(interface);
+    SetBindNameOn(symbol);
+    SetPassNameOn(symbol);
   }
   return symbol;
 }
@@ -5005,7 +5033,7 @@ Symbol *DeclarationVisitor::NoteInterfaceName(const parser::Name &name) {
 
 void DeclarationVisitor::CheckExplicitInterface(const parser::Name &name) {
   if (const Symbol * symbol{name.symbol}) {
-    if (!symbol->HasExplicitInterface()) {
+    if (!context().HasError(*symbol) && !symbol->HasExplicitInterface()) {
       Say(name,
           "'%s' must be an abstract interface or a procedure with "
           "an explicit interface"_err_en_US,
index d6894db..778323b 100644 (file)
@@ -1,7 +1,7 @@
 ! RUN: %S/test_errors.sh %s %t %f18
 
 ! Tests for circularly defined procedures
-!ERROR: Procedure 'sub' is recursively defined.  Procedures in the cycle: ''sub', 'p2''
+!ERROR: Procedure 'sub' is recursively defined.  Procedures in the cycle: 'p2', 'sub'
 subroutine sub(p2)
   PROCEDURE(sub) :: p2
 
@@ -9,7 +9,7 @@ subroutine sub(p2)
 end subroutine
 
 subroutine circular
-  !ERROR: Procedure 'p' is recursively defined.  Procedures in the cycle: ''p', 'sub', 'p2''
+  !ERROR: Procedure 'p' is recursively defined.  Procedures in the cycle: 'p2', 'p', 'sub'
   procedure(sub) :: p
 
   call p(sub)
@@ -21,7 +21,7 @@ subroutine circular
 end subroutine circular
 
 program iface
-  !ERROR: Procedure 'p' is recursively defined.  Procedures in the cycle: ''p', 'sub', 'p2''
+  !ERROR: Procedure 'p' is recursively defined.  Procedures in the cycle: 'p2', 'p', 'sub'
   procedure(sub) :: p
   interface
     subroutine sub(p2)
@@ -38,7 +38,7 @@ Program mutual
   Call p(sub)
 
   contains
-    !ERROR: Procedure 'sub1' is recursively defined.  Procedures in the cycle: ''p', 'sub1', 'arg''
+    !ERROR: Procedure 'sub1' is recursively defined.  Procedures in the cycle: 'arg', 'p', 'sub1'
     Subroutine sub1(arg)
       procedure(sub1) :: arg
     End Subroutine
@@ -54,7 +54,7 @@ Program mutual1
   Call p(sub)
 
   contains
-    !ERROR: Procedure 'sub1' is recursively defined.  Procedures in the cycle: ''p', 'sub1', 'arg', 'sub', 'p2''
+    !ERROR: Procedure 'sub1' is recursively defined.  Procedures in the cycle: 'p2', 'sub', 'arg', 'p', 'sub1'
     Subroutine sub1(arg)
       procedure(sub) :: arg
     End Subroutine
@@ -63,3 +63,24 @@ Program mutual1
       Procedure(sub1) :: p2
     End Subroutine
 End Program
+
+program twoCycle
+  !ERROR: The interface for procedure 'p1' is recursively defined
+  !ERROR: The interface for procedure 'p2' is recursively defined
+  procedure(p1) p2
+  procedure(p2) p1
+  call p1
+  call p2
+end program
+
+program threeCycle
+  !ERROR: The interface for procedure 'p1' is recursively defined
+  !ERROR: The interface for procedure 'p2' is recursively defined
+  procedure(p1) p2
+  !ERROR: The interface for procedure 'p3' is recursively defined
+  procedure(p2) p3
+  procedure(p3) p1
+  call p1
+  call p2
+  call p3
+end program