[flang] remove the DO stub instructions
authorEric Schweitz <eschweitz@nvidia.com>
Thu, 21 Mar 2019 21:40:10 +0000 (14:40 -0700)
committerEric <eschweitz@nvidia.com>
Sat, 23 Mar 2019 18:14:20 +0000 (11:14 -0700)
Original-commit: flang-compiler/f18@f1ddcd8d76cce67aa988da8eb356c8d8d93c79a8
Reviewed-on: https://github.com/flang-compiler/f18/pull/354
Tree-same-pre-rewrite: false

flang/lib/FIR/afforestation.cc
flang/lib/FIR/builder.h
flang/lib/FIR/statements.cc
flang/lib/FIR/statements.h

index 115f8d3..864d89c 100644 (file)
@@ -240,6 +240,9 @@ static void CreateSwitchTypeHelper(FIRBuilder *builder, Value condition,
 static Expression getApplyExpr(Statement *s) {
   return GetApplyExpr(s)->expression();
 }
+static Expression getLocalVariable(Statement *s) {
+  return GetLocal(s)->variable();
+}
 
 class FortranIRLowering {
 public:
@@ -295,6 +298,13 @@ public:
         context.GetDefaultKind(common::TypeCategory::Real))
         .value();
   }
+  Expression ConsExpr(
+      common::RelationalOperator op, Expression e1, Expression e2) {
+    evaluate::ExpressionAnalyzer context{semanticsContext_};
+    return evaluate::AsGenericExpr(evaluate::Relate(
+        context.GetContextualMessages(), op, std::move(e1), std::move(e2))
+                                       .value());
+  }
 
   template<typename T>
   void ProcessRoutine(const T &here, const std::string &name) {
@@ -824,14 +834,13 @@ public:
   // DO loop handlers
   struct DoBoundsInfo {
     Statement *doVariable;
-    Statement *lowerBound;
-    Statement *upperBound;
+    Statement *counter;
     Statement *stepExpr;
     Statement *condition;
   };
   void PushDoContext(const parser::NonLabelDoStmt *doStmt, Statement *doVar,
-      Statement *lowBound, Statement *upBound, Statement *stepExp) {
-    doMap_.emplace(doStmt, DoBoundsInfo{doVar, lowBound, upBound, stepExp});
+      Statement *counter, Statement *stepExp) {
+    doMap_.emplace(doStmt, DoBoundsInfo{doVar, counter, stepExp});
   }
   void PopDoContext(const parser::NonLabelDoStmt *doStmt) {
     doMap_.erase(doStmt);
@@ -847,19 +856,24 @@ public:
     return nullptr;
   }
 
-  // do_var = do_var + e3
+  // evaluate: do_var = do_var + e3; counter--
   void handleLinearDoIncrement(const flat::DoIncrementOp &inc) {
     auto *info{GetBoundsInfo(inc)};
-    auto *var{builder_->CreateLoad(info->doVariable)};
-    builder_->CreateIncrement(var, info->stepExpr);
+    auto *incremented{builder_->CreateExpr(std::move(
+        ConsExpr<evaluate::Add>(GetAddressable(info->doVariable)->address(),
+            GetApplyExpr(info->stepExpr)->expression())))};
+    builder_->CreateStore(info->doVariable, incremented);
+    auto *decremented{builder_->CreateExpr(ConsExpr<evaluate::Subtract>(
+        GetAddressable(info->counter)->address(), CreateConstant(1)))};
+    builder_->CreateStore(info->counter, decremented);
   }
 
-  // (e3 > 0 && do_var <= e2) || (e3 < 0 && do_var >= e2)
+  // is (counter > 0)?
   void handleLinearDoCompare(const flat::DoCompareOp &cmp) {
     auto *info{GetBoundsInfo(cmp)};
-    auto *var{builder_->CreateLoad(info->doVariable)};
-    auto *cond{
-        builder_->CreateDoCondition(info->stepExpr, var, info->upperBound)};
+    Expression compare{ConsExpr(common::RelationalOperator::GT,
+        getLocalVariable(info->counter), CreateConstant(0))};
+    auto *cond{builder_->CreateExpr(&compare)};
     info->condition = cond;
   }
 
@@ -930,7 +944,7 @@ public:
                     getApplyExpr(e3))};
                 auto *totalTrips{builder_->CreateExpr(&tripExpr)};
                 builder_->CreateStore(tripCounter, totalTrips);
