From f8455ed754196cd554cbaaaaa440dd178383154d Mon Sep 17 00:00:00 2001 From: Jane Wang Date: Tue, 11 Dec 2018 21:03:13 -0800 Subject: [PATCH] add gloo support for gather on GPU (#14916) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14916 as titled Reviewed By: pietern Differential Revision: D13267832 fbshipit-source-id: 3b89d08af93f74941f17ff892c33fc2a4a023c19 --- test/test_c10d.py | 57 +++++++++++++- torch/lib/c10d/ProcessGroupGloo.cpp | 118 +++++++++++++++++++++++++++- 2 files changed, 168 insertions(+), 7 deletions(-) diff --git a/test/test_c10d.py b/test/test_c10d.py index bf28d2d843..329a7be8db 100644 --- a/test/test_c10d.py +++ b/test/test_c10d.py @@ -929,13 +929,13 @@ class ProcessGroupGlooTest(MultiProcessTestCase): opts.rootRank = (self.rank + 1) % self.world_size pg.gather([[t1] * self.world_size], [t1], opts) - def test_gather_basics(self): + def _test_gather_basics(self, fn): store = c10d.FileStore(self.file.name, self.world_size) pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts()) # Preallocate tensors for input/output - input = [torch.Tensor([self.rank])] - outputs = [torch.Tensor([-1]) for _ in range(self.world_size)] + input = [fn(torch.Tensor([self.rank]))] + outputs = [fn(torch.Tensor([-1])) for _ in range(self.world_size)] # Take turns being the gather root and accumulate work items work = [] @@ -954,6 +954,57 @@ class ProcessGroupGlooTest(MultiProcessTestCase): if i == self.rank: self.assertEqual(expected, outputs) + def test_gather_basics(self): + self._test_gather_basics(lambda t: t.clone()) + + @skip_if_not_multigpu + def test_gather_basics_cuda(self): + self._test_gather_basics(lambda t: t.clone().cuda()) + + def _test_gather_stress(self, inputs, fn): + store = c10d.FileStore(self.file.name, self.world_size) + pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts(threads=8)) + work_handles = [] + outputs = [ + [ + [fn(torch.Tensor([-1])) for _ in range(self.world_size)] + ] for _ in range(len(inputs)) + ] + expected_outputs = [ + [ + [torch.Tensor([i + j]) for j in range(self.world_size)] + ] for i in range(len(inputs)) + ] + for i in range(len(inputs)): + for root in range(self.world_size): + opts = c10d.GatherOptions() + opts.rootRank = root + if root == self.rank: + work = pg.gather(outputs[i], [fn(inputs[i])], opts) + else: + work = pg.gather([], [fn(inputs[i])], opts) + work_handles.append(work) + + for i, work_handle in enumerate(work_handles): + work_handle.wait() + iter = i // self.world_size + root = i % self.world_size + if root == self.rank: + self.assertEqual( + expected_outputs[iter], + outputs[iter], + "Mismatch in iteration %d for root %d" % (iter, root) + ) + + def test_gather_stress(self): + inputs = [torch.Tensor([i + self.rank]) for i in range(1000)] + self._test_gather_stress(inputs, lambda t: t.clone()) + + @skip_if_not_multigpu + def test_gather_stress_cuda(self): + inputs = [torch.Tensor([i + self.rank]).cuda() for i in range(1000)] + self._test_gather_stress(inputs, lambda t: t.clone().cuda()) + def test_allgather_checks(self): store = c10d.FileStore(self.file.name, self.world_size) pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts()) diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp index 43ed7fbe25..2f1493d2b1 100644 --- a/torch/lib/c10d/ProcessGroupGloo.cpp +++ b/torch/lib/c10d/ProcessGroupGloo.cpp @@ -1008,7 +1008,9 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { const int root; const uint32_t tag; - void run() override { + void gather( + std::vector>& outputs, + std::vector& inputs) { const auto scalarType = inputs[0].type().scalarType(); gloo::GatherOptions opts(context); opts.setRoot(root); @@ -1033,8 +1035,95 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { } } } + + void run() override { + gather(outputs, inputs); + } }; +#ifdef USE_CUDA + +// Note: current CUDA implementation holds the assumptions: +// - inputs.size() is 1 +// - outputs.size() is 1 +// - the size of the nested output tensors is world size, i.e., +// outputs[0].size, is world size +class AsyncGatherCUDAWork : public AsyncGatherWork { + public: + AsyncGatherCUDAWork( + const std::shared_ptr& context, + std::vector>& outputs, + std::vector& inputs, + int root, + uint32_t tag) + : AsyncGatherWork(context, outputs, inputs, root, tag) { + initializeStreamsEvents(inputs, inputStreams, inputEvents); + initializeStreamsEvents(outputs, outputStreams, outputEvents); + + // Kick off copy from CUDA tensors to pinned CPU tensors. + tmpInputs.reserve(inputs.size()); + at::cuda::OptionalCUDAStreamGuard guard; + for (size_t i = 0; i < inputs.size(); i++) { + guard.reset_stream(inputStreams[i]); + tmpInputs.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); + } + + tmpOutputs.resize(outputs.size()); + for (size_t i = 0; i < outputs.size(); i++) { + tmpOutputs[i].reserve(outputs[i].size()); + for (size_t j = 0; j < outputs[i].size(); j++) { + tmpOutputs[i].push_back(pinnedLike(outputs[i][j])); + } + } + } + + void run() override { + // Synchronize with copy operations. + at::cuda::OptionalCUDAGuard device_guard; + for (size_t i = 0; i < inputs.size(); i++) { + device_guard.set_index(inputs[i].get_device()); + AT_CUDA_CHECK(cudaStreamSynchronize(inputStreams[i])); + } + + for (size_t i = 0; i < outputs.size(); i++) { + device_guard.set_index(outputs[i][0].get_device()); + AT_CUDA_CHECK(cudaStreamSynchronize(outputStreams[i])); + } + + // Run gather on host side tensors. + gather(tmpOutputs, tmpInputs); + + // Kick off copy back to the CUDA tensors. + at::cuda::OptionalCUDAStreamGuard stream_guard; + for (size_t i = 0; i < outputs.size(); i++) { + stream_guard.reset_stream(outputStreams[i]); + for (size_t j = 0; j < outputs[i].size(); j++) { + outputs[i][j].copy_(tmpOutputs[i][j], /* non_blocking */ true); + } + outputEvents[i].record(outputStreams[i]); + } + } + + void synchronize() override { + // Synchronize with the copy back to CUDA tensors. + at::cuda::OptionalCUDAGuard guard; + for (size_t i = 0; i < outputs.size(); i++) { + guard.set_index(static_cast(outputs[i][0].get_device())); + outputEvents[i].block(at::cuda::getCurrentCUDAStream()); + } + } + + std::vector tmpInputs; + std::vector inputStreams; + std::vector inputEvents; + + std::vector> tmpOutputs; + std::vector outputStreams; + std::vector outputEvents; +}; + +#endif + } // namespace std::shared_ptr ProcessGroupGloo::gather( @@ -1048,7 +1137,6 @@ std::shared_ptr ProcessGroupGloo::gather( assertRootRank(invalidArgument, opts.rootRank, size_); assertSingleElementInput(invalidArgument, inputs); assertDense(invalidArgument, inputs); - assertCPU(invalidArgument, inputs); if (getRank() == opts.rootRank) { if (outputs.size() != 1 || @@ -1067,8 +1155,30 @@ std::shared_ptr ProcessGroupGloo::gather( } } - auto work = std::make_shared( - contexts_[0], outputs, inputs, opts.rootRank, nextTag()); + const auto& device = inputs[0].device(); + switch (device.type()) { + case at::kCPU: +#ifdef USE_CUDA + case at::kCUDA: +#endif + break; + default: + invalidArgument("unsupported device type"); + } + + std::shared_ptr work; + auto& context = contexts_[0]; + if (device.type() == at::kCPU) { + work = std::make_shared( + context, outputs, inputs, opts.rootRank, nextTag()); +#ifdef USE_CUDA + } else if (device.type() == at::kCUDA) { + work = std::make_shared( + context, outputs, inputs, opts.rootRank, nextTag()); +#endif + } else { + throw std::runtime_error("Invalid backend"); + } enqueue(work); return work; } -- 2.34.1