[flang] temporary symbol creation and improve dump routine
authorEric Schweitz <eschweitz@nvidia.com>
Fri, 22 Mar 2019 15:21:01 +0000 (08:21 -0700)
committerEric Schweitz <eschweitz@nvidia.com>
Mon, 25 Mar 2019 16:29:01 +0000 (09:29 -0700)
check for labels on construct delimiting statements
ifdef out code to avoid warning

Original-commit: flang-compiler/f18@667674aaa08a90f65dd9f0023a157c337f749892
Reviewed-on: https://github.com/flang-compiler/f18/pull/354
Tree-same-pre-rewrite: false

flang/lib/FIR/afforestation.cc
flang/lib/FIR/builder.h
flang/lib/FIR/common.h
flang/lib/FIR/flattened.cc
flang/lib/FIR/graph-writer.cc
flang/lib/FIR/region.h
flang/lib/FIR/statements.cc
flang/lib/FIR/statements.h
flang/lib/FIR/value.h

index 864d89c..bc66211 100644 (file)
@@ -159,13 +159,12 @@ static std::vector<SwitchTypeStmt::ValueType> populateSwitchValues(
         std::get<parser::Statement<parser::TypeGuardStmt>>(v.t).statement.t)};
     std::visit(
         common::visitors{
-            [&](const parser::TypeSpec &typeSpec) {
-              result.emplace_back(
-                  SwitchTypeStmt::TypeSpec{typeSpec.declTypeSpec});
+            [&](const parser::TypeSpec &spec) {
+              result.emplace_back(SwitchTypeStmt::TypeSpec{spec.declTypeSpec});
             },
-            [&](const parser::DerivedTypeSpec &derivedTypeSpec) {
+            [&](const parser::DerivedTypeSpec &spec) {
               result.emplace_back(
-                  SwitchTypeStmt::DerivedTypeSpec{nullptr /*FIXME*/});
+                  SwitchTypeStmt::DerivedTypeSpec{nullptr /* FIXME */});
             },
             [&](const parser::Default &) {
               result.emplace_back(SwitchTypeStmt::Default{});
@@ -244,6 +243,33 @@ static Expression getLocalVariable(Statement *s) {
   return GetLocal(s)->variable();
 }
 
+// create a new temporary name (as heap garbage)
+static parser::CharBlock NewTemporaryName() {
+  constexpr int SizeMagicValue{32};
+  static int counter;
+  char cache[SizeMagicValue];
+  int bytesWritten{snprintf(cache, SizeMagicValue, ".t%d", counter++)};
+  CHECK(bytesWritten < SizeMagicValue);
+  auto len{strlen(cache)};
+  char *name{new char[len]};  // XXX: add these to a pool?
+  memcpy(name, cache, len);
+  return {name, name + len};
+}
+
+static TypeRep GetDefaultIntegerType() {
+  return {semantics::NumericTypeSpec{evaluate::SubscriptInteger::category,
+      AsExpr(evaluate::Constant<evaluate::SubscriptInteger>{
+          evaluate::SubscriptInteger::kind})}};
+}
+
+#if 0
+static TypeRep GetDefaultLogicalType() {
+  return {semantics::LogicalTypeSpec{
+      AsExpr(evaluate::Constant<evaluate::SubscriptInteger>{
+          evaluate::LogicalResult::kind})}};
+}
+#endif
+
 class FortranIRLowering {
 public:
   using LabelMapType = std::map<flat::LabelRef, BasicBlock *>;
@@ -305,6 +331,21 @@ public:
         context.GetContextualMessages(), op, std::move(e1), std::move(e2))
                                        .value());
   }
+  parser::Name MakeTemp(Type tempType) {
+    auto name{NewTemporaryName()};
+    auto details{semantics::ObjectEntityDetails{true}};
+    details.set_type(std::move(*tempType));
+    auto *sym{&semanticsContext_.globalScope().MakeSymbol(
+        name, {}, std::move(details))};
+    return {name, sym};
+  }
+  Statement *CreateTemp(TypeRep &&spec) {
+    TypeRep declSpec{std::move(spec)};
+    auto temp{MakeTemp(&declSpec)};
+    auto expr{ToExpression(temp)};
+    auto *localType{temp.symbol->get<semantics::ObjectEntityDetails>().type()};
+    return builder_->CreateLocal(localType, expr);
+  }
 
   template<typename T>
   void ProcessRoutine(const T &here, const std::string &name) {
@@ -838,8 +879,9 @@ public:
     Statement *stepExpr;
     Statement *condition;
   };
-  void PushDoContext(const parser::NonLabelDoStmt *doStmt, Statement *doVar,
-      Statement *counter, Statement *stepExp) {
+  void PushDoContext(const parser::NonLabelDoStmt *doStmt,
+      Statement *doVar = nullptr, Statement *counter = nullptr,
+      Statement *stepExp = nullptr) {
     doMap_.emplace(doStmt, DoBoundsInfo{doVar, counter, stepExp});
   }
   void PopDoContext(const parser::NonLabelDoStmt *doStmt) {
@@ -856,25 +898,33 @@ public:
     return nullptr;
   }
 
-  // evaluate: do_var = do_var + e3; counter--
   void handleLinearDoIncrement(const flat::DoIncrementOp &inc) {
     auto *info{GetBoundsInfo(inc)};
-    auto *incremented{builder_->CreateExpr(std::move(
-        ConsExpr<evaluate::Add>(GetAddressable(info->doVariable)->address(),
-            GetApplyExpr(info->stepExpr)->expression())))};
-    builder_->CreateStore(info->doVariable, incremented);
-    auto *decremented{builder_->CreateExpr(ConsExpr<evaluate::Subtract>(
-        GetAddressable(info->counter)->address(), CreateConstant(1)))};
-    builder_->CreateStore(info->counter, decremented);
+    if (info->doVariable) {
+      if (info->stepExpr) {
+        // evaluate: do_var = do_var + e3; counter--
+        auto *incremented{builder_->CreateExpr(
+            ConsExpr<evaluate::Add>(GetAddressable(info->doVariable)->address(),
+                GetApplyExpr(info->stepExpr)->expression()))};
+        builder_->CreateStore(info->doVariable, incremented);
+        auto *decremented{builder_->CreateExpr(ConsExpr<evaluate::Subtract>(
+            GetAddressable(info->counter)->address(), CreateConstant(1)))};
+        builder_->CreateStore(info->counter, decremented);
+      }
+    }
   }
 
   // is (counter > 0)?
   void handleLinearDoCompare(const flat::DoCompareOp &cmp) {
     auto *info{GetBoundsInfo(cmp)};
-    Expression compare{ConsExpr(common::RelationalOperator::GT,
-        getLocalVariable(info->counter), CreateConstant(0))};
-    auto *cond{builder_->CreateExpr(&compare)};
-    info->condition = cond;
+    if (info->doVariable) {
+      if (info->stepExpr) {
+        Expression compare{ConsExpr(common::RelationalOperator::GT,
+            getLocalVariable(info->counter), CreateConstant(0))};
+        auto *cond{builder_->CreateExpr(&compare)};
+        info->condition = cond;
+      }
+    }
   }
 
   // InitiateConstruct - many constructs require some initial setup
@@ -933,7 +983,8 @@ public:
                 }
                 // name <- e1
                 builder_->CreateStore(name, e1);
-                auto *tripCounter{builder_->CreateLocal(nullptr)};
+                auto *tripCounter{CreateTemp(GetDefaultIntegerType())};
+                // See 11.1.7.4.1, para. 1, item (3)
                 // totalTrips ::= iteration count = a
                 //   where a = (e2 - e1 + e3) / e3 if a > 0 and 0 otherwise
                 Expression tripExpr{ConsExpr<evaluate::Divide>(
@@ -946,16 +997,20 @@ public:
                 builder_->CreateStore(tripCounter, totalTrips);
                 PushDoContext(stmt, name, tripCounter, e3);
               },
-              [&](const parser::ScalarLogicalExpr &whileExpr) {
-                // FIXME
+              [&](const parser::ScalarLogicalExpr &expr) {
+                // See 11.1.7.4.1, para. 2
+                // See BuildLoopLatchExpression()
+                PushDoContext(stmt);
               },
               [&](const parser::LoopControl::Concurrent &cc) {
+                // See 11.1.7.4.2
                 // FIXME
               },
           },
           ctrl->u);
     } else {
-      // loop forever
+      // loop forever (See 11.1.7.4.1, para. 2)
+      PushDoContext(stmt);
     }
   }
 
