[flang] Proper PDT handling
authorpeter klausler <pklausler@nvidia.com>
Thu, 8 Aug 2019 00:10:43 +0000 (17:10 -0700)
committerpeter klausler <pklausler@nvidia.com>
Fri, 9 Aug 2019 16:41:51 +0000 (09:41 -0700)
Original-commit: flang-compiler/f18@32256daa15d663cd6b24bccd15727cd1e4ddefe6
Reviewed-on: https://github.com/flang-compiler/f18/pull/638
Tree-same-pre-rewrite: false

flang/lib/semantics/resolve-names.cc

index 7d68c33..e27d0b8 100644 (file)
@@ -141,7 +141,7 @@ public:
   SemanticsContext &context() const { return *context_; }
   void set_context(SemanticsContext &);
   evaluate::FoldingContext &GetFoldingContext() const {
-    return context_->foldingContext();
+    return DEREF(context_).foldingContext();
   }
 
   // Make a placeholder symbol for a Name that otherwise wouldn't have one.
@@ -763,8 +763,7 @@ public:
       const parser::Name &, const parser::InitialDataTarget &);
   void PointerInitialization(
       const parser::Name &, const parser::ProcPointerInit &);
-  void CheckBindings(
-      const Scope &, const parser::TypeBoundProcedureStmt::WithoutInterface &);
+  void CheckBindings(const parser::TypeBoundProcedureStmt::WithoutInterface &);
 
 protected:
   bool BeginDecl();
@@ -1103,7 +1102,8 @@ private:
   void AddSubpNames(const ProgramTree &);
   bool BeginScope(const ProgramTree &);
   void FinishSpecificationParts(const ProgramTree &);
-  void FinishDerivedType(Scope &);
+  void FinishDerivedTypeDefinition(Scope &);
+  void FinishDerivedTypeInstantiation(Scope &);
   void SetPassArg(const Symbol &, const Symbol *, WithPassArg &);
   void ResolveExecutionParts(const ProgramTree &);
 };
@@ -3426,16 +3426,14 @@ void DeclarationVisitor::Post(
       SetPassNameOn(*s);
     }
   }
-  if (currScope().IsParameterizedDerivedType()) {
-    CheckBindings(currScope(), x);
-  }
 }
 
