add gloo allgather support on GPU (#14576)
authorJane Wang <janewang@fb.com>
Mon, 10 Dec 2018 22:29:41 +0000 (14:29 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 10 Dec 2018 22:32:54 +0000 (14:32 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14576

as titled

Reviewed By: pietern

Differential Revision: D13266063

fbshipit-source-id: e262f77d63724a7504a7112907bbfba49612fe75

test/test_c10d.py
torch/lib/c10d/ProcessGroupGloo.cpp

index 9f477c9..84a7ec2 100644 (file)
@@ -938,18 +938,18 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
         with self.assertRaisesRegex(ValueError, "invalid tensor size"):
             pg.allgather([([t1, t3] * (self.world_size))[:self.world_size]], [t1])
 
-    def test_allgather_basics(self):
+    def _test_allgather_basics(self, fn):
         store = c10d.FileStore(self.file.name, self.world_size)
         pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts())
 
         # Run with N input tensor per rank
         for n in [1, 2, 3]:
             input = [
-                torch.Tensor([n * self.rank + i]) for i in range(n)
+                fn(torch.Tensor([n * self.rank + i])) for i in range(n)
             ]
             output = [
                 [
-                    torch.Tensor([-1]) for _ in range(n * self.world_size)
+                    fn(torch.Tensor([-1])) for _ in range(n * self.world_size)
                 ] for _ in range(n)
             ]
             expected_output = [
@@ -961,6 +961,48 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
             work.wait()
             self.assertEqual(expected_output, output)
 
+    def test_allgather_basics(self):
+        self._test_allgather_basics(lambda t: t.clone())
+
+    @skip_if_not_multigpu
+    def test_allgather_basics_cuda(self):
+        self._test_allgather_basics(lambda t: t.clone().cuda())
+
+    def _test_allgather_stress(self, inputs, fn):
+        store = c10d.FileStore(self.file.name, self.world_size)
+        pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts(threads=8))
+        work_handles = []
+        outputs = [
+            [
+                [fn(torch.Tensor([-1])) for _ in range(self.world_size)]
+            ] for _ in range(len(inputs))
+        ]
+        expected_outputs = [
+            [
+                [torch.Tensor([i + j]) for j in range(self.world_size)]
+            ] for i in range(len(inputs))
+        ]
+        for i in range(len(inputs)):
+            work = pg.allgather(outputs[i], [fn(inputs[i])])
+            work_handles.append(work)
+
+        for i, work_handle in enumerate(work_handles):
+            work_handle.wait()
+            self.assertEqual(
+                expected_outputs[i],
+                outputs[i],
+                "Mismatch in iteration %d" % i
+            )
+
+    def test_allgather_stress(self):
+        inputs = [torch.Tensor([i + self.rank]) for i in range(1000)]
+        self._test_allgather_stress(inputs, lambda t: t.clone())
+
+    @skip_if_not_multigpu
+    def test_allgather_stress_cuda(self):
+        inputs = [torch.Tensor([i + self.rank]).cuda() for i in range(1000)]
+        self._test_allgather_stress(inputs, lambda t: t.clone().cuda())
+
     def test_reduce_checks(self):
         store = c10d.FileStore(self.file.name, self.world_size)
         pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts())
index c12c0d8..293b080 100644 (file)
@@ -128,26 +128,63 @@ at::Tensor pinnedLike(at::Tensor& tensor) {
 }
 
 // This function initializes a vector of CUDA streams, one for every
-// tensor in the input vector, and ensures that these streams are
+// tensor in the input tensor vector, and ensures that these streams are
 // synchronized with the current default streams. This is needed so
 // that new work on the new streams is serialized w.r.t. all operations
-// on the tensors in the input vector.
+// on the tensors.
 void initializeStreamsEvents(
-    std::vector<at::Tensor>& inputs,
+    std::vector<at::Tensor>& tensors,
+    std::vector<at::cuda::CUDAStream>& streams,
+    std::vector<at::cuda::CUDAEvent>& events) {
+  at::cuda::OptionalCUDAGuard guard;
+  streams.reserve(tensors.size());
+  events.resize(tensors.size());
+  for (size_t i = 0; i < tensors.size(); i++) {
+    guard.set_index(tensors[i].device().index());
+    // Record event on current stream
+    events[i].record(at::cuda::getCurrentCUDAStream());
+    // Get a non-default stream to execute asynchronous CUDA operations
+    // on for this device. This ensures that the default stream used
+    // by the caller is not occupied by c10d related operations.
+    streams.push_back(at::cuda::getStreamFromPool(
+        /* isHighPriority */ true, tensors[i].device().index()));
+    // Ensure the new stream is synchronized with the current stream.
+    events[i].block(streams[i]);
+  }
+}
+
+// This function initializes a vector of CUDA streams, one per device,
+// and ensures that these streams are synchronized with the current default
+// streams. It is assumed that the tensors in the nested tensor vectors are
+// on the same device.
+void initializeStreamsEvents(
+    std::vector<std::vector<at::Tensor>>& tensors,
     std::vector<at::cuda::CUDAStream>& streams,
     std::vector<at::cuda::CUDAEvent>& events) {
+  // Ensure that the tensors in the nested tensor vectors are on the same
+  // device.
+  for (size_t i = 0; i < tensors.size(); i++) {
+    auto device_id = tensors[i][0].device().index();
+    for (size_t j = 1; j < tensors[i].size(); j++) {
+      if (tensors[i][j].device().index() != device_id) {
+        throw std::runtime_error(
+            "tensors in the nested tensor vectors need to be on the same device");
+      }
+    }
+  }
+
   at::cuda::OptionalCUDAGuard guard;
-  streams.reserve(inputs.size());
-  events.resize(inputs.size());
-  for (size_t i = 0; i < inputs.size(); i++) {
-    guard.set_index(inputs[i].get_device());
+  streams.reserve(tensors.size());
+  events.resize(tensors.size());
+  for (size_t i = 0; i < tensors.size(); i++) {
+    guard.set_index(tensors[i][0].device().index());
     // Record event on current stream
     events[i].record(at::cuda::getCurrentCUDAStream());
     // Get a non-default stream to execute asynchronous CUDA operations
-    // on for this input. This ensures that the default stream used
+    // on for this output. This ensures that the default stream used
     // by the caller is not occupied by c10d related operations.
     streams.push_back(at::cuda::getStreamFromPool(
-        /* isHighPriority */ true, inputs[i].get_device()));
+        /* isHighPriority */ true, tensors[i][0].device().index()));
     // Ensure the new stream is synchronized with the current stream.
     events[i].block(streams[i]);
   }
@@ -379,7 +416,7 @@ class AsyncBroadcastCUDAWork : public AsyncBroadcastWork {
 
     // Synchronize with the copy back to CUDA tensors.
     for (size_t i = 0; i < inputs.size(); i++) {
-      guard.set_index(inputs[i].get_device());
+      guard.set_index(inputs[i].device().index());
       events[i].block(at::cuda::getCurrentCUDAStream());
     }
   }
@@ -510,7 +547,7 @@ class AsyncAllreduceCUDAWork : public AsyncAllreduceWork {
     // Synchronize with copy operations.
     at::cuda::OptionalCUDAGuard device_guard;
     for (size_t i = 0; i < inputs.size(); i++) {
-      device_guard.set_index(inputs[i].get_device());
+      device_guard.set_index(inputs[i].device().index());
       AT_CUDA_CHECK(cudaStreamSynchronize(streams[i]));
     }
 
@@ -534,7 +571,7 @@ class AsyncAllreduceCUDAWork : public AsyncAllreduceWork {
     // Synchronize with the copy back to CUDA tensors.
     at::cuda::OptionalCUDAGuard guard;
     for (size_t i = 0; i < inputs.size(); i++) {
-      guard.set_index(static_cast<at::DeviceIndex>(inputs[i].get_device()));
+      guard.set_index(inputs[i].device().index());
       events[i].block(at::cuda::getCurrentCUDAStream());
     }
   }
@@ -669,7 +706,7 @@ class AsyncReduceCUDAWork : public AsyncReduceWork {
     // Synchronize with copy operations.
     at::cuda::OptionalCUDAGuard device_guard;
     for (size_t i = 0; i < inputs.size(); i++) {
-      device_guard.set_index(inputs[i].get_device());
+      device_guard.set_index(inputs[i].device().index());
       AT_CUDA_CHECK(cudaStreamSynchronize(streams[i]));
     }
 
@@ -689,7 +726,7 @@ class AsyncReduceCUDAWork : public AsyncReduceWork {
     // Synchronize with the copy back to CUDA tensors.
     at::cuda::OptionalCUDAGuard guard;
     for (size_t i = 0; i < inputs.size(); i++) {
-      guard.set_index(static_cast<at::DeviceIndex>(inputs[i].get_device()));
+      guard.set_index(inputs[i].device().index());
       events[i].block(at::cuda::getCurrentCUDAStream());
     }
   }
@@ -769,7 +806,9 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork {
   std::vector<at::Tensor> inputs;
   const uint32_t tag;
 
-  void run() override {
+  void allgather(
+      std::vector<std::vector<at::Tensor>>& outputs,
+      std::vector<at::Tensor>& inputs) {
     const auto& scalarType = inputs[0].scalar_type();
     gloo::AllgatherOptions opts(context);
     opts.setTag(tag);
@@ -792,10 +831,95 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork {
       }
     }
   }
+
+  void run() override {
+    allgather(outputs, inputs);
+  }
+};
+
+#ifdef USE_CUDA
+
+// Note: current CUDA implementation holds the assumption that the
+// tensors in the nested output tensor vectors are on the same device.
+class AsyncAllgatherCUDAWork : public AsyncAllgatherWork {
+ public:
+  AsyncAllgatherCUDAWork(
+      const std::shared_ptr<gloo::Context>& context,
+      std::vector<std::vector<at::Tensor>>& outputs,
+      std::vector<at::Tensor>& inputs,
+      uint32_t tag)
+      : AsyncAllgatherWork(context, outputs, inputs, tag) {
+    initializeStreamsEvents(inputs, inputStreams, inputEvents);
+    initializeStreamsEvents(outputs, outputStreams, outputEvents);
+
+    // Kick off copy from CUDA tensors to pinned CPU tensors.
+    tmpInputs.reserve(inputs.size());
+    at::cuda::OptionalCUDAStreamGuard guard;
+    for (size_t i = 0; i < inputs.size(); i++) {
+      guard.reset_stream(inputStreams[i]);
+      tmpInputs.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true));
+    }
+
+    tmpOutputs.resize(outputs.size());
+    for (size_t i = 0; i < outputs.size(); i++) {
+      tmpOutputs[i].reserve(outputs[i].size());
+      for (size_t j = 0; j < outputs[i].size(); j++) {
+        tmpOutputs[i].push_back(pinnedLike(outputs[i][j]));
+      }
+    }
+  }
+
+  void run() override {
+    // Synchronize with copy operations.
+    at::cuda::OptionalCUDAGuard device_guard;
+    for (size_t i = 0; i < inputs.size(); i++) {
+      device_guard.set_index(inputs[i].device().index());
+      AT_CUDA_CHECK(cudaStreamSynchronize(inputStreams[i]));
+    }
+
+    for (size_t i = 0; i < outputs.size(); i++) {
+      device_guard.set_index(outputs[i][0].device().index());
+      AT_CUDA_CHECK(cudaStreamSynchronize(outputStreams[i]));
+    }
+
+    // Run allgather on host side tensors.
+    allgather(tmpOutputs, tmpInputs);
+
+    // Kick off copy back to the CUDA tensors.
+    at::cuda::OptionalCUDAStreamGuard stream_guard;
+    for (size_t i = 0; i < outputs.size(); i++) {
+      stream_guard.reset_stream(outputStreams[i]);
+      for (size_t j = 0; j < outputs[i].size(); j++) {
+        outputs[i][j].copy_(tmpOutputs[i][j], /* non_blocking */ true);
+      }
+      outputEvents[i].record(outputStreams[i]);
+    }
+  }
+
+  void synchronize() override {
+    // Synchronize with the copy back to CUDA tensors.
+    at::cuda::OptionalCUDAGuard guard;
+    for (size_t i = 0; i < outputs.size(); i++) {
+      guard.set_index(outputs[i][0].device().index());
+      outputEvents[i].block(at::cuda::getCurrentCUDAStream());
+    }
+  }
+
+  std::vector<at::Tensor> tmpInputs;
+  std::vector<at::cuda::CUDAStream> inputStreams;
+  std::vector<at::cuda::CUDAEvent> inputEvents;
+
+  std::vector<std::vector<at::Tensor>> tmpOutputs;
+  std::vector<at::cuda::CUDAStream> outputStreams;
+  std::vector<at::cuda::CUDAEvent> outputEvents;
 };
 
+#endif
+
 } // namespace
 
