From 774705ba0509878cb2e8141b52164d090bfef5d7 Mon Sep 17 00:00:00 2001 From: SsnL Date: Mon, 14 Jan 2019 15:59:29 -0800 Subject: [PATCH] Fix c10d checking errno unconditionally (#15986) Summary: In #15964, I learned that `errno` is only meaningful if the function call fails. E.g., on some macos, a successful `fork()` sets `errno` to `EINVAL` in child process. This commit changes the `SYSCALL` macro so error checking is only done when an error happens. This means checking whether `rv == -1` for most calls, but is checking `rv == nullptr` for `inet_ntop`. Now `SYSCALL` accepts a second argument `success_cond`, which should be an expression returning whether the call succeeded. `SYSCHECK_ERR_RETURN_NEG1` is the shorthand for checking if rv is `-1`. Any suggestion on better macro names is welcomed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15986 Reviewed By: janewangfb Differential Revision: D13661790 Pulled By: pietern fbshipit-source-id: 9551b14b9f88805454a7bfb8e4d39e0f3aed8131 --- torch/lib/c10d/TCPStore.cpp | 4 ++-- torch/lib/c10d/Utils.cpp | 28 ++++++++++++++-------------- torch/lib/c10d/Utils.hpp | 38 +++++++++++++++++++++++++++----------- 3 files changed, 43 insertions(+), 27 deletions(-) diff --git a/torch/lib/c10d/TCPStore.cpp b/torch/lib/c10d/TCPStore.cpp index d1ecf06..523ef3f 100644 --- a/torch/lib/c10d/TCPStore.cpp +++ b/torch/lib/c10d/TCPStore.cpp @@ -67,7 +67,7 @@ void TCPStoreDaemon::run() { fds[i].revents = 0; } - SYSCHECK(::poll(fds.data(), fds.size(), -1)); + SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1)); // TCPStore's listening socket has an event and it should now be able to // accept new connections. @@ -351,7 +351,7 @@ void TCPStore::wait( if (timeout != kNoTimeout) { struct timeval timeoutTV = {.tv_sec = timeout.count() / 1000, .tv_usec = (timeout.count() % 1000) * 1000}; - SYSCHECK(::setsockopt( + SYSCHECK_ERR_RETURN_NEG1(::setsockopt( storeSocket_, SOL_SOCKET, SO_RCVTIMEO, diff --git a/torch/lib/c10d/Utils.cpp b/torch/lib/c10d/Utils.cpp index 48b1f74..4cecc59 100644 --- a/torch/lib/c10d/Utils.cpp +++ b/torch/lib/c10d/Utils.cpp @@ -26,14 +26,14 @@ constexpr int LISTEN_QUEUE_SIZE = 64; void setSocketNoDelay(int socket) { int flag = 1; socklen_t optlen = sizeof(flag); - SYSCHECK(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, optlen)); + SYSCHECK_ERR_RETURN_NEG1(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, optlen)); } PortType getSocketPort(int fd) { PortType listenPort; struct ::sockaddr_storage addrStorage; socklen_t addrLen = sizeof(addrStorage); - SYSCHECK(getsockname( + SYSCHECK_ERR_RETURN_NEG1(getsockname( fd, reinterpret_cast(&addrStorage), &addrLen)); if (addrStorage.ss_family == AF_INET) { @@ -58,11 +58,11 @@ std::string sockaddrToString(struct ::sockaddr* addr) { char address[INET6_ADDRSTRLEN + 1]; if (addr->sa_family == AF_INET) { struct ::sockaddr_in* s = reinterpret_cast(addr); - SYSCHECK(::inet_ntop(AF_INET, &(s->sin_addr), address, INET_ADDRSTRLEN)) + SYSCHECK(::inet_ntop(AF_INET, &(s->sin_addr), address, INET_ADDRSTRLEN), __output != nullptr) address[INET_ADDRSTRLEN] = '\0'; } else if (addr->sa_family == AF_INET6) { struct ::sockaddr_in6* s = reinterpret_cast(addr); - SYSCHECK(::inet_ntop(AF_INET6, &(s->sin6_addr), address, INET6_ADDRSTRLEN)) + SYSCHECK(::inet_ntop(AF_INET6, &(s->sin6_addr), address, INET6_ADDRSTRLEN), __output != nullptr) address[INET6_ADDRSTRLEN] = '\0'; } else { throw std::runtime_error("unsupported protocol"); @@ -94,18 +94,18 @@ std::pair listen(PortType port) { int socket; while (true) { try { - SYSCHECK( + SYSCHECK_ERR_RETURN_NEG1( socket = ::socket( nextAddr->ai_family, nextAddr->ai_socktype, nextAddr->ai_protocol)) int optval = 1; - SYSCHECK( + SYSCHECK_ERR_RETURN_NEG1( ::setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(int))) - SYSCHECK(::bind(socket, nextAddr->ai_addr, nextAddr->ai_addrlen)) - SYSCHECK(::listen(socket, LISTEN_QUEUE_SIZE)) + SYSCHECK_ERR_RETURN_NEG1(::bind(socket, nextAddr->ai_addr, nextAddr->ai_addrlen)) + SYSCHECK_ERR_RETURN_NEG1(::listen(socket, LISTEN_QUEUE_SIZE)) break; } catch (const std::system_error& e) { @@ -155,7 +155,7 @@ int connect( bool anyRefused = false; while (true) { try { - SYSCHECK( + SYSCHECK_ERR_RETURN_NEG1( socket = ::socket( nextAddr->ai_family, nextAddr->ai_socktype, @@ -164,7 +164,7 @@ int connect( ResourceGuard socketGuard([socket]() { ::close(socket); }); // We need to connect in non-blocking mode, so we can use a timeout - SYSCHECK(::fcntl(socket, F_SETFL, O_NONBLOCK)); + SYSCHECK_ERR_RETURN_NEG1(::fcntl(socket, F_SETFL, O_NONBLOCK)); int ret = ::connect(socket, nextAddr->ai_addr, nextAddr->ai_addrlen); @@ -198,8 +198,8 @@ int connect( // Disable non-blocking mode int flags; - SYSCHECK(flags = ::fcntl(socket, F_GETFL)); - SYSCHECK(::fcntl(socket, F_SETFL, flags & (~O_NONBLOCK))); + SYSCHECK_ERR_RETURN_NEG1(flags = ::fcntl(socket, F_GETFL)); + SYSCHECK_ERR_RETURN_NEG1(::fcntl(socket, F_SETFL, flags & (~O_NONBLOCK))); socketGuard.release(); break; @@ -255,12 +255,12 @@ std::tuple accept( } int socket; - SYSCHECK(socket = ::accept(listenSocket, NULL, NULL)) + SYSCHECK_ERR_RETURN_NEG1(socket = ::accept(listenSocket, NULL, NULL)) // Get address of the connecting process struct ::sockaddr_storage addr; socklen_t addrLen = sizeof(addr); - SYSCHECK(::getpeername( + SYSCHECK_ERR_RETURN_NEG1(::getpeername( socket, reinterpret_cast(&addr), &addrLen)) setSocketNoDelay(socket); diff --git a/torch/lib/c10d/Utils.hpp b/torch/lib/c10d/Utils.hpp index 1756ebf..59c1bfa 100644 --- a/torch/lib/c10d/Utils.hpp +++ b/torch/lib/c10d/Utils.hpp @@ -289,16 +289,32 @@ using RankType = uint32_t; using PortType = uint16_t; using SizeType = uint64_t; -#define SYSCHECK(expr) \ - { \ - do { \ - errno = 0; \ - auto ___output = (expr); \ - (void)___output; \ - } while (errno == EINTR); \ - if (errno != 0) \ +// `errno` is only meaningful when it fails. E.g., a successful `fork()` sets +// `errno` to `EINVAL` in child process on some macos +// (https://stackoverflow.com/a/20295079), and thus `errno` should really only +// be inspected if an error occured. +// +// `success_cond` is an expression used to check if an error has happend. So for +// `fork()`, we can use `SYSCHECK(pid = fork(), pid != -1)`. The function output +// is stored in variable `__output` and may be used in `success_cond`. +#define SYSCHECK(expr, success_cond) \ +while (true) { \ + auto __output = (expr); \ + (void) __output; \ + if (!(success_cond)) { \ + if (errno == EINTR) { \ + continue; \ + } else { \ throw std::system_error(errno, std::system_category()); \ - } + } \ + } else { \ + break; \ + } \ +} + +// Most functions indicate error by returning `-1`. This is a helper macro for +// this common case with `SYSCHECK`. +#define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1) // Helper resource guard class class ResourceGuard { @@ -350,7 +366,7 @@ void sendBytes( while (bytesToSend > 0) { ssize_t bytesSent; - SYSCHECK(bytesSent = ::send(socket, currentBytes, bytesToSend, flags)) + SYSCHECK_ERR_RETURN_NEG1(bytesSent = ::send(socket, currentBytes, bytesToSend, flags)) if (bytesSent == 0) { throw std::system_error(ECONNRESET, std::system_category()); } @@ -372,7 +388,7 @@ void recvBytes(int socket, T* buffer, size_t length) { while (bytesToReceive > 0) { ssize_t bytesReceived; - SYSCHECK(bytesReceived = ::recv(socket, currentBytes, bytesToReceive, 0)) + SYSCHECK_ERR_RETURN_NEG1(bytesReceived = ::recv(socket, currentBytes, bytesToReceive, 0)) if (bytesReceived == 0) { throw std::system_error(ECONNRESET, std::system_category()); } -- 2.7.4