@@ -1299,7 +1354,7 @@ public:
           [](FIRBuilder *builder, BasicBlock *block, flat::LabelRef dest,
               const LabelMapType &map) {
             builder->SetInsertionPoint(block);
-            CHECK(map.find(dest) != map.end());
+            CHECK(map.find(dest) != map.end() && "no destination");
             builder->CreateBranch(map.find(dest)->second);
           },
           builder_, builder_->GetInsertionPoint(), dest, _1));
index 2362e8f..ad470a3 100644 (file)
@@ -85,9 +85,8 @@ struct FIRBuilder {
   Statement *CreateLoad(Statement *addr) {
     return Insert(LoadInsn::Create(addr));
   }
-  Statement *CreateLocal(
-      Type type, int alignment = 0, Expression *expr = nullptr) {
-    return Insert(AllocateLocalInsn::Create(type, alignment, expr));
+  Statement *CreateLocal(Type type, const Expression &expr, int alignment = 0) {
+    return Insert(AllocateLocalInsn::Create(type, expr, alignment));
   }
   Statement *CreateNullify(Statement *s) {
     return Insert(DisassociateInsn::Create(s));
index 332e209..91c9c2b 100644 (file)
@@ -73,7 +73,8 @@ using PathVariable = const parser::Variable;
 using Scope = const semantics::Scope;
 using PHIPair = std::pair<Value, BasicBlock *>;
 using CallArguments = std::vector<Expression>;
-using Type = const semantics::DeclTypeSpec *;  // FIXME
+using TypeRep = semantics::DeclTypeSpec;  // FIXME
+using Type = const TypeRep *;
 
 enum InputOutputCallType {
   InputOutputCallBackspace = 11,
index 8e03d88..1ad0f69 100644 (file)
@@ -540,6 +540,12 @@ struct ControlFlowAnalyzer {
     }
     return true;
   }
+  template<typename A>
+  void appendIfLabeled(const parser::Statement<A> &stmt, std::list<Op> &ops) {
+    if (stmt.label) {
+      ops.emplace_back(findLabel(*stmt.label));
+    }
+  }
 
   // named constructs
   template<typename A> bool linearConstruct(const A &construct) {
@@ -547,10 +553,12 @@ struct ControlFlowAnalyzer {
     LabelOp label{buildNewLabel()};
     const parser::Name *name{getName(construct)};
     ad.nameStack.emplace_back(name, GetLabelRef(label), unspecifiedLabel);
+    appendIfLabeled(std::get<0>(construct.t), ops);
     ops.emplace_back(BeginOp{construct});
     ControlFlowAnalyzer cfa{ops, ad};
     Walk(std::get<parser::Block>(construct.t), cfa);
     ops.emplace_back(label);
+    appendIfLabeled(std::get<2>(construct.t), ops);
     ops.emplace_back(EndOp{construct});
     linearOps.splice(linearOps.end(), ops);
     ad.nameStack.pop_back();
@@ -569,9 +577,13 @@ struct ControlFlowAnalyzer {
             .statement.v};
     const parser::Name *name{optName ? &*optName : nullptr};
     ad.nameStack.emplace_back(name, GetLabelRef(label), unspecifiedLabel);
+    appendIfLabeled(
+        std::get<parser::Statement<parser::BlockStmt>>(construct.t), ops);
     ops.emplace_back(BeginOp{construct});
     ControlFlowAnalyzer cfa{ops, ad};
     Walk(std::get<parser::Block>(construct.t), cfa);
+    appendIfLabeled(
+        std::get<parser::Statement<parser::EndBlockStmt>>(construct.t), ops);
     ops.emplace_back(EndOp{construct});
     ops.emplace_back(label);
     linearOps.splice(linearOps.end(), ops);
@@ -588,6 +600,8 @@ struct ControlFlowAnalyzer {
     const parser::Name *name{getName(construct)};
     LabelRef exitOpRef{GetLabelRef(exitLab)};
     ad.nameStack.emplace_back(name, exitOpRef, GetLabelRef(incrementLab));
+    appendIfLabeled(
+        std::get<parser::Statement<parser::NonLabelDoStmt>>(construct.t), ops);
     ops.emplace_back(BeginOp{construct});
     ops.emplace_back(GotoOp{GetLabelRef(backedgeLab)});
     ops.emplace_back(incrementLab);
@@ -600,6 +614,8 @@ struct ControlFlowAnalyzer {
     ops.push_back(entryLab);
     ControlFlowAnalyzer cfa{ops, ad};
     Walk(std::get<parser::Block>(construct.t), cfa);
+    appendIfLabeled(
+        std::get<parser::Statement<parser::EndDoStmt>>(construct.t), ops);
     ops.emplace_back(GotoOp{GetLabelRef(incrementLab)});
     ops.emplace_back(EndOp{construct});
     ops.emplace_back(exitLab);
@@ -615,6 +631,8 @@ struct ControlFlowAnalyzer {
     LabelOp exitLab{buildNewLabel()};
     const parser::Name *name{getName(construct)};
     ad.nameStack.emplace_back(name, GetLabelRef(exitLab), unspecifiedLabel);
+    appendIfLabeled(
+        std::get<parser::Statement<parser::IfThenStmt>>(construct.t), ops);
     ops.emplace_back(BeginOp{construct});
     ops.emplace_back(ConditionalGotoOp{
         std::get<parser::Statement<parser::IfThenStmt>>(construct.t),
@@ -626,6 +644,8 @@ struct ControlFlowAnalyzer {
     ops.emplace_back(GotoOp{exitOpRef});
     for (const auto &elseIfBlock :
         std::get<std::list<parser::IfConstruct::ElseIfBlock>>(construct.t)) {
+      appendIfLabeled(
+          std::get<parser::Statement<parser::ElseIfStmt>>(elseIfBlock.t), ops);
       ops.emplace_back(elseLab);
       LabelOp newThenLab{buildNewLabel()};
       LabelOp newElseLab{buildNewLabel()};
@@ -641,10 +661,14 @@ struct ControlFlowAnalyzer {
     if (const auto &optElseBlock{
             std::get<std::optional<parser::IfConstruct::ElseBlock>>(
                 construct.t)}) {
+      appendIfLabeled(
+          std::get<parser::Statement<parser::ElseStmt>>(optElseBlock->t), ops);
       Walk(std::get<parser::Block>(optElseBlock->t), cfa);
     }
     ops.emplace_back(GotoOp{exitOpRef});
     ops.emplace_back(exitLab);
+    appendIfLabeled(
+        std::get<parser::Statement<parser::EndIfStmt>>(construct.t), ops);
     ops.emplace_back(EndOp{construct});
     linearOps.splice(linearOps.end(), ops);
     ad.nameStack.pop_back();
@@ -657,6 +681,7 @@ struct ControlFlowAnalyzer {
     LabelOp exitLab{buildNewLabel()};
     const parser::Name *name{getName(construct)};
     ad.nameStack.emplace_back(name, GetLabelRef(exitLab), unspecifiedLabel);
+    appendIfLabeled(std::get<0>(construct.t), ops);
     ops.emplace_back(BeginOp{construct});
     const auto N{std::get<std::list<B>>(construct.t).size()};
     LabelRef exitOpRef{GetLabelRef(exitLab)};
@@ -676,11 +701,13 @@ struct ControlFlowAnalyzer {
       i = 0;
       for (const auto &caseBlock : std::get<std::list<B>>(construct.t)) {
         ops.emplace_back(toLabels[i++]);
+        appendIfLabeled(std::get<0>(caseBlock.t), ops);
         Walk(std::get<parser::Block>(caseBlock.t), cfa);
         ops.emplace_back(GotoOp{exitOpRef});
       }
     }
     ops.emplace_back(exitLab);
+    appendIfLabeled(std::get<2>(construct.t), ops);
     ops.emplace_back(EndOp{construct});
     linearOps.splice(linearOps.end(), ops);
     ad.nameStack.pop_back();
@@ -696,6 +723,8 @@ struct ControlFlowAnalyzer {
     LabelOp label{buildNewLabel()};
     const parser::Name *name{getName(c)};
     ad.nameStack.emplace_back(name, GetLabelRef(label), unspecifiedLabel);
+    appendIfLabeled(
+        std::get<parser::Statement<parser::WhereConstructStmt>>(c.t), ops);
     ops.emplace_back(BeginOp{c});
     ControlFlowAnalyzer cfa{ops, ad};
     Walk(std::get<std::list<parser::WhereBodyConstruct>>(c.t), cfa);
@@ -703,6 +732,8 @@ struct ControlFlowAnalyzer {
         std::get<std::list<parser::WhereConstruct::MaskedElsewhere>>(c.t), cfa);
     Walk(std::get<std::optional<parser::WhereConstruct::Elsewhere>>(c.t), cfa);
     ops.emplace_back(label);
+    appendIfLabeled(
+        std::get<parser::Statement<parser::EndWhereStmt>>(c.t), ops);
     ops.emplace_back(EndOp{c});
     linearOps.splice(linearOps.end(), ops);
     ad.nameStack.pop_back();
@@ -714,10 +745,15 @@ struct ControlFlowAnalyzer {
     LabelOp label{buildNewLabel()};
     const parser::Name *name{getName(construct)};
     ad.nameStack.emplace_back(name, GetLabelRef(label), unspecifiedLabel);
+    appendIfLabeled(
+        std::get<parser::Statement<parser::ForallConstructStmt>>(construct.t),
+        ops);
     ops.emplace_back(BeginOp{construct});
     ControlFlowAnalyzer cfa{ops, ad};
     Walk(std::get<std::list<parser::ForallBodyConstruct>>(construct.t), cfa);
     ops.emplace_back(label);
+    appendIfLabeled(
+        std::get<parser::Statement<parser::EndForallStmt>>(construct.t), ops);
     ops.emplace_back(EndOp{construct});
     linearOps.splice(linearOps.end(), ops);
     ad.nameStack.pop_back();
index f642b33..a2c5bf7 100644 (file)
@@ -98,8 +98,7 @@ void GraphWriter::dump(BasicBlock &block, std::optional<const char *> color) {
   if (isEntry_) {
     output_ << "<<ENTRY>>\\n";
   }
-  output_ << block_id(block) << '(' << reinterpret_cast<std::intptr_t>(&block)
-          << ")\\n";
+  output_ << block_id(block) << '(' << ToString(&block) << ")\\n";
   for (auto &action : block.getSublist(static_cast<Statement *>(nullptr))) {
     output_ << action.dump() << "\\n";
   }
index 25c9a05..6ac03a2 100644 (file)
@@ -47,7 +47,7 @@ public:
   iterator end() { return subregionList_.end(); }
   const_iterator end() const { return subregionList_.end(); }
   Region *GetEnclosing() const { return enclosingRegion_; }
-  bool IsOutermost() const { return GetEnclosing() == nullptr; }
+  bool IsOutermost() const { return !GetEnclosing(); }
   static Region *Create(Procedure *procedure, Scope *scope = nullptr,
       Region *inRegion = nullptr, Region *insertBefore = nullptr) {
     return new Region(procedure, scope, inRegion, insertBefore);
index a56217b..a41d097 100644 (file)
@@ -141,53 +141,74 @@ StoreInsn::StoreInsn(Statement *addr, BasicBlock *val)
   CHECK(val);
 }
 
+static std::string dumpStoreValue(const StoreInsn::ValueType &v) {
+  return std::visit(
+      common::visitors{
+          [](const Value &v) { return v.dump(); },
+          [](const ApplyExprStmt *e) { return FIR::dump(e->expression()); },
+          [](const Addressable_impl *e) { return FIR::dump(e->address()); },
+          [](const BasicBlock *bb) { return ToString(bb); },
+      },
+      v);
+}
+
+// dump is intended for debugging rather than idiomatic FIR output
 std::string Statement::dump() const {
   return std::visit(
       common::visitors{
-          [](const ReturnStmt &) { return "return"s; },
-          [](const BranchStmt &branch) {
-            if (branch.hasCondition()) {
-              std::string cond{"???"};
-              return "branch (" + cond + ") " +
-                  std::to_string(
-                      reinterpret_cast<std::intptr_t>(branch.getTrueSucc())) +
-                  ' ' +
-                  std::to_string(
-                      reinterpret_cast<std::intptr_t>(branch.getFalseSucc()));
+          [](const ReturnStmt &s) { return "return " + ToString(s.value()); },
+          [](const BranchStmt &s) {
+            if (s.hasCondition()) {
+              return "cgoto (" + s.getCond().dump() + ") " +
+                  ToString(s.getTrueSucc()) + ", " + ToString(s.getFalseSucc());
             }
-            return "goto " +
-                std::to_string(
-                    reinterpret_cast<std::intptr_t>(branch.getTrueSucc()));
+            return "goto " + ToString(s.getTrueSucc());
           },
-          [](const SwitchStmt &stmt) {
-            // return "switch(" + stmt.getCond().dump() + ")";
-            return "switch(?)"s;
+          [](const SwitchStmt &s) {
+            return "switch (" + s.getCond().dump() + ")";
           },
-          [](const SwitchCaseStmt &switchCaseStmt) {
-            // return "switch-case(" + switchCaseStmt.getCond().dump() + ")";
-            return "switch-case(?)"s;
+          [](const SwitchCaseStmt &s) {
+            return "switch-case (" + s.getCond().dump() + ")";
           },
-          [](const SwitchTypeStmt &switchTypeStmt) {
-            // return "switch-type(" + switchTypeStmt.getCond().dump() + ")";
-            return "switch-type(?)"s;
+          [](const SwitchTypeStmt &s) {
+            return "switch-type (" + s.getCond().dump() + ")";
           },
-          [](const SwitchRankStmt &switchRankStmt) {
-            // return "switch-rank(" + switchRankStmt.getCond().dump() + ")";
-            return "switch-rank(?)"s;
+          [](const SwitchRankStmt &s) {
+            return "switch-rank (" + s.getCond().dump() + ")";
+          },
+          [](const IndirectBranchStmt &s) {
+            std::string targets;
+            for (auto *b : s.succ_blocks()) {
+              targets += " " + ToString(b);
+            }
+            return "igoto (" + ToString(s.variable()) + ")" + targets;
           },
-          [](const IndirectBranchStmt &) { return "ibranch"s; },
           [](const UnreachableStmt &) { return "unreachable"s; },
-          [](const ApplyExprStmt &e) { return FIR::dump(e.expression()); },
-          [](const LocateExprStmt &e) {
-            return "&" + FIR::dump(e.expression());
+          [&](const ApplyExprStmt &e) {
+            return '%' + ToString(&u) + ": eval " + FIR::dump(e.expression());
+          },
+          [&](const LocateExprStmt &e) {
+            return '%' + ToString(&u) + ": addr-of " +
+                FIR::dump(e.expression());
           },
           [](const AllocateInsn &) { return "alloc"s; },
-          [](const DeallocateInsn &) { return "dealloc"s; },
-          [](const AllocateLocalInsn &) { return "alloca"s; },
-          [](const LoadInsn &) { return "load"s; },
-          [](const StoreInsn &) { return "store"s; },
+          [](const DeallocateInsn &s) {
+            return "dealloc (" + ToString(s.alloc()) + ")";
+          },
+          [&](const AllocateLocalInsn &insn) {
+            return '%' + ToString(&u) + ": alloca " +
+                FIR::dump(insn.variable());
+          },
+          [&](const LoadInsn &insn) {
+            return '%' + ToString(&u) + ": load " + insn.address().dump();
+          },
+          [](const StoreInsn &insn) {
+            std::string value{dumpStoreValue(insn.value())};
+            return "store " + value + " to " +
+                FIR::dump(insn.address()->address());
+          },
           [](const DisassociateInsn &) { return "NULLIFY"s; },
-          [](const CallStmt &) { return "call"s; },
+          [&](const CallStmt &) { return '%' + ToString(&u) + ": call"s; },
           [](const RuntimeStmt &) { return "runtime-call()"s; },
           [](const IORuntimeStmt &) { return "io-call()"s; },
           [](const ScopeEnterStmt &) { return "scopeenter"s; },
@@ -196,4 +217,16 @@ std::string Statement::dump() const {
       },
       u);
 }
+
+std::string Value::dump() const {
+  return std::visit(
+      common::visitors{
+          [](const Nothing &) { return "<none>"s; },
+          [](const DataObject *obj) { return "obj_" + ToString(obj); },
+          [](const Statement *s) { return "stmt_" + ToString(s); },
+          [](const BasicBlock *bb) { return "block_" + ToString(bb); },
+          [](const Procedure *p) { return "proc_" + ToString(p); },
+      },
+      u);
+}
 }
index d130a43..6078e91 100644 (file)
@@ -55,7 +55,7 @@ public:
 };
 
 // Every basic block must end in a terminator
-class TerminatorStmt_impl : public Stmt_impl {
+class TerminatorStmt_impl : virtual public Stmt_impl {
 public:
   virtual std::list<BasicBlock *> succ_blocks() const = 0;
   virtual ~TerminatorStmt_impl() = default;
@@ -68,7 +68,7 @@ public:
   static ReturnStmt Create(Statement *stmt) { return ReturnStmt{stmt}; }
   static ReturnStmt Create() { return ReturnStmt{nullptr}; }
   std::list<BasicBlock *> succ_blocks() const override { return {}; }
-  bool has_value() const { return value_ != nullptr; }
+  bool has_value() const { return value_; }
   Statement *value() const;
 
 private:
@@ -250,7 +250,7 @@ private:
   explicit UnreachableStmt() = default;
 };
 
-class ActionStmt_impl : public Stmt_impl {
+class ActionStmt_impl : virtual public Stmt_impl {
 public:
   using ActionTrait = std::true_type;
 
@@ -331,7 +331,7 @@ public:
     return DeallocateInsn{alloc};
   }
 
-  AllocateInsn *alloc() { return alloc_; }
+  Statement *alloc() const;
 
 private:
   explicit DeallocateInsn(AllocateInsn *alloc) : alloc_{alloc} {}
@@ -343,22 +343,17 @@ private:
 class AllocateLocalInsn : public Addressable_impl, public MemoryStmt_impl {
 public:
   static AllocateLocalInsn Create(
-      Type type, int alignment = 0, Expression *expr = nullptr) {
-    if (expr != nullptr) {
-      return AllocateLocalInsn{type, alignment, *expr};
-    }
-    return AllocateLocalInsn{type, alignment};
+      Type type, const Expression &expr, int alignment = 0) {
+    return AllocateLocalInsn{type, alignment, expr};
   }
 
   Type type() const { return type_; }
   int alignment() const { return alignment_; }
-  Expression variable() { return addrExpr_.value(); }
+  Expression variable() const { return addrExpr_.value(); }
 
 private:
   explicit AllocateLocalInsn(Type type, int alignment, const Expression &expr)
     : Addressable_impl{expr}, type_{type}, alignment_{alignment} {}
-  explicit AllocateLocalInsn(Type type, int alignment)
-    : type_{type}, alignment_{alignment} {}
 
   Type type_;
   int alignment_;
@@ -371,6 +366,8 @@ public:
   static LoadInsn Create(Value &&addr) { return LoadInsn{addr}; }
   static LoadInsn Create(Statement *addr) { return LoadInsn{addr}; }
 
+  Value address() const { return address_; }
+
 private:
   explicit LoadInsn(const Value &addr);
   explicit LoadInsn(Value &&addr);
@@ -381,6 +378,8 @@ private:
 // Store value(s) from an applied expression to a location
 class StoreInsn : public MemoryStmt_impl {
 public:
+  using ValueType =
+      std::variant<Value, ApplyExprStmt *, Addressable_impl *, BasicBlock *>;
   template<typename T> static StoreInsn Create(T *addr, T *value) {
     return StoreInsn{addr, value};
   }
@@ -388,6 +387,9 @@ public:
     return StoreInsn{addr, value};
   }
 
+  Addressable_impl *address() const { return address_; }
+  ValueType value() const { return value_; }
+
 private:
   explicit StoreInsn(Value addr, Value val);
   explicit StoreInsn(Value addr, BasicBlock *val);
@@ -395,7 +397,7 @@ private:
   explicit StoreInsn(Statement *addr, BasicBlock *val);
 
   Addressable_impl *address_;
-  std::variant<Value, ApplyExprStmt *, Addressable_impl *, BasicBlock *> value_;
+  ValueType value_;
 };
 
 // NULLIFY - make pointer object disassociated
@@ -576,6 +578,9 @@ inline std::list<BasicBlock *> succ_list(BasicBlock &block) {
 }
 
 inline Statement *ReturnStmt::value() const { return Statement::From(value_); }
+inline Statement *DeallocateInsn::alloc() const {
+  return Statement::From(alloc_);
+}
 
 inline ApplyExprStmt *GetApplyExpr(Statement *stmt) {
   return std::get_if<ApplyExprStmt>(&stmt->u);
@@ -586,6 +591,12 @@ inline AllocateLocalInsn *GetLocal(Statement *stmt) {
 }
 
 Addressable_impl *GetAddressable(Statement *stmt);
+
+template<typename A> std::string ToString(const A *a) {
+  std::stringstream ss;
+  ss << std::hex << reinterpret_cast<std::intptr_t>(a);
+  return ss.str();
+}
 }
 
 #endif  // FORTRAN_FIR_STATEMENTS_H_
index 371af46..be9e4e0 100644 (file)
@@ -34,11 +34,20 @@ public:
   template<typename A> Value(A *a) : SumTypeCopyMixin{a} {}
   Value(const Nothing &n) : SumTypeCopyMixin{n} {}
   Value() : SumTypeCopyMixin{NOTHING} {}
+  std::string dump() const;
 };
 
 inline bool IsNothing(Value value) {
   return std::holds_alternative<Nothing>(value.u);
 }
+
+inline bool IsStatement(Value value) {
+  return std::holds_alternative<Statement *>(value.u);
+}
+
+inline bool IsBasicBlock(Value value) {
+  return std::holds_alternative<BasicBlock *>(value.u);
+}
 }
 
 #endif  // FORTRAN_FIR_VALUE_H_