From 2875e4cddc4cfff53a19a7e69aa331119c8db7ef Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Tue, 7 Jul 2020 19:34:10 -0500 Subject: [PATCH] Cache object refs in loop partitioner instead of object pointers (#6004) * Cache object refs in loop partitioner instead of object pointers Loop partitioner modifies the IR, which can cause TIR objects to become dead and be destroyed. To avoid working on junk data cache object references instead of object pointers. * Fix format/lint errors --- src/tir/transforms/loop_partition.cc | 71 +++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 30 deletions(-) diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index d8d784b..23f41e1 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -58,18 +58,27 @@ using arith::DeduceBound; using arith::Intersect; using arith::IntSet; -using PartitionKey = std::pair; +using PartitionKey = std::pair; struct PartitionKeyHash { std::size_t operator()(PartitionKey const& k) const noexcept { - std::size_t h1 = std::hash{}(k.first); + std::size_t h1 = ObjectPtrHash{}(k.first); // NOLINT(whitespace/braces) std::size_t h2 = std::hash{}(k.second); return h1 ^ h2; } }; +struct PartitionKeyEqual { + bool operator()(const PartitionKey& k1, const PartitionKey& k2) const { + // NOLINTNEXTLINE(whitespace/braces) + return k1.second == k2.second && ObjectPtrEqual{}(k1.first, k2.first); + } +}; + // Each mapping (cond, cond_value) -> interval represents the fact that // condition cond is proven to have value cond_value (true or false) in interval. -using Partition = std::unordered_map; +using Partition = std::unordered_map; + +using ExpressionSet = std::unordered_set; bool ExprUseVars(PrimExpr expr, const std::unordered_set& vars) { bool success = false; @@ -101,7 +110,7 @@ class CandidateSelector final : public StmtExprVisitor { record_.insert({var, false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var) && !no_split_) { - candidates.insert(op); + candidates.insert(GetRef(op)); } record_.erase(var); } else { @@ -119,7 +128,7 @@ class CandidateSelector final : public StmtExprVisitor { record_.insert({var.get(), false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var.get()) && !no_split_) { - candidates.insert(op); + candidates.insert(GetRef(op)); } record_.erase(var.get()); return; @@ -160,7 +169,7 @@ class CandidateSelector final : public StmtExprVisitor { } } - std::unordered_set candidates; + std::unordered_set candidates; private: bool in_likely_{false}; @@ -224,14 +233,14 @@ class PartitionFinder : public StmtExprVisitor { IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_); if (!interval.IsNothing()) { // cond is true within interval - partitions[{cond.get(), true}] = interval; + partitions[{cond, true}] = interval; } PrimExpr inverse_cond = InverseCond(cond); if (inverse_cond.defined()) { IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); if (!interval.IsNothing()) { // cond is false within interval - partitions[{cond.get(), false}] = interval; + partitions[{cond, false}] = interval; } } } @@ -276,25 +285,25 @@ class PartitionFinder : public StmtExprVisitor { // Replace the set of conditions given by ps with cond_value (true or false) class ConditionEliminator : public StmtExprMutator { public: - explicit ConditionEliminator(const std::unordered_set& ps, bool cond_value = true) + explicit ConditionEliminator(const ExpressionSet& ps, bool cond_value = true) : ps_(ps), cond_value_(cond_value) {} PrimExpr VisitExpr(const PrimExpr& e) final { - if (ps_.find(e.get()) != ps_.end()) { + if (ps_.find(e) != ps_.end()) { return VisitExpr(cond_value_ ? const_true() : const_false()); } return StmtExprMutator::VisitExpr(e); } private: - std::unordered_set ps_; + ExpressionSet ps_; bool cond_value_; }; // Insert the partition branch at the innermost thread scope class ThreadPartitionInserter : public StmtMutator { public: - explicit ThreadPartitionInserter(const std::unordered_set& ps, PrimExpr cond) + explicit ThreadPartitionInserter(const ExpressionSet& ps, PrimExpr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -316,7 +325,7 @@ class ThreadPartitionInserter : public StmtMutator { } private: - const std::unordered_set& ps_; + const ExpressionSet& ps_; PrimExpr cond_; bool innermost_thread_scope_; }; @@ -334,9 +343,9 @@ class LoopPartitioner : public StmtMutator { } Stmt VisitStmt_(const ForNode* op) final { - if (selector.candidates.count(op)) { - Stmt s = TryPartition(op, GetRef(op), op->loop_var, op->min, op->min + op->extent - 1, - op->body, false); + auto fs = GetRef(op); + if (selector.candidates.count(fs)) { + Stmt s = TryPartition(fs, op->loop_var, op->min, op->min + op->extent - 1, op->body, false); if (s.defined()) return s; } @@ -356,8 +365,9 @@ class LoopPartitioner : public StmtMutator { const IterVarNode* iv = op->node.as(); CHECK(iv); Var var = iv->var; - if (selector.candidates.count(op)) { - Stmt s = TryPartition(op, GetRef(op), var, 0, op->value - 1, op->body, true); + auto as = GetRef(op); + if (selector.candidates.count(as)) { + Stmt s = TryPartition(as, var, 0, op->value - 1, op->body, true); if (s.defined()) return s; } @@ -378,11 +388,12 @@ class LoopPartitioner : public StmtMutator { } private: - Stmt TryPartition(const Object* op, const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, - Stmt body, bool partition_thread_scope); + Stmt TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body, + bool partition_thread_scope); - std::pair> GetIntervalAndCondset( - const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value); + std::pair GetIntervalAndCondset(const Partition& partitions, + const arith::IntervalSet& for_interval, + bool cond_value); inline Stmt MakeFor(const Object* op, PrimExpr extent, Stmt body); @@ -395,10 +406,10 @@ class LoopPartitioner : public StmtMutator { // Returns an interval (in the first component) in which all the conditions // given in the second component provably have value given by cond_value -std::pair> LoopPartitioner::GetIntervalAndCondset( +std::pair LoopPartitioner::GetIntervalAndCondset( const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value) { Array sets; - std::unordered_set cond_set; + ExpressionSet cond_set; for (const auto& kv : partitions) { if (kv.first.second == cond_value) { @@ -460,8 +471,8 @@ std::pair> LoopPartitioner::GetInterva * which will eventually be simplified to empty code. And because only one loop was generated * from loop 2 we stop recursing. */ -Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var, PrimExpr min, - PrimExpr max, Stmt body, bool partition_thread_scope) { +Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body, + bool partition_thread_scope) { using namespace arith; // include hint of var. hint_map_.insert({var.get(), IntSet::Interval(min, max)}); @@ -475,7 +486,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var arith::IntervalSet for_interval(min, max); bool cond_value; IntSet middle_interval; - std::unordered_set cond_set; + ExpressionSet cond_set; // find an interval in which all conditions on var are true std::tie(middle_interval, cond_set) = GetIntervalAndCondset(finder.partitions, for_interval, true); @@ -516,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var } if (!partition_thread_scope) { Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); - pre_stmt = MakeFor(node, body_begin - min, pre_body); + pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body); } } } else { @@ -541,7 +552,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var } if (!partition_thread_scope) { Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}}); - post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); + post_stmt = MakeFor(stmt.get(), max - post_doubt_begin + 1, post_body); } } } else { @@ -557,7 +568,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var // [body_begin, post_doubt_begin) Stmt simplified_body = ConditionEliminator(cond_set, cond_value)(body); Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}}); - mid_stmt = MakeFor(node, post_doubt_begin - body_begin, new_body); + mid_stmt = MakeFor(stmt.get(), post_doubt_begin - body_begin, new_body); // Recurse for each non-empty subrange only if there are at least // two non-empty subranges -- 2.7.4