-                PushDoContext(stmt, name, nullptr, tripCounter, e3);
+                PushDoContext(stmt, name, tripCounter, e3);
               },
               [&](const parser::ScalarLogicalExpr &whileExpr) {
                 // FIXME
index 7a65c88..2362e8f 100644 (file)
@@ -67,9 +67,6 @@ struct FIRBuilder {
   Statement *CreateDealloc(AllocateInsn *alloc) {
     return Insert(DeallocateInsn::Create(alloc));
   }
-  Statement *CreateDoCondition(Statement *dir, Statement *v1, Statement *v2) {
-    return Insert(DoConditionStmt::Create(dir, v1, v2));
-  }
   Statement *CreateExpr(const Expression *e) {
     return Insert(ApplyExprStmt::Create(e));
   }
@@ -79,9 +76,6 @@ struct FIRBuilder {
   ApplyExprStmt *MakeAsExpr(const Expression *e) {
     return GetApplyExpr(CreateExpr(e));
   }
-  Statement *CreateIncrement(Statement *v1, Statement *v2) {
-    return Insert(IncrementStmt::Create(v1, v2));
-  }
   Statement *CreateIndirectBr(Variable *v, const std::vector<BasicBlock *> &p) {
     return InsertTerminator(IndirectBranchStmt::Create(v, p));
   }
index 39f8c5f..a56217b 100644 (file)
@@ -141,11 +141,6 @@ StoreInsn::StoreInsn(Statement *addr, BasicBlock *val)
   CHECK(val);
 }
 
-IncrementStmt::IncrementStmt(Value v1, Value v2) : value_{v1, v2} {}
-
-DoConditionStmt::DoConditionStmt(Value dir, Value v1, Value v2)
-  : value_{dir, v1, v2} {}
-
 std::string Statement::dump() const {
   return std::visit(
       common::visitors{
@@ -182,8 +177,6 @@ std::string Statement::dump() const {
           },
           [](const IndirectBranchStmt &) { return "ibranch"s; },
           [](const UnreachableStmt &) { return "unreachable"s; },
-          [](const IncrementStmt &) { return "increment"s; },
-          [](const DoConditionStmt &) { return "compare"s; },
           [](const ApplyExprStmt &e) { return FIR::dump(e.expression()); },
           [](const LocateExprStmt &e) {
             return "&" + FIR::dump(e.expression());
index 40cc0aa..d130a43 100644 (file)
@@ -28,8 +28,6 @@ class SwitchTypeStmt;
 class SwitchRankStmt;
 class IndirectBranchStmt;
 class UnreachableStmt;
-class IncrementStmt;
-class DoConditionStmt;
 class ApplyExprStmt;
 class LocateExprStmt;
 class AllocateInsn;
@@ -263,33 +261,6 @@ protected:
   std::optional<evaluate::DynamicType> type;
 };
 
-class IncrementStmt : public ActionStmt_impl {
-public:
-  static IncrementStmt Create(Value v1, Value v2) {
-    return IncrementStmt(v1, v2);
-  }
-  Value leftValue() const { return value_[0]; }
-  Value rightValue() const { return value_[1]; }
-
-private:
-  explicit IncrementStmt(Value v1, Value v2);
-  Value value_[2];
-};
-
-class DoConditionStmt : public ActionStmt_impl {
-public:
-  static DoConditionStmt Create(Value dir, Value left, Value right) {
-    return DoConditionStmt(dir, left, right);
-  }
-  Value direction() const { return value_[0]; }
-  Value leftValue() const { return value_[1]; }
-  Value rightValue() const { return value_[2]; }
-
-private:
-  explicit DoConditionStmt(Value dir, Value left, Value right);
-  Value value_[3];
-};
-
 // Compute the value of an expression
 class ApplyExprStmt : public ActionStmt_impl {
 public:
@@ -306,6 +277,9 @@ private:
 
 // Base class of all addressable statements
 class Addressable_impl : public ActionStmt_impl {
+public:
+  Expression address() const { return addrExpr_.value(); }
+
 protected:
   Addressable_impl() : addrExpr_{std::nullopt} {}
   explicit Addressable_impl(const Expression &ae) : addrExpr_{ae} {}
@@ -320,7 +294,7 @@ public:
   }
   static LocateExprStmt Create(Expression &&e) { return LocateExprStmt(e); }
 
-  const Expression &expression() const { return *addrExpr_; }
+  const Expression &expression() const { return addrExpr_.value(); }
 
 private:
   explicit LocateExprStmt(const Expression &e) : Addressable_impl{e} {}
@@ -378,6 +352,7 @@ public:
 
   Type type() const { return type_; }
   int alignment() const { return alignment_; }
+  Expression variable() { return addrExpr_.value(); }
 
 private:
   explicit AllocateLocalInsn(Type type, int alignment, const Expression &expr)
@@ -555,8 +530,6 @@ class Statement : public SumTypeMixin<ReturnStmt,  //
                       SwitchRankStmt,  //
                       IndirectBranchStmt,  //
                       UnreachableStmt,  //
-                      IncrementStmt,  //
-                      DoConditionStmt,  //
                       ApplyExprStmt,  //
                       LocateExprStmt,  //
                       AllocateInsn,  //
@@ -608,6 +581,10 @@ inline ApplyExprStmt *GetApplyExpr(Statement *stmt) {
   return std::get_if<ApplyExprStmt>(&stmt->u);
 }
 
+inline AllocateLocalInsn *GetLocal(Statement *stmt) {
+  return std::get_if<AllocateLocalInsn>(&stmt->u);
+}
+
 Addressable_impl *GetAddressable(Statement *stmt);
 }