using arith::Intersect;
using arith::IntSet;
-using PartitionKey = std::pair<const Object*, bool>;
+using PartitionKey = std::pair<PrimExpr, bool>;
struct PartitionKeyHash {
std::size_t operator()(PartitionKey const& k) const noexcept {
- std::size_t h1 = std::hash<const Object*>{}(k.first);
+ std::size_t h1 = ObjectPtrHash{}(k.first); // NOLINT(whitespace/braces)
std::size_t h2 = std::hash<bool>{}(k.second);
return h1 ^ h2;
}
};
+struct PartitionKeyEqual {
+ bool operator()(const PartitionKey& k1, const PartitionKey& k2) const {
+ // NOLINTNEXTLINE(whitespace/braces)
+ return k1.second == k2.second && ObjectPtrEqual{}(k1.first, k2.first);
+ }
+};
+
// Each mapping (cond, cond_value) -> interval represents the fact that
// condition cond is proven to have value cond_value (true or false) in interval.
-using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash>;
+using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash, PartitionKeyEqual>;
+
+using ExpressionSet = std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
bool ExprUseVars(PrimExpr expr, const std::unordered_set<const VarNode*>& vars) {
bool success = false;
record_.insert({var, false});
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var) && !no_split_) {
- candidates.insert(op);
+ candidates.insert(GetRef<Stmt>(op));
}
record_.erase(var);
} else {
record_.insert({var.get(), false});
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var.get()) && !no_split_) {
- candidates.insert(op);
+ candidates.insert(GetRef<Stmt>(op));
}
record_.erase(var.get());
return;
}
}
- std::unordered_set<const Object*> candidates;
+ std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> candidates;
private:
bool in_likely_{false};
IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_);
if (!interval.IsNothing()) {
// cond is true within interval
- partitions[{cond.get(), true}] = interval;
+ partitions[{cond, true}] = interval;
}
PrimExpr inverse_cond = InverseCond(cond);
if (inverse_cond.defined()) {
IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
if (!interval.IsNothing()) {
// cond is false within interval
- partitions[{cond.get(), false}] = interval;
+ partitions[{cond, false}] = interval;
}
}
}
// Replace the set of conditions given by ps with cond_value (true or false)
class ConditionEliminator : public StmtExprMutator {
public:
- explicit ConditionEliminator(const std::unordered_set<const Object*>& ps, bool cond_value = true)
+ explicit ConditionEliminator(const ExpressionSet& ps, bool cond_value = true)
: ps_(ps), cond_value_(cond_value) {}
PrimExpr VisitExpr(const PrimExpr& e) final {
- if (ps_.find(e.get()) != ps_.end()) {
+ if (ps_.find(e) != ps_.end()) {
return VisitExpr(cond_value_ ? const_true() : const_false());
}
return StmtExprMutator::VisitExpr(e);
}
private:
- std::unordered_set<const Object*> ps_;
+ ExpressionSet ps_;
bool cond_value_;
};
// Insert the partition branch at the innermost thread scope
class ThreadPartitionInserter : public StmtMutator {
public:
- explicit ThreadPartitionInserter(const std::unordered_set<const Object*>& ps, PrimExpr cond)
+ explicit ThreadPartitionInserter(const ExpressionSet& ps, PrimExpr cond)
: ps_(ps), cond_(cond), innermost_thread_scope_(false) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
}
private:
- const std::unordered_set<const Object*>& ps_;
+ const ExpressionSet& ps_;
PrimExpr cond_;
bool innermost_thread_scope_;
};
}
Stmt VisitStmt_(const ForNode* op) final {
- if (selector.candidates.count(op)) {
- Stmt s = TryPartition(op, GetRef<Stmt>(op), op->loop_var, op->min, op->min + op->extent - 1,
- op->body, false);
+ auto fs = GetRef<Stmt>(op);
+ if (selector.candidates.count(fs)) {
+ Stmt s = TryPartition(fs, op->loop_var, op->min, op->min + op->extent - 1, op->body, false);
if (s.defined()) return s;
}
const IterVarNode* iv = op->node.as<IterVarNode>();
CHECK(iv);
Var var = iv->var;
- if (selector.candidates.count(op)) {
- Stmt s = TryPartition(op, GetRef<Stmt>(op), var, 0, op->value - 1, op->body, true);
+ auto as = GetRef<Stmt>(op);
+ if (selector.candidates.count(as)) {
+ Stmt s = TryPartition(as, var, 0, op->value - 1, op->body, true);
if (s.defined()) return s;
}
}
private:
- Stmt TryPartition(const Object* op, const Stmt& stmt, Var var, PrimExpr min, PrimExpr max,
- Stmt body, bool partition_thread_scope);
+ Stmt TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body,
+ bool partition_thread_scope);
- std::pair<IntSet, std::unordered_set<const Object*>> GetIntervalAndCondset(
- const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value);
+ std::pair<IntSet, ExpressionSet> GetIntervalAndCondset(const Partition& partitions,
+ const arith::IntervalSet& for_interval,
+ bool cond_value);
inline Stmt MakeFor(const Object* op, PrimExpr extent, Stmt body);
// Returns an interval (in the first component) in which all the conditions
// given in the second component provably have value given by cond_value
-std::pair<IntSet, std::unordered_set<const Object*>> LoopPartitioner::GetIntervalAndCondset(
+std::pair<IntSet, ExpressionSet> LoopPartitioner::GetIntervalAndCondset(
const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value) {
Array<IntSet> sets;
- std::unordered_set<const Object*> cond_set;
+ ExpressionSet cond_set;
for (const auto& kv : partitions) {
if (kv.first.second == cond_value) {
* which will eventually be simplified to empty code. And because only one loop was generated
* from loop 2 we stop recursing.
*/
-Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var, PrimExpr min,
- PrimExpr max, Stmt body, bool partition_thread_scope) {
+Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body,
+ bool partition_thread_scope) {
using namespace arith;
// include hint of var.
hint_map_.insert({var.get(), IntSet::Interval(min, max)});
arith::IntervalSet for_interval(min, max);
bool cond_value;
IntSet middle_interval;
- std::unordered_set<const Object*> cond_set;
+ ExpressionSet cond_set;
// find an interval in which all conditions on var are true
std::tie(middle_interval, cond_set) =
GetIntervalAndCondset(finder.partitions, for_interval, true);
}
if (!partition_thread_scope) {
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
- pre_stmt = MakeFor(node, body_begin - min, pre_body);
+ pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body);
}
}
} else {
}
if (!partition_thread_scope) {
Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
- post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
+ post_stmt = MakeFor(stmt.get(), max - post_doubt_begin + 1, post_body);
}
}
} else {
// [body_begin, post_doubt_begin)
Stmt simplified_body = ConditionEliminator(cond_set, cond_value)(body);
Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
- mid_stmt = MakeFor(node, post_doubt_begin - body_begin, new_body);
+ mid_stmt = MakeFor(stmt.get(), post_doubt_begin - body_begin, new_body);
// Recurse for each non-empty subrange only if there are at least
// two non-empty subranges