*
* \param var The variable.
* \param range The range we bind to.
+ * \param override Whether we allow overriding an existing var's range.
*/
- TVM_DLL void Bind(const Var& var, const Range& range);
+ TVM_DLL void Bind(const Var& var, const Range& range, bool override = false);
private:
friend class Analyzer;
*
* \param var The variable.
* \param expr The expression we bind to.
+ * \param override Whether we allow overriding an existing var's expression.
*/
- void Bind(const Var& var, const PrimExpr& expr);
+ void Bind(const Var& var, const PrimExpr& expr, bool override = false);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* \param var The variable.
* \param range The range we bind to.
+ * \param override Whether we allow overriding an existing var's expression.
*/
- void Bind(const Var& var, const Range& range);
+ void Bind(const Var& var, const Range& range, bool override = false);
/*!
* \brief Bind all the vars in the Map
*
* \param variables The {variable -> range} map.
+ * \param override Whether we allow overriding an existing var's expression.
*/
- void Bind(const Map<Var, Range>& variables);
+ void Bind(const Map<Var, Range>& variables, bool override = false);
/*!
* \brief Whether can we prove expr >= val.
*/
bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
/*!
+ * \brief Whether can we prove expr < val.
+
+ * Non-negative proof is very useful in integer analysis
+ * to lower divisions and mods given difference in trunc and ceil mode.
+ *
+ * \param expr The expression.
+ * \param upper_bound The upper bound.
+ * \return Whether we can prove it.
+ *
+ * \note Analyzer will call into sub-analyzers to get the result.
+ */
+ bool CanProveLess(const PrimExpr& expr, int64_t upper_bound);
+ /*!
* \brief Whether can we prove condition.
*
* \param cond The expression to be proved.
// Integer set legacy API.
//------------------------------------------------
/*!
+ * \brief Convert std::unordered_map<const VarNode*, IntSet> to Map<Var, IntSet>
+ *
+ * \param dom_map The domain map to convert.
+ * \return The converted map.
+ */
+Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>& dom_map);
+/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
*
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
-IntSet EvalSet(PrimExpr e,
- const Map<IterVar, IntSet>& dom_map);
+IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
*/
IntSet EvalSet(PrimExpr e,
const std::unordered_map<const tir::VarNode*, IntSet>& dom_map);
-
/*!
* \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map.
*/
IntSet EvalSet(Range r,
const std::unordered_map<const VarNode*, IntSet>& dom_map);
-
/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectHash, ObjectEqual>;
/*!
int_set(this) {
}
-void Analyzer::Bind(const Var& var, const PrimExpr& expr) {
+void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) {
PrimExpr new_expr = expr;
new_expr = this->canonical_simplify(new_expr);
new_expr = this->rewrite_simplify(new_expr);
- this->const_int_bound.Update(var, this->const_int_bound(new_expr));
- this->modular_set.Update(var, this->modular_set(new_expr));
- this->rewrite_simplify.Update(var, new_expr);
- this->canonical_simplify.Update(var, new_expr);
+ this->const_int_bound.Update(var, this->const_int_bound(new_expr), override);
+ this->modular_set.Update(var, this->modular_set(new_expr), override);
+ this->rewrite_simplify.Update(var, new_expr, override);
+ this->canonical_simplify.Update(var, new_expr, override);
}
-void Analyzer::Bind(const Var& var, const Range& range) {
+void Analyzer::Bind(const Var& var, const Range& range, bool override) {
CHECK(range.defined());
if (tir::is_one(range->extent)) {
- this->Bind(var, range->min);
+ this->Bind(var, range->min, override);
} else {
- this->const_int_bound.Bind(var, range);
+ this->const_int_bound.Bind(var, range, override);
}
// skip modular_set
// skip rewrite simplify
}
-void Analyzer::Bind(const Map<Var, Range>& variables) {
+void Analyzer::Bind(const Map<Var, Range>& variables, bool override) {
for (const auto& iter : variables) {
- this->Bind(iter.first, iter.second);
+ this->Bind(iter.first, iter.second, override);
}
}
return false;
}
+bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) {
+ if (const auto* ptr = expr.as<tir::IntImmNode>()) {
+ return ptr->value < upper_bound;
+ }
+ auto bd = this->const_int_bound(this->rewrite_simplify(expr));
+ if (bd->max_value < upper_bound) return true;
+ return false;
+}
+
bool Analyzer::CanProve(const PrimExpr& expr) {
if (const auto* ptr = expr.as<IntImmNode>()) {
return ptr->value != 0;
}
};
- void Bind(const Var& var, const Range& range) {
+ void Bind(const Var& var, const Range& range, bool override) {
Entry a = VisitExpr(range->min);
Entry b = VisitExpr(range->extent);
Entry ret;
ret.min_value = a.min_value;
ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1));
- Update(var, ret, false);
+ Update(var, ret, override);
}
void Update(const Var& var,
const PrimExprNode* op = expr.as<PrimExprNode>();
auto val = bound_->find(op);
if (val != bound_->end()) {
- CHECK(val->second->min_value == res.min_value &&
- val->second->max_value == res.max_value)
- << "Detected bound for " << expr
- << "conflicts with memorization";
+ auto everything = Everything(op->dtype);
+ CHECK(
+ (val->second->min_value == res.min_value && val->second->max_value == res.max_value) ||
+ (val->second->min_value == everything.min_value &&
+ val->second->max_value == everything.max_value))
+ << "Detected bound for " << expr << "conflicts with memorization";
}
(*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
}
impl_->Update(var, info, override);
}
-void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) {
- impl_->Bind(var, range);
+void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool override) {
+ impl_->Bind(var, range, override);
}
std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) {
LOG(FATAL) << "Modular by zero in CombineInterval Mod";
}
if (analyzer->CanProveGreaterEqual(divisor, 0)) {
+ if (divisor.as<tir::IntImmNode>()) {
+ // a mod b = a - (a / b) * b if a_max / b == a_min / b
+ auto qmax = floordiv(a->max_value, divisor);
+ auto qmin = floordiv(a->min_value, divisor);
+ if (analyzer->CanProve(qmax == qmin)) {
+ auto tmax = a->max_value - divisor * qmin;
+ auto tmin = a->min_value - divisor * qmin;
+ return IntervalSet(tmin, tmax);
+ }
+ }
return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
} else {
PrimExpr bound = abs(divisor) - 1;
// undefined behaviour), so we can intersect the estimated set of the argument with the
// range expected by the tensor. However, intersection may result in overly complex
// expressions, so we perform a more relaxed form of intersection.
- IntSet arg_intset = EvalSet(call->args[i], dom_map);
+ IntSet arg_intset = analyzer->int_set(call->args[i], ConvertDomMap(dom_map));
const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
if (arg_interval) {
PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype());
PrimExpr min_value = arg_interval->min_value;
PrimExpr max_value = arg_interval->max_value;
// Prefer the shape bounds only when we can prove they are tighter.
- if (arith::is_neg_inf(min_value) ||
- analyzer->CanProve(shape_i_min_value >= min_value)) {
+ // We must update bound's ends in pairs. Here is an counter example: shape_i is
+ // [0, 0] and arg_interval is [threadIdx.y, threadIdx.y], where threadIdx.y's range is
+ // [0, 7]. If we allowed updating one end, the bound would become [threadIdx.y, 0],
+ // awkward for further analysis.
+ if ((arith::is_pos_inf(max_value) && arith::is_neg_inf(min_value)) ||
+ (analyzer->CanProve(shape_i_min_value >= min_value) &&
+ analyzer->CanProve(shape_i_max_value <= max_value))) {
min_value = shape_i_min_value;
- }
- if (arith::is_pos_inf(max_value) ||
- analyzer->CanProve(shape_i_max_value <= max_value)) {
max_value = shape_i_max_value;
}
dom.data[i].push_back(IntSet::interval(min_value, max_value));
Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
// The parent set.
for (const Operation& op : consumers) {
- std::unordered_map<const VarNode*, IntSet> relax_set;
+ Map<Var, IntSet> relax_set;
std::unordered_map<IterVar, IntSet> up_state;
bool found_attach = false;
CHECK(ctx.op2stage_.count(op.get()));
<< "InferBound requires every leaf iter var's min equals 0, "
<< "call schedule.normalize to achieve this.";
if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
- relax_set[iv->var.get()] = IntSet::range(vrange);
+ relax_set.Set(iv->var, IntSet::range(vrange));
if (ctx.bind_map.count(iv)) {
- relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange);
+ relax_set.Set(ctx.bind_map.at(iv)->var, IntSet::range(vrange));
}
}
}
// Relax if needed.
std::unordered_map<const VarNode*, IntSet> dom_map;
arith::Analyzer analyzer;
+ for (auto entry : *rmap) {
+ analyzer.Bind(entry.first->var, entry.second);
+ }
for (auto iv : op->root_iter_vars()) {
Range r;
if (up_state.count(iv)) {
r = iv->dom;
}
if (relax_set.size() != 0) {
- dom_map[iv->var.get()] = EvalSet(r, relax_set);
+ dom_map[iv->var.get()] = IntSet::interval(
+ analyzer.int_set(r->min, relax_set).min(),
+ analyzer.int_set(r->min + r->extent - 1, relax_set).max());
} else {
dom_map[iv->var.get()] = IntSet::range(r);
}
- analyzer.Bind(iv->var, r);
+ analyzer.Bind(iv->var, r, true);
}
op->PropBoundToInputs(op, &analyzer, dom_map, &tmap);
}
PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
std::vector<PrimExpr> preds;
- std::unordered_map<const VarNode*, IntSet> iset_dmap;
+ Map<Var, IntSet> iset_dmap;
// setup domain map for set analysis
for (const auto& kv : dom_map) {
- iset_dmap[kv.first->var.get()] = IntSet::range(kv.second);
+ iset_dmap.Set(kv.first->var, IntSet::range(kv.second));
+ }
+
+ for (auto entry : dom_map) {
+ analyzer.Bind(entry.first->var, entry.second);
}
for (const IterVar& iv : stage->all_iter_vars) {
if (bound_state.at(iv)) {
Range dom = dom_map.at(iv);
PrimExpr value = value_map.at(iv) - dom->min;
- PrimExpr vmax = EvalSet(value, iset_dmap).max();
+ PrimExpr vmax = analyzer.int_set(value, iset_dmap).max();
if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent);
}
CHECK(iv->dom.defined());
if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) {
PrimExpr value = value_map.at(iv) - iv->dom->min;
- IntSet s = EvalSet(value, iset_dmap);
+ IntSet s = analyzer.int_set(value, iset_dmap);
PrimExpr vmin = s.min();
PrimExpr vmax = s.max();
// The range of `value` resides in [vmin, vmax]
flm = tvm.te.floormod
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9))
+ ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 5)}, (3, 5))
+ ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(13, 15)}, (3, 5))
+ ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 15)}, (0, 9))
+ ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 11)}, (0, 9))
+ ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(1, 21)}, (0, 9))
+
+ floordiv = tvm.te.floordiv
+ z = te.var("z")
+ ck.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 3))
+ ck.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)},
+ (0, 7))
+ ck1 = IntSetChecker()
+ ck1.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 2))
+ ck1.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, (x*4, x*4+3))
def test_max_min():
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import te
+
+def test_bound_tile_mod():
+ def compute(M_tiles, N_tiles, factor, dtype):
+ # Algo
+ M = M_tiles * factor
+ N = N_tiles * factor
+
+ A = tvm.te.placeholder((N, M), name='A', dtype=dtype)
+ C = tvm.te.compute((N, M), lambda n, m: A[n, m], name='C')
+ s = tvm.te.create_schedule(C.op)
+
+ return s, A, C
+
+ def schedule(s, factor, padding, A, C):
+ C_local = s.cache_write(C, "local")
+
+ n, m = C.op.axis
+ bn, bm, ni, mi = s[C].tile(n, m, factor, factor)
+ nio, nii = s[C].split(ni, 2)
+ n = s[C].fuse(nii, mi)
+ C_shared = s.cache_write(C, "shared")
+ bn, bm, ni, mi = C_shared.op.axis
+ s[C_shared].storage_align(ni, factor * 2, padding)
+
+ n, m = s[C].op.axis
+ bn, bm, ni, mi = s[C].tile(n, m, factor, factor)
+ s[C].set_scope("global")
+ niio, niii = s[C].split(ni, 32)
+ s[C_shared].compute_at(s[C], niio)
+
+ return s
+
+ s, A, C = compute(2, 2, 128, "float16")
+ s = schedule(s, 128, 8, A, C)
+ bounds = tvm.te.schedule.InferBound(s)
+ check = (bounds[s.stages[2].op.axis[2]].extent == 16)
+ if(not check):
+ print(tvm.lower(s, [A, C], simple_mode=True))
+ assert(check)
+
+if __name__ == "__main__":
+ test_bound_tile_mod()