#include <THC/THC.h>
#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/irange.h>
#include <c10/util/Logging.h>
}
}
+inline void errorIfCapturingNonCapturableNCCL() {
+ auto status = c10::cuda::currentStreamCaptureStatusMayInitCtx();
+ // parentheses avoid some compiler warnings
+ static const uint64_t min_version = (((uint64_t)2) << 32) + (((uint64_t)9) << 16) + ((uint64_t)6);
+ static const uint64_t cur_version = torch::cuda::nccl::version();
+ if (cur_version < min_version) {
+ TORCH_CHECK(status == c10::cuda::CaptureStatus::None,
+ "Capturing NCCL collectives is only allowed with NCCL >= 2.9.6");
+ }
+}
+
} // namespace
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 10000;
OpType opType,
const char* profilingTitle) {
+ errorIfCapturingNonCapturableNCCL();
+
// Bump collective counter
if (sequenceNum_) {
sequenceNum_->increment();