[flang] Add more checks on WHERE and FORALL
authorTim Keith <tkeith@nvidia.com>
Thu, 20 Feb 2020 22:54:46 +0000 (14:54 -0800)
committerTim Keith <tkeith@nvidia.com>
Fri, 21 Feb 2020 23:47:01 +0000 (15:47 -0800)
Check that masks and LHS of assignments in WHERE statements and
constructs have consistent shapes. They must all have the same rank and
any extents that are compile-time constants must match.

Emit a warning for assignments in FORALL statements and constructs where
the LHS does not reference each of the index variables.

Original-commit: flang-compiler/f18@8b04dbebcf5621cfd571a8c45878cebcd1a1bfb0
Reviewed-on: https://github.com/flang-compiler/f18/pull/1009

flang/include/flang/semantics/semantics.h
flang/lib/semantics/assignment.cpp
flang/lib/semantics/assignment.h
flang/lib/semantics/check-do-forall.cpp
flang/lib/semantics/semantics.cpp
flang/test/semantics/assign01.f90
flang/test/semantics/forall01.f90

index b13f617..0e64e42 100644 (file)
@@ -159,10 +159,13 @@ public:
   void CheckIndexVarRedefine(const parser::Name &);
   void ActivateIndexVar(const parser::Name &, IndexVarKind);
   void DeactivateIndexVar(const parser::Name &);
+  SymbolVector GetIndexVars(IndexVarKind);
 
 private:
   void CheckIndexVarRedefine(
       const parser::CharBlock &, const Symbol &, parser::MessageFixedText &&);
+  bool CheckError(bool);
+
   const common::IntrinsicTypeDefaultKinds &defaultKinds_;
   const common::LanguageFeatureControl languageFeatures_;
   parser::AllSources &allSources_;
@@ -176,8 +179,6 @@ private:
   Scope globalScope_;
   parser::Messages messages_;
   evaluate::FoldingContext foldingContext_;
