CUDA device API & VerifyGPUCode pass update (#5898)
authorChenfan <jcf94@outlook.com>
Thu, 25 Jun 2020 05:44:39 +0000 (13:44 +0800)
committerGitHub <noreply@github.com>
Thu, 25 Jun 2020 05:44:39 +0000 (22:44 -0700)
* Add kMaxRegistersPerBlock device api for cuda

* Add vectorize check to verify_gpu_code

* Lint fix

* Cast fix

include/tvm/runtime/device_api.h
src/runtime/cuda/cuda_device_api.cc
src/runtime/metal/metal_device_api.mm
src/runtime/opencl/opencl_device_api.cc
src/runtime/rocm/rocm_device_api.cc
src/runtime/vulkan/vulkan.cc
src/tir/analysis/verify_gpu_code.cc
tests/python/unittest/test_tir_analysis_verify_gpu_code.py

index 421811a..3cf5566 100644 (file)
@@ -44,7 +44,8 @@ enum DeviceAttrKind : int {
   kMaxClockRate = 6,
   kMultiProcessorCount = 7,
   kMaxThreadDimensions = 8,
-  kGcnArch = 9
+  kMaxRegistersPerBlock = 9,
+  kGcnArch = 10
 };
 
 /*! \brief Number of bytes each allocation must align to */
index a6d4a54..ccd8e91 100644 (file)
@@ -92,6 +92,10 @@ class CUDADeviceAPI final : public DeviceAPI {
         *rv = ss.str();
         return;
       }
+      case kMaxRegistersPerBlock: {
+        CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxRegistersPerBlock, ctx.device_id));
+        break;
+      }
       case kGcnArch:
         return;
     }
index 3bad2c3..a64f35c 100644 (file)
@@ -64,7 +64,9 @@ void MetalWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* r
     case kMaxThreadDimensions:
       return;
     case kExist:
-      break;
+      return;
+    case kMaxRegistersPerBlock:
+      return;
     case kGcnArch:
       return;
   }
index 6d9835e..72d03fb 100644 (file)
@@ -107,6 +107,8 @@ void OpenCLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
       *rv = ss.str();
       break;
     }
+    case kMaxRegistersPerBlock:
+      return;
     case kGcnArch:
       return;
   }
index 475c4fb..e3dbef5 100644 (file)
@@ -102,6 +102,8 @@ class ROCMDeviceAPI final : public DeviceAPI {
         *rv = ss.str();
         return;
       }
+      case kMaxRegistersPerBlock:
+        return;
       case kGcnArch: {
         hipDeviceProp_t prop;
         ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
index 4481011..ade4ddc 100644 (file)
@@ -413,6 +413,8 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
       *rv = ss.str();
       break;
     }
+    case kMaxRegistersPerBlock:
+      return;
     case kGcnArch:
       return;
   }
