[flang] folding array constructors
authorpeter klausler <pklausler@nvidia.com>
Tue, 29 Jan 2019 00:56:50 +0000 (16:56 -0800)
committerpeter klausler <pklausler@nvidia.com>
Thu, 31 Jan 2019 17:59:33 +0000 (09:59 -0800)
Original-commit: flang-compiler/f18@a4e045fc5a913c7c3814820bf662205cc7687f1e
Reviewed-on: https://github.com/flang-compiler/f18/pull/271
Tree-same-pre-rewrite: false

flang/lib/evaluate/common.h
flang/lib/evaluate/fold.cc
flang/lib/semantics/resolve-names.cc

index 67268bc..aed918f 100644 (file)
@@ -22,6 +22,7 @@
 #include "../parser/char-block.h"
 #include "../parser/message.h"
 #include <cinttypes>
+#include <map>
 
 namespace Fortran::semantics {
 class DerivedTypeSpec;
@@ -199,19 +200,20 @@ struct FoldingContext {
     : messages{m}, rounding{round}, flushSubnormalsToZero{flush} {}
   FoldingContext(const FoldingContext &that)
     : messages{that.messages}, rounding{that.rounding},
-      flushSubnormalsToZero{that.flushSubnormalsToZero}, pdtInstance{
-                                                           that.pdtInstance} {}
+      flushDenormalsToZero{that.flushDenormalsToZero},
+      pdtInstance{that.pdtInstance}, impliedDos{that.impliedDos} {}
   FoldingContext(
       const FoldingContext &that, const parser::ContextualMessages &m)
     : messages{m}, rounding{that.rounding},
-      flushSubnormalsToZero{that.flushSubnormalsToZero}, pdtInstance{
-                                                           that.pdtInstance} {}
+      flushDenormalsToZero{that.flushDenormalsToZero},
+      pdtInstance{that.pdtInstance}, impliedDos{that.impliedDos} {}
 
   parser::ContextualMessages messages;
   Rounding rounding{defaultRounding};
   bool flushSubnormalsToZero{false};
   bool bigEndian{false};
   const semantics::DerivedTypeSpec *pdtInstance{nullptr};
+  std::map<parser::CharBlock, std::int64_t> impliedDos;
 };
 
 void RealFlagWarnings(FoldingContext &, const RealFlags &, const char *op);
index 32e9fe8..bfc6114 100644 (file)
@@ -14,6 +14,7 @@
 
 #include "fold.h"
 #include "common.h"
+#include "constant.h"
 #include "expression.h"
 #include "int-power.h"
 #include "tools.h"
@@ -54,6 +55,8 @@ template<typename T> Expr<T> FoldOperation(FoldingContext &, Designator<T> &&);
 template<int KIND>
 Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(
     FoldingContext &, TypeParamInquiry<KIND> &&);
+template<typename T>
+Expr<T> FoldOperation(FoldingContext &, ArrayConstructor<T> &&);
 
 // Overloads, instantiations, and specializations of FoldOperation().
 
@@ -209,6 +212,100 @@ Expr<T> FoldOperation(FoldingContext &context, Designator<T> &&designator) {
       std::move(designator.u));
 }
 