-
-  bool CheckError(bool);
   ConstructStack constructStack_;
   struct IndexVarInfo {
     parser::CharBlock location;
index aee651e..b286f65 100644 (file)
@@ -29,194 +29,62 @@ using namespace Fortran::parser::literals;
 
 namespace Fortran::semantics {
 
-using ControlExpr = evaluate::Expr<evaluate::SubscriptInteger>;
-using MaskExpr = evaluate::Expr<evaluate::LogicalResult>;
-
-// The context tracks some number of active FORALL statements/constructs
-// and some number of active WHERE statements/constructs.  WHERE can nest
-// in FORALL but not vice versa.  Pointer assignments are allowed in
-// FORALL but not in WHERE.  These constraints are manifest in the grammar
-// and don't need to be rechecked here, since errors cannot appear in the
-// parse tree.
-struct Control {
-  Symbol *name;
-  ControlExpr lower, upper, step;
-};
-
-struct ForallContext {
-  explicit ForallContext(const ForallContext *that) : outer{that} {}
-
-  const ForallContext *outer{nullptr};
-  std::optional<parser::CharBlock> constructName;
-  std::vector<Control> control;
-  std::optional<MaskExpr> maskExpr;
-  std::set<parser::CharBlock> activeNames;
-};
-
-struct WhereContext {
-  WhereContext(MaskExpr &&x, const WhereContext *o, const ForallContext *f)
-    : outer{o}, forall{f}, thisMaskExpr{std::move(x)} {}
-  const WhereContext *outer{nullptr};
-  const ForallContext *forall{nullptr};  // innermost enclosing FORALL
-  std::optional<parser::CharBlock> constructName;
-  MaskExpr thisMaskExpr;  // independent of outer WHERE, if any
-  MaskExpr cumulativeMaskExpr{thisMaskExpr};
-};
-
 class AssignmentContext {
 public:
-  explicit AssignmentContext(SemanticsContext &c) : context_{c} {}
-  AssignmentContext(const AssignmentContext &c, WhereContext &w)
-    : context_{c.context_}, where_{&w} {}
-  AssignmentContext(const AssignmentContext &c, ForallContext &f)
-    : context_{c.context_}, forall_{&f} {}
-
+  explicit AssignmentContext(SemanticsContext &context) : context_{context} {}
+  AssignmentContext(AssignmentContext &&) = default;
+  AssignmentContext(const AssignmentContext &) = delete;
   bool operator==(const AssignmentContext &x) const { return this == &x; }
 
+  template<typename A> void PushWhereContext(const A &);
+  void PopWhereContext();
   void Analyze(const parser::AssignmentStmt &);
   void Analyze(const parser::PointerAssignmentStmt &);
-  void Analyze(const parser::WhereStmt &);
-  void Analyze(const parser::WhereConstruct &);
-  void Analyze(const parser::ForallConstruct &);
-
-  template<typename A> void Analyze(const parser::UnlabeledStatement<A> &stmt) {
-    context_.set_location(stmt.source);
-    Analyze(stmt.statement);
-  }
-  template<typename A> void Analyze(const common::Indirection<A> &x) {
-    Analyze(x.value());
-  }
-  template<typename A> std::enable_if_t<UnionTrait<A>> Analyze(const A &x) {
-    std::visit([&](const auto &y) { Analyze(y); }, x.u);
-  }
-  template<typename A> void Analyze(const std::list<A> &list) {
-    for (const auto &elem : list) {
-      Analyze(elem);
-    }
-  }
-  template<typename A> void Analyze(const std::optional<A> &x) {
-    if (x) {
-      Analyze(*x);
-    }
-  }
+  void Analyze(const parser::ConcurrentControl &);
 
 private:
-  void Analyze(const parser::WhereConstruct::MaskedElsewhere &);
-  void Analyze(const parser::MaskedElsewhereStmt &);
-  void Analyze(const parser::WhereConstruct::Elsewhere &);
-
   void CheckForPureContext(const SomeExpr &lhs, const SomeExpr &rhs,
       parser::CharBlock rhsSource, bool isPointerAssignment);
-
-  MaskExpr GetMask(const parser::LogicalExpr &, bool defaultValue = true);
-
+  void CheckShape(parser::CharBlock, const SomeExpr *);
   template<typename... A>
   parser::Message *Say(parser::CharBlock at, A &&... args) {
     return &context_.Say(at, std::forward<A>(args)...);
   }
+  evaluate::FoldingContext &foldingContext() {
+    return context_.foldingContext();
+  }
 
   SemanticsContext &context_;
-  WhereContext *where_{nullptr};
-  ForallContext *forall_{nullptr};
+  int whereDepth_{0};  // number of WHEREs currently nested in
+  // shape of masks in LHS of assignments in current WHERE:
+  std::vector<std::optional<std::int64_t>> whereExtents_;
 };
 
 void AssignmentContext::Analyze(const parser::AssignmentStmt &stmt) {
-  // Assignment statement analysis is in expression.cpp where user-defined
-  // assignments can be recognized and replaced.
   if (const evaluate::Assignment * assignment{GetAssignment(stmt)}) {
-    if (forall_) {
-      // TODO: Warn if some name in forall_->activeNames or its outer
-      // contexts does not appear on LHS
+    const SomeExpr &lhs{assignment->lhs};
+    const SomeExpr &rhs{assignment->rhs};
+    auto lhsLoc{std::get<parser::Variable>(stmt.t).GetSource()};
+    auto rhsLoc{std::get<parser::Expr>(stmt.t).source};
+    if (whereDepth_ > 0) {
+      CheckShape(lhsLoc, &lhs);
     }
-    CheckForPureContext(assignment->lhs, assignment->rhs,
-        std::get<parser::Expr>(stmt.t).source, false /* not => */);
+    CheckForPureContext(lhs, rhs, rhsLoc, false);
   }
-  // TODO: Fortran 2003 ALLOCATABLE assignment semantics (automatic
-  // (re)allocation of LHS array when unallocated or nonconformable)
 }
 
 void AssignmentContext::Analyze(const parser::PointerAssignmentStmt &stmt) {
-  CHECK(!where_);
-  const evaluate::Assignment *assignment{GetAssignment(stmt)};
-  if (!assignment) {
-    return;
-  }
-  const SomeExpr &lhs{assignment->lhs};
-  const SomeExpr &rhs{assignment->rhs};
-  if (forall_) {
-    // TODO: Warn if some name in forall_->activeNames or its outer
-    // contexts does not appear on LHS
-  }
-  CheckForPureContext(lhs, rhs, std::get<parser::Expr>(stmt.t).source,
-      true /* isPointerAssignment */);
-  auto restorer{context_.foldingContext().messages().SetLocation(
-      context_.location().value())};
-  CheckPointerAssignment(context_.foldingContext(), *assignment);
-}
-
-void AssignmentContext::Analyze(const parser::WhereStmt &stmt) {
-  WhereContext where{
-      GetMask(std::get<parser::LogicalExpr>(stmt.t)), where_, forall_};
-  AssignmentContext nested{*this, where};
-  nested.Analyze(std::get<parser::AssignmentStmt>(stmt.t));
-}
-
-// N.B. Construct name matching is checked during label resolution.
-void AssignmentContext::Analyze(const parser::WhereConstruct &construct) {
-  const auto &whereStmt{
-      std::get<parser::Statement<parser::WhereConstructStmt>>(construct.t)};
-  WhereContext where{
-      GetMask(std::get<parser::LogicalExpr>(whereStmt.statement.t)), where_,
-      forall_};
-  if (const auto &name{
-          std::get<std::optional<parser::Name>>(whereStmt.statement.t)}) {
-    where.constructName = name->source;
-  }
-  AssignmentContext nested{*this, where};
-  nested.Analyze(std::get<std::list<parser::WhereBodyConstruct>>(construct.t));
-  nested.Analyze(std::get<std::list<parser::WhereConstruct::MaskedElsewhere>>(
-      construct.t));
-  nested.Analyze(
-      std::get<std::optional<parser::WhereConstruct::Elsewhere>>(construct.t));
-}
-
-void AssignmentContext::Analyze(
-    const parser::WhereConstruct::MaskedElsewhere &elsewhere) {
-  CHECK(where_);
-  Analyze(
-      std::get<parser::Statement<parser::MaskedElsewhereStmt>>(elsewhere.t));
-  Analyze(std::get<std::list<parser::WhereBodyConstruct>>(elsewhere.t));
-}
-
-void AssignmentContext::Analyze(const parser::MaskedElsewhereStmt &elsewhere) {
-  MaskExpr mask{GetMask(std::get<parser::LogicalExpr>(elsewhere.t))};
-  MaskExpr copyCumulative{where_->cumulativeMaskExpr};
-  MaskExpr notOldMask{evaluate::LogicalNegation(std::move(copyCumulative))};
-  if (!evaluate::AreConformable(notOldMask, mask)) {
-    context_.Say("mask of ELSEWHERE statement is not conformable with "
-                 "the prior mask(s) in its WHERE construct"_err_en_US);
-  }
-  MaskExpr copyMask{mask};
-  where_->cumulativeMaskExpr =
-      evaluate::BinaryLogicalOperation(evaluate::LogicalOperator::Or,
-          std::move(where_->cumulativeMaskExpr), std::move(copyMask));
-  where_->thisMaskExpr = evaluate::BinaryLogicalOperation(
-      evaluate::LogicalOperator::And, std::move(notOldMask), std::move(mask));
-  if (where_->outer &&
-      !evaluate::AreConformable(
-          where_->outer->thisMaskExpr, where_->thisMaskExpr)) {
-    context_.Say("effective mask of ELSEWHERE statement is not conformable "
-                 "with the mask of the surrounding WHERE construct"_err_en_US);
+  CHECK(whereDepth_ == 0);
+  if (const evaluate::Assignment * assignment{GetAssignment(stmt)}) {
+    const SomeExpr &lhs{assignment->lhs};
+    const SomeExpr &rhs{assignment->rhs};
+    CheckForPureContext(lhs, rhs, std::get<parser::Expr>(stmt.t).source, true);
+    auto restorer{
+        foldingContext().messages().SetLocation(context_.location().value())};
+    CheckPointerAssignment(foldingContext(), *assignment);
   }
 }
 
-void AssignmentContext::Analyze(
-    const parser::WhereConstruct::Elsewhere &elsewhere) {
-  MaskExpr copyCumulative{DEREF(where_).cumulativeMaskExpr};
-  where_->thisMaskExpr = evaluate::LogicalNegation(std::move(copyCumulative));
-  Analyze(std::get<std::list<parser::WhereBodyConstruct>>(elsewhere.t));
-}
-
 // C1594 checks
 static bool IsPointerDummyOfPureFunction(const Symbol &x) {
   return IsPointerDummy(x) && FindPureProcedureContaining(x.owner()) &&
@@ -333,14 +201,45 @@ void AssignmentContext::CheckForPureContext(const SomeExpr &lhs,
   }
 }
 
-MaskExpr AssignmentContext::GetMask(
-    const parser::LogicalExpr &logicalExpr, bool defaultValue) {
-  MaskExpr mask{defaultValue};
-  if (const SomeExpr * expr{GetExpr(logicalExpr)}) {
-    auto *logical{std::get_if<evaluate::Expr<evaluate::SomeLogical>>(&expr->u)};
-    mask = evaluate::ConvertTo(mask, common::Clone(DEREF(logical)));
+// 10.2.3.1(2) The masks and LHS of assignments must all have the same shape
+void AssignmentContext::CheckShape(parser::CharBlock at, const SomeExpr *expr) {
+  if (auto shape{evaluate::GetShape(foldingContext(), expr)}) {
+    std::size_t size{shape->size()};
+    if (whereDepth_ == 0) {
+      whereExtents_.resize(size);
+    } else if (whereExtents_.size() != size) {
+      Say(at,
+          "Must have rank %zd to match prior mask or assignment of"
+          " WHERE construct"_err_en_US,
+          whereExtents_.size());
+      return;
+    }
+    for (std::size_t i{0}; i < size; ++i) {
+      if (std::optional<std::int64_t> extent{evaluate::ToInt64((*shape)[i])}) {
+        if (!whereExtents_[i]) {
+          whereExtents_[i] = *extent;
+        } else if (*whereExtents_[i] != *extent) {
+          Say(at,
+              "Dimension %d must have extent %jd to match prior mask or"
+              " assignment of WHERE construct"_err_en_US,
+              i + 1, static_cast<std::intmax_t>(*whereExtents_[i]));
+        }
+      }
+    }
+  }
+}
+
+template<typename A> void AssignmentContext::PushWhereContext(const A &x) {
+  const auto &expr{std::get<parser::LogicalExpr>(x.t)};
+  CheckShape(expr.thing.value().source, GetExpr(expr));
+  ++whereDepth_;
+}
+
+void AssignmentContext::PopWhereContext() {
+  --whereDepth_;
+  if (whereDepth_ == 0) {
+    whereExtents_.clear();
   }
-  return mask;
 }
 
 AssignmentChecker::~AssignmentChecker() {}
@@ -354,10 +253,22 @@ void AssignmentChecker::Enter(const parser::PointerAssignmentStmt &x) {
   context_.value().Analyze(x);
 }
 void AssignmentChecker::Enter(const parser::WhereStmt &x) {
-  context_.value().Analyze(x);
+  context_.value().PushWhereContext(x);
 }
-void AssignmentChecker::Enter(const parser::WhereConstruct &x) {
-  context_.value().Analyze(x);
+void AssignmentChecker::Leave(const parser::WhereStmt &) {
+  context_.value().PopWhereContext();
+}
+void AssignmentChecker::Enter(const parser::WhereConstructStmt &x) {
+  context_.value().PushWhereContext(x);
+}
+void AssignmentChecker::Leave(const parser::EndWhereStmt &) {
+  context_.value().PopWhereContext();
+}
+void AssignmentChecker::Enter(const parser::MaskedElsewhereStmt &x) {
+  context_.value().PushWhereContext(x);
+}
+void AssignmentChecker::Leave(const parser::MaskedElsewhereStmt &) {
+  context_.value().PopWhereContext();
 }
 
 }
index d86bd45..51b7c17 100644 (file)
 namespace Fortran::parser {
 class ContextualMessages;
 struct AssignmentStmt;
+struct EndWhereStmt;
+struct MaskedElsewhereStmt;
 struct PointerAssignmentStmt;
+struct WhereConstructStmt;
 struct WhereStmt;
-struct WhereConstruct;
 }
 
 namespace Fortran::semantics {
@@ -41,7 +43,11 @@ public:
   void Enter(const parser::AssignmentStmt &);
   void Enter(const parser::PointerAssignmentStmt &);
   void Enter(const parser::WhereStmt &);
-  void Enter(const parser::WhereConstruct &);
+  void Leave(const parser::WhereStmt &);
+  void Enter(const parser::WhereConstructStmt &);
+  void Leave(const parser::EndWhereStmt &);
+  void Enter(const parser::MaskedElsewhereStmt &);
+  void Leave(const parser::MaskedElsewhereStmt &);
 
 private:
   common::Indirection<AssignmentContext> context_;
index 1b1abfd..071a873 100644 (file)
@@ -452,6 +452,7 @@ public:
         common::visitors{[&](const auto &x) { return GetAssignment(x); }},
         stmt.u)};
     if (assignment) {
+      CheckForallIndexesUsed(*assignment);
       CheckForImpureCall(assignment->lhs);
       CheckForImpureCall(assignment->rhs);
       if (const auto *proc{
@@ -753,6 +754,38 @@ private:
     }
   }
 
+  // Each index should be used on the LHS of each assignment in a FORALL
+  void CheckForallIndexesUsed(const evaluate::Assignment &assignment) {
+    SymbolVector indexVars{context_.GetIndexVars(IndexVarKind::FORALL)};
+    if (!indexVars.empty()) {
+      SymbolSet symbols{evaluate::CollectSymbols(assignment.lhs)};
+      std::visit(
+          common::visitors{
+              [&](const evaluate::Assignment::BoundsSpec &spec) {
+                for (const auto &bound : spec) {
+                  symbols.merge(evaluate::CollectSymbols(bound));
+                }
+              },
+              [&](const evaluate::Assignment::BoundsRemapping &remapping) {
+                for (const auto &bounds : remapping) {
+                  symbols.merge(evaluate::CollectSymbols(bounds.first));
+                  symbols.merge(evaluate::CollectSymbols(bounds.second));
+                }
+              },
+              [](const auto &) {},
+          },
+          assignment.u);
+      for (const Symbol &index : indexVars) {
+        if (symbols.count(index) == 0) {
+          context_.Say(
+              "Warning: FORALL index variable '%s' not used on left-hand side"
+              " of assignment"_en_US,
+              index.name());
+        }
+      }
+    }
+  }
+
   // For messages where the DO loop must be DO CONCURRENT, make that explicit.
   const char *LoopKindName() const {
     return kind_ == IndexVarKind::DO ? "DO CONCURRENT" : "FORALL";
index d64ba63..16d2eba 100644 (file)
@@ -123,7 +123,8 @@ static bool PerformStatementSemantics(
   RewriteParseTree(context, program);
   CheckDeclarations(context);
   StatementSemanticsPass1{context}.Walk(program);
-  return StatementSemanticsPass2{context}.Walk(program);
+  StatementSemanticsPass2{context}.Walk(program);
+  return !context.AnyFatalError();
 }
 
 SemanticsContext::SemanticsContext(
@@ -262,6 +263,16 @@ void SemanticsContext::DeactivateIndexVar(const parser::Name &name) {
   }
 }
 
+SymbolVector SemanticsContext::GetIndexVars(IndexVarKind kind) {
+  SymbolVector result;
+  for (const auto &[symbol, info] : activeIndexVars_) {
+    if (info.kind == kind) {
+      result.push_back(symbol);
+    }
+  }
+  return result;
+}
+
 bool Semantics::Perform() {
   return ValidateLabels(context_, program_) &&
       parser::CanonicalizeDo(program_) &&  // force line break
index c2ab99c..b125da8 100644 (file)
@@ -1,14 +1,53 @@
-integer :: a1(10), a2(10)
-logical :: m1(10), m2(5,5)
-m1 = .true.
-m2 = .false.
-a1 = [((i),i=1,10)]
-where (m1)
-  a2 = 1
-!ERROR: mask of ELSEWHERE statement is not conformable with the prior mask(s) in its WHERE construct
-elsewhere (m2)
-  a2 = 2
-elsewhere
-  a2 = 3
-end where
+! 10.2.3.1(2) All masks and LHS of assignments in a WHERE must conform
+
+subroutine s1
+  integer :: a1(10), a2(10)
+  logical :: m1(10), m2(5,5)
+  m1 = .true.
+  m2 = .false.
+  a1 = [((i),i=1,10)]
+  where (m1)
+    a2 = 1
+  !ERROR: Must have rank 1 to match prior mask or assignment of WHERE construct
+  elsewhere (m2)
+    a2 = 2
+  elsewhere
+    a2 = 3
+  end where
+end
+
+subroutine s2
+  logical, allocatable :: m1(:), m4(:,:)
+  logical :: m2(2), m3(3)
+  where(m1)
+    where(m2)
+    end where
+    !ERROR: Dimension 1 must have extent 2 to match prior mask or assignment of WHERE construct
+    where(m3)
+    end where
+    !ERROR: Must have rank 1 to match prior mask or assignment of WHERE construct
+    where(m4)
+    end where
+  endwhere
+  where(m1)
+    where(m3)
+    end where
+  !ERROR: Dimension 1 must have extent 3 to match prior mask or assignment of WHERE construct
+  elsewhere(m2)
+  end where
+end
+
+subroutine s3
+  logical, allocatable :: m1(:,:)
+  logical :: m2(4,2)
+  real :: x(4,4), y(4,4)
+  real :: a(4,5), b(4,5)
+  where(m1)
+    x = y
+    !ERROR: Dimension 2 must have extent 4 to match prior mask or assignment of WHERE construct
+    a = b
+    !ERROR: Dimension 2 must have extent 4 to match prior mask or assignment of WHERE construct
+    where(m2)
+    end where
+  end where
 end
index bd665e2..e90a17f 100644 (file)
@@ -16,7 +16,6 @@ subroutine forall1
   end forall
 end
 
-
 subroutine forall2
   integer, pointer :: a(:)
   integer, target :: b(10,10)
@@ -73,3 +72,34 @@ subroutine forall4
   !ERROR: FORALL step expression may not be zero
   forall(i=1:10:zero) a(i) = i
 end
+
+! Note: this gets warnings but not errors
+subroutine forall5
+  real, target :: x(10), y(10)
+  forall(i=1:10)
+    x(i) = y(i)
+  end forall
+  forall(i=1:10)
+    x = y  ! warning: i not used on LHS
+    forall(j=1:10)
+      x(i) = y(i)  ! warning: j not used on LHS
+      x(j) = y(j)  ! warning: i not used on LHS
+    endforall
+  endforall
+  do concurrent(i=1:10)
+    x = y
+    forall(i=1:10) x = y
+  end do
+end
+
+subroutine forall6
+  type t
+    real, pointer :: p
+  end type
+  type(t) :: a(10)
+  real, target :: b(10)
+  forall(i=1:10)
+    a(i)%p => b(i)
+    a(1)%p => b(i)  ! warning: i not used on LHS
+  end forall
+end