bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) {
body_begin = ir::Simplify(middle_interval.min());
- Expr cond = (body_begin - min >= 0);
- if (!analyzer_.CanProve(cond)) {
- LOG(WARNING) << "Cannot prove: " << cond
- << ", when generating the pre doubt loop";
- body_begin = Max::make(body_begin, min);
- // stop recursing on this interval if we can't prove it has non-negative length
- pre_stmt_recurse = false;
- }
- if (!partition_thread_scope) {
- Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
- pre_stmt = MakeFor(node, body_begin - min, pre_body);
+ if (!analyzer_.CanProve(body_begin == min)) {
+ Expr cond = (body_begin - min >= 0);
+ if (!analyzer_.CanProve(cond)) {
+ LOG(WARNING) << "Cannot prove: " << cond
+ << ", when generating the pre doubt loop";
+ body_begin = Max::make(body_begin, min);
+ // stop recursing on this interval if we can't prove it has non-negative length
+ pre_stmt_recurse = false;
+ }
+ if (!partition_thread_scope) {
+ Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
+ pre_stmt = MakeFor(node, body_begin - min, pre_body);
+ }
}
} else {
body_begin = min;
bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
- // require the extent to be non-negative
- Expr cond = (max - post_doubt_begin + 1 >= 0);
- if (!analyzer_.CanProve(cond)) {
- LOG(WARNING) << "Cannot prove: " << cond
- << ", when generating the post doubt loop";
- post_doubt_begin = Min::make(post_doubt_begin, max+1);
- // stop recursing on this interval if we can't prove it has non-negative length
- post_stmt_recurse = false;
- }
- if (!partition_thread_scope) {
- Stmt post_body =
- Substitute(body, {{Var{var}, var + post_doubt_begin}});
- post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
+ if (!analyzer_.CanProve(middle_interval.max() == max)) {
+ // require the extent to be non-negative
+ Expr cond = (max - post_doubt_begin + 1 >= 0);
+ if (!analyzer_.CanProve(cond)) {
+ LOG(WARNING) << "Cannot prove: " << cond
+ << ", when generating the post doubt loop";
+ post_doubt_begin = Min::make(post_doubt_begin, max+1);
+ // stop recursing on this interval if we can't prove it has non-negative length
+ post_stmt_recurse = false;
+ }
+ if (!partition_thread_scope) {
+ Stmt post_body =
+ Substitute(body, {{Var{var}, var + post_doubt_begin}});
+ post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
+ }
}
} else {
post_doubt_begin = max + 1;
stmt = tvm.ir_pass.Simplify(stmt)
assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
+
+def test_multilevel_splitting_with_indivisble_factors():
+ import topi
+ A = tvm.placeholder((130,), dtype="float32")
+ B = topi.nn.relu(A)
+ s = tvm.create_schedule(B.op)
+ (y,) = s[B].op.axis
+ (yo, yi) = s[B].split(y, factor=8)
+ (yoo, yoi) = s[B].split(yo, factor=16)
+ s[B].reorder(yoo, yoi, yi)
+ s[B].unroll(yi)
+
+ ## But this does the right thing.
+ with tvm.build_config(partition_const_loop=True):
+ lowered_body = tvm.lower(s, [A, B]).body
+ def visit_stmt(op):
+ return(isinstance(op, tvm.expr.Max))
+ num_max = collect_visit(lowered_body, visit_stmt)
+ assert num_max.count(True) == 10
+
+
def test_double_splitting_with_indivisible_factors():
m = 48
dtype="float32"
test_cce_loop_3()
test_conv_tiling()
test_double_splitting_with_indivisible_factors()
+ test_multilevel_splitting_with_indivisble_factors()
test_simple_rfactor()