*/
class ConstIntBoundAnalyzer {
public:
+ using BoundMapType = std::unordered_map<PrimExpr, ConstIntBound, ObjectHash, ObjectEqual>;
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \param bound The lookup table to store the intermediate results
* \return the result of the analysis.
*/
- TVM_DLL ConstIntBound operator()(const PrimExpr& expr,
- std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);
+ TVM_DLL ConstIntBound operator()(const PrimExpr& expr, BoundMapType* bound);
/*!
* \brief Update constant int bound information of var.
}
}
if (bound_) {
- const PrimExprNode* op = expr.as<PrimExprNode>();
- auto val = bound_->find(op);
+ auto val = bound_->find(expr);
if (val != bound_->end()) {
- auto everything = Everything(op->dtype);
+ auto everything = Everything(expr->dtype);
CHECK(
(val->second->min_value == res.min_value && val->second->max_value == res.max_value) ||
(val->second->min_value == everything.min_value &&
val->second->max_value == everything.max_value))
<< "Detected bound for " << expr << "conflicts with memorization";
}
- (*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
+ (*bound_)[expr] = ConstIntBound(res.min_value, res.max_value);
}
return res;
}
// additional bound info
std::vector<BoundInfo> additional_info_;
// look up table for memorization
- std::unordered_map<const PrimExprNode*, ConstIntBound>* bound_{nullptr};
+ BoundMapType* bound_{nullptr};
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
}
ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr,
- std::unordered_map<const PrimExprNode*, ConstIntBound>* bound) {
+ BoundMapType* bound) {
impl_->bound_ = bound;
Entry ret = impl_->VisitExpr(expr);
impl_->bound_ = nullptr;
void VisitExpr(const PrimExpr& e) {
if (e.dtype().is_int()) {
int bits = max_bits_;
- const PrimExprNode* op = e.as<PrimExprNode>();
- if (bound_.find(op) == bound_.end()) {
+ if (bound_.find(e) == bound_.end()) {
analyzer_.const_int_bound(e, &bound_);
}
- ConstIntBound bound = bound_[op];
+ ConstIntBound bound = bound_[e];
int64_t ubound = Downcast<IntImm>(max_value(DataType::Int(target_bits_)))->value;
int64_t lbound = Downcast<IntImm>(min_value(DataType::Int(target_bits_)))->value;
if (e.dtype().bits() <= target_bits_ ||
// the extent of vars to be rewritten
std::unordered_map<const VarNode*, DataType> vextent_;
// the memorized bound generated by ConstIntBoundAnalyzer
- std::unordered_map<const PrimExprNode*, ConstIntBound> bound_;
+ arith::ConstIntBoundAnalyzer::BoundMapType bound_;
};
class DataTypeRewriter : public StmtExprMutator {