[CUDA graphs] Error if attempting to capture uncapturable nccl (#64440)
authorMichael Carilli <mcarilli@gmail.com>
Fri, 3 Sep 2021 20:21:23 +0000 (13:21 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 3 Sep 2021 20:23:07 +0000 (13:23 -0700)
Summary:
NCCL < 2.9.6 is not capturable. Attempting to capture it can cause nasty behavior (for example, ive seen capture succeed, but replay silently hang). Pytorch should preempt this with a friendlier error.

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse agolynski SciPioneer H-Huang mrzzd cbalioglu gcramer23

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64440

Reviewed By: mruberry

Differential Revision: D30733884

Pulled By: ngimel

fbshipit-source-id: 5f2df3cf5cc0e5e68f49bf22a80d9f58064dc7ec

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

index 911963b..9773b35 100644 (file)
@@ -11,6 +11,7 @@
 #include <THC/THC.h>
 
 #include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGraphsC10Utils.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <c10/util/irange.h>
 #include <c10/util/Logging.h>
@@ -189,6 +190,17 @@ std::string getExceptionMsgFromExceptionPtr(
   }
 }
 
+inline void errorIfCapturingNonCapturableNCCL() {
+  auto status = c10::cuda::currentStreamCaptureStatusMayInitCtx();
+  // parentheses avoid some compiler warnings
+  static const uint64_t min_version = (((uint64_t)2) << 32) + (((uint64_t)9) << 16) + ((uint64_t)6);
+  static const uint64_t cur_version = torch::cuda::nccl::version();
+  if (cur_version < min_version) {
+    TORCH_CHECK(status == c10::cuda::CaptureStatus::None,
+                "Capturing NCCL collectives is only allowed with NCCL >= 2.9.6");
+  }
+}
+
 } // namespace
 
 const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 10000;
@@ -1079,6 +1091,8 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
     OpType opType,
     const char* profilingTitle) {
 
+  errorIfCapturingNonCapturableNCCL();
+
   // Bump collective counter
   if (sequenceNum_) {
     sequenceNum_->increment();