[TIR] Enhance VerifyGPUCode (#6194)
authorLianmin Zheng <lianminzheng@gmail.com>
Mon, 3 Aug 2020 20:23:07 +0000 (13:23 -0700)
committerGitHub <noreply@github.com>
Mon, 3 Aug 2020 20:23:07 +0000 (13:23 -0700)
src/tir/analysis/verify_gpu_code.cc
tests/python/unittest/test_tir_analysis_verify_gpu_code.py

index d221dde..cce0823 100644 (file)
@@ -37,13 +37,14 @@ class GPUCodeVerifier : public StmtExprVisitor {
  public:
   bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block,
               int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y,
-              int64_t max_thread_z, int64_t max_vector_bytes) {
+              int64_t max_thread_z, int64_t max_vthread, int64_t max_vector_bytes) {
     max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
     max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
     max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
     max_thread_x_ = static_cast<size_t>(max_thread_x);
     max_thread_y_ = static_cast<size_t>(max_thread_y);
     max_thread_z_ = static_cast<size_t>(max_thread_z);
+    max_vthread_ = static_cast<size_t>(max_vthread);
     max_vector_bytes_ = static_cast<size_t>(max_vector_bytes);
 
     Reset_();
@@ -78,7 +79,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
         visited_shared_buffers_.insert(op->node.as<VarNode>());
       }
       StmtVisitor::VisitStmt_(op);
-    } else if (op->attr_key == attr::thread_extent) {
+    } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
       if (nest_level_ == 0) {
         // enter a new kernel, reset statistics
         Reset_();
@@ -88,9 +89,10 @@ class GPUCodeVerifier : public StmtExprVisitor {
       const auto* extent = op->value.as<IntImmNode>();
       CHECK(extent);
 
-      // record the number of threads in a block
       std::string name = var.get()->name_hint;
-      if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") {
+      // record the number of threads in a block
+      if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z" ||
+          name == "vthread") {
         size_t length = static_cast<size_t>(extent->value);
         if (!visited_threads_.count(name)) {
           visited_threads_.insert(name);
@@ -105,6 +107,8 @@ class GPUCodeVerifier : public StmtExprVisitor {
           } else if (name == "threadIdx.z") {
             valid_ &= length <= max_thread_z_;
             thread_z_extent_ = length;
+          } else if (name == "vthread") {
+            valid_ &= length <= max_vthread_;
           }
         } else {
           // the thread should be bound to axes with the same length
@@ -134,25 +138,28 @@ class GPUCodeVerifier : public StmtExprVisitor {
     }
   }
 
+  void VisitStmt_(const ForNode* op) {
+    if (op->loop_var->name_hint == "vthread.s") {
+      const auto* extent = op->extent.as<IntImmNode>();
+      CHECK(extent);
+
+      valid_ &= static_cast<size_t>(extent->value) <= max_vthread_;
+    }
+
+    StmtVisitor::VisitStmt_(op);
+  }
+
   void VisitExpr_(const LoadNode* op) {
-    // Currently not able to check out: If the index expression failed
-    // to be simplified to a RampNode
-    if (op->index->IsInstance<RampNode>()) {
-      if (op->dtype.lanes() > 1) {
-        valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
-      }
+    if (op->dtype.lanes() > 1) {
+      valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
     }
     ExprVisitor::VisitExpr_(op);
   }
 
   void VisitStmt_(const StoreNode* op) {
-    // Currently not able to check out: If the index expression failed
-    // to be simplified to a RampNode
-    if (op->index->IsInstance<RampNode>()) {
-      if (op->index->dtype.lanes() > 1) {
-        valid_ &= static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes()) <=
-                  max_vector_bytes_;
-      }
+    if (op->index->dtype.lanes() > 1) {
+      valid_ &= static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes()) <=
+                max_vector_bytes_;
     }
     StmtVisitor::VisitStmt_(op);
   }
@@ -173,7 +180,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
   size_t max_local_memory_per_block_;
   size_t max_shared_memory_per_block_;
   size_t max_threads_per_block_;
-  size_t max_thread_x_, max_thread_y_, max_thread_z_;
+  size_t max_thread_x_, max_thread_y_, max_thread_z_, max_vthread_;
   size_t max_vector_bytes_;
 
   bool valid_{true};
@@ -198,6 +205,7 @@ bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
   int64_t max_thread_x = INT64_MAX;
   int64_t max_thread_y = INT64_MAX;
   int64_t max_thread_z = INT64_MAX;
+  int64_t max_vthread = INT64_MAX;
   int64_t max_vector_bytes = INT64_MAX;
 
   for (auto iter : constraints) {
@@ -214,6 +222,8 @@ bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
       max_thread_y = val->value;
     } else if (iter.first == "max_thread_z") {
       max_thread_z = val->value;
+    } else if (iter.first == "max_vthread") {
+      max_vthread = val->value;
     } else if (iter.first == "max_vector_bytes") {
       max_vector_bytes = val->value;
     } else {
@@ -223,7 +233,7 @@ bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
 
   return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block,
                          max_threads_per_block, max_thread_x, max_thread_y, max_thread_z,
-                         max_vector_bytes);
+                         max_vthread, max_vector_bytes);
 }
 
 TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode);
index ece8402..2e37de4 100644 (file)
@@ -233,6 +233,35 @@ def test_vectorize():
             tvm.lower(s, [A, B])
         assert not valid[0]
 
+def test_vthread():
+    N = 1024
+
+    A = te.placeholder((N, 16), name='A')
+    B = te.compute((N, 16), lambda i, j: A[i, j])
+
+    s = te.create_schedule([B.op])
+
+    s[B].bind(s[B].op.axis[0], te.thread_axis("blockIdx.x"))
+    s[B].bind(s[B].op.axis[1], te.thread_axis("vthread"))
+
+    for target in ['opencl', 'cuda']:
+        if not tvm.context(target).exist:
+            continue
+
+        valid = [None]
+
+        for phase in [1, 2]:
+            with tvm.transform.PassContext(config={"tir.add_lower_pass": [
+                (phase, get_verify_pass(valid, max_vthread=16))]}):
+                tvm.build(s, [A, B], target)
+            assert valid[0]
+
+            with tvm.transform.PassContext(config={"tir.add_lower_pass": [
+                (phase, get_verify_pass(valid, max_vthread=15))]}):
+                tvm.build(s, [A, B], target)
+            assert not valid[0]
+
+
 if __name__ == "__main__":
     test_local_memory()
     test_shared_memory()
@@ -240,3 +269,4 @@ if __name__ == "__main__":
     test_multiple_kernels()
     test_wrong_bind()
     test_vectorize()
+    test_vthread()