Propagate ProcessGroup timeout to Store (#16571)
authorShen Li <shenli@fb.com>
Tue, 9 Apr 2019 19:06:04 +0000 (12:06 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 9 Apr 2019 19:36:28 +0000 (12:36 -0700)
Summary:
closes #16520

Hi pietern, I am not sure if this is the expected way to pass timeout to `Store`, could you please help take a look? Thanks!

Questions:
1. How do I write tests for this? I wanted to do something like `test_barrier_timeout_global`, but it seems I need to set the pg's timeout larger than the `Store`'s default timeout (3 min) to see a difference, which is too long for a unit test. And I do not want to change the `Store`'s default timeout either. Any suggestion?
2. Should I also propagate timeout configuration down to `PrefixStore` in `_new_process_group_helper`?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16571

Differential Revision: D13954527

Pulled By: mrshenli

fbshipit-source-id: 77f2653903f24255207233eb298f7c0321119a87

test/test_c10d.py
torch/distributed/distributed_c10d.py
torch/lib/c10d/Utils.hpp

index f18cf51..6b88fa9 100644 (file)
@@ -5,6 +5,7 @@ import os
 import random
 import sys
 import tempfile
+import threading
 import time
 import unittest
 from datetime import timedelta
@@ -501,6 +502,52 @@ class MultiProcessTestCase(TestCase):
         return self.rank == 0
 
 
+class TimeoutTest(TestCase):
+    def _test_store_timeout(self, backend, init_method, c2p):
+        c10d.distributed_c10d.init_process_group(
+            backend=backend, init_method=init_method, world_size=1, rank=0,
+            timeout=timedelta(seconds=1))
+        default_store = c10d.distributed_c10d._get_default_store()
+        tik = time.time()
+        with self.assertRaisesRegex(RuntimeError, "Timeout"):
+            default_store.get("nonexistent key")
+        tok = time.time()
+        c10d.destroy_process_group()
+        c2p.append(tok - tik)
+
+    def _init_methods(self):
+        f = tempfile.NamedTemporaryFile(delete=False)
+        yield "file://%s" % f.name
+        f.close()
+        yield "tcp://127.0.0.1:%d" % common.find_free_port()
+
+    def _test_default_store_timeout(self, backend):
+        for init_method in self._init_methods():
+            c2p = []
+            t = threading.Thread(
+                target=self._test_store_timeout,
+                args=(backend, init_method, c2p))
+            t.daemon = True
+            t.start()
+            t.join(5)
+
+            self.assertEqual(1, len(c2p))
+            # waiting time should be 1s, use 3s to rule out false alarm
+            self.assertGreater(3, c2p[0])
+
+    @retry_on_address_already_in_use_error
+    def test_default_store_timeout_nccl(self):
+        # TODO remove this hack
+        if not hasattr(c10d, "ProcessGroupNCCL"):
+            raise unittest.SkipTest("C10D is not built with NCCL process group,"
+                                    " skipping test")
+        self._test_default_store_timeout('nccl')
+
+    @retry_on_address_already_in_use_error
+    def test_default_store_timeout_gloo(self):
+        self._test_default_store_timeout('gloo')
+
+
 class ProcessGroupGlooTest(MultiProcessTestCase):
     def opts(self, threads=2):
         opts = c10d.ProcessGroupGloo.Options()
index ddd3554..6ff0e66 100644 (file)
@@ -276,6 +276,18 @@ def _get_default_group():
     return _default_pg
 
 
+def _get_default_store():
+    """
+    Getting the default store created by init_process_group
+
+    """
+    if not is_initialized():
+        raise RuntimeError("Default process group has not been initialized, "
+                           "please make sure to call init_process_group.")
+    _, default_store = _pg_map[_default_pg]
+    return default_store
+
+
 def get_backend(group=group.WORLD):
     """
     Returns the backend of the given process group.
@@ -378,6 +390,8 @@ def init_process_group(backend,
 
         if store is None:
             store, rank, world_size = next(rendezvous(url))
+            store.set_timeout(timeout)
+
         if backend == Backend.GLOO:
             _default_pg = ProcessGroupGloo(
                 store,
index b9b3a73..0947028 100644 (file)
@@ -304,6 +304,8 @@ while (true) {                                                \
   if (!(success_cond)) {                                      \
     if (errno == EINTR) {                                     \
       continue;                                               \
+    } else if (errno == EAGAIN || errno == EWOULDBLOCK) {     \
+      throw std::runtime_error("Socket Timeout");             \
     } else {                                                  \
       throw std::system_error(errno, std::system_category()); \
     }                                                         \