[flang] Rework expression constraint checking
authorpeter klausler <pklausler@nvidia.com>
Fri, 28 Dec 2018 23:58:17 +0000 (15:58 -0800)
committerpeter klausler <pklausler@nvidia.com>
Fri, 28 Dec 2018 23:58:17 +0000 (15:58 -0800)
Original-commit: flang-compiler/f18@7a31c1ed2ba3c57575d3bd98f6b13dfc94a07af0
Reviewed-on: https://github.com/flang-compiler/f18/pull/250
Tree-same-pre-rewrite: false

flang/lib/parser/parse-tree-visitor.h
flang/lib/semantics/dump-parse-tree.h
flang/lib/semantics/expression.cc
flang/lib/semantics/expression.h
flang/lib/semantics/resolve-names.cc

index 8f75975..7a9af80 100644 (file)
@@ -264,6 +264,7 @@ template<typename T, typename M> void Walk(DefaultChar<T> &x, M &mutator) {
 template<typename T, typename V> void Walk(const Statement<T> &x, V &visitor) {
   if (visitor.Pre(x)) {
     // N.B. the label is not traversed
+    Walk(x.source, visitor);
     Walk(x.statement, visitor);
     visitor.Post(x);
   }
@@ -271,6 +272,7 @@ template<typename T, typename V> void Walk(const Statement<T> &x, V &visitor) {
 template<typename T, typename M> void Walk(Statement<T> &x, M &mutator) {
   if (mutator.Pre(x)) {
     // N.B. the label is not traversed
+    Walk(x.source, mutator);
     Walk(x.statement, mutator);
     mutator.Post(x);
   }
@@ -278,11 +280,13 @@ template<typename T, typename M> void Walk(Statement<T> &x, M &mutator) {
 
 template<typename V> void Walk(const Name &x, V &visitor) {
   if (visitor.Pre(x)) {
+    Walk(x.source, visitor);
     visitor.Post(x);
   }
 }
 template<typename M> void Walk(Name &x, M &mutator) {
   if (mutator.Pre(x)) {
+    Walk(x.source, mutator);
     mutator.Post(x);
   }
 }
@@ -464,6 +468,20 @@ template<typename T, typename M> void Walk(LoopBounds<T> &x, M &mutator) {
     mutator.Post(x);
   }
 }
+template<typename V> void Walk(const Expr &x, V &visitor) {
+  if (visitor.Pre(x)) {
+    Walk(x.source, visitor);
+    Walk(x.u, visitor);
+    visitor.Post(x);
+  }
+}
+template<typename M> void Walk(Expr &x, M &mutator) {
+  if (mutator.Pre(x)) {
+    Walk(x.source, mutator);
+    Walk(x.u, mutator);
+    mutator.Post(x);
+  }
+}
 template<typename V> void Walk(const PartRef &x, V &visitor) {
   if (visitor.Pre(x)) {
     Walk(x.name, visitor);
@@ -498,6 +516,20 @@ template<typename M> void Walk(ReadStmt &x, M &mutator) {
     mutator.Post(x);
   }
 }
+template<typename V> void Walk(const SignedIntLiteralConstant &x, V &visitor) {
+  if (visitor.Pre(x)) {
+    Walk(x.source, visitor);
+    Walk(x.t, visitor);
+    visitor.Post(x);
+  }
+}
+template<typename M> void Walk(SignedIntLiteralConstant &x, M &mutator) {
+  if (mutator.Pre(x)) {
+    Walk(x.source, mutator);
+    Walk(x.t, mutator);
+    mutator.Post(x);
+  }
+}
 template<typename V> void Walk(const RealLiteralConstant &x, V &visitor) {
   if (visitor.Pre(x)) {
     Walk(x.real, visitor);
@@ -514,11 +546,13 @@ template<typename M> void Walk(RealLiteralConstant &x, M &mutator) {
 }
 template<typename V> void Walk(const RealLiteralConstant::Real &x, V &visitor) {
   if (visitor.Pre(x)) {
+    Walk(x.source, visitor);
     visitor.Post(x);
   }
 }
 template<typename M> void Walk(RealLiteralConstant::Real &x, M &mutator) {
   if (mutator.Pre(x)) {
+    Walk(x.source, mutator);
     mutator.Post(x);
   }
 }
@@ -694,6 +728,20 @@ void Walk(format::IntrinsicTypeDataEditDesc &x, M &mutator) {
     mutator.Post(x);
   }
 }
+template<typename V> void Walk(const CompilerDirective &x, V &visitor) {
+  if (visitor.Pre(x)) {
+    Walk(x.source, visitor);
+    Walk(x.u, visitor);
+    visitor.Post(x);
+  }
+}
+template<typename M> void Walk(CompilerDirective &x, M &mutator) {
+  if (mutator.Pre(x)) {
+    Walk(x.source, mutator);
+    Walk(x.u, mutator);
+    mutator.Post(x);
+  }
+}
 template<typename V>
 void Walk(const OmpLinearClause::WithModifier &x, V &visitor) {
   if (visitor.Pre(x)) {
index 2a0c083..1f8b99a 100644 (file)
@@ -779,6 +779,9 @@ public:
 
   // A few types we want to ignore
 
+  bool Pre(const parser::CharBlock &) { return true; }
+  void Post(const parser::CharBlock &) {}
+
   template<typename T> bool Pre(const parser::Statement<T> &) { return true; }
 
   template<typename T> void Post(const parser::Statement<T> &) {}
index 1013e32..c9f0058 100644 (file)
 #include <iostream>  // TODO pmk remove soon
 #include <optional>
 
-using namespace Fortran::parser::literals;
-
 // Typedef for optional generic expressions (ubiquitous in this file)
 using MaybeExpr =
     std::optional<Fortran::evaluate::Expr<Fortran::evaluate::SomeType>>;
 
+namespace Fortran::parser {
+bool SourceLocationFindingVisitor::Pre(const Expr &x) {
+  source = x.source;
+  return false;
+}
+void SourceLocationFindingVisitor::Post(const CharBlock &at) { source = at; }
+}
+
 // Much of the code that implements semantic analysis of expressions is
 // tightly coupled with their typed representations in lib/evaluate,
 // and appears here in namespace Fortran::evaluate for convenience.
@@ -40,61 +46,6 @@ namespace Fortran::evaluate {
 
 using common::TypeCategory;
 
-// Constraint checking
-void ExpressionAnalysisContext::CheckConstraints(MaybeExpr &expr) {
-  if (inner_ != nullptr) {
-    inner_->CheckConstraints(expr);
-  }
-  if (constraint_ != nullptr && expr.has_value()) {
-    if (!(this->*constraint_)(*expr)) {
-      expr.reset();
-    }
-  }
-}
-
-bool ExpressionAnalysisContext::ScalarConstraint(Expr<SomeType> &expr) {
-  int rank{expr.Rank()};
-  if (rank == 0) {
-    return true;
-  }
-  Say("expression must be scalar, but has rank %d"_err_en_US, rank);
-  return false;
-}
-
-bool ExpressionAnalysisContext::ConstantConstraint(Expr<SomeType> &expr) {
-  expr = Fold(context_.foldingContext(), std::move(expr));
-  if (IsConstant(expr)) {
-    return true;
-  }
-  Say("expression must be constant"_err_en_US);
-  return false;
-}
-
-bool ExpressionAnalysisContext::IntegerConstraint(Expr<SomeType> &expr) {
-  if (std::holds_alternative<Expr<SomeInteger>>(expr.u)) {
-    return true;
-  }
-  Say("expression must be INTEGER"_err_en_US);
-  return false;
-}
-
-bool ExpressionAnalysisContext::LogicalConstraint(Expr<SomeType> &expr) {
-  if (std::holds_alternative<Expr<SomeLogical>>(expr.u)) {
-    return true;
-  }
-  Say("expression must be LOGICAL"_err_en_US);
-  return false;
-}
-
-bool ExpressionAnalysisContext::DefaultCharConstraint(Expr<SomeType> &expr) {
-  if (auto *charExpr{std::get_if<Expr<SomeCharacter>>(&expr.u)}) {
-    return charExpr->GetKind() ==
-        context_.defaultKinds().GetDefaultKind(TypeCategory::Character);
-  }
-  Say("expression must be default CHARACTER"_err_en_US);
-  return false;
-}
-
 // If a generic expression simply wraps a DataRef, extract it.
 // TODO: put in tools.h?
 template<typename A> std::optional<DataRef> ExtractDataRef(A &&) {
@@ -1466,13 +1417,9 @@ MaybeExpr ExpressionAnalysisContext::Analyze(const parser::Expr &expr) {
     // Analyze the expression in a specified source position context for better
     // error reporting.
     auto save{context_.foldingContext().messages.SetLocation(expr.source)};
-    MaybeExpr result{AnalyzeExpr(*this, expr.u)};
-    CheckConstraints(result);
-    return result;
+    return AnalyzeExpr(*this, expr.u);
   } else {
-    MaybeExpr result{AnalyzeExpr(*this, expr.u)};
-    CheckConstraints(result);
-    return result;
+    return AnalyzeExpr(*this, expr.u);
   }
 }
 }
index fcace3f..639d734 100644 (file)
 #include "semantics.h"
 #include "../common/indirection.h"
 #include "../evaluate/expression.h"
+#include "../evaluate/tools.h"
 #include "../evaluate/type.h"
+#include "../parser/parse-tree-visitor.h"
+#include "../parser/parse-tree.h"
 #include <optional>
 #include <variant>
 
+using namespace Fortran::parser::literals;
+
 namespace Fortran::parser {
-struct Expr;
-struct Program;
-template<typename> struct Scalar;
-template<typename> struct Integer;
-template<typename> struct Constant;
-template<typename> struct Logical;
-template<typename> struct DefaultChar;
+struct SourceLocationFindingVisitor {
+  template<typename A> bool Pre(const A &) { return true; }
+  template<typename A> void Post(const A &) {}
+  bool Pre(const Expr &);
+  template<typename A> bool Pre(const Statement<A> &stmt) {
+    source = stmt.source;
+    return false;
+  }
+  void Post(const CharBlock &);
+
+  CharBlock source;
+};
+
+template<typename A> CharBlock FindSourceLocation(const A &x) {
+  SourceLocationFindingVisitor visitor;
+  Walk(x, visitor);
+  return visitor.source;
+}
 }
 
 // The expression semantic analysis code has its implementation in
@@ -45,21 +61,14 @@ template<typename> struct DefaultChar;
 // The ExpressionAnalysisContext wraps a SemanticsContext reference
 // and implements constraint checking on expressions using the
 // parse tree node wrappers that mirror the grammar annotations used
-// in the Fortran standard (i.e., scalar-, constant-, &c.).  These
-// constraint checks are performed in a deferred manner so that any
-// errors are reported on the most accurate source location available.
+// in the Fortran standard (i.e., scalar-, constant-, &c.).
 
 namespace Fortran::evaluate {
 class ExpressionAnalysisContext {
 public:
-  using ConstraintChecker = bool (ExpressionAnalysisContext::*)(
-      Expr<SomeType> &);
-
   ExpressionAnalysisContext(semantics::SemanticsContext &sc) : context_{sc} {}
   ExpressionAnalysisContext(ExpressionAnalysisContext &i)
     : context_{i.context_}, inner_{&i} {}
-  ExpressionAnalysisContext(ExpressionAnalysisContext &i, ConstraintChecker cc)
-    : context_{i.context_}, inner_{&i}, constraint_{cc} {}
 
   semantics::SemanticsContext &context() const { return context_; }
 
@@ -67,12 +76,10 @@ public:
     context_.foldingContext().messages.Say(std::forward<A>(args)...);
   }
 
-  void CheckConstraints(std::optional<Expr<SomeType>> &);
-  bool ScalarConstraint(Expr<SomeType> &);
-  bool ConstantConstraint(Expr<SomeType> &);
-  bool IntegerConstraint(Expr<SomeType> &);
-  bool LogicalConstraint(Expr<SomeType> &);
-  bool DefaultCharConstraint(Expr<SomeType> &);
+  template<typename T, typename... A> void SayAt(const T &parsed, A... args) {
+    context_.foldingContext().messages.Say(
+        parser::FindSourceLocation(parsed), std::forward<A>(args)...);
+  }
 
   std::optional<Expr<SomeType>> Analyze(const parser::Expr &);
 
@@ -81,7 +88,6 @@ protected:
 
 private:
   ExpressionAnalysisContext *inner_{nullptr};
-  ConstraintChecker constraint_{nullptr};
 };
 
 template<typename PARSED>
@@ -121,46 +127,72 @@ std::optional<Expr<SomeType>> AnalyzeExpr(
   return AnalyzeExpr(context, *x);
 }
 
-// These specializations create nested expression analysis contexts
-// to implement constraint checking.
+// These specializations implement constraint checking.
 
 template<typename A>
 std::optional<Expr<SomeType>> AnalyzeExpr(
-    ExpressionAnalysisContext &context, const parser::Scalar<A> &expr) {
-  ExpressionAnalysisContext withCheck{
-      context, &ExpressionAnalysisContext::ScalarConstraint};
-  return AnalyzeExpr(withCheck, expr.thing);
+    ExpressionAnalysisContext &context, const parser::Scalar<A> &x) {
+  auto result{AnalyzeExpr(context, x.thing)};
+  if (result.has_value()) {
+    if (int rank{result->Rank()}; rank != 0) {
+      context.SayAt(
+          x, "Must be a scalar value, but is a rank-%d array"_err_en_US);
+    }
+  }
+  return result;
 }
 
 template<typename A>
 std::optional<Expr<SomeType>> AnalyzeExpr(
-    ExpressionAnalysisContext &context, const parser::Constant<A> &expr) {
-  ExpressionAnalysisContext withCheck{
-      context, &ExpressionAnalysisContext::ConstantConstraint};
-  return AnalyzeExpr(withCheck, expr.thing);
+    ExpressionAnalysisContext &context, const parser::Constant<A> &x) {
+  auto result{AnalyzeExpr(context, x.thing)};
+  if (result.has_value()) {
+    *result = Fold(context.context().foldingContext(), std::move(*result));
+    if (!IsConstant(*result)) {
+      context.SayAt(x, "Must be a constant value"_err_en_US);
+    }
+  }
+  return result;
 }
 
 template<typename A>
 std::optional<Expr<SomeType>> AnalyzeExpr(
-    ExpressionAnalysisContext &context, const parser::Integer<A> &expr) {
-  ExpressionAnalysisContext withCheck{
-      context, &ExpressionAnalysisContext::IntegerConstraint};
-  return AnalyzeExpr(withCheck, expr.thing);
+    ExpressionAnalysisContext &context, const parser::Integer<A> &x) {
+  auto result{AnalyzeExpr(context, x.thing)};
+  if (result.has_value()) {
+    if (!std::holds_alternative<Expr<SomeInteger>>(result->u)) {
+      context.SayAt(x, "Must have INTEGER type"_err_en_US);
+    }
+  }
+  return result;
 }
 
 template<typename A>
 std::optional<Expr<SomeType>> AnalyzeExpr(
-    ExpressionAnalysisContext &context, const parser::Logical<A> &expr) {
-  ExpressionAnalysisContext withCheck{
-      context, &ExpressionAnalysisContext::LogicalConstraint};
-  return AnalyzeExpr(withCheck, expr.thing);
+    ExpressionAnalysisContext &context, const parser::Logical<A> &x) {
+  auto result{AnalyzeExpr(context, x.thing)};
+  if (result.has_value()) {
+    if (!std::holds_alternative<Expr<SomeLogical>>(result->u)) {
+      context.SayAt(x, "Must have LOGICAL type"_err_en_US);
+    }
+  }
+  return result;
 }
 template<typename A>
 std::optional<Expr<SomeType>> AnalyzeExpr(
-    ExpressionAnalysisContext &context, const parser::DefaultChar<A> &expr) {
-  ExpressionAnalysisContext withCheck{
-      context, &ExpressionAnalysisContext::DefaultCharConstraint};
-  return AnalyzeExpr(withCheck, expr.thing);
+    ExpressionAnalysisContext &context, const parser::DefaultChar<A> &x) {
+  auto result{AnalyzeExpr(context, x.thing)};
+  if (result.has_value()) {
+    if (auto *charExpr{std::get_if<Expr<SomeCharacter>>(&result->u)}) {
+      if (charExpr->GetKind() ==
+          context.context().defaultKinds().GetDefaultKind(
+              TypeCategory::Character)) {
+        return result;
+      }
+    }
+    context.SayAt(x, "Must have default CHARACTER type"_err_en_US);
+  }
+  return result;
 }
 }
 
index 6aea635..db35147 100644 (file)
@@ -1115,10 +1115,11 @@ int DeclTypeSpecVisitor::GetKindParamValue(
       common::visitors{
           [&](const parser::ScalarIntConstantExpr &x) -> int {
             if (auto maybeExpr{EvaluateExpr(x)}) {
-              return evaluate::ToInt64(*maybeExpr).value();
-            } else {
-              return 0;
+              if (auto intConst{evaluate::ToInt64(*maybeExpr)}) {
+                return *intConst;
+              }
             }
+            return 0;
           },
           [&](const parser::KindSelector::StarSize &x) -> int {
             std::uint64_t size{x.v};