From e4ff14ad5955f7c4d052aa44069c77654e8b5f2e Mon Sep 17 00:00:00 2001 From: Michael Carilli Date: Fri, 3 Sep 2021 13:21:23 -0700 Subject: [PATCH] [CUDA graphs] Error if attempting to capture uncapturable nccl (#64440) Summary: NCCL < 2.9.6 is not capturable. Attempting to capture it can cause nasty behavior (for example, ive seen capture succeed, but replay silently hang). Pytorch should preempt this with a friendlier error. cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse agolynski SciPioneer H-Huang mrzzd cbalioglu gcramer23 Pull Request resolved: https://github.com/pytorch/pytorch/pull/64440 Reviewed By: mruberry Differential Revision: D30733884 Pulled By: ngimel fbshipit-source-id: 5f2df3cf5cc0e5e68f49bf22a80d9f58064dc7ec --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 911963b..9773b35 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -189,6 +190,17 @@ std::string getExceptionMsgFromExceptionPtr( } } +inline void errorIfCapturingNonCapturableNCCL() { + auto status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); + // parentheses avoid some compiler warnings + static const uint64_t min_version = (((uint64_t)2) << 32) + (((uint64_t)9) << 16) + ((uint64_t)6); + static const uint64_t cur_version = torch::cuda::nccl::version(); + if (cur_version < min_version) { + TORCH_CHECK(status == c10::cuda::CaptureStatus::None, + "Capturing NCCL collectives is only allowed with NCCL >= 2.9.6"); + } +} + } // namespace const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 10000; @@ -1079,6 +1091,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( OpType opType, const char* profilingTitle) { + errorIfCapturingNonCapturableNCCL(); + // Bump collective counter if (sequenceNum_) { sequenceNum_->increment(); -- 2.7.4