void start_manager() {
int pipe_ends[2];
- SYSCHECK(pipe(pipe_ends));
+ SYSCHECK_ERR_RETURN_NEG1(pipe(pipe_ends));
pid_t pid;
- SYSCHECK(pid = fork());
+ SYSCHECK_ERR_RETURN_NEG1(pid = fork());
if (!pid) {
- close(pipe_ends[0]);
- dup2(pipe_ends[1], 1); // Replace stdout
- close(pipe_ends[1]);
+ SYSCHECK_ERR_RETURN_NEG1(close(pipe_ends[0]));
+ SYSCHECK_ERR_RETURN_NEG1(dup2(pipe_ends[1], 1)); // Replace stdout
+ SYSCHECK_ERR_RETURN_NEG1(close(pipe_ends[1]));
execl(manager_executable_path.c_str(), "torch_shm_manager", NULL);
exit(1);
}
- SYSCHECK(close(pipe_ends[1]));
+ SYSCHECK_ERR_RETURN_NEG1(close(pipe_ends[1]));
ssize_t bytes_read;
char buffer[1000];
std::string handle;
for (;;) {
- SYSCHECK(bytes_read = read(pipe_ends[0], buffer, sizeof(buffer)));
+ SYSCHECK_ERR_RETURN_NEG1(bytes_read = read(pipe_ends[0], buffer, sizeof(buffer)));
handle.append(buffer, bytes_read);
if (bytes_read == 0 || handle[handle.length() - 1] == '\n') {
break;
}
}
- SYSCHECK(close(pipe_ends[0]));
+ SYSCHECK_ERR_RETURN_NEG1(close(pipe_ends[0]));
if (handle.length() == 0) {
std::string msg("error executing torch_shm_manager at \"");
msg += manager_executable_path;
#pragma once
#include <system_error>
+#include <cerrno>
-#define SYSCHECK(call) { auto __ret = (call); if (__ret < 0) { throw std::system_error(errno, std::system_category()); } }
+// `errno` is only meaningful when it fails. E.g., a successful `fork()` sets
+// `errno` to `EINVAL` in child process on some macos
+// (https://stackoverflow.com/a/20295079), and thus `errno` should really only
+// be inspected if an error occured.
+//
+// All functions used in `libshm` (so far) indicate error by returning `-1`. If
+// you want to use a function with a different error reporting mechanism, you
+// need to port `SYSCHECK` from `torch/lib/c10d/Utils.hpp`.
+#define SYSCHECK_ERR_RETURN_NEG1(expr) \
+while (true) { \
+ if ((expr) == -1) { \
+ if (errno == EINTR) { \
+ continue; \
+ } else { \
+ throw std::system_error(errno, std::system_category()); \
+ } \
+ } else { \
+ break; \
+ } \
+}
protected:
Socket() {
- SYSCHECK(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0));
+ SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0));
}
Socket(const Socket& other) = delete;
Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) { other.socket_fd = -1; };
pfd.fd = socket_fd;
pfd.events = POLLIN;
while (bytes_received < num_bytes) {
- SYSCHECK(poll(&pfd, 1, 1000));
+ SYSCHECK_ERR_RETURN_NEG1(poll(&pfd, 1, 1000));
if (pfd.revents & POLLIN) {
- SYSCHECK(step_received = ::read(socket_fd, buffer, num_bytes - bytes_received));
+ SYSCHECK_ERR_RETURN_NEG1(step_received = ::read(socket_fd, buffer, num_bytes - bytes_received));
if (step_received == 0)
throw std::runtime_error("Other end has closed the connection");
bytes_received += step_received;
size_t bytes_sent = 0;
ssize_t step_sent;
while (bytes_sent < num_bytes) {
- SYSCHECK(step_sent = ::write(socket_fd, buffer, num_bytes));
+ SYSCHECK_ERR_RETURN_NEG1(step_sent = ::write(socket_fd, buffer, num_bytes));
bytes_sent += step_sent;
buffer += step_sent;
}
try {
struct sockaddr_un address = prepare_address(path.c_str());
size_t len = address_length(address);
- SYSCHECK(bind(socket_fd, (struct sockaddr *)&address, len));
- SYSCHECK(listen(socket_fd, 10));
+ SYSCHECK_ERR_RETURN_NEG1(bind(socket_fd, (struct sockaddr *)&address, len));
+ SYSCHECK_ERR_RETURN_NEG1(listen(socket_fd, 10));
} catch(std::exception &e) {
- close(socket_fd);
+ SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
throw;
}
}
int client_fd;
struct sockaddr_un addr;
socklen_t addr_len = sizeof(addr);
- SYSCHECK(client_fd = ::accept(socket_fd, (struct sockaddr *)&addr, &addr_len));
+ SYSCHECK_ERR_RETURN_NEG1(client_fd = ::accept(socket_fd, (struct sockaddr *)&addr, &addr_len));
return ManagerSocket(client_fd);
}
try {
struct sockaddr_un address = prepare_address(path.c_str());
size_t len = address_length(address);
- SYSCHECK(connect(socket_fd, (struct sockaddr *)&address, len));
+ SYSCHECK_ERR_RETURN_NEG1(connect(socket_fd, (struct sockaddr *)&address, len));
} catch(std::exception &e) {
- close(socket_fd);
+ SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
throw;
}
}