[flang] simplify switch builders
authorEric Schweitz <eschweitz@nvidia.com>
Wed, 20 Mar 2019 22:58:06 +0000 (15:58 -0700)
committerEric <eschweitz@nvidia.com>
Sat, 23 Mar 2019 18:14:20 +0000 (11:14 -0700)
Original-commit: flang-compiler/f18@787ee1fecd275dd2b34171cd165cc38e7c46f8ba
Reviewed-on: https://github.com/flang-compiler/f18/pull/354
Tree-same-pre-rewrite: false

flang/lib/FIR/afforestation.cc
flang/lib/FIR/builder.h
flang/lib/FIR/flattened.cc
flang/lib/FIR/statements.cc
flang/lib/FIR/statements.h

index 6745f69..7522fb6 100644 (file)
@@ -25,9 +25,9 @@
 
 namespace Fortran::FIR {
 namespace {
-Expression *ExprRef(const parser::Expr &a) { return a.typedExpr.get()->v; }
+Expression *ExprRef(const parser::Expr &a) { return &a.typedExpr.get()->v; }
 Expression *ExprRef(const common::Indirection<parser::Expr> &a) {
-  return a.value().typedExpr.get()->v;
+  return &a.value().typedExpr.get()->v;
 }
 
 template<typename STMTTYPE, typename CT>
@@ -47,7 +47,6 @@ void DumpSwitchWithSelector(
 
 template<typename T> struct SwitchArgs {
   Value exp;
-  flat::LabelRef defLab;
   std::vector<T> values;
   std::vector<flat::LabelRef> labels;
 };
@@ -60,20 +59,19 @@ template<typename T> bool IsDefault(const typename T::ValueType &valueType) {
   return std::holds_alternative<typename T::Default>(valueType);
 }
 
+// move the default case to be first
 template<typename T>
-void cleanupSwitchPairs(flat::LabelRef &defLab,
-    std::vector<typename T::ValueType> &values,
+void cleanupSwitchPairs(std::vector<typename T::ValueType> &values,
     std::vector<flat::LabelRef> &labels) {
   CHECK(values.size() == labels.size());
-  for (std::size_t i{0}, len{values.size()}; i < len; ++i) {
+  for (std::size_t i{1}, len{values.size()}; i < len; ++i) {
     if (IsDefault<T>(values[i])) {
-      defLab = labels[i];
-      for (std::size_t j{i}; j < len - 1; ++j) {
-        values[j] = values[j + 1];
-        labels[j] = labels[j + 1];
-      }
-      values.pop_back();
-      labels.pop_back();
+      auto v{values[0]};
+      values[0] = values[i];
+      values[i] = v;
+      auto w{labels[0]};
+      labels[0] = labels[i];
+      labels[i] = w;
       break;
     }
   }
@@ -178,11 +176,6 @@ static std::vector<SwitchTypeStmt::ValueType> populateSwitchValues(
   return result;
 }
 
-static void buildMultiwayDefaultNext(SwitchArguments &result) {
-  result.defLab = result.labels.back();
-  result.labels.pop_back();
-}
-
 template<typename T>
 const T *FindReadWriteSpecifier(
     const std::list<parser::IoControlSpec> &specifiers) {
@@ -228,23 +221,20 @@ static Expression CreateConstant(int64_t value) {
 }
 
 static void CreateSwitchHelper(FIRBuilder *builder, Value condition,
-    BasicBlock *defaultCase, const SwitchStmt::ValueSuccPairListType &rest) {
-  builder->CreateSwitch(condition, defaultCase, rest);
+    const SwitchStmt::ValueSuccPairListType &rest) {
+  builder->CreateSwitch(condition, rest);
 }
 static void CreateSwitchCaseHelper(FIRBuilder *builder, Value condition,
-    BasicBlock *defaultCase,
     const SwitchCaseStmt::ValueSuccPairListType &rest) {
-  builder->CreateSwitchCase(condition, defaultCase, rest);
+  builder->CreateSwitchCase(condition, rest);
 }
 static void CreateSwitchRankHelper(FIRBuilder *builder, Value condition,
-    BasicBlock *defaultCase,
     const SwitchRankStmt::ValueSuccPairListType &rest) {
-  builder->CreateSwitchRank(condition, defaultCase, rest);
+  builder->CreateSwitchRank(condition, rest);
 }
 static void CreateSwitchTypeHelper(FIRBuilder *builder, Value condition,
-    BasicBlock *defaultCase,
     const SwitchTypeStmt::ValueSuccPairListType &rest) {
-  builder->CreateSwitchType(condition, defaultCase, rest);
+  builder->CreateSwitchType(condition, rest);
 }
 
 class FortranIRLowering {
@@ -491,59 +481,62 @@ public:
     return GetSwitchSelector<parser::SelectTypeStmt>(selectTypeConstruct);
   }
   Statement *GetSwitchCaseSelector(const parser::CaseConstruct *construct) {
+    using A = parser::Statement<parser::SelectCaseStmt>;
     const auto &x{std::get<parser::Scalar<parser::Expr>>(
-        std::get<parser::Statement<parser::SelectCaseStmt>>(construct->t)
-            .statement.t)};
+        std::get<A>(construct->t).statement.t)};
     return builder_->CreateExpr(ExprRef(x.thing));
   }
+
+  SwitchArguments ComposeIOSwitchArgs(const flat::SwitchIOOp &IOp) {
+    return {};  // FIXME
+  }
   SwitchArguments ComposeSwitchArgs(const flat::SwitchOp &op) {
-    SwitchArguments result{NOTHING, flat::unspecifiedLabel, {}, op.refs};
-    std::visit(
+    return std::visit(
         common::visitors{
             [&](const parser::ComputedGotoStmt *c) {
               const auto &e{std::get<parser::ScalarIntExpr>(c->t)};
-              result.exp = builder_->CreateExpr(ExprRef(e.thing.thing));
-              buildMultiwayDefaultNext(result);
+              auto *exp{builder_->CreateExpr(ExprRef(e.thing.thing))};
+              return SwitchArguments{exp, {}, op.refs};
             },
             [&](const parser::ArithmeticIfStmt *c) {
-              result.exp =
-                  builder_->CreateExpr(ExprRef(std::get<parser::Expr>(c->t)));
+              const auto &e{std::get<parser::Expr>(c->t)};
+              auto *exp{builder_->CreateExpr(ExprRef(e))};
+              return SwitchArguments{exp, {}, op.refs};
             },
             [&](const parser::CallStmt *c) {
-              result.exp = NOTHING;  // fixme - result of call
-              buildMultiwayDefaultNext(result);
+              auto exp{NOTHING};  // fixme - result of call
+              return SwitchArguments{exp, {}, op.refs};
+            },
+            [](const auto *) {
+              WRONG_PATH();
+              return SwitchArguments{};
             },
-            [](const auto *) { WRONG_PATH(); },
         },
         op.u);
-    return result;
   }
+
   SwitchCaseArguments ComposeSwitchCaseArguments(
       const parser::CaseConstruct *caseConstruct,
       const std::vector<flat::LabelRef> &refs) {
-    auto &cases{
-        std::get<std::list<parser::CaseConstruct::Case>>(caseConstruct->t)};
+    using A = std::list<parser::CaseConstruct::Case>;
+    auto &cases{std::get<A>(caseConstruct->t)};
     SwitchCaseArguments result{GetSwitchCaseSelector(caseConstruct),
-        flat::unspecifiedLabel, populateSwitchValues(builder_, cases),
-        std::move(refs)};
-    cleanupSwitchPairs<SwitchCaseStmt>(
-        result.defLab, result.values, result.labels);
+        populateSwitchValues(builder_, cases), std::move(refs)};
+    cleanupSwitchPairs<SwitchCaseStmt>(result.values, result.labels);
     return result;
   }
   SwitchRankArguments ComposeSwitchRankArguments(
-      const parser::SelectRankConstruct *selectRankConstruct,
+      const parser::SelectRankConstruct *crct,
       const std::vector<flat::LabelRef> &refs) {
-    auto &ranks{std::get<std::list<parser::SelectRankConstruct::RankCase>>(
-        selectRankConstruct->t)};
-    SwitchRankArguments result{GetSwitchRankSelector(selectRankConstruct),
-        flat::unspecifiedLabel, populateSwitchValues(ranks), std::move(refs)};
-    if (auto &name{GetSwitchAssociateName<parser::SelectRankStmt>(
-            selectRankConstruct)}) {
+    auto &ranks{
+        std::get<std::list<parser::SelectRankConstruct::RankCase>>(crct->t)};
+    SwitchRankArguments result{GetSwitchRankSelector(crct),
+        populateSwitchValues(ranks), std::move(refs)};
+    if (auto &name{GetSwitchAssociateName<parser::SelectRankStmt>(crct)}) {
       (void)name;  // get rid of warning
       // TODO: handle associate-name -> Add an assignment stmt?
     }
-    cleanupSwitchPairs<SwitchRankStmt>(
-        result.defLab, result.values, result.labels);
+    cleanupSwitchPairs<SwitchRankStmt>(result.values, result.labels);
     return result;
   }
   SwitchTypeArguments ComposeSwitchTypeArguments(
@@ -552,14 +545,13 @@ public:
     auto &types{std::get<std::list<parser::SelectTypeConstruct::TypeCase>>(
         selectTypeConstruct->t)};
     SwitchTypeArguments result{GetSwitchTypeSelector(selectTypeConstruct),
-        flat::unspecifiedLabel, populateSwitchValues(types), std::move(refs)};
+        populateSwitchValues(types), std::move(refs)};
     if (auto &name{GetSwitchAssociateName<parser::SelectTypeStmt>(
             selectTypeConstruct)}) {
       (void)name;  // get rid of warning
       // TODO: handle associate-name -> Add an assignment stmt?
     }
-    cleanupSwitchPairs<SwitchTypeStmt>(
-        result.defLab, result.values, result.labels);
+    cleanupSwitchPairs<SwitchTypeStmt>(result.values, result.labels);
     return result;
   }
 
@@ -932,8 +924,12 @@ public:
                 builder_->CreateStore(var, e1);
                 PushDoContext(stmt, var, e1, e2, e3);
               },
-              [&](const parser::ScalarLogicalExpr &whileExpr) {},
-              [&](const parser::LoopControl::Concurrent &cc) {},
+              [&](const parser::ScalarLogicalExpr &whileExpr) {
+                // FIXME
+              },
+              [&](const parser::LoopControl::Concurrent &cc) {
+                // FIXME
+              },
           },
           ctrl->u);
     } else {
@@ -945,16 +941,10 @@ public:
   void FinishConstruct(const parser::NonLabelDoStmt *stmt) {
     auto &ctrl{std::get<std::optional<parser::LoopControl>>(stmt->t)};
     if (ctrl.has_value()) {
-      std::visit(
-          common::visitors{
-              [&](const parser::LoopBounds<parser::ScalarIntExpr> &) {
-                PopDoContext(stmt);
-              },
-              [&](auto &) {
-                // do nothing
-              },
-          },
-          ctrl->u);
+      using A = parser::LoopBounds<parser::ScalarIntExpr>;
+      if (std::holds_alternative<A>(ctrl->u)) {
+        PopDoContext(stmt);
+      }
     }
   }
 
@@ -981,6 +971,45 @@ public:
     return builder_->CreateExpr(AlwaysTrueExpression());
   }
 
+  template<typename SWITCHTYPE, typename F>
+  void AddOrQueueSwitch(Value condition,
+      const std::vector<typename SWITCHTYPE::ValueType> &values,
+      const std::vector<flat::LabelRef> &labels, F function) {
+    auto defer{false};
+    typename SWITCHTYPE::ValueSuccPairListType cases;
+    CHECK(values.size() == labels.size());
+    auto valiter{values.begin()};
+    for (auto lab : labels) {
+      auto labIter{blockMap_.find(lab)};
+      if (labIter == blockMap_.end()) {
+        defer = true;
+        break;
+      } else {
+        cases.emplace_back(*valiter++, labIter->second);
+      }
+    }
+    if (defer) {
+      using namespace std::placeholders;
+      controlFlowEdgesToAdd_.emplace_back(std::bind(
+          [](FIRBuilder *builder, BasicBlock *block, Value expr,
+              const std::vector<typename SWITCHTYPE::ValueType> &values,
+              const std::vector<flat::LabelRef> &labels, F function,
+              const LabelMapType &map) {
+            builder->SetInsertionPoint(block);
+            typename SWITCHTYPE::ValueSuccPairListType cases;
+            auto valiter{values.begin()};
+            for (auto &lab : labels) {
+              cases.emplace_back(*valiter++, map.find(lab)->second);
+            }
+            function(builder, expr, cases);
+          },
+          builder_, builder_->GetInsertionPoint(), condition, values, labels,
+          function, _1));
+    } else {
+      function(builder_, condition, cases);
+    }
+  }
+
   void ConstructFIR(AnalysisData &ad) {
     for (auto iter{linearOperations_.begin()}, iend{linearOperations_.end()};
          iter != iend; ++iter) {
@@ -1076,8 +1105,9 @@ public:
               },
               [&](const flat::SwitchIOOp &IOp) {
                 CheckInsertionPoint();
+                auto args{ComposeIOSwitchArgs(IOp)};
                 AddOrQueueSwitch<SwitchStmt>(
-                    NOTHING, IOp.next, {}, {}, CreateSwitchHelper);
+                    args.exp, args.values, args.labels, CreateSwitchHelper);
                 builder_->ClearInsertionPoint();
               },
               [&](const flat::SwitchOp &sop) {
@@ -1086,31 +1116,23 @@ public:
                     common::visitors{
                         [&](auto) {
                           auto args{ComposeSwitchArgs(sop)};
-                          AddOrQueueSwitch<SwitchStmt>(args.exp, args.defLab,
-                              args.values, args.labels, CreateSwitchHelper);
+                          AddOrQueueSwitch<SwitchStmt>(args.exp, args.values,
+                              args.labels, CreateSwitchHelper);
                         },
-                        [&](const parser::CaseConstruct *caseConstruct) {
-                          auto args{ComposeSwitchCaseArguments(
-                              caseConstruct, sop.refs)};
+                        [&](const parser::CaseConstruct *crct) {
+                          auto args{ComposeSwitchCaseArguments(crct, sop.refs)};
                           AddOrQueueSwitch<SwitchCaseStmt>(args.exp,
-                              args.defLab, args.values, args.labels,
-                              CreateSwitchCaseHelper);
+                              args.values, args.labels, CreateSwitchCaseHelper);
                         },
-                        [&](const parser::SelectRankConstruct
-                                *selectRankConstruct) {
-                          auto args{ComposeSwitchRankArguments(
-                              selectRankConstruct, sop.refs)};
+                        [&](const parser::SelectRankConstruct *crct) {
+                          auto args{ComposeSwitchRankArguments(crct, sop.refs)};
                           AddOrQueueSwitch<SwitchRankStmt>(args.exp,
-                              args.defLab, args.values, args.labels,
-                              CreateSwitchRankHelper);
+                              args.values, args.labels, CreateSwitchRankHelper);
                         },
-                        [&](const parser::SelectTypeConstruct
-                                *selectTypeConstruct) {
-                          auto args{ComposeSwitchTypeArguments(
-                              selectTypeConstruct, sop.refs)};
+                        [&](const parser::SelectTypeConstruct *crct) {
+                          auto args{ComposeSwitchTypeArguments(crct, sop.refs)};
                           AddOrQueueSwitch<SwitchTypeStmt>(args.exp,
-                              args.defLab, args.values, args.labels,
-                              CreateSwitchTypeHelper);
+                              args.values, args.labels, CreateSwitchTypeHelper);
                         },
                     },
                     sop.u);
@@ -1132,69 +1154,55 @@ public:
                 std::visit(
                     common::visitors{
                         [&](const parser::AssociateConstruct *crct) {
-                          const auto &statement{std::get<
-                              parser::Statement<parser::AssociateStmt>>(
-                              crct->t)};
+                          using A = parser::Statement<parser::AssociateStmt>;
+                          const auto &statement{std::get<A>(crct->t)};
                           const auto &position{statement.source};
                           EnterRegion(position);
                           InitiateConstruct(&statement.statement);
                         },
                         [&](const parser::BlockConstruct *crct) {
-                          EnterRegion(
-                              std::get<parser::Statement<parser::BlockStmt>>(
-                                  crct->t)
-                                  .source);
+                          using A = parser::Statement<parser::BlockStmt>;
+                          EnterRegion(std::get<A>(crct->t).source);
                         },
                         [&](const parser::CaseConstruct *crct) {
-                          InitiateConstruct(
-                              &std::get<
-                                  parser::Statement<parser::SelectCaseStmt>>(
-                                  crct->t)
-                                   .statement);
+                          using A = parser::Statement<parser::SelectCaseStmt>;
+                          InitiateConstruct(&std::get<A>(crct->t).statement);
                         },
                         [&](const parser::ChangeTeamConstruct *crct) {
-                          const auto &statement{std::get<
-                              parser::Statement<parser::ChangeTeamStmt>>(
-                              crct->t)};
+                          using A = parser::Statement<parser::ChangeTeamStmt>;
+                          const auto &statement{std::get<A>(crct->t)};
                           EnterRegion(statement.source);
                           InitiateConstruct(&statement.statement);
                         },
                         [&](const parser::DoConstruct *crct) {
-                          const auto &statement{std::get<
-                              parser::Statement<parser::NonLabelDoStmt>>(
-                              crct->t)};
+                          using A = parser::Statement<parser::NonLabelDoStmt>;
+                          const auto &statement{std::get<A>(crct->t)};
                           EnterRegion(statement.source);
                           InitiateConstruct(&statement.statement);
                         },
                         [&](const parser::IfConstruct *crct) {
-                          InitiateConstruct(
-                              &std::get<parser::Statement<parser::IfThenStmt>>(
-                                  crct->t)
-                                   .statement);
+                          using A = parser::Statement<parser::IfThenStmt>;
+                          InitiateConstruct(&std::get<A>(crct->t).statement);
                         },
                         [&](const parser::SelectRankConstruct *crct) {
-                          const auto &statement{std::get<
-                              parser::Statement<parser::SelectRankStmt>>(
-                              crct->t)};
+                          using A = parser::Statement<parser::SelectRankStmt>;
+                          const auto &statement{std::get<A>(crct->t)};
                           EnterRegion(statement.source);
                         },
                         [&](const parser::SelectTypeConstruct *crct) {
-                          const auto &statement{std::get<
-                              parser::Statement<parser::SelectTypeStmt>>(
-                              crct->t)};
+                          using A = parser::Statement<parser::SelectTypeStmt>;
+                          const auto &statement{std::get<A>(crct->t)};
                           EnterRegion(statement.source);
                         },
                         [&](const parser::WhereConstruct *crct) {
-                          InitiateConstruct(
-                              &std::get<parser::Statement<
-                                   parser::WhereConstructStmt>>(crct->t)
-                                   .statement);
+                          using A =
+                              parser::Statement<parser::WhereConstructStmt>;
+                          InitiateConstruct(&std::get<A>(crct->t).statement);
                         },
                         [&](const parser::ForallConstruct *crct) {
-                          InitiateConstruct(
-                              &std::get<parser::Statement<
-                                   parser::ForallConstructStmt>>(crct->t)
-                                   .statement);
+                          using A =
+                              parser::Statement<parser::ForallConstructStmt>;
+                          InitiateConstruct(&std::get<A>(crct->t).statement);
                         },
                         [](const parser::CriticalConstruct *) { /*fixme*/ },
                         [](const parser::CompilerDirective *) { /*fixme*/ },
@@ -1205,16 +1213,10 @@ public:
                     con.u);
                 auto next{iter};
                 const auto &nextOp{*(++next)};
-                std::visit(
-                    common::visitors{
-                        [](const auto &) {},
-                        [&](const flat::LabelOp &op) {
-                          blockMap_.insert(
-                              {op.get(), builder_->GetInsertionPoint()});
-                          ++iter;
-                        },
-                    },
-                    nextOp.u);
+                if (auto *op{std::get_if<flat::LabelOp>(&nextOp.u)}) {
+                  blockMap_.insert({op->get(), builder_->GetInsertionPoint()});
+                  ++iter;
+                }
               },
               [&](const flat::EndOp &con) {
                 std::visit(
@@ -1306,51 +1308,6 @@ public:
     }
   }
 
-  template<typename SWITCHTYPE, typename F>
-  void AddOrQueueSwitch(Value condition, flat::LabelRef defaultLabel,
-      const std::vector<typename SWITCHTYPE::ValueType> &values,
-      const std::vector<flat::LabelRef> &labels, F function) {
-    auto defer{false};
-    auto defaultIter{blockMap_.find(defaultLabel)};
-    typename SWITCHTYPE::ValueSuccPairListType cases;
-    if (defaultIter == blockMap_.end()) {
-      defer = true;
-    } else {
-      CHECK(values.size() == labels.size());
-      auto valiter{values.begin()};
-      for (auto lab : labels) {
-        auto labIter{blockMap_.find(lab)};
-        if (labIter == blockMap_.end()) {
-          defer = true;
-          break;
-        } else {
-          cases.emplace_back(*valiter++, labIter->second);
-        }
-      }
-    }
-    if (defer) {
-      using namespace std::placeholders;
-      controlFlowEdgesToAdd_.emplace_back(std::bind(
-          [](FIRBuilder *builder, BasicBlock *block, Value expr,
-              flat::LabelRef defaultDest,
-              const std::vector<typename SWITCHTYPE::ValueType> &values,
-              const std::vector<flat::LabelRef> &labels, F function,
-              const LabelMapType &map) {
-            builder->SetInsertionPoint(block);
-            typename SWITCHTYPE::ValueSuccPairListType cases;
-            auto valiter{values.begin()};
-            for (auto &lab : labels) {
-              cases.emplace_back(*valiter++, map.find(lab)->second);
-            }
-            function(builder, expr, map.find(defaultDest)->second, cases);
-          },
-          builder_, builder_->GetInsertionPoint(), condition, defaultLabel,
-          values, labels, function, _1));
-    } else {
-      function(builder_, condition, defaultIter->second, cases);
-    }
-  }
-
   Variable *ConvertToVariable(const semantics::Symbol *symbol) {
     // FIXME: how to convert semantics::Symbol to evaluate::Variable?
     return new Variable(symbol);
index 8530db2..15db40b 100644 (file)
@@ -107,23 +107,21 @@ struct FIRBuilder {
       RuntimeCallType call, RuntimeCallArguments &&arguments) {
     return Insert(RuntimeStmt::Create(call, std::move(arguments)));
   }
-  Statement *CreateSwitch(Value condition, BasicBlock *defaultCase,
-      const SwitchStmt::ValueSuccPairListType &rest) {
-    return InsertTerminator(SwitchStmt::Create(condition, defaultCase, rest));
+  Statement *CreateSwitch(
+      Value cond, const SwitchStmt::ValueSuccPairListType &pairs) {
+    return InsertTerminator(SwitchStmt::Create(cond, pairs));
   }
-  Statement *CreateSwitchCase(Value condition, BasicBlock *defaultCase,
-      const SwitchCaseStmt::ValueSuccPairListType &rest) {
-    return InsertTerminator(
-        SwitchCaseStmt::Create(condition, defaultCase, rest));
+  Statement *CreateSwitchCase(
+      Value cond, const SwitchCaseStmt::ValueSuccPairListType &pairs) {
+    return InsertTerminator(SwitchCaseStmt::Create(cond, pairs));
   }
-  Statement *CreateSwitchType(Value condition, BasicBlock *defaultCase,
-      const SwitchTypeStmt::ValueSuccPairListType &rest) {
-    return InsertTerminator(
-        SwitchTypeStmt::Create(condition, defaultCase, rest));
+  Statement *CreateSwitchType(
+      Value cond, const SwitchTypeStmt::ValueSuccPairListType &pairs) {
+    return InsertTerminator(SwitchTypeStmt::Create(cond, pairs));
   }
   Statement *CreateSwitchRank(
-      Value c, BasicBlock *d, const SwitchRankStmt::ValueSuccPairListType &r) {
-    return InsertTerminator(SwitchRankStmt::Create(c, d, r));
+      Value cond, const SwitchRankStmt::ValueSuccPairListType &pairs) {
+    return InsertTerminator(SwitchRankStmt::Create(cond, pairs));
   }
   Statement *CreateUnreachable() {
     return InsertTerminator(UnreachableStmt::Create());
index 49ca311..6f3f54a 100644 (file)
@@ -214,6 +214,17 @@ std::vector<LabelRef> toLabelRef(AnalysisData &ad, const A &labels) {
   return result;
 }
 
+template<typename A>
+std::vector<LabelRef> toLabelRef(
+    const LabelOp &next, AnalysisData &ad, const A &labels) {
+  std::vector<LabelRef> result;
+  result.emplace_back(next);
+  auto refs{toLabelRef(ad, labels)};
+  result.insert(result.end(), refs.begin(), refs.end());
+  CHECK(result.size() == labels.size() + 1);
+  return result;
+}
+
 static bool hasAltReturns(const parser::CallStmt &callStmt) {
   const auto &args{std::get<std::list<parser::ActualArgSpec>>(callStmt.v.t)};
   for (const auto &arg : args) {
@@ -275,23 +286,15 @@ void ReturnOp::dump() const {
 }
 
 void ConditionalGotoOp::dump() const {
-  DebugChannel()
-      << "\tcbranch .T.:" << trueLabel << " .F.:" << falseLabel << " ["
-      << std::visit(
-             common::visitors{
-                 [](const parser::Statement<parser::IfThenStmt> *s) {
-                   return GetSource(s);
-                 },
-                 [](const parser::Statement<parser::ElseIfStmt> *s) {
-                   return GetSource(s);
-                 },
-                 [](const parser::IfStmt *) { return "if-stmt"s; },
-                 [](const parser::Statement<parser::NonLabelDoStmt> *s) {
-                   return GetSource(s);
-                 },
-             },
-             u)
-      << "]\n";
+  DebugChannel() << "\tcbranch .T.:" << trueLabel << " .F.:" << falseLabel
+                 << " ["
+                 << std::visit(
+                        common::visitors{
+                            [](const auto *s) { return GetSource(s); },
+                            [](const parser::IfStmt *) { return "if-stmt"s; },
+                        },
+                        u)
+                 << "]\n";
 }
 
 void SwitchIOOp::dump() const {
@@ -444,8 +447,8 @@ void Op::Build(std::list<Op> &ops,
           [&](const common::Indirection<parser::CallStmt> &s) {
             if (hasAltReturns(s.value())) {
               auto next{BuildNewLabel(ad)};
-              auto labels{toLabelRef(ad, getAltReturnLabels(s.value().v))};
-              labels.push_back(next);
+              auto alts{getAltReturnLabels(s.value().v)};
+              auto labels{toLabelRef(next, ad, alts)};
               ops.emplace_back(
                   SwitchOp{s.value(), std::move(labels), ec.source});
               ops.emplace_back(next);
@@ -519,8 +522,7 @@ void Op::Build(std::list<Op> &ops,
           [&](const common::Indirection<parser::ComputedGotoStmt> &s) {
             auto next{BuildNewLabel(ad)};
             auto labels{toLabelRef(
-                ad, std::get<std::list<parser::Label>>(s.value().t))};
-            labels.push_back(next);
+                next, ad, std::get<std::list<parser::Label>>(s.value().t))};
             ops.emplace_back(SwitchOp{s.value(), std::move(labels), ec.source});
             ops.emplace_back(next);
           },
index 809967e..39f8c5f 100644 (file)
@@ -56,45 +56,53 @@ ReturnStmt::ReturnStmt(Statement *exp) : value_{GetApplyExpr(exp)} {
   CHECK(value_);
 }
 
-SwitchStmt::SwitchStmt(const Value &cond, BasicBlock *defaultBlock,
-    const ValueSuccPairListType &args)
+SwitchStmt::SwitchStmt(const Value &cond, const ValueSuccPairListType &args)
   : condition_{cond} {
-  valueSuccPairs_.push_back({NOTHING, defaultBlock});
   valueSuccPairs_.insert(valueSuccPairs_.end(), args.begin(), args.end());
 }
 std::list<BasicBlock *> SwitchStmt::succ_blocks() const {
   return SuccBlocks<SwitchStmt>(valueSuccPairs_);
 }
+BasicBlock *SwitchStmt::defaultSucc() const {
+  CHECK(IsNothing(valueSuccPairs_[0].first));
+  return valueSuccPairs_[0].second;
+}
 
-SwitchCaseStmt::SwitchCaseStmt(
-    Value cond, BasicBlock *defaultBlock, const ValueSuccPairListType &args)
+SwitchCaseStmt::SwitchCaseStmt(Value cond, const ValueSuccPairListType &args)
   : condition_{cond} {
-  valueSuccPairs_.push_back({SwitchCaseStmt::Default{}, defaultBlock});
   valueSuccPairs_.insert(valueSuccPairs_.end(), args.begin(), args.end());
 }
 std::list<BasicBlock *> SwitchCaseStmt::succ_blocks() const {
   return SuccBlocks<SwitchCaseStmt>(valueSuccPairs_);
 }
+BasicBlock *SwitchCaseStmt::defaultSucc() const {
+  CHECK(std::holds_alternative<Default>(valueSuccPairs_[0].first));
+  return valueSuccPairs_[0].second;
+}
 
-SwitchTypeStmt::SwitchTypeStmt(
-    Value cond, BasicBlock *defaultBlock, const ValueSuccPairListType &args)
+SwitchTypeStmt::SwitchTypeStmt(Value cond, const ValueSuccPairListType &args)
   : condition_{cond} {
-  valueSuccPairs_.push_back({SwitchTypeStmt::Default{}, defaultBlock});
   valueSuccPairs_.insert(valueSuccPairs_.end(), args.begin(), args.end());
 }
 std::list<BasicBlock *> SwitchTypeStmt::succ_blocks() const {
   return SuccBlocks<SwitchTypeStmt>(valueSuccPairs_);
 }
+BasicBlock *SwitchTypeStmt::defaultSucc() const {
+  CHECK(std::holds_alternative<Default>(valueSuccPairs_[0].first));
+  return valueSuccPairs_[0].second;
+}
 
-SwitchRankStmt ::SwitchRankStmt(
-    Value cond, BasicBlock *defaultBlock, const ValueSuccPairListType &args)
+SwitchRankStmt ::SwitchRankStmt(Value cond, const ValueSuccPairListType &args)
   : condition_{cond} {
-  valueSuccPairs_.push_back({SwitchRankStmt::Default{}, defaultBlock});
   valueSuccPairs_.insert(valueSuccPairs_.end(), args.begin(), args.end());
 }
 std::list<BasicBlock *> SwitchRankStmt::succ_blocks() const {
   return SuccBlocks<SwitchRankStmt>(valueSuccPairs_);
 }
+BasicBlock *SwitchRankStmt::defaultSucc() const {
+  CHECK(std::holds_alternative<Default>(valueSuccPairs_[0].first));
+  return valueSuccPairs_[0].second;
+}
 
 // check LoadInsn constraints
 static void CheckLoadInsn(const Value &v) {
index cbc1751..e8c340b 100644 (file)
@@ -114,17 +114,17 @@ public:
   using ValueType = Value;
   using ValueSuccPairType = std::pair<ValueType, BasicBlock *>;
   using ValueSuccPairListType = std::vector<ValueSuccPairType>;
-  static SwitchStmt Create(const Value &switchEval, BasicBlock *defaultBlock,
-      const ValueSuccPairListType &args) {
-    return SwitchStmt{switchEval, defaultBlock, args};
+  static SwitchStmt Create(
+      const Value &switchEval, const ValueSuccPairListType &args) {
+    return SwitchStmt{switchEval, args};
   }
-  BasicBlock *defaultSucc() const { return valueSuccPairs_[0].second; }
+  BasicBlock *defaultSucc() const;
   std::list<BasicBlock *> succ_blocks() const override;
   Value getCond() const { return condition_; }
 
 private:
-  explicit SwitchStmt(const Value &condition, BasicBlock *defaultBlock,
-      const ValueSuccPairListType &args);
+  explicit SwitchStmt(
+      const Value &condition, const ValueSuccPairListType &args);
 
   Value condition_;
   ValueSuccPairListType valueSuccPairs_;
@@ -153,17 +153,16 @@ public:
   using ValueSuccPairType = std::pair<ValueType, BasicBlock *>;
   using ValueSuccPairListType = std::vector<ValueSuccPairType>;
 
-  static SwitchCaseStmt Create(Value switchEval, BasicBlock *defaultBlock,
-      const ValueSuccPairListType &args) {
-    return SwitchCaseStmt{switchEval, defaultBlock, args};
+  static SwitchCaseStmt Create(
+      Value switchEval, const ValueSuccPairListType &args) {
+    return SwitchCaseStmt{switchEval, args};
   }
-  BasicBlock *defaultSucc() const { return valueSuccPairs_[0].second; }
+  BasicBlock *defaultSucc() const;
   std::list<BasicBlock *> succ_blocks() const override;
   Value getCond() const { return condition_; }
 
 private:
-  explicit SwitchCaseStmt(Value condition, BasicBlock *defaultBlock,
-      const ValueSuccPairListType &args);
+  explicit SwitchCaseStmt(Value condition, const ValueSuccPairListType &args);
 
   Value condition_;
   ValueSuccPairListType valueSuccPairs_;
@@ -182,17 +181,16 @@ public:
   using ValueType = std::variant<Default, TypeSpec, DerivedTypeSpec>;
   using ValueSuccPairType = std::pair<ValueType, BasicBlock *>;
   using ValueSuccPairListType = std::vector<ValueSuccPairType>;
-  static SwitchTypeStmt Create(Value switchEval, BasicBlock *defaultBlock,
-      const ValueSuccPairListType &args) {
-    return SwitchTypeStmt{switchEval, defaultBlock, args};
+  static SwitchTypeStmt Create(
+      Value switchEval, const ValueSuccPairListType &args) {
+    return SwitchTypeStmt{switchEval, args};
   }
-  BasicBlock *defaultSucc() const { return valueSuccPairs_[0].second; }
+  BasicBlock *defaultSucc() const;
   std::list<BasicBlock *> succ_blocks() const override;
   Value getCond() const { return condition_; }
 
 private:
-  explicit SwitchTypeStmt(Value condition, BasicBlock *defaultBlock,
-      const ValueSuccPairListType &args);
+  explicit SwitchTypeStmt(Value condition, const ValueSuccPairListType &args);
   Value condition_;
   ValueSuccPairListType valueSuccPairs_;
 };
@@ -208,17 +206,16 @@ public:
   using ValueType = std::variant<Exactly, AssumedSize, Default>;
   using ValueSuccPairType = std::pair<ValueType, BasicBlock *>;
   using ValueSuccPairListType = std::vector<ValueSuccPairType>;
-  static SwitchRankStmt Create(Value switchEval, BasicBlock *defaultBlock,
-      const ValueSuccPairListType &args) {
-    return SwitchRankStmt{switchEval, defaultBlock, args};
+  static SwitchRankStmt Create(
+      Value switchEval, const ValueSuccPairListType &args) {
+    return SwitchRankStmt{switchEval, args};
   }
-  BasicBlock *defaultSucc() const { return valueSuccPairs_[0].second; }
+  BasicBlock *defaultSucc() const;
   std::list<BasicBlock *> succ_blocks() const override;
   Value getCond() const { return condition_; }
 
 private:
-  explicit SwitchRankStmt(Value condition, BasicBlock *defaultBlock,
-      const ValueSuccPairListType &args);
+  explicit SwitchRankStmt(Value condition, const ValueSuccPairListType &args);
 
   Value condition_;
   ValueSuccPairListType valueSuccPairs_;