public:
bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block,
int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y,
- int64_t max_thread_z, int64_t max_vector_bytes) {
+ int64_t max_thread_z, int64_t max_vthread, int64_t max_vector_bytes) {
max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
max_thread_x_ = static_cast<size_t>(max_thread_x);
max_thread_y_ = static_cast<size_t>(max_thread_y);
max_thread_z_ = static_cast<size_t>(max_thread_z);
+ max_vthread_ = static_cast<size_t>(max_vthread);
max_vector_bytes_ = static_cast<size_t>(max_vector_bytes);
Reset_();
visited_shared_buffers_.insert(op->node.as<VarNode>());
}
StmtVisitor::VisitStmt_(op);
- } else if (op->attr_key == attr::thread_extent) {
+ } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
if (nest_level_ == 0) {
// enter a new kernel, reset statistics
Reset_();
const auto* extent = op->value.as<IntImmNode>();
CHECK(extent);
- // record the number of threads in a block
std::string name = var.get()->name_hint;
- if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") {
+ // record the number of threads in a block
+ if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z" ||
+ name == "vthread") {
size_t length = static_cast<size_t>(extent->value);
if (!visited_threads_.count(name)) {
visited_threads_.insert(name);
} else if (name == "threadIdx.z") {
valid_ &= length <= max_thread_z_;
thread_z_extent_ = length;
+ } else if (name == "vthread") {
+ valid_ &= length <= max_vthread_;
}
} else {
// the thread should be bound to axes with the same length
}
}
+ void VisitStmt_(const ForNode* op) {
+ if (op->loop_var->name_hint == "vthread.s") {
+ const auto* extent = op->extent.as<IntImmNode>();
+ CHECK(extent);
+
+ valid_ &= static_cast<size_t>(extent->value) <= max_vthread_;
+ }
+
+ StmtVisitor::VisitStmt_(op);
+ }
+
void VisitExpr_(const LoadNode* op) {
- // Currently not able to check out: If the index expression failed
- // to be simplified to a RampNode
- if (op->index->IsInstance<RampNode>()) {
- if (op->dtype.lanes() > 1) {
- valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
- }
+ if (op->dtype.lanes() > 1) {
+ valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
}
ExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const StoreNode* op) {
- // Currently not able to check out: If the index expression failed
- // to be simplified to a RampNode
- if (op->index->IsInstance<RampNode>()) {
- if (op->index->dtype.lanes() > 1) {
- valid_ &= static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes()) <=
- max_vector_bytes_;
- }
+ if (op->index->dtype.lanes() > 1) {
+ valid_ &= static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes()) <=
+ max_vector_bytes_;
}
StmtVisitor::VisitStmt_(op);
}
size_t max_local_memory_per_block_;
size_t max_shared_memory_per_block_;
size_t max_threads_per_block_;
- size_t max_thread_x_, max_thread_y_, max_thread_z_;
+ size_t max_thread_x_, max_thread_y_, max_thread_z_, max_vthread_;
size_t max_vector_bytes_;
bool valid_{true};
int64_t max_thread_x = INT64_MAX;
int64_t max_thread_y = INT64_MAX;
int64_t max_thread_z = INT64_MAX;
+ int64_t max_vthread = INT64_MAX;
int64_t max_vector_bytes = INT64_MAX;
for (auto iter : constraints) {
max_thread_y = val->value;
} else if (iter.first == "max_thread_z") {
max_thread_z = val->value;
+ } else if (iter.first == "max_vthread") {
+ max_vthread = val->value;
} else if (iter.first == "max_vector_bytes") {
max_vector_bytes = val->value;
} else {
return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block,
max_threads_per_block, max_thread_x, max_thread_y, max_thread_z,
- max_vector_bytes);
+ max_vthread, max_vector_bytes);
}
TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode);
tvm.lower(s, [A, B])
assert not valid[0]
+def test_vthread():
+ N = 1024
+
+ A = te.placeholder((N, 16), name='A')
+ B = te.compute((N, 16), lambda i, j: A[i, j])
+
+ s = te.create_schedule([B.op])
+
+ s[B].bind(s[B].op.axis[0], te.thread_axis("blockIdx.x"))
+ s[B].bind(s[B].op.axis[1], te.thread_axis("vthread"))
+
+ for target in ['opencl', 'cuda']:
+ if not tvm.context(target).exist:
+ continue
+
+ valid = [None]
+
+ for phase in [1, 2]:
+ with tvm.transform.PassContext(config={"tir.add_lower_pass": [
+ (phase, get_verify_pass(valid, max_vthread=16))]}):
+ tvm.build(s, [A, B], target)
+ assert valid[0]
+
+ with tvm.transform.PassContext(config={"tir.add_lower_pass": [
+ (phase, get_verify_pass(valid, max_vthread=15))]}):
+ tvm.build(s, [A, B], target)
+ assert not valid[0]
+
+
if __name__ == "__main__":
test_local_memory()
test_shared_memory()
test_multiple_kernels()
test_wrong_bind()
test_vectorize()
+ test_vthread()