[flang] Avoid crashing from recursion on very tall expression parse trees
authorPeter Klausler <pklausler@nvidia.com>
Thu, 19 Jan 2023 22:32:09 +0000 (14:32 -0800)
committerPeter Klausler <pklausler@nvidia.com>
Wed, 1 Feb 2023 22:09:07 +0000 (14:09 -0800)
In the parse tree visitation framework (Parser/parse-tree-visitor.h)
and in the semantic analyzer for expressions (Semantics/expression.cpp)
avoid crashing due to stack size limitations by using an iterative
traversal algorithm rather than straightforward recursive tree walking.
The iterative approach is the obvious one of building a work queue and
using it to (in the case of the parse tree visitor) call the visitor
object's Pre() and Post() routines on subexpressions in the same order
as they would have been called during a recursive traversal.

This change helps the compiler survive some artificial stress tests
and perhaps with future exposure to machine-generated source code.

Differential Revision: https://reviews.llvm.org/D142771

flang/include/flang/Parser/parse-tree-visitor.h
flang/include/flang/Semantics/expression.h
flang/lib/Semantics/expression.cpp
flang/test/Evaluate/big-expr-tree.F90 [new file with mode: 0644]

index 4e749d3..75466e6 100644 (file)
@@ -16,6 +16,7 @@
 #include <tuple>
 #include <utility>
 #include <variant>
+#include <vector>
 
 /// Parse tree visitor
 /// Call Walk(x, visitor) to visit x and, by default, each node under x.
@@ -483,20 +484,76 @@ template <typename M> void Walk(CommonStmt &x, M &mutator) {
     mutator.Post(x);
   }
 }
+
+// Expr traversal uses iteration rather than recursion to avoid
+// blowing out the stack on very deep expression parse trees.
+// It replaces implementations that looked like:
+//   template <typename V> void Walk(const Expr &x, V visitor) {
+//     if (visitor.Pre(x)) {      // Pre on the Expr
+//       Walk(x.source, visitor);
+//       // Pre on the operator, walk the operands, Post on operator
+//       Walk(x.u, visitor);
+//       visitor.Post(x);         // Post on the Expr
+//     }
+//   }
+template <typename A, typename V, typename UNARY, typename BINARY>
+static void IterativeWalk(A &start, V &visitor) {
+  struct ExprWorkList {
+    ExprWorkList(A &x) : expr(&x) {}
+    bool doPostExpr{false}, doPostOpr{false};
+    A *expr;
+  };
+  std::vector<ExprWorkList> stack;
+  stack.emplace_back(start);
+  do {
+    A &expr{*stack.back().expr};
+    if (stack.back().doPostOpr) {
+      stack.back().doPostOpr = false;
+      common::visit([&visitor](auto &y) { visitor.Post(y); }, expr.u);
+    } else if (stack.back().doPostExpr) {
+      visitor.Post(expr);
+      stack.pop_back();
+    } else if (!visitor.Pre(expr)) {
+      stack.pop_back();
+    } else {
+      stack.back().doPostExpr = true;
+      Walk(expr.source, visitor);
+      UNARY *unary{nullptr};
+      BINARY *binary{nullptr};
+      common::visit(
+          [&unary, &binary](auto &y) {
+            if constexpr (std::is_convertible_v<decltype(&y), UNARY *>) {
+              unary = &y;
+            } else if constexpr (std::is_convertible_v<decltype(&y),
+                                     BINARY *>) {
+              binary = &y;
+            }
+          },
+          expr.u);
+      if (!unary && !binary) {
+        Walk(expr.u, visitor);
+      } else if (common::visit(
+                     [&visitor](auto &y) { return visitor.Pre(y); }, expr.u)) {
+        stack.back().doPostOpr = true;
+        if (unary) {
+          stack.emplace_back(unary->v.value());
+        } else {
+          stack.emplace_back(std::get<1>(binary->t).value());
+          stack.emplace_back(std::get<0>(binary->t).value());
+        }
+      }
+    }
+  } while (!stack.empty());
+}
 template <typename V> void Walk(const Expr &x, V &visitor) {
-  if (visitor.Pre(x)) {
-    Walk(x.source, visitor);
-    Walk(x.u, visitor);
-    visitor.Post(x);
-  }
+  IterativeWalk<const Expr, V, const Expr::IntrinsicUnary,
+      const Expr::IntrinsicBinary>(x, visitor);
 }
 template <typename M> void Walk(Expr &x, M &mutator) {
-  if (mutator.Pre(x)) {
-    Walk(x.source, mutator);
-    Walk(x.u, mutator);
-    mutator.Post(x);
-  }
+  IterativeWalk<Expr, M, Expr::IntrinsicUnary, Expr::IntrinsicBinary>(
+      x, mutator);
 }
