import contextlib
import os
import socket
+import errno
-_DEFAULT_SOCK_OPTION = socket.SO_REUSEADDR if os.name == 'nt' else socket.SO_REUSEPORT
+_DEFAULT_SOCK_OPTIONS = (socket.SO_REUSEADDR,
+ socket.SO_REUSEPORT) if os.name != 'nt' else (
+ socket.SO_REUSEADDR,)
+_UNRECOVERABLE_ERRNOS = (errno.EADDRINUSE, errno.ENOSR)
def get_socket(bind_address='localhost',
+ port=0,
listen=True,
- sock_options=(_DEFAULT_SOCK_OPTION,)):
- """Opens a socket bound to an arbitrary port.
+ sock_options=_DEFAULT_SOCK_OPTIONS):
+ """Opens a socket.
Useful for reserving a port for a system-under-test.
Args:
bind_address: The host to which to bind.
+ port: The port to which to bind.
listen: A boolean value indicating whether or not to listen on the socket.
sock_options: A sequence of socket options to apply to the socket.
sock = socket.socket(address_family, socket.SOCK_STREAM)
for sock_option in _sock_options:
sock.setsockopt(socket.SOL_SOCKET, sock_option, 1)
- sock.bind((bind_address, 0))
+ sock.bind((bind_address, port))
if listen:
sock.listen(1)
return bind_address, sock.getsockname()[1], sock
- except socket.error:
+ except OSError as os_error:
+ sock.close()
+ if os_error.errno in _UNRECOVERABLE_ERRNOS:
+ raise
+ else:
+ continue
+ # For PY2, socket.error is a child class of IOError; for PY3, it is
+ # pointing to OSError. We need this catch to make it 2/3 agnostic.
+ except socket.error: # pylint: disable=duplicate-except
sock.close()
continue
raise RuntimeError("Failed to bind to {} with sock_options {}".format(
@contextlib.contextmanager
def bound_socket(bind_address='localhost',
+ port=0,
listen=True,
- sock_options=(_DEFAULT_SOCK_OPTION,)):
+ sock_options=_DEFAULT_SOCK_OPTIONS):
"""Opens a socket bound to an arbitrary port.
Useful for reserving a port for a system-under-test.
Args:
bind_address: The host to which to bind.
+ port: The port to which to bind.
listen: A boolean value indicating whether or not to listen on the socket.
sock_options: A sequence of socket options to apply to the socket.
- the address to which the socket is bound
- the port to which the socket is bound
"""
- host, port, sock = get_socket(
- bind_address=bind_address, listen=listen, sock_options=sock_options)
+ host, port, sock = get_socket(bind_address=bind_address,
+ port=port,
+ listen=listen,
+ sock_options=sock_options)
try:
yield host, port
finally: