From: Shen Li Date: Thu, 28 Mar 2019 22:05:53 +0000 (-0700) Subject: Fix NCCL/Gloo process groups and DDP stream sync bug (#18465) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~577 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=aea8ee1f6831a7704044d85df7c24e041e2363f5;p=platform%2Fupstream%2Fpytorch.git Fix NCCL/Gloo process groups and DDP stream sync bug (#18465) Summary: DDP with NCCL backend uses a [worker stream](https://github.com/pytorch/pytorch/blob/d3eb941ed96774efb8d89a0b20c9e49807ea85a7/torch/csrc/distributed/c10d/ddp.cpp#L142) to flatten grand batch tensors, and passes the flattened tensor to [another stream](https://github.com/pytorch/pytorch/blob/d3eb941ed96774efb8d89a0b20c9e49807ea85a7/torch/lib/c10d/ProcessGroupNCCL.cpp#L379) to conduct ncclAllReduce. The flattened tensor has to record the ncclAllReduce stream, otherwise multiple streams might access the same memory space. cc ppwwyyxx Pull Request resolved: https://github.com/pytorch/pytorch/pull/18465 Differential Revision: D14613449 Pulled By: mrshenli fbshipit-source-id: b62773732552d12cc87b7adeb6897e9e11753ea9 --- diff --git a/torch/csrc/distributed/c10d/ddp.cpp b/torch/csrc/distributed/c10d/ddp.cpp index 6d33686..7673431 100644 --- a/torch/csrc/distributed/c10d/ddp.cpp +++ b/torch/csrc/distributed/c10d/ddp.cpp @@ -145,8 +145,18 @@ std::tuple, at::Tensor> queueReduction( events[devIdx].record(); workerStreams.push_back( at::cuda::getStreamFromPool(false, devices[devIdx])); - // Let the worker stream to wait for the default stream + // Let worker streams to wait for default streams to make sure worker + // streams do not touch `gradsBatch` until all pending ops to create + // `gradBatch` finish. events[devIdx].block(workerStreams.back()); + + // Input `gradsBatch` are created on current streams and used in worker + // streams. Hence, they must record worker streams to prevent being + // freed before their worker stream ops finish. + for (at::Tensor& grad : gradsBatch[devIdx]) { + c10::cuda::CUDACachingAllocator::recordStream( + grad.storage().data(), workerStreams.back()); + } } // Stream guards, now the current stream is the worker stream @@ -179,6 +189,15 @@ void syncReduction( // and intra-node reduce to be operated on this worker stream to // improve performance at::cuda::CUDAStream workerStream = at::cuda::getStreamFromPool(); + + // Input `gradsBatch` are created on the current stream and used on the worker + // stream. Hence, they must record worker streams to prevent being freed + // before their worker stream ops finish. + for (at::Tensor& grad : gradsBatch) { + c10::cuda::CUDACachingAllocator::recordStream( + grad.storage().data(), workerStream); + } + at::cuda::CUDAStreamGuard cudaGuard(workerStream); // Let the worker stream wait on the reduction stream diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp index 140c85e..d5c674c 100644 --- a/torch/lib/c10d/ProcessGroupGloo.cpp +++ b/torch/lib/c10d/ProcessGroupGloo.cpp @@ -10,10 +10,11 @@ #ifdef USE_CUDA #include -#include -#include #include #include +#include +#include +#include #endif #include @@ -150,6 +151,11 @@ void initializeStreamsEvents( /* isHighPriority */ true, tensors[i].device().index())); // Ensure the new stream is synchronized with the current stream. events[i].block(streams[i]); + + // `tensors` are created on a different stream. Hence, they must record + // new streams in this Work to prevent being freed before the Work finishes. + c10::cuda::CUDACachingAllocator::recordStream( + tensors[i].storage().data(), streams[i]); } } @@ -187,6 +193,14 @@ void initializeStreamsEvents( /* isHighPriority */ true, tensors[i][0].device().index())); // Ensure the new stream is synchronized with the current stream. events[i].block(streams[i]); + + for (at::Tensor& tensor : tensors[i]) { + // `tensors` are created on a different stream. Hence, they must record + // new streams in this Work to prevent being freed before the Work + // finishes. + c10::cuda::CUDACachingAllocator::recordStream( + tensor.storage().data(), streams[i]); + } } } diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 5eedb89..486d2ce 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -66,7 +66,19 @@ std::vector getDeviceList(const std::vector& tensors) { return res; } -// Helper that lets the input ncclStreams to wait for the current stream +// [Sync Streams] Helper that lets the input ncclStreams to wait for the current +// stream. NCCL communications run on ncclStreams, but input tensors are +// allocated on different streams (i.e., current streams). Communications on +// ncclStreams cannot start before pending input tensor ops on current streams +// finish. Otherwise, ops on two streams might read/write same tensors +// concurrently. +// +// The synchronization above alone is not enough. We also need to make sure +// input tensors are not freed before their usages on ncclStreams finish. This +// can be achieved by calling c10::cuda::CUDACachingAllocator::recordStream, +// which remembers the usage stream (ncclStream), creates an event on the usage +// stream when GC attempts to free the input tensor, and delays GC until that +// event is done. void syncStreams( const std::vector& devices, std::vector& ncclEvents, @@ -361,7 +373,7 @@ std::shared_ptr ProcessGroupNCCL::allreduce( auto key = getKeyFromDevices(devices); auto& ncclComms = getNCCLComm(key, devices); - // First let NCCL streams wait for THC stream + // First let NCCL streams wait for input tensors allocation streams syncStreams(devices, ncclEvents_[key], ncclStreams_[key]); // Work itself will create the CUDA events on all GPUs of tensors @@ -378,6 +390,12 @@ std::shared_ptr ProcessGroupNCCL::allreduce( gpuGuard.set_index(devices[i].index()); at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; + // Input `tensors` are created on a worker stream and used in different + // ncclStream. Hence, `tensors` must record the ncclStream to prevent being + // freed before ncclAllReduce finishes. See [Sync Streams]. + c10::cuda::CUDACachingAllocator::recordStream( + tensors[i].storage().data(), ncclStream); + C10D_NCCL_CHECK(ncclAllReduce( tensors[i].data_ptr(), tensors[i].data_ptr(), @@ -427,6 +445,12 @@ std::shared_ptr ProcessGroupNCCL::broadcast( // root rank of the the GPU int root = opts.rootRank * tensors.size() + opts.rootTensor; + // Input `tensors` are created on worker streams and used in different + // ncclStreams. Hence, `tensors` must record ncclStreams to prevent being + // freed before ncclBcast finishes. See [Sync Streams]. + c10::cuda::CUDACachingAllocator::recordStream( + tensors[i].storage().data(), ncclStream); + C10D_NCCL_CHECK(ncclBcast( tensors[i].data_ptr(), tensors[i].numel(), @@ -475,6 +499,12 @@ std::shared_ptr ProcessGroupNCCL::reduce( // root rank of the the GPU int root = opts.rootRank * tensors.size() + opts.rootTensor; + // Input `tensors` are created on worker streams and used in different + // ncclStreams. Hence, `tensors` must record ncclStreams to prevent being + // freed before ncclReduce finishes. See [Sync Streams]. + c10::cuda::CUDACachingAllocator::recordStream( + tensors[i].storage().data(), ncclStream); + C10D_NCCL_CHECK(ncclReduce( tensors[i].data_ptr(), tensors[i].data_ptr(), @@ -543,6 +573,16 @@ std::shared_ptr ProcessGroupNCCL::allgather( at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; + // Input `inputTensors` and `flattenOutputTensors` are created on worker + // streams and used in different ncclStreams. Hence, `tensors` must record + // ncclStreams to prevent beingfreed before ncclReduce finishes. + // See [Sync Streams]. + c10::cuda::CUDACachingAllocator::recordStream( + inputTensors[i].storage().data(), ncclStream); + + c10::cuda::CUDACachingAllocator::recordStream( + flattenOutputTensors[i].storage().data(), ncclStream); + C10D_NCCL_CHECK(ncclAllGather( inputTensors[i].data_ptr(), flattenOutputTensors[i].data_ptr(), @@ -559,6 +599,10 @@ std::shared_ptr ProcessGroupNCCL::allgather( at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; at::cuda::CUDAStreamGuard guard(ncclStream); for (size_t j = 0; j < outputTensors[0].size(); ++j) { + // See [Sync Streams]. + c10::cuda::CUDACachingAllocator::recordStream( + outputTensors[i][i].storage().data(), ncclStream); + outputTensors[i][j].copy_(flattenOutputTensors[i][j], true); } }