+
 template <typename V> void Walk(const Designator &x, V &visitor) {
   if (visitor.Pre(x)) {
     Walk(x.source, visitor);
index 1e56dde..e8c313b 100644 (file)
@@ -381,6 +381,8 @@ private:
   bool CheckIsValidForwardReference(const semantics::DerivedTypeSpec &);
   MaybeExpr AnalyzeComplex(MaybeExpr &&re, MaybeExpr &&im, const char *what);
 
+  MaybeExpr IterativelyAnalyzeSubexpressions(const parser::Expr &);
+
   semantics::SemanticsContext &context_;
   FoldingContext &foldingContext_{context_.foldingContext()};
   std::map<parser::CharBlock, int> impliedDos_; // values are INTEGER kinds
@@ -391,6 +393,7 @@ private:
   bool inDataStmtObject_{false};
   bool inDataStmtConstant_{false};
   bool inStmtFunctionDefinition_{false};
+  bool iterativelyAnalyzingSubexpressions_{false};
   friend class ArgumentAnalyzer;
 };
 
index af6cce3..b61b97a 100644 (file)
@@ -29,6 +29,7 @@
 #include <functional>
 #include <optional>
 #include <set>
+#include <vector>
 
 // Typedef for optional generic expressions (ubiquitous in this file)
 using MaybeExpr =
@@ -3326,6 +3327,12 @@ MaybeExpr ExpressionAnalyzer::ExprOrVariable(
     result = Analyze(x.u);
   }
   if (result) {
+    if constexpr (std::is_same_v<PARSED, parser::Expr>) {
+      if (!isNullPointerOk_ && IsNullPointer(*result)) {
+        Say(source,
+            "NULL() may not be used as an expression in this context"_err_en_US);
+      }
+    }
     SetExpr(x, Fold(std::move(*result)));
     return x.typedExpr->v;
   } else {
@@ -3341,15 +3348,76 @@ MaybeExpr ExpressionAnalyzer::ExprOrVariable(
   }
 }
 
+// This is an optional preliminary pass over parser::Expr subtrees.
+// Given an expression tree, iteratively traverse it in a bottom-up order
+// to analyze all of its subexpressions.  A later normal top-down analysis
+// will then be able to use the results that will have been saved in the
+// parse tree without having to recurse deeply.  This technique keeps
+// absurdly deep expression parse trees from causing the analyzer to overflow
+// its stack.
+MaybeExpr ExpressionAnalyzer::IterativelyAnalyzeSubexpressions(
+    const parser::Expr &top) {
+  std::vector<const parser::Expr *> queue, finish;
+  queue.push_back(&top);
+  do {
+    const parser::Expr &expr{*queue.back()};
+    queue.pop_back();
+    if (!expr.typedExpr) {
+      const parser::Expr::IntrinsicUnary *unary{nullptr};
+      const parser::Expr::IntrinsicBinary *binary{nullptr};
+      common::visit(
+          [&unary, &binary](auto &y) {
+            if constexpr (std::is_convertible_v<decltype(&y),
+                              decltype(unary)>) {
+              // Don't evaluate a constant operand to Negate
+              if (!std::holds_alternative<parser::LiteralConstant>(
+                      y.v.value().u)) {
+                unary = &y;
+              }
+            } else if constexpr (std::is_convertible_v<decltype(&y),
+                                     decltype(binary)>) {
+              binary = &y;
+            }
+          },
+          expr.u);
+      if (unary) {
+        queue.push_back(&unary->v.value());
+      } else if (binary) {
+        queue.push_back(&std::get<0>(binary->t).value());
+        queue.push_back(&std::get<1>(binary->t).value());
+      }
+      finish.push_back(&expr);
+    }
+  } while (!queue.empty());
+  // Analyze the collected subexpressions in bottom-up order.
+  // On an error, bail out and leave partial results in place.
+  MaybeExpr result;
+  for (auto riter{finish.rbegin()}; riter != finish.rend(); ++riter) {
+    const parser::Expr &expr{**riter};
+    result = ExprOrVariable(expr, expr.source);
+    if (!result) {
+      return result;
+    }
+  }
+  return result; // last value was from analysis of "top"
+}
+
 MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr &expr) {
-  if (useSavedTypedExprs_ && expr.typedExpr) {
-    return expr.typedExpr->v;
+  bool wasIterativelyAnalyzing{iterativelyAnalyzingSubexpressions_};
+  MaybeExpr result;
+  if (useSavedTypedExprs_) {
+    if (expr.typedExpr) {
+      return expr.typedExpr->v;
+    }
+    if (!wasIterativelyAnalyzing) {
+      iterativelyAnalyzingSubexpressions_ = true;
+      result = IterativelyAnalyzeSubexpressions(expr);
+    }
   }
-  MaybeExpr result{ExprOrVariable(expr, expr.source)};
-  if (!isNullPointerOk_ && result && IsNullPointer(*result)) {
-    Say(expr.source,
-        "NULL() may not be used as an expression in this context"_err_en_US);
+  if (!result) {
+    result = ExprOrVariable(expr, expr.source);
   }
+  iterativelyAnalyzingSubexpressions_ = wasIterativelyAnalyzing;
   return result;
 }
 
@@ -4017,7 +4085,7 @@ std::optional<ActualArgument> ArgumentAnalyzer::AnalyzeExpr(
     const parser::Expr &expr) {
   source_.ExtendToCover(expr.source);
   if (const Symbol *assumedTypeDummy{AssumedTypeDummy(expr)}) {
-    expr.typedExpr.Reset(new GenericExprWrapper{}, GenericExprWrapper::Deleter);
+    ResetExpr(expr);
     if (isProcedureCall_) {
       ActualArgument arg{ActualArgument::AssumedType{*assumedTypeDummy}};
       SetArgSourceLocation(arg, expr.source);
diff --git a/flang/test/Evaluate/big-expr-tree.F90 b/flang/test/Evaluate/big-expr-tree.F90
new file mode 100644 (file)
index 0000000..feaa298
--- /dev/null
@@ -0,0 +1,8 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Exercise parsing, expression analysis, and folding on a very tall expression tree
+! 32*32 = 1024 repetitions
+#define M0(x) x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x
+#define M1(x) x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x
+module m
+  logical, parameter :: test_1 = 32**2 .EQ. M1(M0(1))
+end module