Expr base = e;
Array<Expr> coeff;
- if (0 == vars.size()) {
- coeff.push_back(make_const(Int(32), 1));
- } else {
- for (Var v : vars) {
- LinearEqEntry ret;
- if (!LinearEqDetector(v).Detect(base, &ret)) {
- return Array<Expr>();
- }
- coeff.push_back(ret.coeff);
- base = std::move(ret.base);
+ for (Var v : vars) {
+ LinearEqEntry ret;
+ if (!LinearEqDetector(v).Detect(base, &ret)) {
+ return Array<Expr>();
}
+ coeff.push_back(ret.coeff);
+ base = std::move(ret.base);
+ }
- std::unordered_set<const Variable*> vset;
- for (size_t i = vars.size(); i != 1; --i) {
- vset.insert(vars[i - 1].get());
- // The previous coeff contains the variable
- if (ExprUseVar(coeff[i - 2], vset)) {
- return Array<Expr>();
- }
+ std::unordered_set<const Variable*> vset;
+ for (size_t i = vars.size(); i > 1; --i) {
+ vset.insert(vars[i - 1].get());
+ // The previous coeff contains the variable
+ if (ExprUseVar(coeff[i - 2], vset)) {
+ return Array<Expr>();
}
}
coeff.push_back(base);
bool MatchCopyPattern(Stmt stmt, Stmt *out) {
using namespace arith;
Stmt body = stmt;
- bool is_single_point_copy = false;
// strip the loops
std::vector<const For*> loops;
const Cast* cast = store->value.as<Cast>();
const Load* load = store->value.as<Load>();
if (0 == loops.size()) {
- is_single_point_copy = true;
CHECK(!has_cond);
}
// for now only support true condition matching
arith::DetectLinearEquation(load->index, loop_vars);
if (load_strides.size() == 0 || store_strides.size() == 0) return false;
Array<Expr> dst_shape;
- auto loop_var_size = loop_vars.size();
- if (is_single_point_copy) {
- loop_var_size = 1;
+ const size_t loop_var_size = loop_vars.size();
+ if (loop_var_size == 0) {
dst_shape.push_back(make_const(Int(32), 1));
} else {
for (const For* op : loops) {
CHECK_EQ(load_strides.size(), loop_var_size + 1);
Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
+ if (loop_var_size == 0) {
+ src_strides.push_back(make_const(Int(32), 1));
+ dst_strides.push_back(make_const(Int(32), 1));
+ }
Buffer dst = BufferNode::make(
Var(store->buffer_var.node_),
store->value.type(),
m = tvm.arith.DetectLinearEquation(b * 7, [a])
assert m[0].value == 0
+ m = tvm.arith.DetectLinearEquation(b * 7, [])
+ assert len(m) == 1
+ assert tvm.ir_pass.Simplify(m[0] - b * 7).value == 0
+
def test_multivariate():
v = [tvm.var("v%d" % i) for i in range(4)]
b = tvm.var("b")
assert(m[0].value == 0)
assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0)
+ m = tvm.arith.DetectLinearEquation((v[0] - v[1]), [])
+ assert(len(m) == 1)
+ assert(tvm.ir_pass.Simplify(m[0] - (v[0] - v[1])).value == 0)
+
if __name__ == "__main__":
test_basic()
test_multivariate()