static Expression getApplyExpr(Statement *s) {
return GetApplyExpr(s)->expression();
}
+static Expression getLocalVariable(Statement *s) {
+ return GetLocal(s)->variable();
+}
class FortranIRLowering {
public:
context.GetDefaultKind(common::TypeCategory::Real))
.value();
}
+ Expression ConsExpr(
+ common::RelationalOperator op, Expression e1, Expression e2) {
+ evaluate::ExpressionAnalyzer context{semanticsContext_};
+ return evaluate::AsGenericExpr(evaluate::Relate(
+ context.GetContextualMessages(), op, std::move(e1), std::move(e2))
+ .value());
+ }
template<typename T>
void ProcessRoutine(const T &here, const std::string &name) {
// DO loop handlers
struct DoBoundsInfo {
Statement *doVariable;
- Statement *lowerBound;
- Statement *upperBound;
+ Statement *counter;
Statement *stepExpr;
Statement *condition;
};
void PushDoContext(const parser::NonLabelDoStmt *doStmt, Statement *doVar,
- Statement *lowBound, Statement *upBound, Statement *stepExp) {
- doMap_.emplace(doStmt, DoBoundsInfo{doVar, lowBound, upBound, stepExp});
+ Statement *counter, Statement *stepExp) {
+ doMap_.emplace(doStmt, DoBoundsInfo{doVar, counter, stepExp});
}
void PopDoContext(const parser::NonLabelDoStmt *doStmt) {
doMap_.erase(doStmt);
return nullptr;
}
- // do_var = do_var + e3
+ // evaluate: do_var = do_var + e3; counter--
void handleLinearDoIncrement(const flat::DoIncrementOp &inc) {
auto *info{GetBoundsInfo(inc)};
- auto *var{builder_->CreateLoad(info->doVariable)};
- builder_->CreateIncrement(var, info->stepExpr);
+ auto *incremented{builder_->CreateExpr(std::move(
+ ConsExpr<evaluate::Add>(GetAddressable(info->doVariable)->address(),
+ GetApplyExpr(info->stepExpr)->expression())))};
+ builder_->CreateStore(info->doVariable, incremented);
+ auto *decremented{builder_->CreateExpr(ConsExpr<evaluate::Subtract>(
+ GetAddressable(info->counter)->address(), CreateConstant(1)))};
+ builder_->CreateStore(info->counter, decremented);
}
- // (e3 > 0 && do_var <= e2) || (e3 < 0 && do_var >= e2)
+ // is (counter > 0)?
void handleLinearDoCompare(const flat::DoCompareOp &cmp) {
auto *info{GetBoundsInfo(cmp)};
- auto *var{builder_->CreateLoad(info->doVariable)};
- auto *cond{
- builder_->CreateDoCondition(info->stepExpr, var, info->upperBound)};
+ Expression compare{ConsExpr(common::RelationalOperator::GT,
+ getLocalVariable(info->counter), CreateConstant(0))};
+ auto *cond{builder_->CreateExpr(&compare)};
info->condition = cond;
}
getApplyExpr(e3))};
auto *totalTrips{builder_->CreateExpr(&tripExpr)};
builder_->CreateStore(tripCounter, totalTrips);
- PushDoContext(stmt, name, nullptr, tripCounter, e3);
+ PushDoContext(stmt, name, tripCounter, e3);
},
[&](const parser::ScalarLogicalExpr &whileExpr) {
// FIXME
Statement *CreateDealloc(AllocateInsn *alloc) {
return Insert(DeallocateInsn::Create(alloc));
}
- Statement *CreateDoCondition(Statement *dir, Statement *v1, Statement *v2) {
- return Insert(DoConditionStmt::Create(dir, v1, v2));
- }
Statement *CreateExpr(const Expression *e) {
return Insert(ApplyExprStmt::Create(e));
}
ApplyExprStmt *MakeAsExpr(const Expression *e) {
return GetApplyExpr(CreateExpr(e));
}
- Statement *CreateIncrement(Statement *v1, Statement *v2) {
- return Insert(IncrementStmt::Create(v1, v2));
- }
Statement *CreateIndirectBr(Variable *v, const std::vector<BasicBlock *> &p) {
return InsertTerminator(IndirectBranchStmt::Create(v, p));
}
CHECK(val);
}
-IncrementStmt::IncrementStmt(Value v1, Value v2) : value_{v1, v2} {}
-
-DoConditionStmt::DoConditionStmt(Value dir, Value v1, Value v2)
- : value_{dir, v1, v2} {}
-
std::string Statement::dump() const {
return std::visit(
common::visitors{
},
[](const IndirectBranchStmt &) { return "ibranch"s; },
[](const UnreachableStmt &) { return "unreachable"s; },
- [](const IncrementStmt &) { return "increment"s; },
- [](const DoConditionStmt &) { return "compare"s; },
[](const ApplyExprStmt &e) { return FIR::dump(e.expression()); },
[](const LocateExprStmt &e) {
return "&" + FIR::dump(e.expression());
class SwitchRankStmt;
class IndirectBranchStmt;
class UnreachableStmt;
-class IncrementStmt;
-class DoConditionStmt;
class ApplyExprStmt;
class LocateExprStmt;
class AllocateInsn;
std::optional<evaluate::DynamicType> type;
};
-class IncrementStmt : public ActionStmt_impl {
-public:
- static IncrementStmt Create(Value v1, Value v2) {
- return IncrementStmt(v1, v2);
- }
- Value leftValue() const { return value_[0]; }
- Value rightValue() const { return value_[1]; }
-
-private:
- explicit IncrementStmt(Value v1, Value v2);
- Value value_[2];
-};
-
-class DoConditionStmt : public ActionStmt_impl {
-public:
- static DoConditionStmt Create(Value dir, Value left, Value right) {
- return DoConditionStmt(dir, left, right);
- }
- Value direction() const { return value_[0]; }
- Value leftValue() const { return value_[1]; }
- Value rightValue() const { return value_[2]; }
-
-private:
- explicit DoConditionStmt(Value dir, Value left, Value right);
- Value value_[3];
-};
-
// Compute the value of an expression
class ApplyExprStmt : public ActionStmt_impl {
public:
// Base class of all addressable statements
class Addressable_impl : public ActionStmt_impl {
+public:
+ Expression address() const { return addrExpr_.value(); }
+
protected:
Addressable_impl() : addrExpr_{std::nullopt} {}
explicit Addressable_impl(const Expression &ae) : addrExpr_{ae} {}
}
static LocateExprStmt Create(Expression &&e) { return LocateExprStmt(e); }
- const Expression &expression() const { return *addrExpr_; }
+ const Expression &expression() const { return addrExpr_.value(); }
private:
explicit LocateExprStmt(const Expression &e) : Addressable_impl{e} {}
Type type() const { return type_; }
int alignment() const { return alignment_; }
+ Expression variable() { return addrExpr_.value(); }
private:
explicit AllocateLocalInsn(Type type, int alignment, const Expression &expr)
SwitchRankStmt, //
IndirectBranchStmt, //
UnreachableStmt, //
- IncrementStmt, //
- DoConditionStmt, //
ApplyExprStmt, //
LocateExprStmt, //
AllocateInsn, //
return std::get_if<ApplyExprStmt>(&stmt->u);
}
+inline AllocateLocalInsn *GetLocal(Statement *stmt) {
+ return std::get_if<AllocateLocalInsn>(&stmt->u);
+}
+
Addressable_impl *GetAddressable(Statement *stmt);
}