Cache object refs in loop partitioner instead of object pointers (#6004)
authorKrzysztof Parzyszek <kparzysz@quicinc.com>
Wed, 8 Jul 2020 00:34:10 +0000 (19:34 -0500)
committerGitHub <noreply@github.com>
Wed, 8 Jul 2020 00:34:10 +0000 (17:34 -0700)
* Cache object refs in loop partitioner instead of object pointers

Loop partitioner modifies the IR, which can cause TIR objects to
become dead and be destroyed. To avoid working on junk data cache
object references instead of object pointers.

* Fix format/lint errors

src/tir/transforms/loop_partition.cc

index d8d784b..23f41e1 100644 (file)
@@ -58,18 +58,27 @@ using arith::DeduceBound;
 using arith::Intersect;
 using arith::IntSet;
 
-using PartitionKey = std::pair<const Object*, bool>;
+using PartitionKey = std::pair<PrimExpr, bool>;
 struct PartitionKeyHash {
   std::size_t operator()(PartitionKey const& k) const noexcept {
-    std::size_t h1 = std::hash<const Object*>{}(k.first);
+    std::size_t h1 = ObjectPtrHash{}(k.first);  // NOLINT(whitespace/braces)
     std::size_t h2 = std::hash<bool>{}(k.second);
     return h1 ^ h2;
   }
 };
 
+struct PartitionKeyEqual {
+  bool operator()(const PartitionKey& k1, const PartitionKey& k2) const {
+    // NOLINTNEXTLINE(whitespace/braces)
+    return k1.second == k2.second && ObjectPtrEqual{}(k1.first, k2.first);
+  }
+};
+
 // Each mapping (cond, cond_value) -> interval represents the fact that
 // condition cond is proven to have value cond_value (true or false) in interval.
-using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash>;
+using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash, PartitionKeyEqual>;
+
+using ExpressionSet = std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
 
 bool ExprUseVars(PrimExpr expr, const std::unordered_set<const VarNode*>& vars) {
   bool success = false;
@@ -101,7 +110,7 @@ class CandidateSelector final : public StmtExprVisitor {
       record_.insert({var, false});
       StmtExprVisitor::VisitStmt_(op);
       if (record_.at(var) && !no_split_) {
-        candidates.insert(op);
+        candidates.insert(GetRef<Stmt>(op));
       }
       record_.erase(var);
     } else {
@@ -119,7 +128,7 @@ class CandidateSelector final : public StmtExprVisitor {
         record_.insert({var.get(), false});
         StmtExprVisitor::VisitStmt_(op);
         if (record_.at(var.get()) && !no_split_) {
-          candidates.insert(op);
+          candidates.insert(GetRef<Stmt>(op));
         }
         record_.erase(var.get());
         return;
@@ -160,7 +169,7 @@ class CandidateSelector final : public StmtExprVisitor {
     }
   }
 
-  std::unordered_set<const Object*> candidates;
+  std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> candidates;
 
  private:
   bool in_likely_{false};
@@ -224,14 +233,14 @@ class PartitionFinder : public StmtExprVisitor {
         IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_);
         if (!interval.IsNothing()) {
           // cond is true within interval
-          partitions[{cond.get(), true}] = interval;
+          partitions[{cond, true}] = interval;
         }
         PrimExpr inverse_cond = InverseCond(cond);
         if (inverse_cond.defined()) {
           IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
           if (!interval.IsNothing()) {
             // cond is false within interval
-            partitions[{cond.get(), false}] = interval;
+            partitions[{cond, false}] = interval;
           }
         }
       }
@@ -276,25 +285,25 @@ class PartitionFinder : public StmtExprVisitor {
 // Replace the set of conditions given by ps with cond_value (true or false)
 class ConditionEliminator : public StmtExprMutator {
  public:
-  explicit ConditionEliminator(const std::unordered_set<const Object*>& ps, bool cond_value = true)
+  explicit ConditionEliminator(const ExpressionSet& ps, bool cond_value = true)
       : ps_(ps), cond_value_(cond_value) {}
 
   PrimExpr VisitExpr(const PrimExpr& e) final {
-    if (ps_.find(e.get()) != ps_.end()) {
+    if (ps_.find(e) != ps_.end()) {
       return VisitExpr(cond_value_ ? const_true() : const_false());
     }
     return StmtExprMutator::VisitExpr(e);
   }
 
  private:
-  std::unordered_set<const Object*> ps_;
+  ExpressionSet ps_;
   bool cond_value_;
 };
 
 // Insert the partition branch at the innermost thread scope
 class ThreadPartitionInserter : public StmtMutator {
  public:
-  explicit ThreadPartitionInserter(const std::unordered_set<const Object*>& ps, PrimExpr cond)
+  explicit ThreadPartitionInserter(const ExpressionSet& ps, PrimExpr cond)
       : ps_(ps), cond_(cond), innermost_thread_scope_(false) {}
 
   Stmt VisitStmt_(const AttrStmtNode* op) final {
@@ -316,7 +325,7 @@ class ThreadPartitionInserter : public StmtMutator {
   }
 
  private:
-  const std::unordered_set<const Object*>& ps_;
+  const ExpressionSet& ps_;
   PrimExpr cond_;
   bool innermost_thread_scope_;
 };
@@ -334,9 +343,9 @@ class LoopPartitioner : public StmtMutator {
   }
 
   Stmt VisitStmt_(const ForNode* op) final {
-    if (selector.candidates.count(op)) {
-      Stmt s = TryPartition(op, GetRef<Stmt>(op), op->loop_var, op->min, op->min + op->extent - 1,
-                            op->body, false);
+    auto fs = GetRef<Stmt>(op);
+    if (selector.candidates.count(fs)) {
+      Stmt s = TryPartition(fs, op->loop_var, op->min, op->min + op->extent - 1, op->body, false);
       if (s.defined()) return s;
     }
 
@@ -356,8 +365,9 @@ class LoopPartitioner : public StmtMutator {
     const IterVarNode* iv = op->node.as<IterVarNode>();
     CHECK(iv);
     Var var = iv->var;
-    if (selector.candidates.count(op)) {
-      Stmt s = TryPartition(op, GetRef<Stmt>(op), var, 0, op->value - 1, op->body, true);
+    auto as = GetRef<Stmt>(op);
+    if (selector.candidates.count(as)) {
+      Stmt s = TryPartition(as, var, 0, op->value - 1, op->body, true);
       if (s.defined()) return s;
     }
 
@@ -378,11 +388,12 @@ class LoopPartitioner : public StmtMutator {
   }
 
  private:
-  Stmt TryPartition(const Object* op, const Stmt& stmt, Var var, PrimExpr min, PrimExpr max,
-                    Stmt body, bool partition_thread_scope);
+  Stmt TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body,
+                    bool partition_thread_scope);
 
-  std::pair<IntSet, std::unordered_set<const Object*>> GetIntervalAndCondset(
-      const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value);
+  std::pair<IntSet, ExpressionSet> GetIntervalAndCondset(const Partition& partitions,
+                                                         const arith::IntervalSet& for_interval,
+                                                         bool cond_value);
 
   inline Stmt MakeFor(const Object* op, PrimExpr extent, Stmt body);
 
@@ -395,10 +406,10 @@ class LoopPartitioner : public StmtMutator {
 
 // Returns an interval (in the first component) in which all the conditions
 // given in the second component provably have value given by cond_value
-std::pair<IntSet, std::unordered_set<const Object*>> LoopPartitioner::GetIntervalAndCondset(
+std::pair<IntSet, ExpressionSet> LoopPartitioner::GetIntervalAndCondset(
     const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value) {
   Array<IntSet> sets;
-  std::unordered_set<const Object*> cond_set;
+  ExpressionSet cond_set;
 
   for (const auto& kv : partitions) {
     if (kv.first.second == cond_value) {
@@ -460,8 +471,8 @@ std::pair<IntSet, std::unordered_set<const Object*>> LoopPartitioner::GetInterva
  * which will eventually be simplified to empty code. And because only one loop was generated
  * from loop 2 we stop recursing.
  */
-Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var, PrimExpr min,
-                                   PrimExpr max, Stmt body, bool partition_thread_scope) {
+Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body,
+                                   bool partition_thread_scope) {
   using namespace arith;
   // include hint of var.
   hint_map_.insert({var.get(), IntSet::Interval(min, max)});
@@ -475,7 +486,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var
   arith::IntervalSet for_interval(min, max);
   bool cond_value;
   IntSet middle_interval;
-  std::unordered_set<const Object*> cond_set;
+  ExpressionSet cond_set;
   // find an interval in which all conditions on var are true
   std::tie(middle_interval, cond_set) =
       GetIntervalAndCondset(finder.partitions, for_interval, true);
@@ -516,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var
       }
       if (!partition_thread_scope) {
         Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
-        pre_stmt = MakeFor(node, body_begin - min, pre_body);
+        pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body);
       }
     }
   } else {
@@ -541,7 +552,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var
       }
       if (!partition_thread_scope) {
         Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
-        post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
+        post_stmt = MakeFor(stmt.get(), max - post_doubt_begin + 1, post_body);
       }
     }
   } else {
@@ -557,7 +568,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var
       // [body_begin, post_doubt_begin)
       Stmt simplified_body = ConditionEliminator(cond_set, cond_value)(body);
       Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
-      mid_stmt = MakeFor(node, post_doubt_begin - body_begin, new_body);
+      mid_stmt = MakeFor(stmt.get(), post_doubt_begin - body_begin, new_body);
 
       // Recurse for each non-empty subrange only if there are at least
       // two non-empty subranges