opts.rootRank = (self.rank + 1) % self.world_size
pg.gather([[t1] * self.world_size], [t1], opts)
- def test_gather_basics(self):
+ def _test_gather_basics(self, fn):
store = c10d.FileStore(self.file.name, self.world_size)
pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts())
# Preallocate tensors for input/output
- input = [torch.Tensor([self.rank])]
- outputs = [torch.Tensor([-1]) for _ in range(self.world_size)]
+ input = [fn(torch.Tensor([self.rank]))]
+ outputs = [fn(torch.Tensor([-1])) for _ in range(self.world_size)]
# Take turns being the gather root and accumulate work items
work = []
if i == self.rank:
self.assertEqual(expected, outputs)
+ def test_gather_basics(self):
+ self._test_gather_basics(lambda t: t.clone())
+
+ @skip_if_not_multigpu
+ def test_gather_basics_cuda(self):
+ self._test_gather_basics(lambda t: t.clone().cuda())
+
+ def _test_gather_stress(self, inputs, fn):
+ store = c10d.FileStore(self.file.name, self.world_size)
+ pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts(threads=8))
+ work_handles = []
+ outputs = [
+ [
+ [fn(torch.Tensor([-1])) for _ in range(self.world_size)]
+ ] for _ in range(len(inputs))
+ ]
+ expected_outputs = [
+ [
+ [torch.Tensor([i + j]) for j in range(self.world_size)]
+ ] for i in range(len(inputs))
+ ]
+ for i in range(len(inputs)):
+ for root in range(self.world_size):
+ opts = c10d.GatherOptions()
+ opts.rootRank = root
+ if root == self.rank:
+ work = pg.gather(outputs[i], [fn(inputs[i])], opts)
+ else:
+ work = pg.gather([], [fn(inputs[i])], opts)
+ work_handles.append(work)
+
+ for i, work_handle in enumerate(work_handles):
+ work_handle.wait()
+ iter = i // self.world_size
+ root = i % self.world_size
+ if root == self.rank:
+ self.assertEqual(
+ expected_outputs[iter],
+ outputs[iter],
+ "Mismatch in iteration %d for root %d" % (iter, root)
+ )
+
+ def test_gather_stress(self):
+ inputs = [torch.Tensor([i + self.rank]) for i in range(1000)]
+ self._test_gather_stress(inputs, lambda t: t.clone())
+
+ @skip_if_not_multigpu
+ def test_gather_stress_cuda(self):
+ inputs = [torch.Tensor([i + self.rank]).cuda() for i in range(1000)]
+ self._test_gather_stress(inputs, lambda t: t.clone().cuda())
+
def test_allgather_checks(self):
store = c10d.FileStore(self.file.name, self.world_size)
pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts())
const int root;
const uint32_t tag;
- void run() override {
+ void gather(
+ std::vector<std::vector<at::Tensor>>& outputs,
+ std::vector<at::Tensor>& inputs) {
const auto scalarType = inputs[0].type().scalarType();
gloo::GatherOptions opts(context);
opts.setRoot(root);
}
}
}
+
+ void run() override {
+ gather(outputs, inputs);
+ }
};
+#ifdef USE_CUDA
+
+// Note: current CUDA implementation holds the assumptions:
+// - inputs.size() is 1
+// - outputs.size() is 1
+// - the size of the nested output tensors is world size, i.e.,
+// outputs[0].size, is world size
+class AsyncGatherCUDAWork : public AsyncGatherWork {
+ public:
+ AsyncGatherCUDAWork(
+ const std::shared_ptr<gloo::Context>& context,
+ std::vector<std::vector<at::Tensor>>& outputs,
+ std::vector<at::Tensor>& inputs,
+ int root,
+ uint32_t tag)
+ : AsyncGatherWork(context, outputs, inputs, root, tag) {
+ initializeStreamsEvents(inputs, inputStreams, inputEvents);
+ initializeStreamsEvents(outputs, outputStreams, outputEvents);
+
+ // Kick off copy from CUDA tensors to pinned CPU tensors.
+ tmpInputs.reserve(inputs.size());
+ at::cuda::OptionalCUDAStreamGuard guard;
+ for (size_t i = 0; i < inputs.size(); i++) {
+ guard.reset_stream(inputStreams[i]);
+ tmpInputs.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true));
+ }
+
+ tmpOutputs.resize(outputs.size());
+ for (size_t i = 0; i < outputs.size(); i++) {
+ tmpOutputs[i].reserve(outputs[i].size());
+ for (size_t j = 0; j < outputs[i].size(); j++) {
+ tmpOutputs[i].push_back(pinnedLike(outputs[i][j]));
+ }
+ }
+ }
+
+ void run() override {
+ // Synchronize with copy operations.
+ at::cuda::OptionalCUDAGuard device_guard;
+ for (size_t i = 0; i < inputs.size(); i++) {
+ device_guard.set_index(inputs[i].get_device());
+ AT_CUDA_CHECK(cudaStreamSynchronize(inputStreams[i]));
+ }
+
+ for (size_t i = 0; i < outputs.size(); i++) {
+ device_guard.set_index(outputs[i][0].get_device());
+ AT_CUDA_CHECK(cudaStreamSynchronize(outputStreams[i]));
+ }
+
+ // Run gather on host side tensors.
+ gather(tmpOutputs, tmpInputs);
+
+ // Kick off copy back to the CUDA tensors.
+ at::cuda::OptionalCUDAStreamGuard stream_guard;
+ for (size_t i = 0; i < outputs.size(); i++) {
+ stream_guard.reset_stream(outputStreams[i]);
+ for (size_t j = 0; j < outputs[i].size(); j++) {
+ outputs[i][j].copy_(tmpOutputs[i][j], /* non_blocking */ true);
+ }
+ outputEvents[i].record(outputStreams[i]);
+ }
+ }
+
+ void synchronize() override {
+ // Synchronize with the copy back to CUDA tensors.
+ at::cuda::OptionalCUDAGuard guard;
+ for (size_t i = 0; i < outputs.size(); i++) {
+ guard.set_index(static_cast<at::DeviceIndex>(outputs[i][0].get_device()));
+ outputEvents[i].block(at::cuda::getCurrentCUDAStream());
+ }
+ }
+
+ std::vector<at::Tensor> tmpInputs;
+ std::vector<at::cuda::CUDAStream> inputStreams;
+ std::vector<at::cuda::CUDAEvent> inputEvents;
+
+ std::vector<std::vector<at::Tensor>> tmpOutputs;
+ std::vector<at::cuda::CUDAStream> outputStreams;
+ std::vector<at::cuda::CUDAEvent> outputEvents;
+};
+
+#endif
+
} // namespace
std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::gather(
assertRootRank(invalidArgument, opts.rootRank, size_);
assertSingleElementInput(invalidArgument, inputs);
assertDense(invalidArgument, inputs);
- assertCPU(invalidArgument, inputs);
if (getRank() == opts.rootRank) {
if (outputs.size() != 1 ||
}
}
- auto work = std::make_shared<AsyncGatherWork>(
- contexts_[0], outputs, inputs, opts.rootRank, nextTag());
+ const auto& device = inputs[0].device();
+ switch (device.type()) {
+ case at::kCPU:
+#ifdef USE_CUDA
+ case at::kCUDA:
+#endif
+ break;
+ default:
+ invalidArgument("unsupported device type");
+ }
+
+ std::shared_ptr<AsyncGatherWork> work;
+ auto& context = contexts_[0];
+ if (device.type() == at::kCPU) {
+ work = std::make_shared<AsyncGatherWork>(
+ context, outputs, inputs, opts.rootRank, nextTag());
+#ifdef USE_CUDA
+ } else if (device.type() == at::kCUDA) {
+ work = std::make_shared<AsyncGatherCUDAWork>(
+ context, outputs, inputs, opts.rootRank, nextTag());
+#endif
+ } else {
+ throw std::runtime_error("Invalid backend");
+ }
enqueue(work);
return work;
}