[flang] build up expressions implied by DO loop construct
authorEric Schweitz <eschweitz@nvidia.com>
Wed, 20 Mar 2019 23:26:07 +0000 (16:26 -0700)
committerEric <eschweitz@nvidia.com>
Sat, 23 Mar 2019 18:14:20 +0000 (11:14 -0700)
Original-commit: flang-compiler/f18@1e7b9adb62666223662b5445f9303d82be432299
Reviewed-on: https://github.com/flang-compiler/f18/pull/354
Tree-same-pre-rewrite: false

flang/lib/FIR/afforestation.cc
flang/lib/FIR/builder.h
flang/lib/FIR/flattened.cc
flang/lib/FIR/statements.h

index 7522fb6..115f8d3 100644 (file)
@@ -210,14 +210,14 @@ const parser::Format *FindReadWriteFormat(
 }
 
 static Expression AlwaysTrueExpression() {
-  using T = evaluate::Type<evaluate::TypeCategory::Logical, 1>;
-  return {evaluate::AsGenericExpr(evaluate::Constant<T>{true})};
+  using A = evaluate::Type<evaluate::TypeCategory::Logical, 1>;
+  return {evaluate::AsGenericExpr(evaluate::Constant<A>{true})};
 }
 
 // create an integer constant as an expression
 static Expression CreateConstant(int64_t value) {
-  using T = evaluate::SubscriptInteger;
-  return {evaluate::AsGenericExpr(evaluate::Constant<T>{value})};
+  using A = evaluate::SubscriptInteger;
+  return {evaluate::AsGenericExpr(evaluate::Constant<A>{value})};
 }
 
 static void CreateSwitchHelper(FIRBuilder *builder, Value condition,
@@ -237,6 +237,10 @@ static void CreateSwitchTypeHelper(FIRBuilder *builder, Value condition,
   builder->CreateSwitchType(condition, rest);
 }
 
+static Expression getApplyExpr(Statement *s) {
+  return GetApplyExpr(s)->expression();
+}
+
 class FortranIRLowering {
 public:
   using LabelMapType = std::map<flat::LabelRef, BasicBlock *>;
@@ -276,20 +280,33 @@ public:
 
   Program *program() { return fir_; }
 
+  // convert a parse tree data reference to an Expression
+  template<typename A> Expression ToExpression(const A &a) {
+    return {std::move(semantics::AnalyzeExpr(semanticsContext_, a).value())};
+  }
+
+  // build a simple arithmetic Expression
+  template<template<typename> class OPR>
+  Expression ConsExpr(Expression e1, Expression e2) {
+    evaluate::ExpressionAnalyzer context{semanticsContext_};
+    ConformabilityCheck(context.GetContextualMessages(), e1, e2);
+    return evaluate::NumericOperation<OPR>(context.GetContextualMessages(),
+        std::move(e1), std::move(e2),
+        context.GetDefaultKind(common::TypeCategory::Real))
+        .value();
+  }
+
   template<typename T>
   void ProcessRoutine(const T &here, const std::string &name) {
     CHECK(!fir_->containsProcedure(name));
     auto *subp{fir_->getOrInsertProcedure(name, nullptr, {})};
     builder_ = new FIRBuilder(*CreateBlock(subp->getLastRegion()));
     AnalysisData ad;
-#if 0
-    ControlFlowAnalyzer linearize{linearOperations_, ad};
-    Walk(here, linearize);
-#else
     CreateFlatIR(here, linearOperations_, ad);
-#endif
     if (debugLinearFIR_) {
+      DebugChannel() << "define @" << name << "(...) {\n";
       dump(linearOperations_);
+      DebugChannel() << "}\n";
     }
     ConstructFIR(ad);
     DrawRemainingArcs();
@@ -325,7 +342,7 @@ public:
     if (remap) {
       return remap;
     }
-    return builder_->CreateAddr(DataRefToExpression(dataRef));
+    return builder_->CreateAddr(ToExpression(dataRef));
   }
   Type CreateAllocationValue(const parser::Allocation *allocation,
       const parser::AllocateStmt *statement) {
@@ -464,7 +481,7 @@ public:
               return builder_->CreateExpr(ExprRef(e));
             },
             [&](const parser::Variable &v) {
-              return builder_->CreateExpr(VariableToExpression(v));
+              return builder_->CreateExpr(ToExpression(v));
             },
         },
         std::get<parser::Selector>(
@@ -555,29 +572,11 @@ public:
     return result;
   }
 
-  Expression VariableToExpression(const parser::Variable &var) {
-    evaluate::ExpressionAnalyzer analyzer{semanticsContext_};
-    return {std::move(analyzer.Analyze(var).value())};
-  }
-  Expression DataRefToExpression(const parser::DataRef &dr) {
-    evaluate::ExpressionAnalyzer analyzer{semanticsContext_};
-    return {std::move(analyzer.Analyze(dr).value())};
-  }
-  Expression NameToExpression(const parser::Name &name) {
-    evaluate::ExpressionAnalyzer analyzer{semanticsContext_};
-    return {std::move(analyzer.Analyze(name).value())};
-  }
-  Expression StructureComponentToExpression(
-      const parser::StructureComponent &sc) {
-    evaluate::ExpressionAnalyzer analyzer{semanticsContext_};
-    return {std::move(analyzer.Analyze(sc).value())};
-  }
-
   void handleIntrinsicAssignmentStmt(const parser::AssignmentStmt &stmt) {
     // TODO: check if allocation or reallocation should happen, etc.
     auto *value{builder_->CreateExpr(ExprRef(std::get<parser::Expr>(stmt.t)))};
-    auto *addr{builder_->CreateAddr(
-        VariableToExpression(std::get<parser::Variable>(stmt.t)))};
+    auto *addr{
+        builder_->CreateAddr(ToExpression(std::get<parser::Variable>(stmt.t)))};
     builder_->CreateStore(addr, value);
   }
   void handleDefinedAssignmentStmt(const parser::AssignmentStmt &stmt) {
@@ -614,10 +613,10 @@ public:
                 std::visit(
                     common::visitors{
                         [&](const parser::StatVariable &sv) {
-                          opts.stat = VariableToExpression(sv.v.thing.thing);
+                          opts.stat = ToExpression(sv.v.thing.thing);
                         },
                         [&](const parser::MsgVariable &mv) {
-                          opts.errmsg = VariableToExpression(mv.v.thing.thing);
+                          opts.errmsg = ToExpression(mv.v.thing.thing);
                         },
                     },
                     var.u);
@@ -719,12 +718,11 @@ public:
                 std::visit(
                     common::visitors{
                         [&](const parser::Name &n) {
-                          auto *s{builder_->CreateAddr(NameToExpression(n))};
+                          auto *s{builder_->CreateAddr(ToExpression(n))};
                           builder_->CreateNullify(s);
                         },
                         [&](const parser::StructureComponent &sc) {
-                          auto *s{builder_->CreateAddr(
-                              StructureComponentToExpression(sc))};
+                          auto *s{builder_->CreateAddr(ToExpression(sc))};
                           builder_->CreateNullify(s);
                         },
                     },
@@ -801,7 +799,7 @@ public:
             },
             [&](const common::Indirection<parser::AssignStmt> &s) {
               auto *addr{builder_->CreateAddr(
-                  NameToExpression(std::get<parser::Name>(s.value().t)))};
+                  ToExpression(std::get<parser::Name>(s.value().t)))};
               auto *block{blockMap_
                               .find(flat::FetchLabel(
                                   ad, std::get<parser::Label>(s.value().t))
@@ -871,14 +869,12 @@ public:
       auto &selector{std::get<parser::Selector>(assoc.t)};
       auto *expr{builder_->CreateExpr(std::visit(
           common::visitors{
-              [&](const parser::Variable &v) {
-                return VariableToExpression(v);
-              },
+              [&](const parser::Variable &v) { return ToExpression(v); },
               [](const parser::Expr &e) { return *ExprRef(e); },
           },
           selector.u))};
-      auto *name{builder_->CreateAddr(
-          NameToExpression(std::get<parser::Name>(assoc.t)))};
+      auto *name{
+          builder_->CreateAddr(ToExpression(std::get<parser::Name>(assoc.t)))};
       builder_->CreateStore(name, expr);
     }
   }
@@ -908,8 +904,8 @@ public:
       std::visit(
           common::visitors{
               [&](const parser::LoopBounds<parser::ScalarIntExpr> &bounds) {
-                auto *var = builder_->CreateAddr(
-                    NameToExpression(bounds.name.thing.thing));
+                auto *name{builder_->CreateAddr(
+                    ToExpression(bounds.name.thing.thing))};
                 // evaluate e1, e2 [, e3] ...
                 auto *e1{
                     builder_->CreateExpr(ExprRef(bounds.lower.thing.thing))};
@@ -921,8 +917,20 @@ public:
                 } else {
                   e3 = builder_->CreateExpr(CreateConstant(1));
                 }
-                builder_->CreateStore(var, e1);
-                PushDoContext(stmt, var, e1, e2, e3);
+                // name <- e1
+                builder_->CreateStore(name, e1);
+                auto *tripCounter{builder_->CreateLocal(nullptr)};
+                // totalTrips ::= iteration count = a
+                //   where a = (e2 - e1 + e3) / e3 if a > 0 and 0 otherwise
+                Expression tripExpr{ConsExpr<evaluate::Divide>(
+                    ConsExpr<evaluate::Add>(
+                        ConsExpr<evaluate::Subtract>(
+                            getApplyExpr(e2), getApplyExpr(e1)),
+                        getApplyExpr(e3)),
+                    getApplyExpr(e3))};
+                auto *totalTrips{builder_->CreateExpr(&tripExpr)};
+                builder_->CreateStore(tripCounter, totalTrips);
+                PushDoContext(stmt, name, nullptr, tripCounter, e3);
               },
               [&](const parser::ScalarLogicalExpr &whileExpr) {
                 // FIXME
index 15db40b..7a65c88 100644 (file)
@@ -44,6 +44,12 @@ struct FIRBuilder {
 
   BasicBlock *GetInsertionPoint() const { return cursorBlock_; }
 
+  Statement *CreateAddr(const Expression *e) {
+    return Insert(LocateExprStmt::Create(e));
+  }
+  Statement *CreateAddr(Expression &&e) {
+    return Insert(LocateExprStmt::Create(std::move(e)));
+  }
   Statement *CreateAlloc(Type type) {
     return Insert(AllocateInsn::Create(type));
   }
@@ -61,6 +67,9 @@ struct FIRBuilder {
   Statement *CreateDealloc(AllocateInsn *alloc) {
     return Insert(DeallocateInsn::Create(alloc));
   }
+  Statement *CreateDoCondition(Statement *dir, Statement *v1, Statement *v2) {
+    return Insert(DoConditionStmt::Create(dir, v1, v2));
+  }
   Statement *CreateExpr(const Expression *e) {
     return Insert(ApplyExprStmt::Create(e));
   }
@@ -70,32 +79,21 @@ struct FIRBuilder {
   ApplyExprStmt *MakeAsExpr(const Expression *e) {
     return GetApplyExpr(CreateExpr(e));
   }
-  Statement *CreateAddr(const Expression *e) {
-    return Insert(LocateExprStmt::Create(e));
-  }
-  Statement *CreateAddr(Expression &&e) {
-    return Insert(LocateExprStmt::Create(std::move(e)));
-  }
-  Statement *CreateLoad(Statement *addr) {
-    return Insert(LoadInsn::Create(addr));
-  }
-  Statement *CreateStore(Statement *addr, Statement *value) {
-    return Insert(StoreInsn::Create(addr, value));
-  }
-  Statement *CreateStore(Statement *addr, BasicBlock *value) {
-    return Insert(StoreInsn::Create(addr, value));
-  }
   Statement *CreateIncrement(Statement *v1, Statement *v2) {
     return Insert(IncrementStmt::Create(v1, v2));
   }
-  Statement *CreateDoCondition(Statement *dir, Statement *v1, Statement *v2) {
-    return Insert(DoConditionStmt::Create(dir, v1, v2));
+  Statement *CreateIndirectBr(Variable *v, const std::vector<BasicBlock *> &p) {
+    return InsertTerminator(IndirectBranchStmt::Create(v, p));
   }
   Statement *CreateIOCall(InputOutputCallType c, IOCallArguments &&a) {
     return Insert(IORuntimeStmt::Create(c, std::move(a)));
   }
-  Statement *CreateIndirectBr(Variable *v, const std::vector<BasicBlock *> &p) {
-    return InsertTerminator(IndirectBranchStmt::Create(v, p));
+  Statement *CreateLoad(Statement *addr) {
+    return Insert(LoadInsn::Create(addr));
+  }
+  Statement *CreateLocal(
+      Type type, int alignment = 0, Expression *expr = nullptr) {
+    return Insert(AllocateLocalInsn::Create(type, alignment, expr));
   }
   Statement *CreateNullify(Statement *s) {
     return Insert(DisassociateInsn::Create(s));
@@ -107,6 +105,12 @@ struct FIRBuilder {
       RuntimeCallType call, RuntimeCallArguments &&arguments) {
     return Insert(RuntimeStmt::Create(call, std::move(arguments)));
   }
+  Statement *CreateStore(Statement *addr, Statement *value) {
+    return Insert(StoreInsn::Create(addr, value));
+  }
+  Statement *CreateStore(Statement *addr, BasicBlock *value) {
+    return Insert(StoreInsn::Create(addr, value));
+  }
   Statement *CreateSwitch(
       Value cond, const SwitchStmt::ValueSuccPairListType &pairs) {
     return InsertTerminator(SwitchStmt::Create(cond, pairs));
index 6f3f54a..8e03d88 100644 (file)
@@ -271,12 +271,11 @@ template<typename A, typename B> std::string GetSource(const B *s) {
 void LabelOp::dump() const { DebugChannel() << "label_" << get() << ":\n"; }
 
 void GotoOp::dump() const {
-  DebugChannel() << "\tgoto " << target << " ["
-                 << std::visit(
-                        common::visitors{
-                            [](ArtificialJump) { return ""s; },
-                            [&](const auto *) { return GetSource(this); },
-                        },
+  DebugChannel() << "\tgoto %label_" << target << " ["
+                 << std::visit(common::visitors{
+                                   [](ArtificialJump) { return ""s; },
+                                   [&](auto *) { return GetSource(this); },
+                               },
                         u)
                  << "]\n";
 }
@@ -286,12 +285,12 @@ void ReturnOp::dump() const {
 }
 
 void ConditionalGotoOp::dump() const {
-  DebugChannel() << "\tcbranch .T.:" << trueLabel << " .F.:" << falseLabel
-                 << " ["
+  DebugChannel() << "\tcbranch .T.: %label_" << trueLabel << " .F.: %label_"
+                 << falseLabel << " ["
                  << std::visit(
                         common::visitors{
-                            [](const auto *s) { return GetSource(s); },
                             [](const parser::IfStmt *) { return "if-stmt"s; },
+                            [&](auto *s) { return GetSource(s); },
                         },
                         u)
                  << "]\n";
@@ -300,13 +299,13 @@ void ConditionalGotoOp::dump() const {
 void SwitchIOOp::dump() const {
   DebugChannel() << "\tio-call";
   if (errLabel.has_value()) {
-    DebugChannel() << " ERR:" << errLabel.value();
+    DebugChannel() << " ERR: %label_" << errLabel.value();
   }
   if (eorLabel.has_value()) {
-    DebugChannel() << " EOR:" << eorLabel.value();
+    DebugChannel() << " EOR: %label_" << eorLabel.value();
   }
   if (endLabel.has_value()) {
-    DebugChannel() << " END:" << endLabel.value();
+    DebugChannel() << " END: %label_" << endLabel.value();
   }
   DebugChannel() << " [" << GetSource(this) << "]\n";
 }
@@ -317,108 +316,59 @@ void SwitchOp::dump() const {
 
 void ActionOp::dump() const { DebugChannel() << '\t' << GetSource(v) << '\n'; }
 
+template<typename A> std::string dumpConstruct(const A &a) {
+  return std::visit(
+      common::visitors{
+          [](const parser::AssociateConstruct *c) {
+            return GetSource<parser::AssociateStmt>(c);
+          },
+          [](const parser::BlockConstruct *c) {
+            return GetSource<parser::BlockStmt>(c);
+          },
+          [](const parser::CaseConstruct *c) {
+            return GetSource<parser::SelectCaseStmt>(c);
+          },
+          [](const parser::ChangeTeamConstruct *c) {
+            return GetSource<parser::ChangeTeamStmt>(c);
+          },
+          [](const parser::CriticalConstruct *c) {
+            return GetSource<parser::CriticalStmt>(c);
+          },
+          [](const parser::DoConstruct *c) {
+            return GetSource<parser::NonLabelDoStmt>(c);
+          },
+          [](const parser::IfConstruct *c) {
+            return GetSource<parser::IfThenStmt>(c);
+          },
+          [](const parser::SelectRankConstruct *c) {
+            return GetSource<parser::SelectRankStmt>(c);
+          },
+          [](const parser::SelectTypeConstruct *c) {
+            return GetSource<parser::SelectTypeStmt>(c);
+          },
+          [](const parser::WhereConstruct *c) {
+            return GetSource<parser::WhereConstructStmt>(c);
+          },
+          [](const parser::ForallConstruct *c) {
+            return GetSource(
+                &std::get<parser::Statement<parser::ForallConstructStmt>>(
+                    c->t));
+          },
+          [](const parser::CompilerDirective *c) { return GetSource(c); },
+          [](const parser::OpenMPConstruct *) { return "openmp"s; },
+          [](const parser::OpenMPEndLoopDirective *) {
+            return "openmp end loop"s;
+          },
+      },
+      a);
+}
+
 void BeginOp::dump() const {
-  DebugChannel()
-      << "\t["
-      << std::visit(
-             common::visitors{
-                 [](const parser::AssociateConstruct *c) {
-                   return GetSource<parser::AssociateStmt>(c);
-                 },
-                 [](const parser::BlockConstruct *c) {
-                   return GetSource<parser::BlockStmt>(c);
-                 },
-                 [](const parser::CaseConstruct *c) {
-                   return GetSource<parser::SelectCaseStmt>(c);
-                 },
-                 [](const parser::ChangeTeamConstruct *c) {
-                   return GetSource<parser::ChangeTeamStmt>(c);
-                 },
-                 [](const parser::CriticalConstruct *c) {
-                   return GetSource<parser::CriticalStmt>(c);
-                 },
-                 [](const parser::DoConstruct *c) {
-                   return GetSource<parser::NonLabelDoStmt>(c);
-                 },
-                 [](const parser::IfConstruct *c) {
-                   return GetSource<parser::IfThenStmt>(c);
-                 },
-                 [](const parser::SelectRankConstruct *c) {
-                   return GetSource<parser::SelectRankStmt>(c);
-                 },
-                 [](const parser::SelectTypeConstruct *c) {
-                   return GetSource<parser::SelectTypeStmt>(c);
-                 },
-                 [](const parser::WhereConstruct *c) {
-                   return GetSource<parser::WhereConstructStmt>(c);
-                 },
-                 [](const parser::ForallConstruct *c) {
-                   return GetSource(
-                       &std::get<
-                           parser::Statement<parser::ForallConstructStmt>>(
-                           c->t));
-                 },
-                 [](const parser::CompilerDirective *c) {
-                   return GetSource(c);
-                 },
-                 [](const parser::OpenMPConstruct *) { return "openmp"s; },
-                 [](const parser::OpenMPEndLoopDirective *) {
-                   return "openmp end loop"s;
-                 },
-             },
-             u)
-      << "] :{\n";
+  DebugChannel() << "\t[" << dumpConstruct(u) << "] {\n";
 }
 
 void EndOp::dump() const {
-  DebugChannel() << "\t}: ["
-                 << std::visit(
-                        common::visitors{
-                            [](const parser::AssociateConstruct *c) {
-                              return GetSource<parser::EndAssociateStmt>(c);
-                            },
-                            [](const parser::BlockConstruct *c) {
-                              return GetSource<parser::EndBlockStmt>(c);
-                            },
-                            [](const parser::CaseConstruct *c) {
-                              return GetSource<parser::EndSelectStmt>(c);
-                            },
-                            [](const parser::ChangeTeamConstruct *c) {
-                              return GetSource<parser::EndChangeTeamStmt>(c);
-                            },
-                            [](const parser::CriticalConstruct *c) {
-                              return GetSource<parser::EndCriticalStmt>(c);
-                            },
-                            [](const parser::DoConstruct *c) {
-                              return GetSource<parser::EndDoStmt>(c);
-                            },
-                            [](const parser::IfConstruct *c) {
-                              return GetSource<parser::EndIfStmt>(c);
-                            },
-                            [](const parser::SelectRankConstruct *c) {
-                              return GetSource<parser::EndSelectStmt>(c);
-                            },
-                            [](const parser::SelectTypeConstruct *c) {
-                              return GetSource<parser::EndSelectStmt>(c);
-                            },
-                            [](const parser::WhereConstruct *c) {
-                              return GetSource<parser::EndWhereStmt>(c);
-                            },
-                            [](const parser::ForallConstruct *c) {
-                              return GetSource<parser::EndForallStmt>(c);
-                            },
-                            [](const parser::CompilerDirective *c) {
-                              return GetSource(c);
-                            },
-                            [](const parser::OpenMPConstruct *) {
-                              return "openmp"s;
-                            },
-                            [](const parser::OpenMPEndLoopDirective *) {
-                              return "openmp end loop"s;
-                            },
-                        },
-                        u)
-                 << "]\n";
+  DebugChannel() << "\t} [" << dumpConstruct(u) << "]\n";
 }
 
 void IndirectGotoOp::dump() const {
index e8c340b..40cc0aa 100644 (file)
@@ -296,7 +296,7 @@ public:
   static ApplyExprStmt Create(const Expression *e) { return ApplyExprStmt{*e}; }
   static ApplyExprStmt Create(Expression &&e) { return ApplyExprStmt{e}; }
 
-  const Expression &expression() const { return expression_; }
+  Expression expression() const { return expression_; }
 
 private:
   explicit ApplyExprStmt(const Expression &e) : expression_{e} {}
@@ -369,7 +369,7 @@ private:
 class AllocateLocalInsn : public Addressable_impl, public MemoryStmt_impl {
 public:
   static AllocateLocalInsn Create(
-      Type type, int alignment = 0, const Expression *expr = nullptr) {
+      Type type, int alignment = 0, Expression *expr = nullptr) {
     if (expr != nullptr) {
       return AllocateLocalInsn{type, alignment, *expr};
     }