@requires_nccl()
@skip_if_lt_x_gpu(2)
+ def test_nccl_propagate_error_reason(self):
+ # Need to use NCCL_BLOCKING_WAIT and not ASYNC_ERROR_HANDLING,
+ # otherwise process will be taken down and we can't check for errors.
+ os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
+ os.environ["NCCL_BLOCKING_WAIT"] = "1"
+ timeout = timedelta(seconds=2)
+ store = c10d.FileStore(self.file_name, self.world_size)
+ pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timeout)
+ pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
+ pg.barrier().wait()
+ # Simulate stuckness in rank 0.
+ if self.rank == 0:
+ pg_gloo.barrier().wait()
+ inp = torch.ones(1).cuda(self.rank)
+
+ if self.rank != 0:
+ # Time out due to rank 0 not calling into allreduce.
+ with self.assertRaises(RuntimeError):
+ pg.allreduce([inp]).wait()
+
+ # Now when nonzero rank attempts to use communicator, original failure reason should be logged.j
+ try:
+ pg.allreduce([torch.ones(2).cuda(self.rank)]).wait()
+ except RuntimeError as e:
+ self.assertTrue("timed out in call to wait()" in str(e))
+ self.assertTrue("TensorShape=[1]" in str(e))
+ else:
+ self.fail("Expected error to be raised!")
+
+ # Unblock rank 0
+ pg_gloo.barrier().wait()
+
+ # TODO: We can also test that if rank 0 attempts to use the communicator,
+ # then we should error out with the info that it was aborted due to
+ # timeout on another rank. Although this would only be the case after
+ # the watchdog has run on the rank, and there is no reliable way
+ # to confirm it has run.
+
+ @requires_nccl()
+ @skip_if_lt_x_gpu(2)
def test_nccl_backend_multi_device_ids_not_allowed(self):
int_devices = list(range(torch.cuda.device_count()))
devices = [torch.device("cuda:" + str(i)) for i in int_devices]
#include <nccl.h>
#include <c10/util/Exception.h>
+#include <c10/util/Optional.h>
namespace {
// Provides additional detail into NCCL error codes based on when these are
// thrown in the NCCL codebase.
-const inline char* getNcclErrorDetailStr(ncclResult_t error) {
+const inline char* getNcclErrorDetailStr(ncclResult_t error, c10::optional<std::string> processGroupFailureReason = c10::nullopt) {
+ // Prioritize failure reason provided by PG NCCL first, as it can abort
+ // communicators when it encounters collective timeouts, etc.
+ if (processGroupFailureReason != c10::nullopt) {
+ return (*processGroupFailureReason).c_str();
+ }
switch (error) {
case ncclUnhandledCudaError:
return "ncclUnhandledCudaError: Call to CUDA function failed.";
#endif
// Macro to throw on a non-successful NCCL return value.
-#define C10D_NCCL_CHECK(cmd) \
+#define C10D_NCCL_CHECK(cmd, failureReason) \
do { \
ncclResult_t result = cmd; \
if (result != ncclSuccess) { \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
- "\n" + getNcclErrorDetailStr(result); \
- TORCH_CHECK(false, err); \
+ "\n" + getNcclErrorDetailStr(result, failureReason); \
+ TORCH_CHECK(false, err); \
} \
} while (0)
class NCCLComm {
public:
explicit NCCLComm(ncclComm_t ncclComm)
- : ncclComm_(ncclComm), aborted_(false), ncclAsyncErr_(ncclSuccess) {}
+ : ncclComm_(ncclComm),
+ aborted_(false),
+ ncclAsyncErr_(ncclSuccess),
+ commFailureReason_(c10::nullopt) {}
NCCLComm() : NCCLComm(nullptr) {}
ncclUniqueId commId) {
auto comm = std::make_shared<NCCLComm>();
C10D_NCCL_CHECK(
- ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank));
+ ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt);
comm->ncclId_ = commId;
comm->rank_ = rank;
return comm;
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
}
- ncclComm_t getNcclComm() {
+ ncclComm_t getNcclComm();
+
+ c10::optional<std::string> getNcclCommFailureReason() const {
std::unique_lock<std::mutex> lock(mutex_);
- if (aborted_) {
- TORCH_CHECK(false,
- "NCCL communicator was aborted on rank " + std::to_string(rank_) +
- ".");
- }
- return ncclComm_;
+ return commFailureReason_;
}
- void ncclCommAbort() {
+ void ncclCommAbort(
+ c10::optional<std::string> commFailureReason = c10::nullopt) {
std::unique_lock<std::mutex> lock(mutex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING
if (aborted_) {
return;
}
- C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_));
+ // Set true failure reason if provided by ProcessGroupNCCL (e.g. work
+ // timeout)
+ commFailureReason_ = commFailureReason;
+
+ C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
aborted_ = true;
ncclComm_ = nullptr;
if (ncclAsyncErr_ != ncclSuccess) {
return ncclAsyncErr_;
}
- C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_));
+ C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_);
return ncclAsyncErr_;
#else
// Always return success, if error checks are disabled.
mutable std::mutex mutex_;
// Rank that this communicator corresponds to.
int rank_;
+ // Optional reason for communicator failure, provided by ProcessGroupNCCL for
+ // better error messaging.
+ c10::optional<std::string> commFailureReason_;
};
} // namespace c10d
#include <c10d/ProcessGroupNCCL.hpp>
+#include <sstream>
#ifdef USE_C10D_NCCL
#include <torch/csrc/cuda/nccl.h>
#include <c10d/Utils.hpp>
+
namespace c10d {
constexpr const char* const kNCCLAbortedCommStoreKey = "NCCLABORTEDCOMM";
AutoNcclGroup() {
(c10::cuda::CUDACachingAllocator::getFreeMutex())->lock();
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
- C10D_NCCL_CHECK(ncclGroupStart());
+ C10D_NCCL_CHECK(ncclGroupStart(), c10::nullopt);
#endif
}
~AutoNcclGroup() noexcept(false) {
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
- C10D_NCCL_CHECK(ncclGroupEnd());
+ C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
#endif
(c10::cuda::CUDACachingAllocator::getFreeMutex())->unlock();
}
// if throwing timed out excepiton without aborting nccl communicators
// here, it was observed that CUDA GPU will have 100% utilization and
// can not run new events successfully.
+
+ std::stringstream ss;
+ ss << *this;
+ auto timeoutErrorMsg = c10::str("Work ", ss.str(), " timed out in call to wait().");
for (const auto& ncclComm : ncclComms_) {
- ncclComm->ncclCommAbort();
+ ncclComm->ncclCommAbort(timeoutErrorMsg);
const auto& storeKey = getNcclAbortedCommStoreKey(
buildNcclUniqueIdStr(ncclComm->getNcclId()));
auto rankStr = std::to_string(rank_);
auto& ncclComms = it.second;
for (const auto& ncclComm : ncclComms) {
- ncclComm->ncclCommAbort();
+ std::string abortReason =
+ c10::str("Process Group destroyed on rank ", rank_);
+ ncclComm->ncclCommAbort(abortReason);
}
}
}
std::make_exception_ptr(std::runtime_error(exceptionMsg));
work.setException(exception_ptr);
for (const auto& ncclComm : work.ncclComms_) {
- ncclComm->ncclCommAbort();
+ ncclComm->ncclCommAbort(exceptionMsg);
abortedCommIds.emplace(buildNcclUniqueIdStr(ncclComm->getNcclId()));
}
}
}
std::exception_ptr ncclErrorException = checkForNCCLErrors(ncclComms);
if (ncclErrorException) {
+ auto exceptionMsg
+ = getExceptionMsgFromExceptionPtr(ncclErrorException);
LOG(INFO)
<< "[Rank " << rank_
<< "] Received NCCL errors for communicators in the cache: \n"
// collectives and throw exceptions if an exception has been set on
// any of the work objects from this thread.
for (const auto& ncclComm : ncclComms) {
- ncclComm->ncclCommAbort();
+ // We are aborting remaining communicators due to an error in
+ // at least one of these communicators, so propagate that reason
+ // for better debugability.
+ ncclComm->ncclCommAbort(exceptionMsg);
// Note that we don't remove the aborted communicators from the
// cache. The reason is that if we do remove the communicator
// from the cache, it is possible that a new collective operation
std::chrono::milliseconds(kWaitForAbortCommStoreKey));
auto val = store_->get(storeKey);
std::string rank(reinterpret_cast<char*>(val.data()), val.size());
- LOG(INFO) << "[Rank " << rank_
+ std::stringstream ss;
+ ss << "[Rank " << rank_
<< "] Found key in store: " << storeKey
<< ", from rank: " << rank
- << ", aborting appropriate communicators";
+ << ". This means that rank has aborted its NCCL communicators previously and is not in a healthy state."
+ << ". Aborting appropriate communicators";
+ std::string abortReason = ss.str();
+ LOG(WARNING) << abortReason;
// Now abort the appropriate communicators.
std::lock_guard<std::mutex> lock(mutex_);
auto it = ncclIdToCommMap_.find(commId);
TORCH_INTERNAL_ASSERT(it != ncclIdToCommMap_.end());
for (const auto& ncclComm : it->second) {
- ncclComm->ncclCommAbort();
+ // The reason we are aborting is because some other ranks have
+ // aborted their communicators originally, so propagate that
+ // reason.
+ ncclComm->ncclCommAbort(abortReason);
}
abortedComms_.emplace(commId);
LOG(INFO) << "[Rank " << rank_
std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) {
for (const auto& ncclComm : ncclComms) {
+ // Prioritize commFailureReason over checkForNcclError() result if
+ // commFailureReason is set.
+ auto commFailureReason = ncclComm->getNcclCommFailureReason();
+ if (commFailureReason != c10::nullopt) {
+ return std::make_exception_ptr(
+ std::runtime_error(
+ c10::str(
+ "NCCL communicator encountered error set by ProcessGroupNCCL: ",
+ *commFailureReason
+ )
+ )
+ );
+ }
ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError();
if (ncclAsyncErr != ncclSuccess) {
return std::make_exception_ptr(std::runtime_error(
// For point-to-point communication, lower rank of the two will get unique id.
if (rank_ == 0 || (isP2POp(opType) && p2pRank == 0)) {
- C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID));
+ C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), c10::nullopt);
}
// For point-to-point communication on the same process, don't need broadcast.
// created before encountering any communication calls. This is why we need
// the following for loop.
for (const auto i : c10::irange(ncclActiveGroupCounter_)) {
- C10D_NCCL_CHECK(ncclGroupEnd());
+ C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
}
// [Note 1] Create the NCCL communicators for each GPU
- C10D_NCCL_CHECK(ncclGroupStart());
+ C10D_NCCL_CHECK(ncclGroupStart(), c10::nullopt);
for (const auto i : c10::irange(devices.size())) {
// GPU world size and GPU rank
}
// [Note 2 ]
- C10D_NCCL_CHECK(ncclGroupEnd());
+ C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
// See [Group Start/End Note]
for (const auto i : c10::irange(ncclActiveGroupCounter_)) {
- C10D_NCCL_CHECK(ncclGroupStart());
+ C10D_NCCL_CHECK(ncclGroupStart(), c10::nullopt);
}
ncclStreams_.emplace(devicesKey, std::move(streamVal));
gpuGuard.set_index(devices[i].index());
at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
C10D_NCCL_CHECK(
- fn(inputs[i], outputs[i], ncclComms[i]->getNcclComm(), ncclStream));
+ fn(inputs[i], outputs[i], ncclComms[i]->getNcclComm(), ncclStream), ncclComms[i]->getNcclCommFailureReason());
}
}
// be 0 or 1.
int p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank;
C10D_NCCL_CHECK(fn(
- tensors[i], ncclComms[i]->getNcclComm(), ncclStream, p2pTargetRank));
+ tensors[i], ncclComms[i]->getNcclComm(), ncclStream, p2pTargetRank), ncclComms[i]->getNcclCommFailureReason());
}
}
void ProcessGroupNCCL::groupStart() {
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
- C10D_NCCL_CHECK(ncclGroupStart());
+ C10D_NCCL_CHECK(ncclGroupStart(), c10::nullopt);
#endif
++ncclActiveGroupCounter_;
}
void ProcessGroupNCCL::groupEnd() {
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
- C10D_NCCL_CHECK(ncclGroupEnd());
+ C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
#endif
--ncclActiveGroupCounter_;
}