[flang] Clean up AssignmentContext
authorTim Keith <tkeith@nvidia.com>
Sat, 4 Jan 2020 20:38:35 +0000 (12:38 -0800)
committerTim Keith <tkeith@nvidia.com>
Tue, 14 Jan 2020 21:02:55 +0000 (13:02 -0800)
Eliminate `at_` and use location from `SemanticsContext` instead.

Add and use Analyze functions for `std::optional` and `std::list`.

Original-commit: flang-compiler/f18@e171029ccdb9efe874cad3a3d91adcfa200a6550
Reviewed-on: https://github.com/flang-compiler/f18/pull/928
Tree-same-pre-rewrite: false

flang/lib/semantics/assignment.cc
flang/lib/semantics/assignment.h

index 162c714..a829cbc 100644 (file)
@@ -243,6 +243,7 @@ void CheckPointerAssignment(FoldingContext &context, parser::CharBlock source,
       lhs.attrs.test(characteristics::DummyDataObject::Attr::Contiguous)}
       .Check(rhs);
 }
+
 }
 
 namespace Fortran::semantics {
@@ -297,41 +298,46 @@ class AssignmentContext {
 public:
   explicit AssignmentContext(SemanticsContext &c) : context_{c} {}
   AssignmentContext(const AssignmentContext &c, WhereContext &w)
-    : context_{c.context_}, at_{c.at_}, where_{&w} {}
+    : context_{c.context_}, where_{&w} {}
   AssignmentContext(const AssignmentContext &c, ForallContext &f)
-    : context_{c.context_}, at_{c.at_}, forall_{&f} {}
+    : context_{c.context_}, forall_{&f} {}
 
   bool operator==(const AssignmentContext &x) const { return this == &x; }
 
-  void set_at(parser::CharBlock at) {
-    at_ = at;
-    context_.set_location(at_);
-  }
-
   void Analyze(const parser::AssignmentStmt &);
   void Analyze(const parser::PointerAssignmentStmt &);
   void Analyze(const parser::WhereStmt &);
   void Analyze(const parser::WhereConstruct &);
   void Analyze(const parser::ForallStmt &);
   void Analyze(const parser::ForallConstruct &);
+  void Analyze(const parser::ForallConstructStmt &);
   void Analyze(const parser::ConcurrentHeader &);
 
-  template<typename A> void Analyze(const parser::Statement<A> &stmt) {
-    set_at(stmt.source);
+  template<typename A> void Analyze(const parser::UnlabeledStatement<A> &stmt) {
+    context_.set_location(stmt.source);
     Analyze(stmt.statement);
   }
   template<typename A> void Analyze(const common::Indirection<A> &x) {
     Analyze(x.value());
   }
-  template<typename... As> void Analyze(const std::variant<As...> &u) {
-    std::visit([&](const auto &x) { Analyze(x); }, u);
+  template<typename A> std::enable_if_t<UnionTrait<A>> Analyze(const A &x) {
+    std::visit([&](const auto &y) { Analyze(y); }, x.u);
+  }
+  template<typename A> void Analyze(const std::list<A> &list) {
+    for (const auto &elem : list) {
+      Analyze(elem);
+    }
+  }
+  template<typename A> void Analyze(const std::optional<A> &x) {
+    if (x) {
+      Analyze(*x);
+    }
   }
 
 private:
-  void Analyze(const parser::WhereBodyConstruct &constr) { Analyze(constr.u); }
   void Analyze(const parser::WhereConstruct::MaskedElsewhere &);
+  void Analyze(const parser::MaskedElsewhereStmt &);
   void Analyze(const parser::WhereConstruct::Elsewhere &);
-  void Analyze(const parser::ForallAssignmentStmt &stmt) { Analyze(stmt.u); }
 
   int GetIntegerKind(const std::optional<parser::IntegerTypeSpec> &);
   void CheckForImpureCall(const SomeExpr &);
@@ -347,7 +353,6 @@ private:
   }
 
   SemanticsContext &context_;
-  parser::CharBlock at_;
   WhereContext *where_{nullptr};
   ForallContext *forall_{nullptr};
 };
@@ -439,19 +444,11 @@ void AssignmentContext::Analyze(const parser::WhereConstruct &construct) {
     where.constructName = name->source;
   }
   AssignmentContext nested{*this, where};
-  for (const auto &x :
-      std::get<std::list<parser::WhereBodyConstruct>>(construct.t)) {
-    nested.Analyze(x);
-  }
-  for (const auto &x :
-      std::get<std::list<parser::WhereConstruct::MaskedElsewhere>>(
-          construct.t)) {
-    nested.Analyze(x);
-  }
-  if (const auto &x{std::get<std::optional<parser::WhereConstruct::Elsewhere>>(
-          construct.t)}) {
-    nested.Analyze(*x);
-  }
+  nested.Analyze(std::get<std::list<parser::WhereBodyConstruct>>(construct.t));
+  nested.Analyze(std::get<std::list<parser::WhereConstruct::MaskedElsewhere>>(
+      construct.t));
+  nested.Analyze(
+      std::get<std::optional<parser::WhereConstruct::Elsewhere>>(construct.t));
 }
 
 void AssignmentContext::Analyze(const parser::ForallStmt &stmt) {
@@ -460,11 +457,9 @@ void AssignmentContext::Analyze(const parser::ForallStmt &stmt) {
   AssignmentContext nested{*this, forall};
   nested.Analyze(
       std::get<common::Indirection<parser::ConcurrentHeader>>(stmt.t));
-  const auto &assign{
+  nested.Analyze(
       std::get<parser::UnlabeledStatement<parser::ForallAssignmentStmt>>(
-          stmt.t)};
-  nested.set_at(assign.source);
-  nested.Analyze(assign.statement);
+          stmt.t));
 }
 
 // N.B. Construct name matching is checked during label resolution;
@@ -473,31 +468,30 @@ void AssignmentContext::Analyze(const parser::ForallConstruct &construct) {
   CHECK(!where_);
   ForallContext forall{forall_};
   AssignmentContext nested{*this, forall};
-  const auto &forallStmt{
-      std::get<parser::Statement<parser::ForallConstructStmt>>(construct.t)};
-  nested.set_at(forallStmt.source);
-  nested.Analyze(std::get<common::Indirection<parser::ConcurrentHeader>>(
-      forallStmt.statement.t));
-  for (const auto &body :
-      std::get<std::list<parser::ForallBodyConstruct>>(construct.t)) {
-    nested.Analyze(body.u);
-  }
+  nested.Analyze(
+      std::get<parser::Statement<parser::ForallConstructStmt>>(construct.t));
+  nested.Analyze(std::get<std::list<parser::ForallBodyConstruct>>(construct.t));
+}
+
+void AssignmentContext::Analyze(const parser::ForallConstructStmt &stmt) {
+  Analyze(std::get<common::Indirection<parser::ConcurrentHeader>>(stmt.t));
 }
 
 void AssignmentContext::Analyze(
     const parser::WhereConstruct::MaskedElsewhere &elsewhere) {
   CHECK(where_);
-  const auto &elsewhereStmt{
-      std::get<parser::Statement<parser::MaskedElsewhereStmt>>(elsewhere.t)};
-  set_at(elsewhereStmt.source);
-  MaskExpr mask{
-      GetMask(std::get<parser::LogicalExpr>(elsewhereStmt.statement.t))};
+  Analyze(
+      std::get<parser::Statement<parser::MaskedElsewhereStmt>>(elsewhere.t));
+  Analyze(std::get<std::list<parser::WhereBodyConstruct>>(elsewhere.t));
+}
+
+void AssignmentContext::Analyze(const parser::MaskedElsewhereStmt &elsewhere) {
+  MaskExpr mask{GetMask(std::get<parser::LogicalExpr>(elsewhere.t))};
   MaskExpr copyCumulative{where_->cumulativeMaskExpr};
   MaskExpr notOldMask{evaluate::LogicalNegation(std::move(copyCumulative))};
   if (!evaluate::AreConformable(notOldMask, mask)) {
-    Say(elsewhereStmt.source,
-        "mask of ELSEWHERE statement is not conformable with "
-        "the prior mask(s) in its WHERE construct"_err_en_US);
+    context_.Say("mask of ELSEWHERE statement is not conformable with "
+                 "the prior mask(s) in its WHERE construct"_err_en_US);
   }
   MaskExpr copyMask{mask};
   where_->cumulativeMaskExpr =
@@ -508,13 +502,8 @@ void AssignmentContext::Analyze(
   if (where_->outer &&
       !evaluate::AreConformable(
           where_->outer->thisMaskExpr, where_->thisMaskExpr)) {
-    Say(elsewhereStmt.source,
-        "effective mask of ELSEWHERE statement is not conformable "
-        "with the mask of the surrounding WHERE construct"_err_en_US);
-  }
-  for (const auto &x :
-      std::get<std::list<parser::WhereBodyConstruct>>(elsewhere.t)) {
-    Analyze(x);
+    context_.Say("effective mask of ELSEWHERE statement is not conformable "
+                 "with the mask of the surrounding WHERE construct"_err_en_US);
   }
 }
 
@@ -522,10 +511,7 @@ void AssignmentContext::Analyze(
     const parser::WhereConstruct::Elsewhere &elsewhere) {
   MaskExpr copyCumulative{DEREF(where_).cumulativeMaskExpr};
   where_->thisMaskExpr = evaluate::LogicalNegation(std::move(copyCumulative));
-  for (const auto &x :
-      std::get<std::list<parser::WhereBodyConstruct>>(elsewhere.t)) {
-    Analyze(x);
-  }
+  Analyze(std::get<std::list<parser::WhereBodyConstruct>>(elsewhere.t));
 }
 
 void AssignmentContext::Analyze(const parser::ConcurrentHeader &header) {
@@ -556,7 +542,7 @@ int AssignmentContext::GetIntegerKind(
   if (auto value{evaluate::ToInt64(kind)}) {
     return static_cast<int>(*value);
   } else {
-    Say(at_, "Kind of INTEGER type must be a constant value"_err_en_US);
+    context_.Say("Kind of INTEGER type must be a constant value"_err_en_US);
     return context_.GetDefaultKind(TypeCategory::Integer);
   }
 }
@@ -565,7 +551,7 @@ void AssignmentContext::CheckForImpureCall(const SomeExpr &expr) {
   if (forall_) {
     const auto &intrinsics{context_.foldingContext().intrinsics()};
     if (auto bad{FindImpureCall(intrinsics, expr)}) {
-      Say(at_,
+      context_.Say(
           "Impure procedure '%s' may not be referenced in a FORALL"_err_en_US,
           *bad);
     }
@@ -646,7 +632,8 @@ void AssignmentContext::CheckForPureContext(const SomeExpr &lhs,
     const SomeExpr &rhs, parser::CharBlock source, bool isPointerAssignment) {
   const Scope &scope{context_.FindScope(source)};
   if (const Scope * pure{FindPureProcedureContaining(scope)}) {
-    parser::ContextualMessages messages{at_, &context_.messages()};
+    parser::ContextualMessages messages{
+        context_.location().value(), &context_.messages()};
     if (evaluate::ExtractCoarrayRef(lhs)) {
       messages.Say(
           "A pure subprogram may not define a coindexed object"_err_en_US);
@@ -675,7 +662,7 @@ void AssignmentContext::CheckForPureContext(const SomeExpr &lhs,
         // C1596 checks for polymorphic deallocation in a pure subprogram
         // due to automatic reallocation on assignment
         if (type->IsPolymorphic()) {
-          Say(at_,
+          context_.Say(
               "Deallocation of polymorphic object is not permitted in a pure subprogram"_err_en_US);
         }
         if (const DerivedTypeSpec * derived{GetDerivedTypeSpec(type)}) {
@@ -714,29 +701,24 @@ AssignmentChecker::~AssignmentChecker() {}
 AssignmentChecker::AssignmentChecker(SemanticsContext &context)
   : context_{new AssignmentContext{context}} {}
 void AssignmentChecker::Enter(const parser::AssignmentStmt &x) {
-  context_.value().set_at(at_);
   context_.value().Analyze(x);
 }
 void AssignmentChecker::Enter(const parser::PointerAssignmentStmt &x) {
-  context_.value().set_at(at_);
   context_.value().Analyze(x);
 }
 void AssignmentChecker::Enter(const parser::WhereStmt &x) {
-  context_.value().set_at(at_);
   context_.value().Analyze(x);
 }
 void AssignmentChecker::Enter(const parser::WhereConstruct &x) {
-  context_.value().set_at(at_);
   context_.value().Analyze(x);
 }
 void AssignmentChecker::Enter(const parser::ForallStmt &x) {
-  context_.value().set_at(at_);
   context_.value().Analyze(x);
 }
 void AssignmentChecker::Enter(const parser::ForallConstruct &x) {
-  context_.value().set_at(at_);
   context_.value().Analyze(x);
 }
+
 }
 template class Fortran::common::Indirection<
     Fortran::semantics::AssignmentContext>;
index c1c36fb..bc8a16a 100644 (file)
@@ -59,9 +59,6 @@ class AssignmentChecker : public virtual BaseChecker {
 public:
   explicit AssignmentChecker(SemanticsContext &);
   ~AssignmentChecker();
-  template<typename A> void Enter(const parser::Statement<A> &stmt) {
-    at_ = stmt.source;
-  }
   void Enter(const parser::AssignmentStmt &);
   void Enter(const parser::PointerAssignmentStmt &);
   void Enter(const parser::WhereStmt &);
@@ -71,7 +68,6 @@ public:
 
 private:
   common::Indirection<AssignmentContext> context_;
-  parser::CharBlock at_;
 };
 
 // Semantic analysis of an assignment statement or WHERE/FORALL construct.