+// Array constructor folding
+
+Expr<ImpliedDoIndex::Result> FoldOperation(
+    FoldingContext &context, ImpliedDoIndex &&iDo) {
+  auto iter{context.impliedDos.find(iDo.name)};
+  CHECK(iter != context.impliedDos.end());
+  return Expr<ImpliedDoIndex::Result>{iter->second};
+}
+
+template<typename T> class ArrayConstructorFolder {
+public:
+  explicit ArrayConstructorFolder(const FoldingContext &c) : context_{c} {
+    context_.impliedDos.clear();
+  }
+  Expr<T> FoldArray(ArrayConstructor<T> &&array) {
+    if (FoldArray(array.values)) {
+      std::int64_t n = elements_.size();
+      return Expr<T>{
+          Constant<T>{std::move(elements_), std::vector<std::int64_t>{n}}};
+    } else {
+      return Expr<T>{std::move(array)};
+    }
+  }
+
+private:
+  bool FoldArray(const CopyableIndirection<Expr<T>> &expr) {
+    Expr<T> folded{Fold(context_, Expr<T>{*expr})};
+    if (auto *c{UnwrapExpr<Constant<T>>(folded)}) {
+      // Copy elements in Fortran array element order
+      std::vector<std::int64_t> shape{c->shape()};
+      int rank{c->Rank()};
+      std::vector<std::int64_t> index(shape.size(), 1);
+      for (std::size_t n{c->size()}; n-- > 0;) {
+        elements_.push_back(c->At(index));
+        for (int d{0}; d < rank && ++index[d] <= shape[d]; ++d) {
+          index[d] = 1;
+        }
+      }
+      return true;
+    } else {
+      return false;
+    }
+  }
+  bool FoldArray(const ImpliedDo<T> &iDo) {
+    Expr<SubscriptInteger> lower{
+        Fold(context_, Expr<SubscriptInteger>{*iDo.lower})};
+    Expr<SubscriptInteger> upper{
+        Fold(context_, Expr<SubscriptInteger>{*iDo.upper})};
+    Expr<SubscriptInteger> stride{
+        Fold(context_, Expr<SubscriptInteger>{*iDo.stride})};
+    std::optional<std::int64_t> start{ToInt64(lower)}, end{ToInt64(upper)},
+        step{ToInt64(stride)};
+    if (start.has_value() && end.has_value() && step.has_value()) {
+      auto pair{context_.impliedDos.insert(
+          std::make_pair(iDo.controlVariableName, *start))};
+      CHECK(pair.second);
+      bool result{true};
+      for (std::int64_t &j{pair.first->second}; j <= *end; j += *step) {
+        result &= FoldArray(*iDo.values);
+      }
+      context_.impliedDos.erase(pair.first);
+      return result;
+    } else {
+      return false;
+    }
+  }
+  bool FoldArray(const ArrayConstructorValue<T> &x) {
+    return std::visit([&](const auto &y) { return FoldArray(y); }, x.u);
+  }
+  bool FoldArray(const ArrayConstructorValues<T> &xs) {
+    for (const auto &x : xs.values) {
+      if (!FoldArray(x)) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  FoldingContext context_;
+  std::vector<Scalar<T>> elements_;
+};
+
+template<typename T>
+Expr<T> FoldOperation(FoldingContext &context, ArrayConstructor<T> &&array) {
+  ArrayConstructorFolder<T> folder{context};
+  return folder.FoldArray(std::move(array));
+}
+
+Expr<SomeDerived> FoldOperation(
+    FoldingContext &context, ArrayConstructor<SomeDerived> &&array) {
+  // TODO pmk: derived type array constructor folding (no Scalar<T> to use)
+  return Expr<SomeDerived>{std::move(array)};
+}
+
 // Substitute a bare type parameter reference with its value if it has one now
 template<int KIND>
 Expr<Type<TypeCategory::Integer, KIND>> FoldOperation(
index 4cd025c..f21e194 100644 (file)
@@ -694,7 +694,8 @@ protected:
   // Declare a statement entity (e.g., an implied DO loop index).
   // If there isn't a type specified, implicit rules apply.
   // Return pointer to the new symbol, or nullptr on error.
-  Symbol *DeclareStatementEntity(const parser::Name &);
+  Symbol *DeclareStatementEntity(
+      const parser::Name &, const std::optional<parser::IntegerTypeSpec> &);
   bool CheckUseError(const parser::Name &);
   void CheckAccessibility(const parser::Name &, bool, const Symbol &);
 
@@ -3003,7 +3004,8 @@ Symbol *DeclarationVisitor::DeclareConstructEntity(const parser::Name &name) {
   return &symbol;
 }
 
-Symbol *DeclarationVisitor::DeclareStatementEntity(const parser::Name &name) {
+Symbol *DeclarationVisitor::DeclareStatementEntity(const parser::Name &name,
+    const std::optional<parser::IntegerTypeSpec> &type) {
   if (auto *prev{FindSymbol(name)}) {
     if (prev->owner() == currScope()) {
       SayAlreadyDeclared(name, *prev);
@@ -3013,8 +3015,15 @@ Symbol *DeclarationVisitor::DeclareStatementEntity(const parser::Name &name) {
   }
   Symbol &symbol{DeclareEntity<ObjectEntityDetails>(name, {})};
   if (symbol.has<ObjectEntityDetails>()) {
-    if (auto *type{GetDeclTypeSpec()}) {
-      SetType(name, *type);
+    const DeclTypeSpec *declTypeSpec{nullptr};
+    if (type.has_value()) {
+      BeginDeclTypeSpec();
+      DeclarationVisitor::Post(*type);
+      declTypeSpec = GetDeclTypeSpec();
+      EndDeclTypeSpec();
+    }
+    if (declTypeSpec != nullptr) {
+      SetType(name, *declTypeSpec);
     } else {
       ApplyImplicitRules(symbol);
     }
@@ -3229,16 +3238,9 @@ bool ConstructVisitor::Pre(const parser::AcImpliedDo &x) {
   auto &control{std::get<parser::AcImpliedDoControl>(x.t)};
   auto &type{std::get<std::optional<parser::IntegerTypeSpec>>(control.t)};
   auto &bounds{std::get<parser::LoopBounds<parser::ScalarIntExpr>>(control.t)};
-  if (type) {
-    BeginDeclTypeSpec();
-    DeclarationVisitor::Post(*type);
-  }
-  if (auto *symbol{DeclareStatementEntity(bounds.name.thing.thing)}) {
+  if (auto *symbol{DeclareStatementEntity(bounds.name.thing.thing, type)}) {
     CheckScalarIntegerType(*symbol);
   }
-  if (type) {
-    EndDeclTypeSpec();
-  }
   Walk(bounds);
   Walk(values);
   return false;
@@ -3249,16 +3251,9 @@ bool ConstructVisitor::Pre(const parser::DataImpliedDo &x) {
   auto &type{std::get<std::optional<parser::IntegerTypeSpec>>(x.t)};
   auto &bounds{
       std::get<parser::LoopBounds<parser::ScalarIntConstantExpr>>(x.t)};
-  if (type) {
-    BeginDeclTypeSpec();
-    DeclarationVisitor::Post(*type);
-  }
-  if (auto *symbol{DeclareStatementEntity(bounds.name.thing.thing)}) {
+  if (auto *symbol{DeclareStatementEntity(bounds.name.thing.thing, type)}) {
     CheckScalarIntegerType(*symbol);
   }
-  if (type) {
-    EndDeclTypeSpec();
-  }
   Walk(bounds);
   Walk(objects);
   return false;