Fix c10d checking errno unconditionally (#15986)
authorSsnL <tongzhou.wang.1994@gmail.com>
Mon, 14 Jan 2019 23:59:29 +0000 (15:59 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 15 Jan 2019 00:02:05 +0000 (16:02 -0800)
Summary:
In #15964, I learned that `errno` is only meaningful if the function call fails. E.g., on some macos, a successful `fork()` sets `errno` to `EINVAL` in child process. This commit changes the `SYSCALL` macro so error checking is only done when an error happens. This means checking whether `rv == -1` for most calls, but is checking `rv == nullptr` for `inet_ntop`.

Now `SYSCALL` accepts a second argument `success_cond`, which should be an expression returning whether the call succeeded. `SYSCHECK_ERR_RETURN_NEG1` is the shorthand for checking if rv is `-1`.

Any suggestion on better macro names is welcomed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15986

Reviewed By: janewangfb

Differential Revision: D13661790

Pulled By: pietern

fbshipit-source-id: 9551b14b9f88805454a7bfb8e4d39e0f3aed8131

torch/lib/c10d/TCPStore.cpp
torch/lib/c10d/Utils.cpp
torch/lib/c10d/Utils.hpp

index d1ecf06..523ef3f 100644 (file)
@@ -67,7 +67,7 @@ void TCPStoreDaemon::run() {
       fds[i].revents = 0;
     }
 
-    SYSCHECK(::poll(fds.data(), fds.size(), -1));
+    SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));
 
     // TCPStore's listening socket has an event and it should now be able to
     // accept new connections.
