file(GLOB TOPI_SRCS
topi/src/*.cc
)
-file(GLOB_RECURSE HALIDEIR_SRCS 3rdparty/HalideIR/src/*.cpp)
+file(GLOB_RECURSE HALIDEIR_SRCS
+ 3rdparty/HalideIR/src/base/*.cpp
+ 3rdparty/HalideIR/src/ir/*.cpp
+ 3rdparty/HalideIR/src/tvm/*.cpp
+)
list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})
file(GLOB RUNTIME_SRCS
src/runtime/*.cc
* give the domain of each variables. Return undefined IntSet to
* represent failure.
*
+ * \note The returned set may be smaller than set that
+ * contains all possible values of v that satisfies the bound.
+ *
* \param v The target variable to be deduced.
* \param cond The conditional expression.
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
- * The deduce bound mush implies e for all value in relax_map
- * \return An integer set that can cover all the possible values.
+ * The deduce bound must implies e for all value in relax_map
+ * \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(Expr v, Expr cond,
const Map<Var, IntSet>& hint_map,
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map
- * \return An integer set that can cover all the possible values.
+ * \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(Expr v, Expr cond,
const std::unordered_map<const Variable*, IntSet>& hint_map,
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
-#include <arithmetic/Simplify.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expr Analyzer::Simplify(const Expr& expr) {
if (is_const(expr)) return expr;
auto res = this->rewrite_simplify(expr);
+ if (is_const(res)) return res;
res = this->canonical_simplify(res);
return res;
}
void Deduce();
void Visit(const NodeRef& e) final {
- if (!success) return;
+ if (!success_) return;
if (e.get() == path_[iter_++]) {
IRVisitor::Visit(e);
} else {
- success = false;
+ success_ = false;
return;
}
}
void Visit_(const Add* op) final {
bool left = op->a.get() == path_[iter_];
- result -= left ? op->b : op->a;
+ result_ -= left ? op->b : op->a;
Visit(left ? op->a : op->b);
}
void Visit_(const Sub* op) final {
bool left = op->a.get() == path_[iter_];
if (left) {
- result += op->b;
+ result_ += op->b;
} else {
- result -= op->a;
- result = - result;
- is_greater = !is_greater;
+ result_ -= op->a;
+ result_ = - result_;
+ is_greater_ = !is_greater_;
}
Visit(left ? op->a : op->b);
}
void Visit_(const Mul* op) final {
bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a;
+ Expr target_var = left ? op->a : op->b;
- SignType sign;
+ SignType sign_operand;
if (operand.type().is_uint()) {
- sign = kPositive;
+ sign_operand = kPositive;
} else {
- sign = expr_map_[operand].sign_type();
+ sign_operand = expr_map_[operand].sign_type();
}
- if (sign == SignType::kNegative) {
- is_greater = !is_greater;
- } else if (sign == SignType::kUnknown) {
+ if (sign_operand == SignType::kNegative) {
+ is_greater_ = !is_greater_;
+ } else if (sign_operand == SignType::kUnknown) {
// unable to get the sign of operand
- success = false;
+ success_ = false;
return;
}
-
// always use relax bound
- bool divided = can_prove(result % operand == 0);
- result = result / operand;
- // since system will round down when not divided
- // eg. 2/4 -> 0; -2/4 -> -1
- // no need fix for !is_greater:
- // eg. a <= 2/4 -> a <= 0
- // eg. a <= 0/4 -> a <= 0
- // so just fix for not divided and is_greater
- // eg. a >= 2/4 -> a >= 0 + 1
- // eg. a >= 0/4 -> a >= 0
- if (is_greater && !divided) {
- result += 1;
+ bool divided = analyzer_.CanProve(result_ % operand == 0);
+
+ result_ = result_ / operand;
+
+ if (!divided) {
+ // Handle non-divisible case
+ // NOTE: this accounts for truc div behavior.
+ bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative();
+
+ if (is_greater_) {
+ result_ += 1;
+ } else {
+ // NOTE: this is a bit sutble hack.
+ //
+ // condition:
+ // - x * operand <= result
+ // - operand > 0
+ // - x >= 0
+ //
+ // Then it is fine to deduce that x <= result / operand.
+ // - if result > 0, this division round down
+ // - if result < 0, (result / operand) rounds up and may violate the constraint
+ // however, given that x is always non-negative,
+ // it is fine to have this relaxed bound, given that the user of deduce bound
+ // will respect the bound of x
+ //
+ // TODO(tvm-team): think about a better API to incorporate constraint of x.
+ // e.g. specify an interval of x and return a bound
+ // that is in the interval and satisfies the condition.
+ if (target_is_non_neg && sign_operand == kPositive) {
+ // do nothing
+ } else {
+ result_ -= 1;
+ }
+ }
}
-
Visit(left ? op->a : op->b);
}
- Expr result;
- bool is_greater{true};
- bool success{true};
+ Expr result_;
+ bool is_greater_{true};
+ bool success_{true};
private:
void Init();
ExprIntSetMap expr_map_;
std::vector<const Node*> path_;
size_t iter_{0};
+ // internal analzyer
+ Analyzer analyzer_;
};
class BoundDeduceInputChecker: public IRVisitor {
void BoundDeducer::Init() {
BoundDeduceInputChecker checker;
- if (!checker.Check(this)) success = false;
+ if (!checker.Check(this)) success_ = false;
Transform();
}
if (const LT* op = expr_.as<LT>()) {
if (GetPath(target_, op->a).empty()) {
// a < b -> b >= a + 1
- is_greater = true;
+ is_greater_ = true;
expr_ = op->b;
- result = op->a + 1;
+ result_ = op->a + 1;
} else {
// a < b -> a <= b - 1
- is_greater = false;
+ is_greater_ = false;
expr_ = op->a;
- result = op->b - 1;
+ result_ = op->b - 1;
}
} else if (const LE* op = expr_.as<LE>()) {
if (GetPath(target_, op->a).empty()) {
// a <= b -> b >= a
- is_greater = true;
+ is_greater_ = true;
expr_ = op->b;
- result = op->a;
+ result_ = op->a;
} else {
- is_greater = false;
+ is_greater_ = false;
expr_ = op->a;
- result = op->b;
+ result_ = op->b;
}
} else if (const GT* op = expr_.as<GT>()) {
if (GetPath(target_, op->a).empty()) {
// a > b -> b <= a - 1
- is_greater = false;
+ is_greater_ = false;
expr_ = op->b;
- result = op->a - 1;
+ result_ = op->a - 1;
} else {
// a > b -> a >= b + 1
- is_greater = true;
+ is_greater_ = true;
expr_ = op->a;
- result = op->b + 1;
+ result_ = op->b + 1;
}
} else if (const GE* op = expr_.as<GE>()) {
if (GetPath(target_, op->a).empty()) {
// a >= b -> b <= a
- is_greater = false;
+ is_greater_ = false;
expr_ = op->b;
- result = op->a;
+ result_ = op->a;
} else {
- is_greater = true;
+ is_greater_ = true;
expr_ = op->a;
- result = op->b;
+ result_ = op->b;
}
} else {
- success = false;
+ success_ = false;
}
}
void BoundDeducer::Deduce() {
Init();
- if (!success) return;
+ if (!success_) return;
Relax();
- if (!success) return;
+ if (!success_) return;
// get the path
path_ = GetPath(target_, expr_);
if (!path_.size()) {
- success = false;
+ success_ = false;
return;
}
-
expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
Visit(expr_);
void BoundDeducer::Relax() {
IntSet a = EvalSet(expr_, relax_map_);
- IntSet b = EvalSet(result, relax_map_);
+ IntSet b = EvalSet(result_, relax_map_);
if (a.is_everything() || b.is_everything()) {
- success = false;
+ success_ = false;
return;
}
- expr_ = is_greater ? a.min() : a.max();
- result = is_greater ? b.max() : b.min();
+ expr_ = is_greater_ ? a.min() : a.max();
+ result_ = is_greater_ ? b.max() : b.min();
}
IntSet DeduceBound(Expr v, Expr e,
const std::unordered_map<const Variable*, IntSet>& relax_map) {
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce();
- if (!d.success) return IntSet::nothing();
+ if (!d.success_) return IntSet::nothing();
Expr min = neg_inf(), max = pos_inf();
- if (d.is_greater) {
- min = d.result;
+ if (d.is_greater_) {
+ min = d.result_;
} else {
- max = d.result;
+ max = d.result_;
}
return IntSet::interval(min, max);
}
inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type();
- // due to division and mod can have different modes
- // only constant fold positive number where rule is fixed.
- if (pa && pb && pa->value >= 0 && pb->value > 0) {
+ if (pa && pb) {
+ // due to division and mod can have different modes
+ // NOTE: this will assumes truc div.
+ CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm::make(rtype, pa->value / pb->value);
}
if (pa) {
TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y));
TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z));
-
TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y),
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include <tvm/arithmetic.h>
-#include "arithmetic/Simplify.h"
namespace tvm {
namespace arith {
return analyzer.canonical_simplify(expr);
}
-template<typename T>
-T Simplify_(T a, Map<Var, Range> vrange) {
- using namespace HalideIR::Internal;
- Scope<Interval> rscope;
+Expr Simplify(Expr expr, Map<Var, Range> vrange) {
+ arith::Analyzer analyzer;
for (auto kv : vrange) {
- Range r = kv.second;
- rscope.push(
- kv.first.get(),
- Interval(r->min,
- simplify(r->min + r->extent - make_const(r->min.type(), 1))));
- }
- return HalideIR::Internal::simplify(a, true, rscope);
-}
-
-
-Expr Simplify(Expr a, Map<Var, Range> vrange) {
- // Simplify top level reduce.
- if (const Reduce* r = a.as<Reduce>()) {
- Array<Expr> new_source;
- for (auto& e : r->source) {
- new_source.push_back(Simplify_(e, vrange));
- }
- Expr new_condition = Simplify_(r->condition, vrange);
- if (r->source.same_as(new_source) &&
- r->condition.same_as(new_condition)) {
- return a;
- } else {
- return Reduce::make(
- r->combiner, new_source, r->axis, new_condition, r->value_index);
- }
+ analyzer.Bind(kv.first, kv.second);
}
- return Simplify_(a, vrange);
+ expr = analyzer.Simplify(expr);
+ return expr;
}
-Stmt Simplify(Stmt a, Map<Var, Range> vrange) {
- return Simplify_(a, vrange);
+Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
+ return arith::CanonicalStmtSimplifier().CanonicalSimplify(
+ stmt, vrange);
}
} // namespace ir
} // namespace tvm
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <iterator>
+#include <stack>
#include "../arithmetic/compute_expr.h"
namespace tvm {
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
for (size_t i = 0; i < init.size(); ++i) {
CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
CHECK_EQ(init[i]->dtype, update[i]->dtype);
- CHECK(can_prove(init[i]->shape[0] == axis->dom->min))
+ CHECK(prove_equal(init[i]->shape[0], axis->dom->min))
<< "init.shape[0] need to match scan_axis.dom.min";
CHECK(prove_equal(
state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
Stmt body,
bool partition_thread_scope) {
using namespace arith;
+ // include hint of var.
+ hint_map_.insert({var.get(), IntSet::interval(min, max)});
+
PartitionFinder finder(var, hint_map_, relax_map_);
finder.Visit(body);
+
+ hint_map_.erase(var.get());
if (finder.partitions.empty()) return Stmt();
arith::IntervalSet for_interval(min, max);
bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) {
body_begin = ir::Simplify(middle_interval.min());
- if (!can_prove(body_begin == min)) {
+ if (!analyzer_.CanProve(body_begin == min)) {
Expr cond = (body_begin - min >= 0);
- if (!can_prove(cond)) {
+ if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop";
body_begin = Max::make(body_begin, min);
bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
- if (!can_prove(middle_interval.max() == max)) {
+ if (!analyzer_.CanProve(middle_interval.max() == max)) {
// require the extent to be non-negative
Expr cond = (max - post_doubt_begin + 1 >= 0);
- if (!can_prove(cond)) {
+ if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop";
post_doubt_begin = Min::make(post_doubt_begin, max);
// Generating code for middle subrange
if (!partition_thread_scope) {
Stmt mid_stmt;
- if (!can_prove(body_begin >= post_doubt_begin)) {
+ if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) {
// [body_begin, post_doubt_begin)
Stmt simplified_body = ConditionEliminator(cond_set, cond_value).Mutate(body);
Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
s = AppendStmts(s, post_stmt);
} else {
Expr cond = const_true();
- if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin);
- if (!can_prove(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
+ if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin);
+ if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
s = ThreadPartitionInserter(cond_set, cond).Mutate(stmt);
}
s = ConvertSSA(s);
inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) {
const For *for_node = static_cast<const For*>(node);
CHECK(for_node);
- if (can_prove(extent == make_const(Int(32), 1))) {
+ if (analyzer_.CanProve(extent == make_const(Int(32), 1))) {
// If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var}, make_const(Int(32), 0)}});
} else {
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expr base = linear_eq[1];
if (!is_zero(base)) return body;
Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent);
- if (!can_prove(left >= 0)) return body;
+ if (!analyzer_.CanProve(left >= 0)) return body;
// rewrite access index.
ChannelAccessIndexRewriter rw(
ch->handle_var.get(), var * coeff, read_access);
return body;
}
+ arith::Analyzer analyzer_;
std::vector<RewriteEntry> tasks_;
};
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
}
// transform to alloc bytes
auto type_bits = alloc_type.bits() * alloc_type.lanes();
- bool divided = can_prove(combo_size % type_bits == 0);
+ bool divided = analyzer_.CanProve(combo_size % type_bits == 0);
combo_size = combo_size / type_bits;
// round up for can not divided
if (!divided) {
std::unordered_map<const Variable*, StorageEntry*> alloc_map_;
// The allocations
std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
+ // analyzer
+ arith::Analyzer analyzer_;
};
// Turn alloc into vector alloc
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
+#include <tvm/arithmetic.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
if (lanes != 1) {
const Ramp* b_ramp = b.as<Ramp>();
const Ramp* a_ramp = a.as<Ramp>();
- if (a_ramp && b.type().lanes() == 1 && can_prove(b > 0)) {
+ if (a_ramp && b.type().lanes() == 1 && analyzer_.CanProve(b > 0)) {
return Ramp::make(
a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes);
}
- if (b_ramp && a.type().lanes() == 1 && can_prove(a > 0)) {
+ if (b_ramp && a.type().lanes() == 1 && analyzer_.CanProve(a > 0)) {
return Ramp::make(
b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes);
}
Expr stride = this->Mutate(op->stride);
if (base.type().lanes() > 1 && stride.type().lanes() == 1) {
const Ramp* base_ramp = base.as<Ramp>();
- if (can_prove(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) {
+ if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) {
return Ramp::make(base_ramp->base, stride, op->lanes * base_ramp->lanes);
}
}
}
private:
+ // analyzer
+ arith::Analyzer analyzer_;
// variable to be replaced
Var var_;
// the lanes.
*/
void PassUpBoundCheck(const Stage& s,
const Map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, bool>* p_state) {
+ std::unordered_map<IterVar, bool>* p_state,
+ arith::Analyzer* analyzer) {
auto& state = *p_state;
- using HalideIR::Internal::can_prove;
for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1];
if (const SplitNode* s = rel.as<SplitNode>()) {
if (outer || inner) {
state[s->parent] = true;
} else {
- if (can_prove(dom_map.at(s->parent)->extent == factor * step)) {
+ if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) {
state[s->parent] = false;
} else {
state[s->parent] = true;
const std::unordered_map<IterVar, Expr>& value_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter) {
+ Analyzer analyzer;
+
std::unordered_map<IterVar, bool> bound_state;
for (IterVar iv : stage->leaf_iter_vars) {
bound_state[iv] = false;
}
- PassUpBoundCheck(stage, dom_map, &bound_state);
+ PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
std::vector<Expr> preds;
std::unordered_map<const Variable*, IntSet> iset_dmap;
Range dom = dom_map.at(iv);
Expr value = ComputeExpr<Sub>(value_map.at(iv), dom->min);
Expr vmax = EvalSet(value, iset_dmap).max();
- if (vmax.type() != value.type() || !can_prove(vmax < dom->extent)) {
+ if (vmax.type() != value.type() || !analyzer.CanProve(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent);
}
}
Expr vmin = s.min();
Expr vmax = s.max();
// The range of `value` resides in [vmin, vmax]
- if (vmin.type() != value.type() || !can_prove(vmin >= 0)) {
+ if (vmin.type() != value.type() || !analyzer.CanProve(vmin >= 0)) {
preds.emplace_back(value >= 0);
}
- if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) {
+ if (vmax.type() != value.type() || !analyzer.CanProve(vmax < iv->dom->extent)) {
preds.emplace_back(value < iv->dom->extent);
}
}
const Reduce* reduce = compute_op->body[idx].as<Reduce>();
CHECK(reduce) << "Can only rfactor non-inline reductions";
predicates.push_back(reduce->condition);
- Expr predicate = likely(simplify(arith::ComputeReduce<ir::And>(predicates, Expr())));
+ Expr predicate = likely(arith::ComputeReduce<ir::And>(predicates, Expr()));
std::unordered_map<const Variable*, Expr> vsub;
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include <gtest/gtest.h>
#include <tvm/ir_pass.h>
#include <tvm/tvm.h>
-#include <arithmetic/Simplify.h>
-
-TEST(IRSIMPLIFY, Basic) {
- using namespace HalideIR::Internal;
- simplify_test();
-}
TEST(IRSIMPLIFY, MinMax) {
auto x = tvm::var("x");
# under the License.
import tvm
+
+def assert_expr_equal(a, b):
+ res = tvm.ir_pass.Simplify(a - b)
+ equal = isinstance(res, tvm.expr.IntImm) and res.value == 0
+ if not equal:
+ raise ValueError("{} and {} are not equal".format(a, b))
+
+
def test_deduce():
a = tvm.var('a')
b = tvm.var('b')
e0 = (-b)*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
- ans0 = ((d - c) /(b*-1))
- assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
+ ans0 = ((d - c) /(b*-1) + (-1))
+ assert_expr_equal(res0.max_value, ans0)
# expression containing variable a is on rhs
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
- assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
+ assert_expr_equal(res0.max_value, ans0)
e0 = d*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
- ans0 = ((0-c)/d + 1)
- assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
+ ans0 = ((d-c)/d - 1)
+ assert_expr_equal(res0.max_value, ans0)
# expression containing variable a is on rhs
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
- assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
+ assert_expr_equal(res0.max_value, ans0)
+
e1 = (a*4+b < c)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
- ans1 = (((c - b) + -1)/4)
- assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1)
+ ans1 = (((c - b) + -1)/4 -1)
+ assert_expr_equal(res1.max_value, ans1)
+
# expression containing variable a is on rhs
e1 = (c > a*4+b)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
- assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1)
+ assert_expr_equal(res1.max_value, ans1)
+
e2 = (tvm.max(5, a * 4) < 0)
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max_value) == "neg_inf"
assert str(res2.min_value) == "pos_inf"
-
e3 = (-b)+a*c-d
res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
ans3 = 2/c+1
res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
+
def test_check():
a = tvm.var('a')
b = tvm.var('b')
assert(tvm.ir_pass.Equal(e2, x * 8))
e3 = tvm.ir_pass.Simplify(x - x / 3 * 3)
assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3)))
- let = tvm.make.Let(x, 1, x + 3)
- e4 = tvm.ir_pass.Simplify(let)
- assert(tvm.ir_pass.Equal(e4, 4))
def test_verify_ssa():