+// Note: current CUDA implementation holds the assumption that the
+// tensors in the nested output tensor vectors are on the same device.
 std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::allgather(
     std::vector<std::vector<at::Tensor>>& outputs,
     std::vector<at::Tensor>& inputs,
@@ -825,7 +949,6 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::allgather(
   }
 
   assertDense(invalidArgument, inputs);
-  assertCPU(invalidArgument, inputs);
 
   // Expect all input/output tensors to have the same type and sizes
   const auto& type = inputs[0].type();
@@ -835,8 +958,30 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::allgather(
     assertTypeAndSizesMatch(invalidArgument, outputs[i], type, sizes);
   }
 
-  auto work = std::make_shared<AsyncAllgatherWork>(
-      contexts_[0], outputs, inputs, nextTag());
+  const auto& device = inputs[0].device();
+  switch (device.type()) {
+    case at::kCPU:
+#ifdef USE_CUDA
+    case at::kCUDA:
+#endif
+      break;
+    default:
+      invalidArgument("unsupported device type");
+  }
+
+  std::shared_ptr<AsyncAllgatherWork> work;
+  auto& context = contexts_[0];
+  if (device.type() == at::kCPU) {
+    work = std::make_shared<AsyncAllgatherWork>(
+        context, outputs, inputs, nextTag());
+#ifdef USE_CUDA
+  } else if (device.type() == at::kCUDA) {
+    work = std::make_shared<AsyncAllgatherCUDAWork>(
+        context, outputs, inputs, nextTag());
+#endif
+  } else {
+    throw std::runtime_error("Invalid backend");
+  }
   enqueue(work);
   return work;
 }