From 37e5754862d438dbaec63fdcd3367993a1070981 Mon Sep 17 00:00:00 2001 From: yongfeng-nv <49211903+yongfeng-nv@users.noreply.github.com> Date: Sun, 26 Apr 2020 20:00:52 -0400 Subject: [PATCH] Improve IntervalSet's floormod (#5367) --- include/tvm/arith/analyzer.h | 25 +++++++-- include/tvm/arith/int_set.h | 12 +++-- src/arith/analyzer.cc | 29 +++++++---- src/arith/const_int_bound.cc | 18 ++++--- src/arith/int_set.cc | 10 ++++ src/te/operation/compute_op.cc | 14 ++--- src/te/schedule/bound.cc | 15 ++++-- src/te/schedule/message_passing.cc | 12 +++-- tests/python/unittest/test_arith_intset.py | 14 +++++ .../test_te_schedule_bound_inference_tiling.py | 60 ++++++++++++++++++++++ 10 files changed, 168 insertions(+), 41 deletions(-) create mode 100644 tests/python/unittest/test_te_schedule_bound_inference_tiling.py diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 6ca3ba9..c08c0d6 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -138,8 +138,9 @@ class ConstIntBoundAnalyzer { * * \param var The variable. * \param range The range we bind to. + * \param override Whether we allow overriding an existing var's range. */ - TVM_DLL void Bind(const Var& var, const Range& range); + TVM_DLL void Bind(const Var& var, const Range& range, bool override = false); private: friend class Analyzer; @@ -411,8 +412,9 @@ class TVM_DLL Analyzer { * * \param var The variable. * \param expr The expression we bind to. + * \param override Whether we allow overriding an existing var's expression. */ - void Bind(const Var& var, const PrimExpr& expr); + void Bind(const Var& var, const PrimExpr& expr, bool override = false); /*! * \brief Notify all the sub-analyzers that var * is created and binded to a range. @@ -421,14 +423,16 @@ class TVM_DLL Analyzer { * * \param var The variable. * \param range The range we bind to. + * \param override Whether we allow overriding an existing var's expression. */ - void Bind(const Var& var, const Range& range); + void Bind(const Var& var, const Range& range, bool override = false); /*! * \brief Bind all the vars in the Map * * \param variables The {variable -> range} map. + * \param override Whether we allow overriding an existing var's expression. */ - void Bind(const Map& variables); + void Bind(const Map& variables, bool override = false); /*! * \brief Whether can we prove expr >= val. @@ -443,6 +447,19 @@ class TVM_DLL Analyzer { */ bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound); /*! + * \brief Whether can we prove expr < val. + + * Non-negative proof is very useful in integer analysis + * to lower divisions and mods given difference in trunc and ceil mode. + * + * \param expr The expression. + * \param upper_bound The upper bound. + * \return Whether we can prove it. + * + * \note Analyzer will call into sub-analyzers to get the result. + */ + bool CanProveLess(const PrimExpr& expr, int64_t upper_bound); + /*! * \brief Whether can we prove condition. * * \param cond The expression to be proved. diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 86ef906..ab73b07 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -153,6 +153,13 @@ class IntSet : public ObjectRef { // Integer set legacy API. //------------------------------------------------ /*! + * \brief Convert std::unordered_map to Map + * + * \param dom_map The domain map to convert. + * \return The converted map. + */ +Map ConvertDomMap(const std::unordered_map& dom_map); +/*! * \brief Find an symbolic integer set that contains all possible values of * e given the domain of each iteration variables. * @@ -160,8 +167,7 @@ class IntSet : public ObjectRef { * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(PrimExpr e, - const Map& dom_map); +IntSet EvalSet(PrimExpr e, const Map& dom_map); /*! * \brief Same as EvalSet, but takes unordered_map * @@ -171,7 +177,6 @@ IntSet EvalSet(PrimExpr e, */ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map); - /*! * \brief Find an symbolic integer set that contains is union over * all the possible conditional values in dom_map. @@ -202,7 +207,6 @@ IntSet EvalSet(IntSet s, */ IntSet EvalSet(Range r, const std::unordered_map& dom_map); - /*! \brief Map from Expr to IntSet */ using ExprIntSetMap = std::unordered_map; /*! diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 83dfc64..9199bac 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -36,31 +36,31 @@ Analyzer::Analyzer() int_set(this) { } -void Analyzer::Bind(const Var& var, const PrimExpr& expr) { +void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) { PrimExpr new_expr = expr; new_expr = this->canonical_simplify(new_expr); new_expr = this->rewrite_simplify(new_expr); - this->const_int_bound.Update(var, this->const_int_bound(new_expr)); - this->modular_set.Update(var, this->modular_set(new_expr)); - this->rewrite_simplify.Update(var, new_expr); - this->canonical_simplify.Update(var, new_expr); + this->const_int_bound.Update(var, this->const_int_bound(new_expr), override); + this->modular_set.Update(var, this->modular_set(new_expr), override); + this->rewrite_simplify.Update(var, new_expr, override); + this->canonical_simplify.Update(var, new_expr, override); } -void Analyzer::Bind(const Var& var, const Range& range) { +void Analyzer::Bind(const Var& var, const Range& range, bool override) { CHECK(range.defined()); if (tir::is_one(range->extent)) { - this->Bind(var, range->min); + this->Bind(var, range->min, override); } else { - this->const_int_bound.Bind(var, range); + this->const_int_bound.Bind(var, range, override); } // skip modular_set // skip rewrite simplify } -void Analyzer::Bind(const Map& variables) { +void Analyzer::Bind(const Map& variables, bool override) { for (const auto& iter : variables) { - this->Bind(iter.first, iter.second); + this->Bind(iter.first, iter.second, override); } } @@ -92,6 +92,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { return false; } +bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) { + if (const auto* ptr = expr.as()) { + return ptr->value < upper_bound; + } + auto bd = this->const_int_bound(this->rewrite_simplify(expr)); + if (bd->max_value < upper_bound) return true; + return false; +} + bool Analyzer::CanProve(const PrimExpr& expr) { if (const auto* ptr = expr.as()) { return ptr->value != 0; diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 57dfc15..bb7c3dd 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -99,13 +99,13 @@ class ConstIntBoundAnalyzer::Impl : } }; - void Bind(const Var& var, const Range& range) { + void Bind(const Var& var, const Range& range, bool override) { Entry a = VisitExpr(range->min); Entry b = VisitExpr(range->extent); Entry ret; ret.min_value = a.min_value; ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1)); - Update(var, ret, false); + Update(var, ret, override); } void Update(const Var& var, @@ -150,10 +150,12 @@ class ConstIntBoundAnalyzer::Impl : const PrimExprNode* op = expr.as(); auto val = bound_->find(op); if (val != bound_->end()) { - CHECK(val->second->min_value == res.min_value && - val->second->max_value == res.max_value) - << "Detected bound for " << expr - << "conflicts with memorization"; + auto everything = Everything(op->dtype); + CHECK( + (val->second->min_value == res.min_value && val->second->max_value == res.max_value) || + (val->second->min_value == everything.min_value && + val->second->max_value == everything.max_value)) + << "Detected bound for " << expr << "conflicts with memorization"; } (*bound_)[op] = ConstIntBound(res.min_value, res.max_value); } @@ -574,8 +576,8 @@ void ConstIntBoundAnalyzer::Update(const Var& var, impl_->Update(var, info, override); } -void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) { - impl_->Bind(var, range); +void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool override) { + impl_->Bind(var, range, override); } std::function ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) { diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 027259a..d2d43d6 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -311,6 +311,16 @@ inline IntervalSet Combine(Analyzer* analyzer, LOG(FATAL) << "Modular by zero in CombineInterval Mod"; } if (analyzer->CanProveGreaterEqual(divisor, 0)) { + if (divisor.as()) { + // a mod b = a - (a / b) * b if a_max / b == a_min / b + auto qmax = floordiv(a->max_value, divisor); + auto qmin = floordiv(a->min_value, divisor); + if (analyzer->CanProve(qmax == qmin)) { + auto tmax = a->max_value - divisor * qmin; + auto tmin = a->min_value - divisor * qmin; + return IntervalSet(tmin, tmax); + } + } return IntervalSet(make_zero(divisor.dtype()), divisor - 1); } else { PrimExpr bound = abs(divisor) - 1; diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 2d9f13b..1248547 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -231,7 +231,7 @@ void ComputeOpNode::PropBoundToInputs( // undefined behaviour), so we can intersect the estimated set of the argument with the // range expected by the tensor. However, intersection may result in overly complex // expressions, so we perform a more relaxed form of intersection. - IntSet arg_intset = EvalSet(call->args[i], dom_map); + IntSet arg_intset = analyzer->int_set(call->args[i], ConvertDomMap(dom_map)); const arith::IntervalSetNode* arg_interval = arg_intset.as(); if (arg_interval) { PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype()); @@ -239,12 +239,14 @@ void ComputeOpNode::PropBoundToInputs( PrimExpr min_value = arg_interval->min_value; PrimExpr max_value = arg_interval->max_value; // Prefer the shape bounds only when we can prove they are tighter. - if (arith::is_neg_inf(min_value) || - analyzer->CanProve(shape_i_min_value >= min_value)) { + // We must update bound's ends in pairs. Here is an counter example: shape_i is + // [0, 0] and arg_interval is [threadIdx.y, threadIdx.y], where threadIdx.y's range is + // [0, 7]. If we allowed updating one end, the bound would become [threadIdx.y, 0], + // awkward for further analysis. + if ((arith::is_pos_inf(max_value) && arith::is_neg_inf(min_value)) || + (analyzer->CanProve(shape_i_min_value >= min_value) && + analyzer->CanProve(shape_i_max_value <= max_value))) { min_value = shape_i_min_value; - } - if (arith::is_pos_inf(max_value) || - analyzer->CanProve(shape_i_max_value <= max_value)) { max_value = shape_i_max_value; } dom.data[i].push_back(IntSet::interval(min_value, max_value)); diff --git a/src/te/schedule/bound.cc b/src/te/schedule/bound.cc index 4dde945..552d7b7 100644 --- a/src/te/schedule/bound.cc +++ b/src/te/schedule/bound.cc @@ -137,7 +137,7 @@ void InferRootBound(const Stage& stage, Array stage_attach = ctx.attach_path.at(stage->op); // The parent set. for (const Operation& op : consumers) { - std::unordered_map relax_set; + Map relax_set; std::unordered_map up_state; bool found_attach = false; CHECK(ctx.op2stage_.count(op.get())); @@ -176,9 +176,9 @@ void InferRootBound(const Stage& stage, << "InferBound requires every leaf iter var's min equals 0, " << "call schedule.normalize to achieve this."; if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) { - relax_set[iv->var.get()] = IntSet::range(vrange); + relax_set.Set(iv->var, IntSet::range(vrange)); if (ctx.bind_map.count(iv)) { - relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange); + relax_set.Set(ctx.bind_map.at(iv)->var, IntSet::range(vrange)); } } } @@ -190,6 +190,9 @@ void InferRootBound(const Stage& stage, // Relax if needed. std::unordered_map dom_map; arith::Analyzer analyzer; + for (auto entry : *rmap) { + analyzer.Bind(entry.first->var, entry.second); + } for (auto iv : op->root_iter_vars()) { Range r; if (up_state.count(iv)) { @@ -198,11 +201,13 @@ void InferRootBound(const Stage& stage, r = iv->dom; } if (relax_set.size() != 0) { - dom_map[iv->var.get()] = EvalSet(r, relax_set); + dom_map[iv->var.get()] = IntSet::interval( + analyzer.int_set(r->min, relax_set).min(), + analyzer.int_set(r->min + r->extent - 1, relax_set).max()); } else { dom_map[iv->var.get()] = IntSet::range(r); } - analyzer.Bind(iv->var, r); + analyzer.Bind(iv->var, r, true); } op->PropBoundToInputs(op, &analyzer, dom_map, &tmap); } diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 1453ed0..6ae7464 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -579,11 +579,15 @@ std::vector MakeBoundCheck( PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer); std::vector preds; - std::unordered_map iset_dmap; + Map iset_dmap; // setup domain map for set analysis for (const auto& kv : dom_map) { - iset_dmap[kv.first->var.get()] = IntSet::range(kv.second); + iset_dmap.Set(kv.first->var, IntSet::range(kv.second)); + } + + for (auto entry : dom_map) { + analyzer.Bind(entry.first->var, entry.second); } for (const IterVar& iv : stage->all_iter_vars) { @@ -591,7 +595,7 @@ std::vector MakeBoundCheck( if (bound_state.at(iv)) { Range dom = dom_map.at(iv); PrimExpr value = value_map.at(iv) - dom->min; - PrimExpr vmax = EvalSet(value, iset_dmap).max(); + PrimExpr vmax = analyzer.int_set(value, iset_dmap).max(); if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) { preds.emplace_back(value < dom->extent); } @@ -603,7 +607,7 @@ std::vector MakeBoundCheck( CHECK(iv->dom.defined()); if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) { PrimExpr value = value_map.at(iv) - iv->dom->min; - IntSet s = EvalSet(value, iset_dmap); + IntSet s = analyzer.int_set(value, iset_dmap); PrimExpr vmin = s.min(); PrimExpr vmax = s.max(); // The range of `value` resides in [vmin, vmax] diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index e57dcef..9919c7b 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -90,6 +90,20 @@ def test_mod(): flm = tvm.te.floormod ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 5)}, (3, 5)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(13, 15)}, (3, 5)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 15)}, (0, 9)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 11)}, (0, 9)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(1, 21)}, (0, 9)) + + floordiv = tvm.te.floordiv + z = te.var("z") + ck.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 3)) + ck.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, + (0, 7)) + ck1 = IntSetChecker() + ck1.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 2)) + ck1.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, (x*4, x*4+3)) def test_max_min(): diff --git a/tests/python/unittest/test_te_schedule_bound_inference_tiling.py b/tests/python/unittest/test_te_schedule_bound_inference_tiling.py new file mode 100644 index 0000000..3893bb6 --- /dev/null +++ b/tests/python/unittest/test_te_schedule_bound_inference_tiling.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te + +def test_bound_tile_mod(): + def compute(M_tiles, N_tiles, factor, dtype): + # Algo + M = M_tiles * factor + N = N_tiles * factor + + A = tvm.te.placeholder((N, M), name='A', dtype=dtype) + C = tvm.te.compute((N, M), lambda n, m: A[n, m], name='C') + s = tvm.te.create_schedule(C.op) + + return s, A, C + + def schedule(s, factor, padding, A, C): + C_local = s.cache_write(C, "local") + + n, m = C.op.axis + bn, bm, ni, mi = s[C].tile(n, m, factor, factor) + nio, nii = s[C].split(ni, 2) + n = s[C].fuse(nii, mi) + C_shared = s.cache_write(C, "shared") + bn, bm, ni, mi = C_shared.op.axis + s[C_shared].storage_align(ni, factor * 2, padding) + + n, m = s[C].op.axis + bn, bm, ni, mi = s[C].tile(n, m, factor, factor) + s[C].set_scope("global") + niio, niii = s[C].split(ni, 32) + s[C_shared].compute_at(s[C], niio) + + return s + + s, A, C = compute(2, 2, 128, "float16") + s = schedule(s, 128, 8, A, C) + bounds = tvm.te.schedule.InferBound(s) + check = (bounds[s.stages[2].op.axis[2]].extent == 16) + if(not check): + print(tvm.lower(s, [A, C], simple_mode=True)) + assert(check) + +if __name__ == "__main__": + test_bound_tile_mod() -- 2.7.4