namespace tvm {
namespace tir {
-class GPUCodeVerifier : public StmtVisitor {
+class GPUCodeVerifier : public StmtExprVisitor {
public:
bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block,
int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y,
- int64_t max_thread_z) {
+ int64_t max_thread_z, int64_t max_vector_bytes) {
max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
max_thread_x_ = static_cast<size_t>(max_thread_x);
max_thread_y_ = static_cast<size_t>(max_thread_y);
max_thread_z_ = static_cast<size_t>(max_thread_z);
+ max_vector_bytes_ = static_cast<size_t>(max_vector_bytes);
Reset_();
+ // TODO(jcf94): Add support of detecting CUDA Misaligned Address error
this->VisitStmt(stmt);
return valid_;
size_t size = static_cast<size_t>(op->constant_allocation_size());
shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
}
+ if (op->dtype.lanes() > 1) {
+ valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
+ }
}
void VisitStmt_(const AttrStmtNode* op) final {
}
}
+ void VisitExpr_(const LoadNode* op) {
+ // Currently not able to check out: If the index expression failed
+ // to be simplified to a RampNode
+ if (op->index->IsInstance<RampNode>()) {
+ if (op->dtype.lanes() > 1) {
+ valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
+ }
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
+
private:
int nest_level_{0};
size_t max_shared_memory_per_block_;
size_t max_threads_per_block_;
size_t max_thread_x_, max_thread_y_, max_thread_z_;
+ size_t max_vector_bytes_;
bool valid_{true};
int64_t max_thread_x = INT64_MAX;
int64_t max_thread_y = INT64_MAX;
int64_t max_thread_z = INT64_MAX;
+ int64_t max_vector_bytes = INT64_MAX;
for (auto iter : constraints) {
const IntImmNode* val = iter.second.as<IntImmNode>();
- if (iter.first == "max_local_memory_per_block")
+ if (iter.first == "max_local_memory_per_block") {
max_local_memory_per_block = val->value;
- else if (iter.first == "max_shared_memory_per_block")
+ } else if (iter.first == "max_shared_memory_per_block") {
max_shared_memory_per_block = val->value;
- else if (iter.first == "max_threads_per_block")
+ } else if (iter.first == "max_threads_per_block") {
max_threads_per_block = val->value;
- else if (iter.first == "max_thread_x")
+ } else if (iter.first == "max_thread_x") {
max_thread_x = val->value;
- else if (iter.first == "max_thread_y")
+ } else if (iter.first == "max_thread_y") {
max_thread_y = val->value;
- else if (iter.first == "max_thread_z")
+ } else if (iter.first == "max_thread_z") {
max_thread_z = val->value;
- else
+ } else if (iter.first == "max_vector_bytes") {
+ max_vector_bytes = val->value;
+ } else {
LOG(FATAL) << "Invalid check item: " << iter.first;
+ }
}
return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block,
- max_threads_per_block, max_thread_x, max_thread_y, max_thread_z);
+ max_threads_per_block, max_thread_x, max_thread_y, max_thread_z,
+ max_vector_bytes);
}
TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode);
tvm.build(s, [A, B], target)
assert not valid[0]
+def test_vectorize():
+ N = 1024
+
+ A = te.placeholder((N, N), name='A')
+ B = te.compute((N, N), lambda i, j: A[i, j])
+
+ s = te.create_schedule([B.op])
+
+ i, j = s[B].op.axis
+
+ s[B].bind(i, te.thread_axis("blockIdx.x"))
+ jo, ji = s[B].split(j, factor=64)
+ s[B].bind(jo, te.thread_axis("threadIdx.x"))
+ s[B].vectorize(ji)
+
+ for target in ['opencl', 'cuda']:
+ if not tvm.context(target).exist:
+ continue
+
+ valid = [None]
+ with tvm.transform.PassContext(config={"tir.add_lower_pass": [
+ (2, get_verify_pass(valid, max_vector_bytes=16))]}):
+ tvm.lower(s, [A, B])
+ assert not valid[0]
if __name__ == "__main__":
test_local_memory()
test_num_thread()
test_multiple_kernels()
test_wrong_bind()
+ test_vectorize()