[flang] add QualifiedStmt class
authorEric Schweitz <eschweitz@nvidia.com>
Fri, 5 Apr 2019 19:22:08 +0000 (12:22 -0700)
committerEric Schweitz <eschweitz@nvidia.com>
Tue, 9 Apr 2019 18:52:04 +0000 (11:52 -0700)
Original-commit: flang-compiler/f18@6bc660c355a97dae0d3ee3edac00cc820b67b4b4
Reviewed-on: https://github.com/flang-compiler/f18/pull/399
Tree-same-pre-rewrite: false

flang/lib/FIR/afforestation.cc
flang/lib/FIR/builder.h
flang/lib/FIR/flattened.cc
flang/lib/FIR/statements.cc
flang/lib/FIR/statements.h

index d002ab3..059ff83 100644 (file)
@@ -340,7 +340,7 @@ public:
         name, {}, std::move(details))};
     return {name, sym};
   }
-  Statement *CreateTemp(TypeRep &&spec) {
+  QualifiedStmt<Addressable_impl> CreateTemp(TypeRep &&spec) {
     TypeRep declSpec{std::move(spec)};
     auto temp{MakeTemp(&declSpec)};
     auto expr{ToExpression(temp)};
@@ -403,11 +403,11 @@ public:
     // TODO: build an expression for the allocation
     return nullptr;
   }
-  AllocateInsn *CreateDeallocationValue(
+  QualifiedStmt<AllocateInsn> CreateDeallocationValue(
       const parser::AllocateObject *allocateObject,
       const parser::DeallocateStmt *statement) {
     // TODO: build an expression for the deallocation
-    return nullptr;
+    return QualifiedStmt<AllocateInsn>{nullptr};
   }
 
   // IO argument translations ...
@@ -627,7 +627,7 @@ public:
   void handleIntrinsicAssignmentStmt(const parser::AssignmentStmt &stmt) {
     // TODO: check if allocation or reallocation should happen, etc.
     auto *value{builder_->CreateExpr(ExprRef(std::get<parser::Expr>(stmt.t)))};
-    auto *addr{
+    auto addr{
         builder_->CreateAddr(ToExpression(std::get<parser::Variable>(stmt.t)))};
     builder_->CreateStore(addr, value);
   }
@@ -770,11 +770,11 @@ public:
                 std::visit(
                     common::visitors{
                         [&](const parser::Name &n) {
-                          auto *s{builder_->CreateAddr(ToExpression(n))};
+                          auto s{builder_->CreateAddr(ToExpression(n))};
                           builder_->CreateNullify(s);
                         },
                         [&](const parser::StructureComponent &sc) {
-                          auto *s{builder_->CreateAddr(ToExpression(sc))};
+                          auto s{builder_->CreateAddr(ToExpression(sc))};
                           builder_->CreateNullify(s);
                         },
                     },
@@ -787,7 +787,7 @@ public:
             },
             [&](const common::Indirection<parser::PointerAssignmentStmt> &s) {
               auto *value{CreatePointerValue(s.value())};
-              auto *addr{builder_->CreateAddr(
+              auto addr{builder_->CreateAddr(
                   ExprRef(std::get<parser::Expr>(s.value().t)))};
               builder_->CreateStore(addr, value);
             },
@@ -850,7 +850,7 @@ public:
               WRONG_PATH();
             },
             [&](const common::Indirection<parser::AssignStmt> &s) {
-              auto *addr{builder_->CreateAddr(
+              auto addr{builder_->CreateAddr(
                   ToExpression(std::get<parser::Name>(s.value().t)))};
               auto *block{blockMap_
                               .find(flat::FetchLabel(
@@ -875,13 +875,14 @@ public:
 
   // DO loop handlers
   struct DoBoundsInfo {
-    Statement *doVariable;
-    Statement *counter;
+    QualifiedStmt<Addressable_impl> doVariable;
+    QualifiedStmt<Addressable_impl> counter;
     Statement *stepExpr;
     Statement *condition;
   };
   void PushDoContext(const parser::NonLabelDoStmt *doStmt,
-      Statement *doVar = nullptr, Statement *counter = nullptr,
+      QualifiedStmt<Addressable_impl> doVar = nullptr,
+      QualifiedStmt<Addressable_impl> counter = nullptr,
       Statement *stepExp = nullptr) {
     doMap_.emplace(doStmt, DoBoundsInfo{doVar, counter, stepExp});
   }
@@ -938,7 +939,7 @@ public:
               [](const parser::Expr &e) { return *ExprRef(e); },
           },
           selector.u))};
-      auto *name{
+      auto name{
           builder_->CreateAddr(ToExpression(std::get<parser::Name>(assoc.t)))};
       builder_->CreateStore(name, expr);
     }
@@ -969,7 +970,7 @@ public:
       std::visit(
           common::visitors{
               [&](const parser::LoopBounds<parser::ScalarIntExpr> &bounds) {
-                auto *name{builder_->CreateAddr(
+                auto name{builder_->CreateAddr(
                     ToExpression(bounds.name.thing.thing))};
                 // evaluate e1, e2 [, e3] ...
                 auto *e1{
@@ -984,7 +985,7 @@ public:
                 }
                 // name <- e1
                 builder_->CreateStore(name, e1);
-                auto *tripCounter{CreateTemp(GetDefaultIntegerType())};
+                auto tripCounter{CreateTemp(GetDefaultIntegerType())};
                 // See 11.1.7.4.1, para. 1, item (3)
                 // totalTrips ::= iteration count = a
                 //   where a = (e2 - e1 + e3) / e3 if a > 0 and 0 otherwise
@@ -1124,11 +1125,12 @@ public:
                         [&](const parser::ReturnStmt *s) {
                           // alt-return
                           if (s->v) {
-                            auto *app{builder_->CreateExpr(
-                                ExprRef(s->v->thing.thing))};
+                            auto *exp{ExprRef(s->v->thing.thing)};
+                            auto app{builder_->QualifiedCreateExpr(exp)};
                             builder_->CreateReturn(app);
                           } else {
-                            auto *zero{builder_->CreateExpr(CreateConstant(0))};
+                            auto zero{builder_->QualifiedCreateExpr(
+                                CreateConstant(0))};
                             builder_->CreateReturn(zero);
                           }
                         },
index ad470a3..0a4b41c 100644 (file)
 
 namespace Fortran::FIR {
 
+/// Helper class for building FIR statements
 struct FIRBuilder {
   explicit FIRBuilder(BasicBlock &block)
     : cursorRegion_{block.getParent()}, cursorBlock_{&block} {}
+
   template<typename A> Statement *Insert(A &&s) {
     CHECK(GetInsertionPoint());
-    auto *statement{new Statement(GetInsertionPoint(), s)};
+    auto *statement{new Statement(GetInsertionPoint(), std::forward<A>(s))};
     return statement;
   }
+
+  template<typename A, typename B> QualifiedStmt<A> QualifiedInsert(B &&s) {
+    CHECK(GetInsertionPoint());
+    auto *statement{new Statement(GetInsertionPoint(), std::forward<B>(s))};
+    return QualifiedStmt<A>{statement, s};
+  }
+
   template<typename A> Statement *InsertTerminator(A &&s) {
-    auto *stmt{Insert(s)};
+    auto *stmt{Insert(std::forward<A>(s))};
     for (auto *block : s.succ_blocks()) {
       block->addPred(GetInsertionPoint());
     }
     return stmt;
   }
+
+  // manage the insertion point
   void SetInsertionPoint(BasicBlock *bb) {
     cursorBlock_ = bb;
     cursorRegion_ = bb->getParent();
@@ -44,14 +55,16 @@ struct FIRBuilder {
 
   BasicBlock *GetInsertionPoint() const { return cursorBlock_; }
 
-  Statement *CreateAddr(const Expression *e) {
-    return Insert(LocateExprStmt::Create(e));
+  // create the various statements
+  QualifiedStmt<Addressable_impl> CreateAddr(const Expression *e) {
+    return QualifiedInsert<Addressable_impl>(LocateExprStmt::Create(e));
   }
-  Statement *CreateAddr(Expression &&e) {
-    return Insert(LocateExprStmt::Create(std::move(e)));
+  QualifiedStmt<Addressable_impl> CreateAddr(Expression &&e) {
+    return QualifiedInsert<Addressable_impl>(
+        LocateExprStmt::Create(std::move(e)));
   }
-  Statement *CreateAlloc(Type type) {
-    return Insert(AllocateInsn::Create(type));
+  QualifiedStmt<AllocateInsn> CreateAlloc(Type type) {
+    return QualifiedInsert<AllocateInsn>(AllocateInsn::Create(type));
   }
   Statement *CreateBranch(BasicBlock *block) {
     return InsertTerminator(BranchStmt::Create(block));
@@ -64,7 +77,7 @@ struct FIRBuilder {
       Statement *cond, BasicBlock *trueBlock, BasicBlock *falseBlock) {
     return InsertTerminator(BranchStmt::Create(cond, trueBlock, falseBlock));
   }
-  Statement *CreateDealloc(AllocateInsn *alloc) {
+  Statement *CreateDealloc(QualifiedStmt<AllocateInsn> alloc) {
     return Insert(DeallocateInsn::Create(alloc));
   }
   Statement *CreateExpr(const Expression *e) {
@@ -76,6 +89,12 @@ struct FIRBuilder {
   ApplyExprStmt *MakeAsExpr(const Expression *e) {
     return GetApplyExpr(CreateExpr(e));
   }
+  QualifiedStmt<ApplyExprStmt> QualifiedCreateExpr(const Expression *e) {
+    return QualifiedInsert<ApplyExprStmt>(ApplyExprStmt::Create(e));
+  }
+  QualifiedStmt<ApplyExprStmt> QualifiedCreateExpr(Expression &&e) {
+    return QualifiedInsert<ApplyExprStmt>(ApplyExprStmt::Create(std::move(e)));
+  }
   Statement *CreateIndirectBr(Variable *v, const std::vector<BasicBlock *> &p) {
     return InsertTerminator(IndirectBranchStmt::Create(v, p));
   }
@@ -85,23 +104,27 @@ struct FIRBuilder {
   Statement *CreateLoad(Statement *addr) {
     return Insert(LoadInsn::Create(addr));
   }
-  Statement *CreateLocal(Type type, const Expression &expr, int alignment = 0) {
-    return Insert(AllocateLocalInsn::Create(type, expr, alignment));
+  QualifiedStmt<Addressable_impl> CreateLocal(
+      Type type, const Expression &expr, int alignment = 0) {
+    return QualifiedInsert<Addressable_impl>(
+        AllocateLocalInsn::Create(type, expr, alignment));
   }
   Statement *CreateNullify(Statement *s) {
     return Insert(DisassociateInsn::Create(s));
   }
-  Statement *CreateReturn(Statement *expr) {
+  Statement *CreateReturn(QualifiedStmt<ApplyExprStmt> expr) {
     return InsertTerminator(ReturnStmt::Create(expr));
   }
   Statement *CreateRuntimeCall(
       RuntimeCallType call, RuntimeCallArguments &&arguments) {
     return Insert(RuntimeStmt::Create(call, std::move(arguments)));
   }
-  Statement *CreateStore(Statement *addr, Statement *value) {
+  Statement *CreateStore(
+      QualifiedStmt<Addressable_impl> addr, Statement *value) {
     return Insert(StoreInsn::Create(addr, value));
   }
-  Statement *CreateStore(Statement *addr, BasicBlock *value) {
+  Statement *CreateStore(
+      QualifiedStmt<Addressable_impl> addr, BasicBlock *value) {
     return Insert(StoreInsn::Create(addr, value));
   }
   Statement *CreateSwitch(
index 1ad0f69..3c1ae7b 100644 (file)
@@ -272,10 +272,11 @@ void LabelOp::dump() const { DebugChannel() << "label_" << get() << ":\n"; }
 
 void GotoOp::dump() const {
   DebugChannel() << "\tgoto %label_" << target << " ["
-                 << std::visit(common::visitors{
-                                   [](ArtificialJump) { return ""s; },
-                                   [&](auto *) { return GetSource(this); },
-                               },
+                 << std::visit(
+                        common::visitors{
+                            [](ArtificialJump) { return ""s; },
+                            [&](auto *) { return GetSource(this); },
+                        },
                         u)
                  << "]\n";
 }
index b3349e6..72ca2aa 100644 (file)
@@ -54,9 +54,8 @@ static std::list<BasicBlock *> SuccBlocks(
   return result.second;
 }
 
-ReturnStmt::ReturnStmt(Statement *exp) : value_{GetApplyExpr(exp)} {
-  CHECK(value_);
-}
+ReturnStmt::ReturnStmt(QualifiedStmt<ApplyExprStmt> exp) : value_{exp} {}
+ReturnStmt::ReturnStmt() : value_{QualifiedStmt<ApplyExprStmt>{nullptr}} {}
 
 SwitchStmt::SwitchStmt(const Value &cond, const ValueSuccPairListType &args)
   : condition_{cond} {
@@ -126,19 +125,25 @@ LoadInsn::LoadInsn(Statement *addr) : address_{addr} {
   CHECK(GetAddressable(addr));
 }
 
-StoreInsn::StoreInsn(Statement *addr, Statement *val)
-  : address_{GetAddressable(addr)} {
+// Store ctors
+StoreInsn::StoreInsn(QualifiedStmt<Addressable_impl> addr, BasicBlock *val)
+  : address_{addr}, value_{val} {
   CHECK(address_);
-  if (auto *value{GetAddressable(val)}) {
-    value_ = value;
-  } else {
-    auto *expr{GetApplyExpr(val)};
-    CHECK(expr);
-    value_ = expr;
-  }
+  CHECK(val);
+}
+StoreInsn::StoreInsn(QualifiedStmt<Addressable_impl> addr, Value val)
+  : address_{addr}, value_{val} {
+  CHECK(address_);
+}
+StoreInsn::StoreInsn(
+    QualifiedStmt<Addressable_impl> addr, QualifiedStmt<ApplyExprStmt> val)
+  : address_{addr}, value_{val} {
+  CHECK(address_);
+  CHECK(val);
 }
-StoreInsn::StoreInsn(Statement *addr, BasicBlock *val)
-  : address_{GetAddressable(addr)}, value_{val} {
+StoreInsn::StoreInsn(
+    QualifiedStmt<Addressable_impl> addr, QualifiedStmt<Addressable_impl> val)
+  : address_{addr}, value_{val} {
   CHECK(address_);
   CHECK(val);
 }
index eab6a55..21917c0 100644 (file)
@@ -54,6 +54,23 @@ public:
   using StatementTrait = std::true_type;
 };
 
+// Some uses of a Statement should be constrained.  These contraints are imposed
+// at compile time.
+template<typename A = Stmt_impl> class QualifiedStmt {
+public:
+  QualifiedStmt() = delete;
+  template<typename B, std::enable_if_t<std::is_base_of_v<A, B>, int> = 0>
+  QualifiedStmt(Statement *stmt, const B &) : stmt{stmt} {}
+
+  // create a stub, where stmt == nullptr
+  QualifiedStmt(std::nullptr_t) : stmt{nullptr} {}
+  operator Statement *() const { return stmt; }
+  operator bool() const { return stmt; }
+  operator A *() const;
+
+  Statement *stmt;
+};
+
 // Every basic block must end in a terminator
 class TerminatorStmt_impl : virtual public Stmt_impl {
 public:
@@ -65,15 +82,18 @@ public:
 // Transfer control out of the current procedure
 class ReturnStmt : public TerminatorStmt_impl {
 public:
-  static ReturnStmt Create(Statement *stmt) { return ReturnStmt{stmt}; }
-  static ReturnStmt Create() { return ReturnStmt{nullptr}; }
+  static ReturnStmt Create(QualifiedStmt<ApplyExprStmt> stmt) {
+    return ReturnStmt{stmt};
+  }
+  static ReturnStmt Create() { return ReturnStmt{}; }
   std::list<BasicBlock *> succ_blocks() const override { return {}; }
   bool has_value() const { return value_; }
-  Statement *value() const;
+  Statement *value() const { return value_; }
 
 private:
-  ApplyExprStmt *value_;
-  explicit ReturnStmt(Statement *exp);
+  QualifiedStmt<ApplyExprStmt> value_;
+  explicit ReturnStmt(QualifiedStmt<ApplyExprStmt> exp);
+  explicit ReturnStmt();
 };
 
 // Encodes two-way conditional branch and one-way absolute branch
@@ -265,12 +285,15 @@ protected:
 class ApplyExprStmt : public ActionStmt_impl {
 public:
   static ApplyExprStmt Create(const Expression *e) { return ApplyExprStmt{*e}; }
-  static ApplyExprStmt Create(Expression &&e) { return ApplyExprStmt{e}; }
+  static ApplyExprStmt Create(Expression &&e) {
+    return ApplyExprStmt{std::move(e)};
+  }
 
   Expression expression() const { return expression_; }
 
 private:
   explicit ApplyExprStmt(const Expression &e) : expression_{e} {}
+  explicit ApplyExprStmt(Expression &&e) : expression_{e} {}
 
   Expression expression_;
 };
@@ -327,15 +350,15 @@ private:
 // Deallocate storage (per DEALLOCATE)
 class DeallocateInsn : public MemoryStmt_impl {
 public:
-  static DeallocateInsn Create(AllocateInsn *alloc) {
+  static DeallocateInsn Create(QualifiedStmt<AllocateInsn> alloc) {
     return DeallocateInsn{alloc};
   }
 
-  Statement *alloc() const;
+  Statement *alloc() const { return alloc_; }
 
 private:
-  explicit DeallocateInsn(AllocateInsn *alloc) : alloc_{alloc} {}
-  AllocateInsn *alloc_;
+  explicit DeallocateInsn(QualifiedStmt<AllocateInsn> alloc) : alloc_{alloc} {}
+  QualifiedStmt<AllocateInsn> alloc_;
 };
 
 // Allocate space for a temporary by its Type. The lifetime of the temporary
@@ -378,12 +401,11 @@ private:
 // Store value(s) from an applied expression to a location
 class StoreInsn : public MemoryStmt_impl {
 public:
-  using ValueType =
-      std::variant<Value, ApplyExprStmt *, Addressable_impl *, BasicBlock *>;
-  template<typename T> static StoreInsn Create(T *addr, T *value) {
-    return StoreInsn{addr, value};
-  }
-  template<typename T> static StoreInsn Create(T *addr, BasicBlock *value) {
+  using ValueType = std::variant<Value, QualifiedStmt<ApplyExprStmt>,
+      QualifiedStmt<Addressable_impl>, BasicBlock *>;
+
+  template<typename A>
+  static StoreInsn Create(QualifiedStmt<Addressable_impl> addr, A value) {
     return StoreInsn{addr, value};
   }
 
@@ -391,12 +413,14 @@ public:
   ValueType value() const { return value_; }
 
 private:
-  explicit StoreInsn(Value addr, Value val);
-  explicit StoreInsn(Value addr, BasicBlock *val);
-  explicit StoreInsn(Statement *addr, Statement *val);
-  explicit StoreInsn(Statement *addr, BasicBlock *val);
-
-  Addressable_impl *address_;
+  explicit StoreInsn(QualifiedStmt<Addressable_impl> addr, Value val);
+  explicit StoreInsn(
+      QualifiedStmt<Addressable_impl> addr, QualifiedStmt<ApplyExprStmt> val);
+  explicit StoreInsn(QualifiedStmt<Addressable_impl> addr,
+      QualifiedStmt<Addressable_impl> val);
+  explicit StoreInsn(QualifiedStmt<Addressable_impl> addr, BasicBlock *val);
+
+  QualifiedStmt<Addressable_impl> address_;
   ValueType value_;
 };
 
@@ -557,21 +581,12 @@ public:
     parent->insertBefore(this);
   }
   std::string dump() const;
-
-  // g++/clang++ will optimize this to a simple register copy
-  // Every Stmt_impl is wrapped in and the first data member of a Statement;
-  // therefore, a pointer to one or the other is bitwise identical.
-  // This checks that this assumption is, in fact, true.
-  static Statement *From(Stmt_impl *stmt) {
-    static Statement s{nullptr, UnreachableStmt::Create()};
-    auto *result{reinterpret_cast<Statement *>(reinterpret_cast<char *>(stmt) -
-        (reinterpret_cast<char *>(&s.u) - reinterpret_cast<char *>(&s)))};
-    CHECK(result == reinterpret_cast<Statement *>(stmt) &&
-        "expecting pointers to be equal");
-    return result;
-  }
 };
 
+template<typename A> inline QualifiedStmt<A>::operator A *() const {
+  return reinterpret_cast<A *>(&stmt->u);
+}
+
 inline std::list<BasicBlock *> succ_list(BasicBlock &block) {
   if (auto *terminator{block.terminator()}) {
     return reinterpret_cast<const TerminatorStmt_impl *>(&terminator->u)
@@ -581,11 +596,6 @@ inline std::list<BasicBlock *> succ_list(BasicBlock &block) {
   return {};
 }
 
-inline Statement *ReturnStmt::value() const { return Statement::From(value_); }
-inline Statement *DeallocateInsn::alloc() const {
-  return Statement::From(alloc_);
-}
-
 inline ApplyExprStmt *GetApplyExpr(Statement *stmt) {
   return std::get_if<ApplyExprStmt>(&stmt->u);
 }