[c10d] Provide failure reason from ProcessGroup when aborting NCCL comm (#64241)
authorRohan Varma <rvarm1@fb.com>
Wed, 8 Sep 2021 16:17:49 +0000 (09:17 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 16:19:24 +0000 (09:19 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64241

When things go wrong PG NCCL aborts nccl communicators via `ncclCommAbort`, but one issues is that often the error can be set to `ncclSystemError` (see  https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/NCCLUtils.hpp#L176) when that might not be the true cause of the issue and the actual issue is that some prior work timed out, communicator was aborted on other rank, etc.

This results in a lot of confusion when debugging jobs with a large no. of processes as the current message for ncclSystemError is not very informative: https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/NCCLUtils.hpp#L22

The fix here is to pass in a string exception message from PG NCCL down to `NCCLUtils` which will aim to raise that as the actual issue and not the confusing `ncclSystemError` message.

Test Plan: CI

Reviewed By: pallab-zz, cbalioglu

Differential Revision: D30658855

fbshipit-source-id: 17661dbe0a1bb8cc5b87b637c47634b1f52f54e1

test/distributed/test_c10d_nccl.py
torch/csrc/distributed/c10d/NCCLUtils.cpp
torch/csrc/distributed/c10d/NCCLUtils.hpp
torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

index 1378aa0..7cc8db5 100644 (file)
@@ -638,6 +638,46 @@ class DistributedDataParallelTest(
 
     @requires_nccl()
     @skip_if_lt_x_gpu(2)
+    def test_nccl_propagate_error_reason(self):
+        # Need to use NCCL_BLOCKING_WAIT and not ASYNC_ERROR_HANDLING,
+        # otherwise process will be taken down and we can't check for errors.
+        os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
+        os.environ["NCCL_BLOCKING_WAIT"] = "1"
+        timeout = timedelta(seconds=2)
+        store = c10d.FileStore(self.file_name, self.world_size)
+        pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timeout)
+        pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
+        pg.barrier().wait()
+        # Simulate stuckness in rank 0.
+        if self.rank == 0:
+            pg_gloo.barrier().wait()
+        inp = torch.ones(1).cuda(self.rank)
+
+        if self.rank != 0:
+            # Time out due to rank 0 not calling into allreduce.
+            with self.assertRaises(RuntimeError):
+                pg.allreduce([inp]).wait()
+
+            # Now when nonzero rank attempts to use communicator, original failure reason should be logged.j
+            try:
+                pg.allreduce([torch.ones(2).cuda(self.rank)]).wait()
+            except RuntimeError as e:
+                self.assertTrue("timed out in call to wait()" in str(e))
+                self.assertTrue("TensorShape=[1]" in str(e))
+            else:
+                self.fail("Expected error to be raised!")
+
+            # Unblock rank 0
+            pg_gloo.barrier().wait()
+
+        # TODO: We can also test that if rank 0 attempts to use the communicator,
+        # then we should error out with the info that it was aborted due to
+        # timeout on another rank. Although this would only be the case after
+        # the watchdog has run on the rank, and there is no reliable way
+        # to confirm it has run.
+
+    @requires_nccl()
+    @skip_if_lt_x_gpu(2)
     def test_nccl_backend_multi_device_ids_not_allowed(self):
         int_devices = list(range(torch.cuda.device_count()))
         devices = [torch.device("cuda:" + str(i)) for i in int_devices]
index 9e0566a..0c1dd97 100644 (file)
@@ -6,6 +6,24 @@
 
 namespace c10d {
 
+
+ncclComm_t NCCLComm::getNcclComm() {
+  std::unique_lock<std::mutex> lock(mutex_);
+  if (aborted_) {
+    auto commFailureMsg = commFailureReason_ != c10::nullopt
+        ? c10::str(" Original reason for failure was: ", *commFailureReason_)
+        : "";
+    TORCH_CHECK(
+        false,
+        c10::str(
+            "NCCL communicator was aborted on rank ",
+            rank_,
+            ". ",
+            commFailureMsg));
+  }
+  return ncclComm_;
+}
+
 std::string getNcclVersion() {
   static std::once_flag ncclGetVersionFlag;
   static std::string versionString;
index bd50bba..c505017 100644 (file)
 
 #include <nccl.h>
 #include <c10/util/Exception.h>
+#include <c10/util/Optional.h>
 
 namespace {
 // Provides additional detail into NCCL error codes based on when these are
 // thrown in the NCCL codebase.
-const inline char* getNcclErrorDetailStr(ncclResult_t error) {
+const inline char* getNcclErrorDetailStr(ncclResult_t error, c10::optional<std::string> processGroupFailureReason = c10::nullopt) {
+  // Prioritize failure reason provided by PG NCCL first, as it can abort
+  // communicators when it encounters collective timeouts, etc.
+  if (processGroupFailureReason != c10::nullopt) {
+    return (*processGroupFailureReason).c_str();
+  }
   switch (error) {
     case ncclUnhandledCudaError:
       return "ncclUnhandledCudaError: Call to CUDA function failed.";
@@ -60,14 +66,14 @@ const inline char* getNcclErrorDetailStr(ncclResult_t error) {
 #endif
 
 // Macro to throw on a non-successful NCCL return value.
-#define C10D_NCCL_CHECK(cmd)                                                  \
+#define C10D_NCCL_CHECK(cmd, failureReason)                                                  \
   do {                                                                        \
     ncclResult_t result = cmd;                                                \
     if (result != ncclSuccess) {                                              \
       std::string err = "NCCL error in: " + std::string(__FILE__) + ":" +     \
           std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
-          "\n" + getNcclErrorDetailStr(result);                               \
-      TORCH_CHECK(false, err);                                          \
+          "\n" + getNcclErrorDetailStr(result, failureReason);                \
+      TORCH_CHECK(false, err);                                                \
     }                                                                         \
   } while (0)
 
@@ -96,7 +102,10 @@ std::string ncclGetErrorWithVersion(ncclResult_t error);
 class NCCLComm {
  public:
   explicit NCCLComm(ncclComm_t ncclComm)
-      : ncclComm_(ncclComm), aborted_(false), ncclAsyncErr_(ncclSuccess) {}
+      : ncclComm_(ncclComm),
+        aborted_(false),
+        ncclAsyncErr_(ncclSuccess),
+        commFailureReason_(c10::nullopt) {}
 
   NCCLComm() : NCCLComm(nullptr) {}
 
@@ -122,7 +131,7 @@ class NCCLComm {
       ncclUniqueId commId) {
     auto comm = std::make_shared<NCCLComm>();
     C10D_NCCL_CHECK(
-        ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank));
+        ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt);
     comm->ncclId_ = commId;
     comm->rank_ = rank;
     return comm;
@@ -149,17 +158,15 @@ class NCCLComm {
     std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
   }
 
-  ncclComm_t getNcclComm() {
+  ncclComm_t getNcclComm();
+
+  c10::optional<std::string> getNcclCommFailureReason() const {
     std::unique_lock<std::mutex> lock(mutex_);
-    if (aborted_) {
-      TORCH_CHECK(false,
-          "NCCL communicator was aborted on rank " + std::to_string(rank_) +
-          ".");
-    }
-    return ncclComm_;
+    return commFailureReason_;
   }
 
-  void ncclCommAbort() {
+  void ncclCommAbort(
+      c10::optional<std::string> commFailureReason = c10::nullopt) {
     std::unique_lock<std::mutex> lock(mutex_);
 #ifdef ENABLE_NCCL_ERROR_CHECKING
     if (aborted_) {
@@ -167,7 +174,11 @@ class NCCLComm {
       return;
     }
 
-    C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_));
+    // Set true failure reason if provided by ProcessGroupNCCL (e.g. work
+    // timeout)
+    commFailureReason_ = commFailureReason;
+
+    C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
     aborted_ = true;
     ncclComm_ = nullptr;
 
@@ -192,7 +203,7 @@ class NCCLComm {
     if (ncclAsyncErr_ != ncclSuccess) {
       return ncclAsyncErr_;
     }
-    C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_));
+    C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_);
     return ncclAsyncErr_;
 #else
     // Always return success, if error checks are disabled.
@@ -209,6 +220,9 @@ class NCCLComm {
   mutable std::mutex mutex_;
   // Rank that this communicator corresponds to.
   int rank_;
+  // Optional reason for communicator failure, provided by ProcessGroupNCCL for
+  // better error messaging.
+  c10::optional<std::string> commFailureReason_;
 };
 
 } // namespace c10d
index 9773b35..c3b8cfe 100644 (file)
@@ -1,4 +1,5 @@
 #include <c10d/ProcessGroupNCCL.hpp>
+#include <sstream>
 
 #ifdef USE_C10D_NCCL
 
@@ -20,6 +21,7 @@
 #include <torch/csrc/cuda/nccl.h>
 
 #include <c10d/Utils.hpp>
+
 namespace c10d {
 
 constexpr const char* const kNCCLAbortedCommStoreKey = "NCCLABORTEDCOMM";
@@ -35,12 +37,12 @@ struct AutoNcclGroup {
   AutoNcclGroup() {
     (c10::cuda::CUDACachingAllocator::getFreeMutex())->lock();
 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
-    C10D_NCCL_CHECK(ncclGroupStart());
+    C10D_NCCL_CHECK(ncclGroupStart(), c10::nullopt);
 #endif
   }
   ~AutoNcclGroup() noexcept(false) {
 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
-    C10D_NCCL_CHECK(ncclGroupEnd());
+    C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
 #endif
     (c10::cuda::CUDACachingAllocator::getFreeMutex())->unlock();
   }
@@ -379,8 +381,12 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal(
         // if throwing timed out excepiton without aborting nccl communicators
         // here, it was observed that CUDA GPU will have 100% utilization and
         // can not run new events successfully.
+
+        std::stringstream ss;
+        ss << *this;
+        auto timeoutErrorMsg = c10::str("Work ", ss.str(), " timed out in call to wait().");
         for (const auto& ncclComm : ncclComms_) {
-          ncclComm->ncclCommAbort();
+          ncclComm->ncclCommAbort(timeoutErrorMsg);
           const auto& storeKey = getNcclAbortedCommStoreKey(
               buildNcclUniqueIdStr(ncclComm->getNcclId()));
           auto rankStr = std::to_string(rank_);
@@ -545,7 +551,9 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
       auto& ncclComms = it.second;
 
       for (const auto& ncclComm : ncclComms) {
-        ncclComm->ncclCommAbort();
+        std::string abortReason =
+            c10::str("Process Group destroyed on rank ", rank_);
+        ncclComm->ncclCommAbort(abortReason);
       }
     }
   }
@@ -581,7 +589,7 @@ void ProcessGroupNCCL::abortTimedOutCollectives(
           std::make_exception_ptr(std::runtime_error(exceptionMsg));
       work.setException(exception_ptr);
       for (const auto& ncclComm : work.ncclComms_) {
-        ncclComm->ncclCommAbort();
+        ncclComm->ncclCommAbort(exceptionMsg);
         abortedCommIds.emplace(buildNcclUniqueIdStr(ncclComm->getNcclId()));
       }
     }
@@ -620,6 +628,8 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() {
         }
         std::exception_ptr ncclErrorException = checkForNCCLErrors(ncclComms);
         if (ncclErrorException) {
+          auto exceptionMsg
+            = getExceptionMsgFromExceptionPtr(ncclErrorException);
           LOG(INFO)
               << "[Rank " << rank_
               << "] Received NCCL errors for communicators in the cache: \n"
@@ -635,7 +645,10 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() {
             // collectives and throw exceptions if an exception has been set on
             // any of the work objects from this thread.
             for (const auto& ncclComm : ncclComms) {
-              ncclComm->ncclCommAbort();
+              // We are aborting remaining communicators due to an error in
+              // at least one of these communicators, so propagate that reason
+              // for better debugability.
+              ncclComm->ncclCommAbort(exceptionMsg);
               // Note that we don't remove the aborted communicators from the
               // cache. The reason is that if we do remove the communicator
               // from the cache, it is possible that a new collective operation
@@ -692,17 +705,24 @@ void ProcessGroupNCCL::ncclCommWatchdogInternal() {
                 std::chrono::milliseconds(kWaitForAbortCommStoreKey));
             auto val = store_->get(storeKey);
             std::string rank(reinterpret_cast<char*>(val.data()), val.size());
-            LOG(INFO) << "[Rank " << rank_
+            std::stringstream ss;
+            ss << "[Rank " << rank_
                       << "] Found key in store: " << storeKey
                       << ", from rank: " << rank
-                      << ", aborting appropriate communicators";
+                      << ". This means that rank has aborted its NCCL communicators previously and is not in a healthy state."
+                      << ". Aborting appropriate communicators";
+            std::string abortReason = ss.str();
+            LOG(WARNING) << abortReason;
 
             // Now abort the appropriate communicators.
             std::lock_guard<std::mutex> lock(mutex_);
             auto it = ncclIdToCommMap_.find(commId);
             TORCH_INTERNAL_ASSERT(it != ncclIdToCommMap_.end());
             for (const auto& ncclComm : it->second) {
-              ncclComm->ncclCommAbort();
+              // The reason we are aborting is because some other ranks have
+              // aborted their communicators originally, so propagate that
+              // reason.
+              ncclComm->ncclCommAbort(abortReason);
             }
             abortedComms_.emplace(commId);
             LOG(INFO) << "[Rank " << rank_
@@ -773,6 +793,19 @@ std::exception_ptr ProcessGroupNCCL::checkForNCCLErrors(
 std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal(
     const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) {
   for (const auto& ncclComm : ncclComms) {
+    // Prioritize commFailureReason over checkForNcclError() result if
+    // commFailureReason is set.
+    auto commFailureReason = ncclComm->getNcclCommFailureReason();
+    if (commFailureReason != c10::nullopt) {
+        return std::make_exception_ptr(
+          std::runtime_error(
+            c10::str(
+              "NCCL communicator encountered error set by ProcessGroupNCCL: ",
+               *commFailureReason
+            )
+          )
+        );
+    }
     ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError();
     if (ncclAsyncErr != ncclSuccess) {
       return std::make_exception_ptr(std::runtime_error(
@@ -855,7 +888,7 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
 
   // For point-to-point communication, lower rank of the two will get unique id.
   if (rank_ == 0 || (isP2POp(opType) && p2pRank == 0)) {
-    C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID));
+    C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), c10::nullopt);
   }
 
   // For point-to-point communication on the same process, don't need broadcast.
@@ -886,11 +919,11 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
   // created before encountering any communication calls. This is why we need
   // the following for loop.
   for (const auto i : c10::irange(ncclActiveGroupCounter_)) {
-    C10D_NCCL_CHECK(ncclGroupEnd());
+    C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
   }
 
   // [Note 1] Create the NCCL communicators for each GPU
-  C10D_NCCL_CHECK(ncclGroupStart());
+  C10D_NCCL_CHECK(ncclGroupStart(), c10::nullopt);
 
   for (const auto i : c10::irange(devices.size())) {
     // GPU world size and GPU rank
@@ -921,11 +954,11 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
   }
 
   // [Note 2 ]
-  C10D_NCCL_CHECK(ncclGroupEnd());
+  C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
 
   // See [Group Start/End Note]
   for (const auto i : c10::irange(ncclActiveGroupCounter_)) {
-    C10D_NCCL_CHECK(ncclGroupStart());
+    C10D_NCCL_CHECK(ncclGroupStart(), c10::nullopt);
   }
 
   ncclStreams_.emplace(devicesKey, std::move(streamVal));
@@ -1143,7 +1176,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
       gpuGuard.set_index(devices[i].index());
       at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
       C10D_NCCL_CHECK(
-          fn(inputs[i], outputs[i], ncclComms[i]->getNcclComm(), ncclStream));
+          fn(inputs[i], outputs[i], ncclComms[i]->getNcclComm(), ncclStream), ncclComms[i]->getNcclCommFailureReason());
     }
   }
 
@@ -1244,7 +1277,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::pointToPoint(
       // be 0 or 1.
       int p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank;
       C10D_NCCL_CHECK(fn(
-          tensors[i], ncclComms[i]->getNcclComm(), ncclStream, p2pTargetRank));
+          tensors[i], ncclComms[i]->getNcclComm(), ncclStream, p2pTargetRank), ncclComms[i]->getNcclCommFailureReason());
     }
   }
 
@@ -1865,14 +1898,14 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::recv(
 
 void ProcessGroupNCCL::groupStart() {
 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
-  C10D_NCCL_CHECK(ncclGroupStart());
+  C10D_NCCL_CHECK(ncclGroupStart(), c10::nullopt);
 #endif
   ++ncclActiveGroupCounter_;
 }
 
 void ProcessGroupNCCL::groupEnd() {
 #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
-  C10D_NCCL_CHECK(ncclGroupEnd());
+  C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
 #endif
   --ncclActiveGroupCounter_;
 }