Cache PrimExpr instead of raw pointers in bound analyzer (#5533)
authorKrzysztof Parzyszek <kparzysz@quicinc.com>
Thu, 7 May 2020 18:18:26 +0000 (13:18 -0500)
committerGitHub <noreply@github.com>
Thu, 7 May 2020 18:18:26 +0000 (11:18 -0700)
The objects that the raw pointers point to can be deallocated and new
objects can be allocated at the same address, all while these pointers
are still in the cache. This can lead to unexpected behavior, for
example to calculated bound conflicts with previously cached values.

Caching PrimExpr will prevent the objects from being deallocated while
the cache is active.

include/tvm/arith/analyzer.h
src/arith/const_int_bound.cc
src/tir/transforms/narrow_datatype.cc

index c08c0d6..340da7f 100644 (file)
@@ -107,6 +107,7 @@ class ConstIntBound : public ObjectRef {
  */
 class ConstIntBoundAnalyzer {
  public:
+  using BoundMapType = std::unordered_map<PrimExpr, ConstIntBound, ObjectHash, ObjectEqual>;
   /*!
    * \brief analyze the expr
    * \param expr The expression of interest.
@@ -120,8 +121,7 @@ class ConstIntBoundAnalyzer {
    * \param bound The lookup table to store the intermediate results
    * \return the result of the analysis.
    */
-  TVM_DLL ConstIntBound operator()(const PrimExpr& expr,
-                                   std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);
+  TVM_DLL ConstIntBound operator()(const PrimExpr& expr, BoundMapType* bound);
 
   /*!
    * \brief Update constant int bound information of var.
index bb7c3dd..4437225 100644 (file)
@@ -147,17 +147,16 @@ class ConstIntBoundAnalyzer::Impl :
       }
     }
     if (bound_) {
-      const PrimExprNode* op = expr.as<PrimExprNode>();
-      auto val = bound_->find(op);
+      auto val = bound_->find(expr);
       if (val != bound_->end()) {
-        auto everything = Everything(op->dtype);
+        auto everything = Everything(expr->dtype);
         CHECK(
             (val->second->min_value == res.min_value && val->second->max_value == res.max_value) ||
             (val->second->min_value == everything.min_value &&
              val->second->max_value == everything.max_value))
             << "Detected bound for " << expr << "conflicts with memorization";
       }
-      (*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
+      (*bound_)[expr] = ConstIntBound(res.min_value, res.max_value);
     }
     return res;
   }
@@ -369,7 +368,7 @@ class ConstIntBoundAnalyzer::Impl :
   // additional bound info
   std::vector<BoundInfo> additional_info_;
   // look up table for memorization
-  std::unordered_map<const PrimExprNode*, ConstIntBound>* bound_{nullptr};
+  BoundMapType* bound_{nullptr};
   // constants: the limit value means umlimited
   // NOTE: kNegInf/kPosInf are used to represent infinity.
   static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
@@ -563,7 +562,7 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) {
 }
 
 ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr,
-  std::unordered_map<const PrimExprNode*, ConstIntBound>* bound) {
+                                                BoundMapType* bound) {
   impl_->bound_ = bound;
   Entry ret = impl_->VisitExpr(expr);
   impl_->bound_ = nullptr;
index 4cf5ccd..796e39b 100644 (file)
@@ -76,11 +76,10 @@ class DataTypeVisitor final : public StmtExprVisitor {
   void VisitExpr(const PrimExpr& e) {
     if (e.dtype().is_int()) {
       int bits = max_bits_;
-      const PrimExprNode* op = e.as<PrimExprNode>();
-      if (bound_.find(op) == bound_.end()) {
+      if (bound_.find(e) == bound_.end()) {
         analyzer_.const_int_bound(e, &bound_);
       }
-      ConstIntBound bound = bound_[op];
+      ConstIntBound bound = bound_[e];
       int64_t ubound = Downcast<IntImm>(max_value(DataType::Int(target_bits_)))->value;
       int64_t lbound = Downcast<IntImm>(min_value(DataType::Int(target_bits_)))->value;
       if (e.dtype().bits() <= target_bits_ ||
@@ -187,7 +186,7 @@ class DataTypeVisitor final : public StmtExprVisitor {
   // the extent of vars to be rewritten
   std::unordered_map<const VarNode*, DataType> vextent_;
   // the memorized bound generated by ConstIntBoundAnalyzer
-  std::unordered_map<const PrimExprNode*, ConstIntBound> bound_;
+  arith::ConstIntBoundAnalyzer::BoundMapType bound_;
 };
 
 class DataTypeRewriter : public StmtExprMutator {