libshm retry on EINTR (#15964)
authorSsnL <tongzhou.wang.1994@gmail.com>
Mon, 14 Jan 2019 12:24:50 +0000 (04:24 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 14 Jan 2019 12:30:40 +0000 (04:30 -0800)
Summary:
fixes https://github.com/pytorch/pytorch/issues/14314
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15964

Differential Revision: D13639034

Pulled By: soumith

fbshipit-source-id: 44592762aa46982e5d3616d55b5666a2c2ce9105

torch/lib/libshm/core.cpp
torch/lib/libshm/err.h
torch/lib/libshm/manager.cpp
torch/lib/libshm/socket.h

index b933cc3..6a33ff2 100644 (file)
@@ -24,30 +24,30 @@ AllocInfo get_alloc_info(const char* filename) {
 
 void start_manager() {
   int pipe_ends[2];
-  SYSCHECK(pipe(pipe_ends));
+  SYSCHECK_ERR_RETURN_NEG1(pipe(pipe_ends));
 
   pid_t pid;
-  SYSCHECK(pid = fork());
+  SYSCHECK_ERR_RETURN_NEG1(pid = fork());
   if (!pid) {
-    close(pipe_ends[0]);
-    dup2(pipe_ends[1], 1); // Replace stdout
-    close(pipe_ends[1]);
+    SYSCHECK_ERR_RETURN_NEG1(close(pipe_ends[0]));
+    SYSCHECK_ERR_RETURN_NEG1(dup2(pipe_ends[1], 1)); // Replace stdout
+    SYSCHECK_ERR_RETURN_NEG1(close(pipe_ends[1]));
     execl(manager_executable_path.c_str(), "torch_shm_manager", NULL);
     exit(1);
   }
-  SYSCHECK(close(pipe_ends[1]));
+  SYSCHECK_ERR_RETURN_NEG1(close(pipe_ends[1]));
 
   ssize_t bytes_read;
   char buffer[1000];
   std::string handle;
   for (;;) {
-    SYSCHECK(bytes_read = read(pipe_ends[0], buffer, sizeof(buffer)));
+    SYSCHECK_ERR_RETURN_NEG1(bytes_read = read(pipe_ends[0], buffer, sizeof(buffer)));
     handle.append(buffer, bytes_read);
     if (bytes_read == 0 || handle[handle.length() - 1] == '\n') {
       break;
     }
   }
-  SYSCHECK(close(pipe_ends[0]));
+  SYSCHECK_ERR_RETURN_NEG1(close(pipe_ends[0]));
   if (handle.length() == 0) {
     std::string msg("error executing torch_shm_manager at \"");
     msg += manager_executable_path;
index e2244b5..8f4a28f 100644 (file)
@@ -1,5 +1,25 @@
 #pragma once
 
 #include <system_error>
+#include <cerrno>
 
-#define SYSCHECK(call) { auto __ret = (call); if (__ret < 0) { throw std::system_error(errno, std::system_category()); } }
+// `errno` is only meaningful when it fails. E.g., a  successful `fork()` sets
+// `errno` to `EINVAL` in child process on some macos
+// (https://stackoverflow.com/a/20295079), and thus `errno` should really only
+// be inspected if an error occured.
+//
+// All functions used in `libshm` (so far) indicate error by returning `-1`. If
+// you want to use a function with a different error reporting mechanism, you
+// need to port `SYSCHECK` from `torch/lib/c10d/Utils.hpp`.
+#define SYSCHECK_ERR_RETURN_NEG1(expr)                        \
+while (true) {                                                \
+  if ((expr) == -1) {                                         \
+    if (errno == EINTR) {                                     \
+      continue;                                               \
+    } else {                                                  \
+      throw std::system_error(errno, std::system_category()); \
+    }                                                         \
+  } else {                                                    \
+    break;                                                    \
+  }                                                           \
+}
index c210cd9..c315979 100644 (file)
@@ -111,7 +111,7 @@ int main(int argc, char *argv[]) {
     int nevents;
     if (client_sessions.size() == 0)
       timeout = SHUTDOWN_TIMEOUT;
-    SYSCHECK(nevents = poll(pollfds.data(), pollfds.size(), timeout));
+    SYSCHECK_ERR_RETURN_NEG1(nevents = poll(pollfds.data(), pollfds.size(), timeout));
     timeout = -1;
     if (nevents == 0 && client_sessions.size() == 0)
       break;
index e43db56..3d1d285 100644 (file)
@@ -20,7 +20,7 @@ public:
 
 protected:
   Socket() {
-    SYSCHECK(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0));
+    SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0));
   }
   Socket(const Socket& other) = delete;
   Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) { other.socket_fd = -1; };
@@ -50,9 +50,9 @@ protected:
     pfd.fd = socket_fd;
     pfd.events = POLLIN;
     while (bytes_received < num_bytes) {
-      SYSCHECK(poll(&pfd, 1, 1000));
+      SYSCHECK_ERR_RETURN_NEG1(poll(&pfd, 1, 1000));
       if (pfd.revents & POLLIN) {
-        SYSCHECK(step_received = ::read(socket_fd, buffer, num_bytes - bytes_received));
+        SYSCHECK_ERR_RETURN_NEG1(step_received = ::read(socket_fd, buffer, num_bytes - bytes_received));
         if (step_received == 0)
           throw std::runtime_error("Other end has closed the connection");
         bytes_received += step_received;
@@ -70,7 +70,7 @@ protected:
     size_t bytes_sent = 0;
     ssize_t step_sent;
     while (bytes_sent < num_bytes) {
-      SYSCHECK(step_sent = ::write(socket_fd, buffer, num_bytes));
+      SYSCHECK_ERR_RETURN_NEG1(step_sent = ::write(socket_fd, buffer, num_bytes));
       bytes_sent += step_sent;
       buffer += step_sent;
     }
@@ -103,10 +103,10 @@ public:
     try {
       struct sockaddr_un address = prepare_address(path.c_str());
       size_t len = address_length(address);
-      SYSCHECK(bind(socket_fd, (struct sockaddr *)&address, len));
-      SYSCHECK(listen(socket_fd, 10));
+      SYSCHECK_ERR_RETURN_NEG1(bind(socket_fd, (struct sockaddr *)&address, len));
+      SYSCHECK_ERR_RETURN_NEG1(listen(socket_fd, 10));
     } catch(std::exception &e) {
-      close(socket_fd);
+      SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
       throw;
     }
   }
@@ -119,7 +119,7 @@ public:
     int client_fd;
     struct sockaddr_un addr;
     socklen_t addr_len = sizeof(addr);
-    SYSCHECK(client_fd = ::accept(socket_fd, (struct sockaddr *)&addr, &addr_len));
+    SYSCHECK_ERR_RETURN_NEG1(client_fd = ::accept(socket_fd, (struct sockaddr *)&addr, &addr_len));
     return ManagerSocket(client_fd);
   }
 
@@ -132,9 +132,9 @@ public:
     try {
       struct sockaddr_un address = prepare_address(path.c_str());
       size_t len = address_length(address);
-      SYSCHECK(connect(socket_fd, (struct sockaddr *)&address, len));
+      SYSCHECK_ERR_RETURN_NEG1(connect(socket_fd, (struct sockaddr *)&address, len));
     } catch(std::exception &e) {
-      close(socket_fd);
+      SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
       throw;
     }
   }