From: Krzysztof Parzyszek Date: Thu, 7 May 2020 18:18:26 +0000 (-0500) Subject: Cache PrimExpr instead of raw pointers in bound analyzer (#5533) X-Git-Tag: upstream/0.7.0~782 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f05b9119e0612df59172301bdfbe4fabcd8605e9;p=platform%2Fupstream%2Ftvm.git Cache PrimExpr instead of raw pointers in bound analyzer (#5533) The objects that the raw pointers point to can be deallocated and new objects can be allocated at the same address, all while these pointers are still in the cache. This can lead to unexpected behavior, for example to calculated bound conflicts with previously cached values. Caching PrimExpr will prevent the objects from being deallocated while the cache is active. --- diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index c08c0d6..340da7f 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -107,6 +107,7 @@ class ConstIntBound : public ObjectRef { */ class ConstIntBoundAnalyzer { public: + using BoundMapType = std::unordered_map; /*! * \brief analyze the expr * \param expr The expression of interest. @@ -120,8 +121,7 @@ class ConstIntBoundAnalyzer { * \param bound The lookup table to store the intermediate results * \return the result of the analysis. */ - TVM_DLL ConstIntBound operator()(const PrimExpr& expr, - std::unordered_map* bound); + TVM_DLL ConstIntBound operator()(const PrimExpr& expr, BoundMapType* bound); /*! * \brief Update constant int bound information of var. diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index bb7c3dd..4437225 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -147,17 +147,16 @@ class ConstIntBoundAnalyzer::Impl : } } if (bound_) { - const PrimExprNode* op = expr.as(); - auto val = bound_->find(op); + auto val = bound_->find(expr); if (val != bound_->end()) { - auto everything = Everything(op->dtype); + auto everything = Everything(expr->dtype); CHECK( (val->second->min_value == res.min_value && val->second->max_value == res.max_value) || (val->second->min_value == everything.min_value && val->second->max_value == everything.max_value)) << "Detected bound for " << expr << "conflicts with memorization"; } - (*bound_)[op] = ConstIntBound(res.min_value, res.max_value); + (*bound_)[expr] = ConstIntBound(res.min_value, res.max_value); } return res; } @@ -369,7 +368,7 @@ class ConstIntBoundAnalyzer::Impl : // additional bound info std::vector additional_info_; // look up table for memorization - std::unordered_map* bound_{nullptr}; + BoundMapType* bound_{nullptr}; // constants: the limit value means umlimited // NOTE: kNegInf/kPosInf are used to represent infinity. static const constexpr int64_t kNegInf = ConstIntBound::kNegInf; @@ -563,7 +562,7 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) { } ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, - std::unordered_map* bound) { + BoundMapType* bound) { impl_->bound_ = bound; Entry ret = impl_->VisitExpr(expr); impl_->bound_ = nullptr; diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 4cf5ccd..796e39b 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -76,11 +76,10 @@ class DataTypeVisitor final : public StmtExprVisitor { void VisitExpr(const PrimExpr& e) { if (e.dtype().is_int()) { int bits = max_bits_; - const PrimExprNode* op = e.as(); - if (bound_.find(op) == bound_.end()) { + if (bound_.find(e) == bound_.end()) { analyzer_.const_int_bound(e, &bound_); } - ConstIntBound bound = bound_[op]; + ConstIntBound bound = bound_[e]; int64_t ubound = Downcast(max_value(DataType::Int(target_bits_)))->value; int64_t lbound = Downcast(min_value(DataType::Int(target_bits_)))->value; if (e.dtype().bits() <= target_bits_ || @@ -187,7 +186,7 @@ class DataTypeVisitor final : public StmtExprVisitor { // the extent of vars to be rewritten std::unordered_map vextent_; // the memorized bound generated by ConstIntBoundAnalyzer - std::unordered_map bound_; + arith::ConstIntBoundAnalyzer::BoundMapType bound_; }; class DataTypeRewriter : public StmtExprMutator {