@@ -351,7 +351,7 @@ void TCPStore::wait(
   if (timeout != kNoTimeout) {
     struct timeval timeoutTV = {.tv_sec = timeout.count() / 1000,
                                 .tv_usec = (timeout.count() % 1000) * 1000};
-    SYSCHECK(::setsockopt(
+    SYSCHECK_ERR_RETURN_NEG1(::setsockopt(
         storeSocket_,
         SOL_SOCKET,
         SO_RCVTIMEO,
index 48b1f74..4cecc59 100644 (file)
@@ -26,14 +26,14 @@ constexpr int LISTEN_QUEUE_SIZE = 64;
 void setSocketNoDelay(int socket) {
   int flag = 1;
   socklen_t optlen = sizeof(flag);
-  SYSCHECK(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, optlen));
+  SYSCHECK_ERR_RETURN_NEG1(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, optlen));
 }
 
 PortType getSocketPort(int fd) {
   PortType listenPort;
   struct ::sockaddr_storage addrStorage;
   socklen_t addrLen = sizeof(addrStorage);
-  SYSCHECK(getsockname(
+  SYSCHECK_ERR_RETURN_NEG1(getsockname(
       fd, reinterpret_cast<struct ::sockaddr*>(&addrStorage), &addrLen));
 
   if (addrStorage.ss_family == AF_INET) {
@@ -58,11 +58,11 @@ std::string sockaddrToString(struct ::sockaddr* addr) {
   char address[INET6_ADDRSTRLEN + 1];
   if (addr->sa_family == AF_INET) {
     struct ::sockaddr_in* s = reinterpret_cast<struct ::sockaddr_in*>(addr);
-    SYSCHECK(::inet_ntop(AF_INET, &(s->sin_addr), address, INET_ADDRSTRLEN))
+    SYSCHECK(::inet_ntop(AF_INET, &(s->sin_addr), address, INET_ADDRSTRLEN), __output != nullptr)
     address[INET_ADDRSTRLEN] = '\0';
   } else if (addr->sa_family == AF_INET6) {
     struct ::sockaddr_in6* s = reinterpret_cast<struct ::sockaddr_in6*>(addr);
-    SYSCHECK(::inet_ntop(AF_INET6, &(s->sin6_addr), address, INET6_ADDRSTRLEN))
+    SYSCHECK(::inet_ntop(AF_INET6, &(s->sin6_addr), address, INET6_ADDRSTRLEN), __output != nullptr)
     address[INET6_ADDRSTRLEN] = '\0';
   } else {
     throw std::runtime_error("unsupported protocol");
@@ -94,18 +94,18 @@ std::pair<int, PortType> listen(PortType port) {
   int socket;
   while (true) {
     try {
-      SYSCHECK(
+      SYSCHECK_ERR_RETURN_NEG1(
           socket = ::socket(
               nextAddr->ai_family,
               nextAddr->ai_socktype,
               nextAddr->ai_protocol))
 
       int optval = 1;
-      SYSCHECK(
+      SYSCHECK_ERR_RETURN_NEG1(
           ::setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(int)))
 
-      SYSCHECK(::bind(socket, nextAddr->ai_addr, nextAddr->ai_addrlen))
-      SYSCHECK(::listen(socket, LISTEN_QUEUE_SIZE))
+      SYSCHECK_ERR_RETURN_NEG1(::bind(socket, nextAddr->ai_addr, nextAddr->ai_addrlen))
+      SYSCHECK_ERR_RETURN_NEG1(::listen(socket, LISTEN_QUEUE_SIZE))
       break;
 
     } catch (const std::system_error& e) {
@@ -155,7 +155,7 @@ int connect(
   bool anyRefused = false;
   while (true) {
     try {
-      SYSCHECK(
+      SYSCHECK_ERR_RETURN_NEG1(
           socket = ::socket(
               nextAddr->ai_family,
               nextAddr->ai_socktype,
@@ -164,7 +164,7 @@ int connect(
       ResourceGuard socketGuard([socket]() { ::close(socket); });
 
       // We need to connect in non-blocking mode, so we can use a timeout
-      SYSCHECK(::fcntl(socket, F_SETFL, O_NONBLOCK));
+      SYSCHECK_ERR_RETURN_NEG1(::fcntl(socket, F_SETFL, O_NONBLOCK));
 
       int ret = ::connect(socket, nextAddr->ai_addr, nextAddr->ai_addrlen);
 
@@ -198,8 +198,8 @@ int connect(
 
       // Disable non-blocking mode
       int flags;
-      SYSCHECK(flags = ::fcntl(socket, F_GETFL));
-      SYSCHECK(::fcntl(socket, F_SETFL, flags & (~O_NONBLOCK)));
+      SYSCHECK_ERR_RETURN_NEG1(flags = ::fcntl(socket, F_GETFL));
+      SYSCHECK_ERR_RETURN_NEG1(::fcntl(socket, F_SETFL, flags & (~O_NONBLOCK)));
       socketGuard.release();
       break;
 
@@ -255,12 +255,12 @@ std::tuple<int, std::string> accept(
   }
 
   int socket;
-  SYSCHECK(socket = ::accept(listenSocket, NULL, NULL))
+  SYSCHECK_ERR_RETURN_NEG1(socket = ::accept(listenSocket, NULL, NULL))
 
   // Get address of the connecting process
   struct ::sockaddr_storage addr;
   socklen_t addrLen = sizeof(addr);
-  SYSCHECK(::getpeername(
+  SYSCHECK_ERR_RETURN_NEG1(::getpeername(
       socket, reinterpret_cast<struct ::sockaddr*>(&addr), &addrLen))
 
   setSocketNoDelay(socket);
index 1756ebf..59c1bfa 100644 (file)
@@ -289,16 +289,32 @@ using RankType = uint32_t;
 using PortType = uint16_t;
 using SizeType = uint64_t;
 
-#define SYSCHECK(expr)                                        \
-  {                                                           \
-    do {                                                      \
-      errno = 0;                                              \
-      auto ___output = (expr);                                \
-      (void)___output;                                        \
-    } while (errno == EINTR);                                 \
-    if (errno != 0)                                           \
+// `errno` is only meaningful when it fails. E.g., a  successful `fork()` sets
+// `errno` to `EINVAL` in child process on some macos
+// (https://stackoverflow.com/a/20295079), and thus `errno` should really only
+// be inspected if an error occured.
+//
+// `success_cond` is an expression used to check if an error has happend. So for
+// `fork()`, we can use `SYSCHECK(pid = fork(), pid != -1)`. The function output
+// is stored in variable `__output` and may be used in `success_cond`.
+#define SYSCHECK(expr, success_cond)                          \
+while (true) {                                                \
+  auto __output = (expr);                                     \
+  (void) __output;                                            \
+  if (!(success_cond)) {                                      \
+    if (errno == EINTR) {                                     \
+      continue;                                               \
+    } else {                                                  \
       throw std::system_error(errno, std::system_category()); \
-  }
+    }                                                         \
+  } else {                                                    \
+    break;                                                    \
+  }                                                           \
+}
+
+// Most functions indicate error by returning `-1`. This is a helper macro for
+// this common case with `SYSCHECK`.
+#define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1)
 
 // Helper resource guard class
 class ResourceGuard {
@@ -350,7 +366,7 @@ void sendBytes(
 
   while (bytesToSend > 0) {
     ssize_t bytesSent;
-    SYSCHECK(bytesSent = ::send(socket, currentBytes, bytesToSend, flags))
+    SYSCHECK_ERR_RETURN_NEG1(bytesSent = ::send(socket, currentBytes, bytesToSend, flags))
     if (bytesSent == 0) {
       throw std::system_error(ECONNRESET, std::system_category());
     }
@@ -372,7 +388,7 @@ void recvBytes(int socket, T* buffer, size_t length) {
 
   while (bytesToReceive > 0) {
     ssize_t bytesReceived;
-    SYSCHECK(bytesReceived = ::recv(socket, currentBytes, bytesToReceive, 0))
+    SYSCHECK_ERR_RETURN_NEG1(bytesReceived = ::recv(socket, currentBytes, bytesToReceive, 0))
     if (bytesReceived == 0) {
       throw std::system_error(ECONNRESET, std::system_category());
     }