Imported Upstream version 1.27.0
[platform/upstream/grpc.git] / src / python / grpcio_tests / tests / unit / framework / common / __init__.py
index c1ac762..8b58a0c 100644 (file)
 import contextlib
 import os
 import socket
+import errno
 
-_DEFAULT_SOCK_OPTION = socket.SO_REUSEADDR if os.name == 'nt' else socket.SO_REUSEPORT
+_DEFAULT_SOCK_OPTIONS = (socket.SO_REUSEADDR,
+                         socket.SO_REUSEPORT) if os.name != 'nt' else (
+                             socket.SO_REUSEADDR,)
+_UNRECOVERABLE_ERRNOS = (errno.EADDRINUSE, errno.ENOSR)
 
 
 def get_socket(bind_address='localhost',
+               port=0,
                listen=True,
-               sock_options=(_DEFAULT_SOCK_OPTION,)):
-    """Opens a socket bound to an arbitrary port.
+               sock_options=_DEFAULT_SOCK_OPTIONS):
+    """Opens a socket.
 
     Useful for reserving a port for a system-under-test.
 
     Args:
       bind_address: The host to which to bind.
+      port: The port to which to bind.
       listen: A boolean value indicating whether or not to listen on the socket.
       sock_options: A sequence of socket options to apply to the socket.
 
@@ -47,11 +53,19 @@ def get_socket(bind_address='localhost',
             sock = socket.socket(address_family, socket.SOCK_STREAM)
             for sock_option in _sock_options:
                 sock.setsockopt(socket.SOL_SOCKET, sock_option, 1)
-            sock.bind((bind_address, 0))
+            sock.bind((bind_address, port))
             if listen:
                 sock.listen(1)
             return bind_address, sock.getsockname()[1], sock
-        except socket.error:
+        except OSError as os_error:
+            sock.close()
+            if os_error.errno in _UNRECOVERABLE_ERRNOS:
+                raise
+            else:
+                continue
+        # For PY2, socket.error is a child class of IOError; for PY3, it is
+        # pointing to OSError. We need this catch to make it 2/3 agnostic.
+        except socket.error:  # pylint: disable=duplicate-except
             sock.close()
             continue
     raise RuntimeError("Failed to bind to {} with sock_options {}".format(
@@ -60,14 +74,16 @@ def get_socket(bind_address='localhost',
 
 @contextlib.contextmanager
 def bound_socket(bind_address='localhost',
+                 port=0,
                  listen=True,
-                 sock_options=(_DEFAULT_SOCK_OPTION,)):
+                 sock_options=_DEFAULT_SOCK_OPTIONS):
     """Opens a socket bound to an arbitrary port.
 
     Useful for reserving a port for a system-under-test.
 
     Args:
       bind_address: The host to which to bind.
+      port: The port to which to bind.
       listen: A boolean value indicating whether or not to listen on the socket.
       sock_options: A sequence of socket options to apply to the socket.
 
@@ -76,8 +92,10 @@ def bound_socket(bind_address='localhost',
         - the address to which the socket is bound
         - the port to which the socket is bound
     """
-    host, port, sock = get_socket(
-        bind_address=bind_address, listen=listen, sock_options=sock_options)
+    host, port, sock = get_socket(bind_address=bind_address,
+                                  port=port,
+                                  listen=listen,
+                                  sock_options=sock_options)
     try:
         yield host, port
     finally: