this->PrintExpr(op->args[0], os);
os << " == NULL)";
} else if (op->op.same_as(builtin::reinterpret())) {
- // generate (*( TYPE *)(&(ARG)))
+ int ssa_scope = BeginScope();
+ std::string rhs = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype);
os << "(*(";
this->PrintType(op->dtype, os);
- os << " *)(&(";
- this->PrintExpr(op->args[0], os);
- os << ")))";
+ os << " *)(&(" << rhs << ")))";
+ EndScope(ssa_scope);
} else if (op->op.same_as(builtin::isnan())) {
os << "(";
this->PrintExpr(op->args[0], os);
} else {
CHECK(is_one(op->predicate)) << "Predicated store is not supported";
arith::PVar<PrimExpr> base;
+
+ // The assignment below introduces side-effect, and the resulting value cannot
+ // be reused across multiple expression, thus a new scope is needed
+ int vec_scope = BeginScope();
+
if (arith::ramp(base, 1, t.lanes()).Match(op->index)) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value);
} else {
- // The assignment below introduces side-effect, and the resulting value cannot
- // be reused across multiple expression, thus a new scope is needed
- int vec_scope = BeginScope();
-
// store elements seperately
std::string index = SSAGetID(PrintExpr(op->index), op->index.dtype());
std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype());
PrintVecElemLoad(value, op->value.dtype(), i, stream);
stream << ";\n";
}
- EndScope(vec_scope);
}
+ EndScope(vec_scope);
}
}
ExprVisitor::VisitExpr_(op);
}
+ void VisitStmt_(const StoreNode* op) {
+ // Currently not able to check out: If the index expression failed
+ // to be simplified to a RampNode
+ if (op->index->IsInstance<RampNode>()) {
+ if (op->index->dtype.lanes() > 1) {
+ valid_ &= static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes()) <=
+ max_vector_bytes_;
+ }
+ }
+ StmtVisitor::VisitStmt_(op);
+ }
+
private:
int nest_level_{0};
nn = 1024
n = tvm.runtime.convert(nn)
A = te.placeholder((n,), name='A', dtype="int32")
- B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.reinterpret", A(*i)), name='B')
+ B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.reinterpret", 2 + A(*i)), name='B')
s = te.create_schedule(B.op)
def check_c():
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
fadd(a, b)
tvm.testing.assert_allclose(
- b.asnumpy(), a.asnumpy().view('float32'))
+ b.asnumpy(), (2 + a.asnumpy()).view('float32'))
check_c()
vcf_check_common(s, [A, B, C])
+def test_unrolled_vectorization():
+ if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+ print("skip because cuda is not enabled..")
+ return
+
+ dtype = 'float32'
+ target = 'cuda'
+
+ ## Compute declaration
+ N = 128
+ A = te.placeholder((N, N), name='A')
+ B = te.placeholder((N, N), name='B')
+ k = te.reduce_axis((0, N), name='k')
+ C = te.compute((N, N), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
+
+ ## Schedule
+ s = te.create_schedule([C.op])
+ CC = s.cache_write(C, "local")
+ i, j = s[C].op.axis
+ bx, tx, ii, ji = s[C].tile(i, j, 1, 2)
+ s[C].bind(bx, te.thread_axis("blockIdx.x"))
+ s[C].bind(tx, te.thread_axis("threadIdx.x"))
+ s[C].vectorize(ji)
+ s[CC].compute_at(s[C], tx)
+ i, j = s[CC].op.axis
+ k = s[CC].op.reduce_axis[0]
+ ko, ki = s[CC].split(k, 2)
+ s[CC].unroll(ki)
+ s[CC].vectorize(j)
+
+ ## Check correctness
+ ctx = tvm.context(target)
+ a_tvm = tvm.nd.array(np.ones((N, N)).astype(dtype), ctx=ctx)
+ b_tvm = tvm.nd.array(np.ones((N, N)).astype(dtype), ctx=ctx)
+ c_tvm = tvm.nd.empty((N, N), ctx=ctx)
+ func_tvm = tvm.build(s, [A, B, C], target=target)
+ func_tvm(a_tvm, b_tvm, c_tvm)
+ c_np = c_tvm.asnumpy()
+ tvm.testing.assert_allclose(c_np, N * np.ones((N, N)))
+
if __name__ == "__main__":
test_cuda_vectorize_add()
test_cuda_multiply_add()
test_cuda_vectorize_load_permute_pad()
test_vectorized_cooperative_fetching_x()
test_vectorized_cooperative_fetching_xy()
+ test_unrolled_vectorization()
+