index 1fbae0f..9477e04 100644 (file)
 namespace tvm {
 namespace tir {
 
-class GPUCodeVerifier : public StmtVisitor {
+class GPUCodeVerifier : public StmtExprVisitor {
  public:
   bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block,
               int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y,
-              int64_t max_thread_z) {
+              int64_t max_thread_z, int64_t max_vector_bytes) {
     max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
     max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
     max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
     max_thread_x_ = static_cast<size_t>(max_thread_x);
     max_thread_y_ = static_cast<size_t>(max_thread_y);
     max_thread_z_ = static_cast<size_t>(max_thread_z);
+    max_vector_bytes_ = static_cast<size_t>(max_vector_bytes);
 
     Reset_();
 
+    // TODO(jcf94): Add support of detecting CUDA Misaligned Address error
     this->VisitStmt(stmt);
 
     return valid_;
@@ -62,6 +64,9 @@ class GPUCodeVerifier : public StmtVisitor {
       size_t size = static_cast<size_t>(op->constant_allocation_size());
       shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
     }
+    if (op->dtype.lanes() > 1) {
+      valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
+    }
   }
 
   void VisitStmt_(const AttrStmtNode* op) final {
@@ -129,6 +134,17 @@ class GPUCodeVerifier : public StmtVisitor {
     }
   }
 
+  void VisitExpr_(const LoadNode* op) {
+    // Currently not able to check out: If the index expression failed
+    // to be simplified to a RampNode
+    if (op->index->IsInstance<RampNode>()) {
+      if (op->dtype.lanes() > 1) {
+        valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
+      }
+    }
+    ExprVisitor::VisitExpr_(op);
+  }
+
  private:
   int nest_level_{0};
 
@@ -146,6 +162,7 @@ class GPUCodeVerifier : public StmtVisitor {
   size_t max_shared_memory_per_block_;
   size_t max_threads_per_block_;
   size_t max_thread_x_, max_thread_y_, max_thread_z_;
+  size_t max_vector_bytes_;
 
   bool valid_{true};
 
@@ -169,27 +186,32 @@ bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
   int64_t max_thread_x = INT64_MAX;
   int64_t max_thread_y = INT64_MAX;
   int64_t max_thread_z = INT64_MAX;
+  int64_t max_vector_bytes = INT64_MAX;
 
   for (auto iter : constraints) {
     const IntImmNode* val = iter.second.as<IntImmNode>();
-    if (iter.first == "max_local_memory_per_block")
+    if (iter.first == "max_local_memory_per_block") {
       max_local_memory_per_block = val->value;
-    else if (iter.first == "max_shared_memory_per_block")
+    } else if (iter.first == "max_shared_memory_per_block") {
       max_shared_memory_per_block = val->value;
-    else if (iter.first == "max_threads_per_block")
+    } else if (iter.first == "max_threads_per_block") {
       max_threads_per_block = val->value;
-    else if (iter.first == "max_thread_x")
+    } else if (iter.first == "max_thread_x") {
       max_thread_x = val->value;
-    else if (iter.first == "max_thread_y")
+    } else if (iter.first == "max_thread_y") {
       max_thread_y = val->value;
-    else if (iter.first == "max_thread_z")
+    } else if (iter.first == "max_thread_z") {
       max_thread_z = val->value;
-    else
+    } else if (iter.first == "max_vector_bytes") {
+      max_vector_bytes = val->value;
+    } else {
       LOG(FATAL) << "Invalid check item: " << iter.first;
+    }
   }
 
   return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block,
-                         max_threads_per_block, max_thread_x, max_thread_y, max_thread_z);
+                         max_threads_per_block, max_thread_x, max_thread_y, max_thread_z,
+                         max_vector_bytes);
 }
 
 TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode);
index 11960ca..ece8402 100644 (file)
@@ -208,6 +208,30 @@ def test_wrong_bind():
             tvm.build(s, [A, B], target)
         assert not valid[0]
 
+def test_vectorize():
+    N = 1024
+
+    A = te.placeholder((N, N), name='A')
+    B = te.compute((N, N), lambda i, j: A[i, j])
+
+    s = te.create_schedule([B.op])
+
+    i, j = s[B].op.axis
+
+    s[B].bind(i, te.thread_axis("blockIdx.x"))
+    jo, ji = s[B].split(j, factor=64)
+    s[B].bind(jo, te.thread_axis("threadIdx.x"))
+    s[B].vectorize(ji)
+
+    for target in ['opencl', 'cuda']:
+        if not tvm.context(target).exist:
+            continue
+
+        valid = [None]
+        with tvm.transform.PassContext(config={"tir.add_lower_pass": [
+                (2, get_verify_pass(valid, max_vector_bytes=16))]}):
+            tvm.lower(s, [A, B])
+        assert not valid[0]
 
 if __name__ == "__main__":
     test_local_memory()
@@ -215,3 +239,4 @@ if __name__ == "__main__":
     test_num_thread()
     test_multiple_kernels()
     test_wrong_bind()
+    test_vectorize()