-void DeclarationVisitor::CheckBindings(const Scope &typeScope,
+void DeclarationVisitor::CheckBindings(
     const parser::TypeBoundProcedureStmt::WithoutInterface &tbps) {
+  CHECK(currScope().kind() == Scope::Kind::DerivedType);
   for (auto &declaration : tbps.declarations) {
     auto &bindingName{std::get<parser::Name>(declaration.t)};
-    if (Symbol * binding{FindInScope(typeScope, bindingName)}) {
+    if (Symbol * binding{FindInScope(currScope(), bindingName)}) {
       if (auto *details{binding->detailsIf<ProcBindingDetails>()}) {
         const Symbol &procedure{details->symbol().GetUltimate()};
         if (!CanBeTypeBoundProc(procedure)) {
@@ -3713,8 +3711,8 @@ void DeclarationVisitor::CheckSaveStmts() {
               " a COMMON statement"_err_en_US);
         } else {  // C1108
           Say(name,
-             "SAVE statement in BLOCK construct may not contain a"
-             " common block name '%s'"_err_en_US);
+              "SAVE statement in BLOCK construct may not contain a"
+              " common block name '%s'"_err_en_US);
         }
       } else {
         for (Symbol *object : symbol->get<CommonBlockDetails>().objects()) {
@@ -4928,8 +4926,7 @@ void DeclarationVisitor::Initialization(const parser::Name &name,
   if (name.symbol == nullptr) {
     return;
   }
-  if (std::holds_alternative<parser::InitialDataTarget>(init.u) &&
-      !currScope().IsParameterizedDerivedType()) {
+  if (std::holds_alternative<parser::InitialDataTarget>(init.u)) {
     // Defer analysis to the end of the specification parts so that forward
     // references work better.
     return;
@@ -5533,8 +5530,8 @@ bool ResolveNamesVisitor::BeginScope(const ProgramTree &node) {
 }
 
 // The processing of initializers of pointers is deferred until all of
-// the pertinent specification parts have been visited.  This deferral
-// enables the use of forward references in those initializers.
+// the pertinent specification parts have been visited.  This deferred
+// processing enables the use of forward references in those initializers.
 class DeferredPointerInitializationVisitor {
 public:
   explicit DeferredPointerInitializationVisitor(ResolveNamesVisitor &resolver)
@@ -5549,16 +5546,21 @@ public:
 
   void Post(const parser::DerivedTypeStmt &x) {
     auto &name{std::get<parser::Name>(x.t)};
-    if (const Symbol * symbol{name.symbol}) {
-      if (const Scope * scope{symbol->scope()}) {
-        if (scope->kind() == Scope::Kind::DerivedType &&
-            !scope->IsParameterizedDerivedType()) {
-          derivedTypeScope_ = scope;
+    if (Symbol * symbol{name.symbol}) {
+      if (Scope * scope{symbol->scope()}) {
+        if (scope->kind() == Scope::Kind::DerivedType) {
+          resolver_.PushScope(*scope);
+          pushedScope_ = true;
         }
       }
     }
   }
-  void Post(const parser::EndTypeStmt &) { derivedTypeScope_ = nullptr; }
+  void Post(const parser::EndTypeStmt &) {
+    if (pushedScope_) {
+      resolver_.PopScope();
+      pushedScope_ = false;
+    }
+  }
 
   bool Pre(const parser::EntityDecl &decl) {
     Init(std::get<parser::Name>(decl.t),
@@ -5578,8 +5580,8 @@ public:
     return false;
   }
   void Post(const parser::TypeBoundProcedureStmt::WithoutInterface &tbps) {
-    if (derivedTypeScope_ != nullptr) {
-      resolver_.CheckBindings(*derivedTypeScope_, tbps);
+    if (pushedScope_) {
+      resolver_.CheckBindings(tbps);
     }
   }
 
@@ -5595,7 +5597,7 @@ private:
   }
 
   ResolveNamesVisitor &resolver_;
-  const Scope *derivedTypeScope_{nullptr};
+  bool pushedScope_{false};
 };
 
 // Perform checks that need to happen after all of the specification parts
@@ -5616,7 +5618,12 @@ void ResolveNamesVisitor::FinishSpecificationParts(const ProgramTree &node) {
   }
   for (Scope &childScope : currScope().children()) {
     if (childScope.IsDerivedType() && childScope.symbol()) {
-      FinishDerivedType(childScope);
+      FinishDerivedTypeDefinition(childScope);
+    }
+  }
+  for (Scope &childScope : currScope().children()) {
+    if (childScope.IsDerivedType() && !childScope.symbol()) {
+      FinishDerivedTypeInstantiation(childScope);
     }
   }
   for (const auto &child : node.children()) {
@@ -5635,8 +5642,8 @@ static int FindIndexOfName(
 }
 
 // Perform checks on procedure bindings of this type
-void ResolveNamesVisitor::FinishDerivedType(Scope &scope) {
-  CHECK(scope.IsDerivedType());
+void ResolveNamesVisitor::FinishDerivedTypeDefinition(Scope &scope) {
+  CHECK(scope.IsDerivedType() && scope.symbol());
   for (auto &pair : scope) {
     Symbol &comp{*pair.second};
     std::visit(
@@ -5661,6 +5668,53 @@ void ResolveNamesVisitor::FinishDerivedType(Scope &scope) {
   }
 }
 
+// Fold object pointer initializer designators with the actual
+// type parameter values of a particular instantiation.
+void ResolveNamesVisitor::FinishDerivedTypeInstantiation(Scope &scope) {
+  CHECK(scope.IsDerivedType() && !scope.symbol());
+  if (const DerivedTypeSpec * spec{scope.derivedTypeSpec()}) {
+    const Symbol &origTypeSymbol{spec->typeSymbol()};
+    if (const Scope * origTypeScope{origTypeSymbol.scope()}) {
+      CHECK(origTypeScope->IsDerivedType() &&
+          origTypeScope->symbol() == &origTypeSymbol);
+      auto &foldingContext{GetFoldingContext()};
+      auto restorer{foldingContext.WithPDTInstance(*spec)};
+      for (auto &pair : scope) {
+        Symbol &comp{*pair.second};
+        const Symbol &origComp{DEREF(FindInScope(*origTypeScope, comp.name()))};
+        std::visit(
+            common::visitors{
+                [&](ObjectEntityDetails &x) {
+                  if (IsPointer(comp)) {
+                    auto origDetails{origComp.get<ObjectEntityDetails>()};
+                    if (const MaybeExpr & init{origDetails.init()}) {
+                      SomeExpr newInit{*init};
+                      MaybeExpr folded{
+                          evaluate::Fold(foldingContext, std::move(newInit))};
+                      x.set_init(std::move(folded));
+                    }
+                  }
+                },
+                [&](ProcEntityDetails &x) {
+                  auto origDetails{origComp.get<ProcEntityDetails>()};
+                  if (auto pi{origDetails.passIndex()}) {
+                    x.set_passIndex(*pi);
+                  }
+                },
+                [&](ProcBindingDetails &x) {
+                  auto origDetails{origComp.get<ProcBindingDetails>()};
+                  if (auto pi{origDetails.passIndex()}) {
+                    x.set_passIndex(*pi);
+                  }
+                },
+                [](auto &) {},
+            },
+            comp.details());
+      }
+    }
+  }
+}
+
 // Check C760, constraints on the passed-object dummy argument
 // If they all pass, set the passIndex in details.
 void ResolveNamesVisitor::SetPassArg(