new urllib3
authorKenneth Reitz <me@kennethreitz.com>
Tue, 24 Sep 2013 17:59:38 +0000 (13:59 -0400)
committerKenneth Reitz <me@kennethreitz.com>
Tue, 24 Sep 2013 17:59:38 +0000 (13:59 -0400)
requests/packages/urllib3/__init__.py
requests/packages/urllib3/connectionpool.py
requests/packages/urllib3/contrib/pyopenssl.py
requests/packages/urllib3/exceptions.py
requests/packages/urllib3/fields.py [new file with mode: 0644]
requests/packages/urllib3/filepost.py
requests/packages/urllib3/packages/ssl_match_hostname/__init__.py
requests/packages/urllib3/response.py
requests/packages/urllib3/util.py

index bff80b8..73071f7 100644 (file)
@@ -23,7 +23,7 @@ from . import exceptions
 from .filepost import encode_multipart_formdata
 from .poolmanager import PoolManager, ProxyManager, proxy_from_url
 from .response import HTTPResponse
-from .util import make_headers, get_host
+from .util import make_headers, get_host, Timeout
 
 
 # Set default logging handler to avoid "No handler found" warnings.
index 93c0b4b..691d4e2 100644 (file)
@@ -4,12 +4,11 @@
 # This module is part of urllib3 and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import logging
-import socket
 import errno
+import logging
 
 from socket import error as SocketError, timeout as SocketTimeout
-from .util import resolve_cert_reqs, resolve_ssl_version, assert_fingerprint
+import socket
 
 try: # Python 3
     from http.client import HTTPConnection, HTTPException
@@ -22,6 +21,7 @@ try: # Python 3
     from queue import LifoQueue, Empty, Full
 except ImportError:
     from Queue import LifoQueue, Empty, Full
+    import Queue as _  # Platform-specific: Windows
 
 
 try: # Compiled with SSL?
@@ -44,21 +44,29 @@ except (ImportError, AttributeError): # Platform-specific: No SSL.
     pass
 
 
-from .request import RequestMethods
-from .response import HTTPResponse
-from .util import get_host, is_connection_dropped, ssl_wrap_socket
 from .exceptions import (
     ClosedPoolError,
+    ConnectTimeoutError,
     EmptyPoolError,
     HostChangedError,
     MaxRetryError,
     SSLError,
-    TimeoutError,
+    ReadTimeoutError,
+    ProxyError,
 )
-
-from .packages.ssl_match_hostname import match_hostname, CertificateError
+from .packages.ssl_match_hostname import CertificateError, match_hostname
 from .packages import six
-
+from .request import RequestMethods
+from .response import HTTPResponse
+from .util import (
+    assert_fingerprint,
+    get_host,
+    is_connection_dropped,
+    resolve_cert_reqs,
+    resolve_ssl_version,
+    ssl_wrap_socket,
+    Timeout,
+)
 
 xrange = six.moves.xrange
 
@@ -96,7 +104,14 @@ class VerifiedHTTPSConnection(HTTPSConnection):
 
     def connect(self):
         # Add certificate verification
-        sock = socket.create_connection((self.host, self.port), self.timeout)
+        try:
+            sock = socket.create_connection(
+                address=(self.host, self.port),
+                timeout=self.timeout)
+        except SocketTimeout:
+                raise ConnectTimeoutError(
+                    self, "Connection to %s timed out. (connect timeout=%s)" %
+                    (self.host, self.timeout))
 
         resolved_cert_reqs = resolve_cert_reqs(self.cert_reqs)
         resolved_ssl_version = resolve_ssl_version(self.ssl_version)
@@ -123,6 +138,7 @@ class VerifiedHTTPSConnection(HTTPSConnection):
                 match_hostname(self.sock.getpeercert(),
                                self.assert_hostname or self.host)
 
+
 ## Pool objects
 
 class ConnectionPool(object):
@@ -135,6 +151,9 @@ class ConnectionPool(object):
     QueueCls = LifoQueue
 
     def __init__(self, host, port=None):
+        # httplib doesn't like it when we include brackets in ipv6 addresses
+        host = host.strip('[]')
+
         self.host = host
         self.port = port
 
@@ -142,6 +161,8 @@ class ConnectionPool(object):
         return '%s(host=%r, port=%r)' % (type(self).__name__,
                                          self.host, self.port)
 
+# This is taken from http://hg.python.org/cpython/file/7aaba721ebc0/Lib/socket.py#l252
+_blocking_errnos = set([errno.EAGAIN, errno.EWOULDBLOCK])
 
 class HTTPConnectionPool(ConnectionPool, RequestMethods):
     """
@@ -160,9 +181,15 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
         as a valid HTTP/1.0 or 1.1 status line, passed into
         :class:`httplib.HTTPConnection`.
 
+        .. note::
+           Only works in Python 2. This parameter is ignored in Python 3.
+
     :param timeout:
-        Socket timeout in seconds for each individual connection, can be
-        a float. None disables timeout.
+        Socket timeout in seconds for each individual connection. This can
+        be a float or integer, which sets the timeout for the HTTP request,
+        or an instance of :class:`urllib3.util.Timeout` which gives you more
+        fine-grained control over request timeouts. After the constructor has
+        been parsed, this is always a `urllib3.util.Timeout` object.
 
     :param maxsize:
         Number of connections to save that can be reused. More than 1 is useful
@@ -192,13 +219,21 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
 
     scheme = 'http'
 
-    def __init__(self, host, port=None, strict=False, timeout=None, maxsize=1,
-                 block=False, headers=None, _proxy=None, _proxy_headers=None):
+    def __init__(self, host, port=None, strict=False,
+                 timeout=Timeout.DEFAULT_TIMEOUT, maxsize=1, block=False,
+                 headers=None, _proxy=None, _proxy_headers=None):
         ConnectionPool.__init__(self, host, port)
         RequestMethods.__init__(self, headers)
 
         self.strict = strict
+
+        # This is for backwards compatibility and can be removed once a timeout
+        # can only be set to a Timeout object
+        if not isinstance(timeout, Timeout):
+            timeout = Timeout.from_float(timeout)
+
         self.timeout = timeout
+
         self.pool = self.QueueCls(maxsize)
         self.block = block
 
@@ -220,9 +255,14 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
         self.num_connections += 1
         log.info("Starting new HTTP connection (%d): %s" %
                  (self.num_connections, self.host))
-        return HTTPConnection(host=self.host,
-                              port=self.port,
-                              strict=self.strict)
+        extra_params = {}
+        if not six.PY3:  # Python 2
+            extra_params['strict'] = self.strict
+
+        return HTTPConnection(host=self.host, port=self.port,
+                              timeout=self.timeout.connect_timeout,
+                              **extra_params)
+
 
     def _get_conn(self, timeout=None):
         """
@@ -283,31 +323,89 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
                         % self.host)
 
         # Connection never got put back into the pool, close it.
-        conn.close()
+        if conn:
+            conn.close()
+
+    def _get_timeout(self, timeout):
+        """ Helper that always returns a :class:`urllib3.util.Timeout` """
+        if timeout is _Default:
+            return self.timeout.clone()
+
+        if isinstance(timeout, Timeout):
+            return timeout.clone()
+        else:
+            # User passed us an int/float. This is for backwards compatibility,
+            # can be removed later
+            return Timeout.from_float(timeout)
 
     def _make_request(self, conn, method, url, timeout=_Default,
                       **httplib_request_kw):
         """
         Perform a request on a given httplib connection object taken from our
         pool.
+
+        :param conn:
+            a connection from one of our connection pools
+
+        :param timeout:
+            Socket timeout in seconds for the request. This can be a
+            float or integer, which will set the same timeout value for
+            the socket connect and the socket read, or an instance of
+            :class:`urllib3.util.Timeout`, which gives you more fine-grained
+            control over your timeouts.
         """
         self.num_requests += 1
 
-        if timeout is _Default:
-            timeout = self.timeout
-
-        conn.timeout = timeout # This only does anything in Py26+
-        conn.request(method, url, **httplib_request_kw)
+        timeout_obj = self._get_timeout(timeout)
 
-        # Set timeout
-        sock = getattr(conn, 'sock', False) # AppEngine doesn't have sock attr.
-        if sock:
-            sock.settimeout(timeout)
+        try:
+            timeout_obj.start_connect()
+            conn.timeout = timeout_obj.connect_timeout
+            # conn.request() calls httplib.*.request, not the method in
+            # request.py. It also calls makefile (recv) on the socket
+            conn.request(method, url, **httplib_request_kw)
+        except SocketTimeout:
+            raise ConnectTimeoutError(
+                self, "Connection to %s timed out. (connect timeout=%s)" %
+                (self.host, timeout_obj.connect_timeout))
+
+        # Reset the timeout for the recv() on the socket
+        read_timeout = timeout_obj.read_timeout
+        log.debug("Setting read timeout to %s" % read_timeout)
+        # App Engine doesn't have a sock attr
+        if hasattr(conn, 'sock') and \
+            read_timeout is not None and \
+            read_timeout is not Timeout.DEFAULT_TIMEOUT:
+            # In Python 3 socket.py will catch EAGAIN and return None when you
+            # try and read into the file pointer created by http.client, which
+            # instead raises a BadStatusLine exception. Instead of catching
+            # the exception and assuming all BadStatusLine exceptions are read
+            # timeouts, check for a zero timeout before making the request.
+            if read_timeout == 0:
+                raise ReadTimeoutError(
+                    self, url,
+                    "Read timed out. (read timeout=%s)" % read_timeout)
+            conn.sock.settimeout(read_timeout)
+
+        # Receive the response from the server
+        try:
+            try: # Python 2.7+, use buffering of HTTP responses
+                httplib_response = conn.getresponse(buffering=True)
+            except TypeError: # Python 2.6 and older
+                httplib_response = conn.getresponse()
+        except SocketTimeout:
+            raise ReadTimeoutError(
+                self, url, "Read timed out. (read timeout=%s)" % read_timeout)
+
+        except SocketError as e: # Platform-specific: Python 2
+            # See the above comment about EAGAIN in Python 3. In Python 2 we
+            # have to specifically catch it and throw the timeout error
+            if e.errno in _blocking_errnos:
+                raise ReadTimeoutError(
+                    self, url,
+                    "Read timed out. (read timeout=%s)" % read_timeout)
+            raise
 
-        try: # Python 2.7+, use buffering of HTTP responses
-            httplib_response = conn.getresponse(buffering=True)
-        except TypeError: # Python 2.6 and older
-            httplib_response = conn.getresponse()
 
         # AppEngine doesn't have a version attr.
         http_version = getattr(conn, '_http_vsn_str', 'HTTP/?')
@@ -387,7 +485,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
 
         :param redirect:
             If True, automatically handle redirects (status codes 301, 302,
-            303, 307). Each redirect counts as a retry.
+            303, 307, 308). Each redirect counts as a retry.
 
         :param assert_same_host:
             If ``True``, will make sure that the host of the pool requests is
@@ -395,8 +493,9 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
             use the pool on an HTTP proxy and request foreign hosts.
 
         :param timeout:
-            If specified, overrides the default timeout for this one request.
-            It may be a float (in seconds).
+            If specified, overrides the default timeout for this one
+            request. It may be a float (in seconds) or an instance of
+            :class:`urllib3.util.Timeout`.
 
         :param pool_timeout:
             If set and the pool is set to block=True, then this method will
@@ -423,9 +522,6 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
         if retries < 0:
             raise MaxRetryError(self, url)
 
-        if timeout is _Default:
-            timeout = self.timeout
-
         if release_conn is None:
             release_conn = response_kw.get('preload_content', True)
 
@@ -461,20 +557,20 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
             #     ``response.release_conn()`` is called (implicitly by
             #     ``response.read()``)
 
-        except Empty as e:
+        except Empty:
             # Timed out by queue
-            raise TimeoutError(self, url,
-                               "Request timed out. (pool_timeout=%s)" %
-                               pool_timeout)
+            raise ReadTimeoutError(
+                self, url, "Read timed out, no pool connections are available.")
 
-        except SocketTimeout as e:
+        except SocketTimeout:
             # Timed out by socket
-            raise TimeoutError(self, url,
-                               "Request timed out. (timeout=%s)" %
-                               timeout)
+            raise ReadTimeoutError(self, url, "Read timed out.")
 
         except BaseSSLError as e:
             # SSL certificate error
+            if 'timed out' in str(e) or \
+               'did not complete (read)' in str(e): # Platform-specific: Python 2.6
+                raise ReadTimeoutError(self, url, "Read timed out.")
             raise SSLError(e)
 
         except CertificateError as e:
@@ -482,6 +578,10 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
             raise SSLError(e)
 
         except (HTTPException, SocketError) as e:
+            if isinstance(e, SocketError) and self.proxy is not None:
+                raise ProxyError('Cannot connect to proxy. '
+                                 'Socket error: %s.' % e)
+
             # Connection broken, discard. It will be replaced next _get_conn().
             conn = None
             # This is necessary so we can access e below
@@ -548,8 +648,7 @@ class HTTPSConnectionPool(HTTPConnectionPool):
                  ca_certs=None, ssl_version=None,
                  assert_hostname=None, assert_fingerprint=None):
 
-        HTTPConnectionPool.__init__(self, host, port,
-                                    strict, timeout, maxsize,
+        HTTPConnectionPool.__init__(self, host, port, strict, timeout, maxsize,
                                     block, headers, _proxy, _proxy_headers)
         self.key_file = key_file
         self.cert_file = cert_file
@@ -609,8 +708,12 @@ class HTTPSConnectionPool(HTTPConnectionPool):
         else:
             connection_class = VerifiedHTTPSConnection
 
+        extra_params = {}
+        if not six.PY3:  # Python 2
+            extra_params['strict'] = self.strict
         connection = connection_class(host=actual_host, port=actual_port,
-                                      strict=self.strict)
+                                      timeout=self.timeout.connect_timeout,
+                                      **extra_params)
 
         return self._prepare_conn(connection)
 
index 6d0255f..d43bcd6 100644 (file)
@@ -20,13 +20,13 @@ Now you can use :mod:`urllib3` as you normally would, and it will support SNI
 when the required modules are installed.
 '''
 
-from ndg.httpsclient.ssl_peer_verification import (ServerSSLCertVerification,
-                                                   SUBJ_ALT_NAME_SUPPORT)
+from ndg.httpsclient.ssl_peer_verification import SUBJ_ALT_NAME_SUPPORT
 from ndg.httpsclient.subj_alt_name import SubjectAltName
 import OpenSSL.SSL
 from pyasn1.codec.der import decoder as der_decoder
 from socket import _fileobject
 import ssl
+from cStringIO import StringIO
 
 from .. import connectionpool
 from .. import util
@@ -99,6 +99,172 @@ def get_subj_alt_name(peer_cert):
     return dns_name
 
 
+class fileobject(_fileobject):
+
+    def read(self, size=-1):
+        # Use max, disallow tiny reads in a loop as they are very inefficient.
+        # We never leave read() with any leftover data from a new recv() call
+        # in our internal buffer.
+        rbufsize = max(self._rbufsize, self.default_bufsize)
+        # Our use of StringIO rather than lists of string objects returned by
+        # recv() minimizes memory usage and fragmentation that occurs when
+        # rbufsize is large compared to the typical return value of recv().
+        buf = self._rbuf
+        buf.seek(0, 2)  # seek end
+        if size < 0:
+            # Read until EOF
+            self._rbuf = StringIO()  # reset _rbuf.  we consume it via buf.
+            while True:
+                try:
+                    data = self._sock.recv(rbufsize)
+                except OpenSSL.SSL.WantReadError:
+                    continue
+                if not data:
+                    break
+                buf.write(data)
+            return buf.getvalue()
+        else:
+            # Read until size bytes or EOF seen, whichever comes first
+            buf_len = buf.tell()
+            if buf_len >= size:
+                # Already have size bytes in our buffer?  Extract and return.
+                buf.seek(0)
+                rv = buf.read(size)
+                self._rbuf = StringIO()
+                self._rbuf.write(buf.read())
+                return rv
+
+            self._rbuf = StringIO()  # reset _rbuf.  we consume it via buf.
+            while True:
+                left = size - buf_len
+                # recv() will malloc the amount of memory given as its
+                # parameter even though it often returns much less data
+                # than that.  The returned data string is short lived
+                # as we copy it into a StringIO and free it.  This avoids
+                # fragmentation issues on many platforms.
+                try:
+                    data = self._sock.recv(left)
+                except OpenSSL.SSL.WantReadError:
+                    continue
+                if not data:
+                    break
+                n = len(data)
+                if n == size and not buf_len:
+                    # Shortcut.  Avoid buffer data copies when:
+                    # - We have no data in our buffer.
+                    # AND
+                    # - Our call to recv returned exactly the
+                    #   number of bytes we were asked to read.
+                    return data
+                if n == left:
+                    buf.write(data)
+                    del data  # explicit free
+                    break
+                assert n <= left, "recv(%d) returned %d bytes" % (left, n)
+                buf.write(data)
+                buf_len += n
+                del data  # explicit free
+                #assert buf_len == buf.tell()
+            return buf.getvalue()
+
+    def readline(self, size=-1):
+        buf = self._rbuf
+        buf.seek(0, 2)  # seek end
+        if buf.tell() > 0:
+            # check if we already have it in our buffer
+            buf.seek(0)
+            bline = buf.readline(size)
+            if bline.endswith('\n') or len(bline) == size:
+                self._rbuf = StringIO()
+                self._rbuf.write(buf.read())
+                return bline
+            del bline
+        if size < 0:
+            # Read until \n or EOF, whichever comes first
+            if self._rbufsize <= 1:
+                # Speed up unbuffered case
+                buf.seek(0)
+                buffers = [buf.read()]
+                self._rbuf = StringIO()  # reset _rbuf.  we consume it via buf.
+                data = None
+                recv = self._sock.recv
+                while True:
+                    try:
+                        while data != "\n":
+                            data = recv(1)
+                            if not data:
+                                break
+                            buffers.append(data)
+                    except OpenSSL.SSL.WantReadError:
+                        continue
+                    break
+                return "".join(buffers)
+
+            buf.seek(0, 2)  # seek end
+            self._rbuf = StringIO()  # reset _rbuf.  we consume it via buf.
+            while True:
+                try:
+                    data = self._sock.recv(self._rbufsize)
+                except OpenSSL.SSL.WantReadError:
+                    continue
+                if not data:
+                    break
+                nl = data.find('\n')
+                if nl >= 0:
+                    nl += 1
+                    buf.write(data[:nl])
+                    self._rbuf.write(data[nl:])
+                    del data
+                    break
+                buf.write(data)
+            return buf.getvalue()
+        else:
+            # Read until size bytes or \n or EOF seen, whichever comes first
+            buf.seek(0, 2)  # seek end
+            buf_len = buf.tell()
+            if buf_len >= size:
+                buf.seek(0)
+                rv = buf.read(size)
+                self._rbuf = StringIO()
+                self._rbuf.write(buf.read())
+                return rv
+            self._rbuf = StringIO()  # reset _rbuf.  we consume it via buf.
+            while True:
+                try:
+                    data = self._sock.recv(self._rbufsize)
+                except OpenSSL.SSL.WantReadError:
+                        continue
+                if not data:
+                    break
+                left = size - buf_len
+                # did we just receive a newline?
+                nl = data.find('\n', 0, left)
+                if nl >= 0:
+                    nl += 1
+                    # save the excess data to _rbuf
+                    self._rbuf.write(data[nl:])
+                    if buf_len:
+                        buf.write(data[:nl])
+                        break
+                    else:
+                        # Shortcut.  Avoid data copy through buf when returning
+                        # a substring of our first recv().
+                        return data[:nl]
+                n = len(data)
+                if n == size and not buf_len:
+                    # Shortcut.  Avoid data copy through buf when
+                    # returning exactly all of our first recv().
+                    return data
+                if n >= left:
+                    buf.write(data[:left])
+                    self._rbuf.write(data[left:])
+                    break
+                buf.write(data)
+                buf_len += n
+                #assert buf_len == buf.tell()
+            return buf.getvalue()
+
+
 class WrappedSocket(object):
     '''API-compatibility wrapper for Python OpenSSL's Connection-class.'''
 
@@ -110,7 +276,7 @@ class WrappedSocket(object):
         return self.socket.fileno()
 
     def makefile(self, mode, bufsize=-1):
-        return _fileobject(self.connection, mode, bufsize)
+        return fileobject(self.connection, mode, bufsize)
 
     def settimeout(self, timeout):
         return self.socket.settimeout(timeout)
@@ -123,8 +289,9 @@ class WrappedSocket(object):
 
     def getpeercert(self, binary_form=False):
         x509 = self.connection.get_peer_certificate()
+
         if not x509:
-            raise ssl.SSLError('')
+            return x509
 
         if binary_form:
             return OpenSSL.crypto.dump_certificate(
@@ -165,9 +332,13 @@ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
     cnx = OpenSSL.SSL.Connection(ctx, sock)
     cnx.set_tlsext_host_name(server_hostname)
     cnx.set_connect_state()
-    try:
-        cnx.do_handshake()
-    except OpenSSL.SSL.Error as e:
-        raise ssl.SSLError('bad handshake', e)
+    while True:
+        try:
+            cnx.do_handshake()
+        except OpenSSL.SSL.WantReadError:
+            continue
+        except OpenSSL.SSL.Error as e:
+            raise ssl.SSLError('bad handshake', e)
+        break
 
     return WrappedSocket(cnx, sock)
index 2e2a259..98ef9ab 100644 (file)
@@ -39,6 +39,11 @@ class SSLError(HTTPError):
     pass
 
 
+class ProxyError(HTTPError):
+    "Raised when the connection to a proxy fails."
+    pass
+
+
 class DecodeError(HTTPError):
     "Raised when automatic decoding based on Content-Type fails."
     pass
@@ -70,8 +75,29 @@ class HostChangedError(RequestError):
         self.retries = retries
 
 
-class TimeoutError(RequestError):
-    "Raised when a socket timeout occurs."
+class TimeoutStateError(HTTPError):
+    """ Raised when passing an invalid state to a timeout """
+    pass
+
+
+class TimeoutError(HTTPError):
+    """ Raised when a socket timeout error occurs.
+
+    Catching this error will catch both :exc:`ReadTimeoutErrors
+    <ReadTimeoutError>` and :exc:`ConnectTimeoutErrors <ConnectTimeoutError>`.
+    """
+    pass
+
+
+class ReadTimeoutError(TimeoutError, RequestError):
+    "Raised when a socket timeout occurs while receiving data from a server"
+    pass
+
+
+# This timeout error does not have a URL attached and needs to inherit from the
+# base HTTPError
+class ConnectTimeoutError(TimeoutError):
+    "Raised when a socket timeout occurs while connecting to a server"
     pass
 
 
diff --git a/requests/packages/urllib3/fields.py b/requests/packages/urllib3/fields.py
new file mode 100644 (file)
index 0000000..ed01765
--- /dev/null
@@ -0,0 +1,177 @@
+# urllib3/fields.py
+# Copyright 2008-2013 Andrey Petrov and contributors (see CONTRIBUTORS.txt)
+#
+# This module is part of urllib3 and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+import email.utils
+import mimetypes
+
+from .packages import six
+
+
+def guess_content_type(filename, default='application/octet-stream'):
+    """
+    Guess the "Content-Type" of a file.
+
+    :param filename:
+        The filename to guess the "Content-Type" of using :mod:`mimetimes`.
+    :param default:
+        If no "Content-Type" can be guessed, default to `default`.
+    """
+    if filename:
+        return mimetypes.guess_type(filename)[0] or default
+    return default
+
+
+def format_header_param(name, value):
+    """
+    Helper function to format and quote a single header parameter.
+
+    Particularly useful for header parameters which might contain
+    non-ASCII values, like file names. This follows RFC 2231, as
+    suggested by RFC 2388 Section 4.4.
+
+    :param name:
+        The name of the parameter, a string expected to be ASCII only.
+    :param value:
+        The value of the parameter, provided as a unicode string.
+    """
+    if not any(ch in value for ch in '"\\\r\n'):
+        result = '%s="%s"' % (name, value)
+        try:
+            result.encode('ascii')
+        except UnicodeEncodeError:
+            pass
+        else:
+            return result
+    if not six.PY3:  # Python 2:
+        value = value.encode('utf-8')
+    value = email.utils.encode_rfc2231(value, 'utf-8')
+    value = '%s*=%s' % (name, value)
+    return value
+
+
+class RequestField(object):
+    """
+    A data container for request body parameters.
+
+    :param name:
+        The name of this request field.
+    :param data:
+        The data/value body.
+    :param filename:
+        An optional filename of the request field.
+    :param headers:
+        An optional dict-like object of headers to initially use for the field.
+    """
+    def __init__(self, name, data, filename=None, headers=None):
+        self._name = name
+        self._filename = filename
+        self.data = data
+        self.headers = {}
+        if headers:
+            self.headers = dict(headers)
+
+    @classmethod
+    def from_tuples(cls, fieldname, value):
+        """
+        A :class:`~urllib3.fields.RequestField` factory from old-style tuple parameters.
+
+        Supports constructing :class:`~urllib3.fields.RequestField` from parameter
+        of key/value strings AND key/filetuple. A filetuple is a (filename, data, MIME type)
+        tuple where the MIME type is optional. For example: ::
+
+            'foo': 'bar',
+            'fakefile': ('foofile.txt', 'contents of foofile'),
+            'realfile': ('barfile.txt', open('realfile').read()),
+            'typedfile': ('bazfile.bin', open('bazfile').read(), 'image/jpeg'),
+            'nonamefile': 'contents of nonamefile field',
+
+        Field names and filenames must be unicode.
+        """
+        if isinstance(value, tuple):
+            if len(value) == 3:
+                filename, data, content_type = value
+            else:
+                filename, data = value
+                content_type = guess_content_type(filename)
+        else:
+            filename = None
+            content_type = None
+            data = value
+
+        request_param = cls(fieldname, data, filename=filename)
+        request_param.make_multipart(content_type=content_type)
+
+        return request_param
+
+    def _render_part(self, name, value):
+        """
+        Overridable helper function to format a single header parameter.
+
+        :param name:
+            The name of the parameter, a string expected to be ASCII only.
+        :param value:
+            The value of the parameter, provided as a unicode string.
+        """
+        return format_header_param(name, value)
+
+    def _render_parts(self, header_parts):
+        """
+        Helper function to format and quote a single header.
+
+        Useful for single headers that are composed of multiple items. E.g.,
+        'Content-Disposition' fields.
+
+        :param header_parts:
+            A sequence of (k, v) typles or a :class:`dict` of (k, v) to format as
+            `k1="v1"; k2="v2"; ...`.
+        """
+        parts = []
+        iterable = header_parts
+        if isinstance(header_parts, dict):
+            iterable = header_parts.items()
+
+        for name, value in iterable:
+            if value:
+                parts.append(self._render_part(name, value))
+
+        return '; '.join(parts)
+
+    def render_headers(self):
+        """
+        Renders the headers for this request field.
+        """
+        lines = []
+
+        sort_keys = ['Content-Disposition', 'Content-Type', 'Content-Location']
+        for sort_key in sort_keys:
+            if self.headers.get(sort_key, False):
+                lines.append('%s: %s' % (sort_key, self.headers[sort_key]))
+
+        for header_name, header_value in self.headers.items():
+            if header_name not in sort_keys:
+                if header_value:
+                    lines.append('%s: %s' % (header_name, header_value))
+
+        lines.append('\r\n')
+        return '\r\n'.join(lines)
+
+    def make_multipart(self, content_disposition=None, content_type=None, content_location=None):
+        """
+        Makes this request field into a multipart request field.
+
+        This method overrides "Content-Disposition", "Content-Type" and
+        "Content-Location" headers to the request parameter.
+
+        :param content_type:
+            The 'Content-Type' of the request body.
+        :param content_location:
+            The 'Content-Location' of the request body.
+
+        """
+        self.headers['Content-Disposition'] = content_disposition or 'form-data'
+        self.headers['Content-Disposition'] += '; '.join(['', self._render_parts((('name', self._name), ('filename', self._filename)))])
+        self.headers['Content-Type'] = content_type
+        self.headers['Content-Location'] = content_location
index 526a740..4575582 100644 (file)
@@ -12,6 +12,7 @@ from io import BytesIO
 
 from .packages import six
 from .packages.six import b
+from .fields import RequestField
 
 writer = codecs.lookup('utf-8')[3]
 
@@ -23,15 +24,38 @@ def choose_boundary():
     return uuid4().hex
 
 
-def get_content_type(filename):
-    return mimetypes.guess_type(filename)[0] or 'application/octet-stream'
+def iter_field_objects(fields):
+    """
+    Iterate over fields.
+
+    Supports list of (k, v) tuples and dicts, and lists of
+    :class:`~urllib3.fields.RequestField`.
+
+    """
+    if isinstance(fields, dict):
+        i = six.iteritems(fields)
+    else:
+        i = iter(fields)
+
+    for field in i:
+      if isinstance(field, RequestField):
+        yield field
+      else:
+        yield RequestField.from_tuples(*field)
 
 
 def iter_fields(fields):
     """
     Iterate over fields.
 
+    .. deprecated ::
+
+      The addition of `~urllib3.fields.RequestField` makes this function
+      obsolete. Instead, use :func:`iter_field_objects`, which returns
+      `~urllib3.fields.RequestField` objects, instead.
+
     Supports list of (k, v) tuples and dicts.
+
     """
     if isinstance(fields, dict):
         return ((k, v) for k, v in six.iteritems(fields))
@@ -44,15 +68,7 @@ def encode_multipart_formdata(fields, boundary=None):
     Encode a dictionary of ``fields`` using the multipart/form-data MIME format.
 
     :param fields:
-        Dictionary of fields or list of (key, value) or (key, value, MIME type)
-        field tuples.  The key is treated as the field name, and the value as
-        the body of the form-data bytes. If the value is a tuple of two
-        elements, then the first element is treated as the filename of the
-        form-data section and a suitable MIME type is guessed based on the
-        filename. If the value is a tuple of three elements, then the third
-        element is treated as an explicit MIME type of the form-data section.
-
-        Field names and filenames must be unicode.
+        Dictionary of fields or list of (key, :class:`~urllib3.fields.RequestField`).
 
     :param boundary:
         If not specified, then a random boundary will be generated using
@@ -62,24 +78,11 @@ def encode_multipart_formdata(fields, boundary=None):
     if boundary is None:
         boundary = choose_boundary()
 
-    for fieldname, value in iter_fields(fields):
+    for field in iter_field_objects(fields):
         body.write(b('--%s\r\n' % (boundary)))
 
-        if isinstance(value, tuple):
-            if len(value) == 3:
-                filename, data, content_type = value
-            else:
-                filename, data = value
-                content_type = get_content_type(filename)
-            writer(body).write('Content-Disposition: form-data; name="%s"; '
-                               'filename="%s"\r\n' % (fieldname, filename))
-            body.write(b('Content-Type: %s\r\n\r\n' %
-                       (content_type,)))
-        else:
-            data = value
-            writer(body).write('Content-Disposition: form-data; name="%s"\r\n'
-                               % (fieldname))
-            body.write(b'\r\n')
+        writer(body).write(field.render_headers())
+        data = field.data
 
         if isinstance(data, int):
             data = str(data)  # Backwards compatibility
index 9560b04..2d61ac2 100644 (file)
@@ -7,23 +7,60 @@ __version__ = '3.2.2'
 class CertificateError(ValueError):
     pass
 
-def _dnsname_to_pat(dn):
+def _dnsname_match(dn, hostname, max_wildcards=1):
+    """Matching according to RFC 6125, section 6.4.3
+
+    http://tools.ietf.org/html/rfc6125#section-6.4.3
+    """
     pats = []
-    for frag in dn.split(r'.'):
-        if frag == '*':
-            # When '*' is a fragment by itself, it matches a non-empty dotless
-            # fragment.
-            pats.append('[^.]+')
-        else:
-            # Otherwise, '*' matches any dotless fragment.
-            frag = re.escape(frag)
-            pats.append(frag.replace(r'\*', '[^.]*'))
-    return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
+    if not dn:
+        return False
+
+    parts = dn.split(r'.')
+    leftmost = parts[0]
+
+    wildcards = leftmost.count('*')
+    if wildcards > max_wildcards:
+        # Issue #17980: avoid denials of service by refusing more
+        # than one wildcard per fragment.  A survery of established
+        # policy among SSL implementations showed it to be a
+        # reasonable choice.
+        raise CertificateError(
+            "too many wildcards in certificate DNS name: " + repr(dn))
+
+    # speed up common case w/o wildcards
+    if not wildcards:
+        return dn.lower() == hostname.lower()
+
+    # RFC 6125, section 6.4.3, subitem 1.
+    # The client SHOULD NOT attempt to match a presented identifier in which
+    # the wildcard character comprises a label other than the left-most label.
+    if leftmost == '*':
+        # When '*' is a fragment by itself, it matches a non-empty dotless
+        # fragment.
+        pats.append('[^.]+')
+    elif leftmost.startswith('xn--') or hostname.startswith('xn--'):
+        # RFC 6125, section 6.4.3, subitem 3.
+        # The client SHOULD NOT attempt to match a presented identifier
+        # where the wildcard character is embedded within an A-label or
+        # U-label of an internationalized domain name.
+        pats.append(re.escape(leftmost))
+    else:
+        # Otherwise, '*' matches any dotless string, e.g. www*
+        pats.append(re.escape(leftmost).replace(r'\*', '[^.]*'))
+
+    # add the remaining fragments, ignore any wildcards
+    for frag in parts[1:]:
+        pats.append(re.escape(frag))
+
+    pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
+    return pat.match(hostname)
+
 
 def match_hostname(cert, hostname):
     """Verify that *cert* (in decoded format as returned by
-    SSLSocket.getpeercert()) matches the *hostname*.  RFC 2818 rules
-    are mostly followed, but IP addresses are not accepted for *hostname*.
+    SSLSocket.getpeercert()) matches the *hostname*.  RFC 2818 and RFC 6125
+    rules are followed, but IP addresses are not accepted for *hostname*.
 
     CertificateError is raised on failure. On success, the function
     returns nothing.
@@ -34,7 +71,7 @@ def match_hostname(cert, hostname):
     san = cert.get('subjectAltName', ())
     for key, value in san:
         if key == 'DNS':
-            if _dnsname_to_pat(value).match(hostname):
+            if _dnsname_match(value, hostname):
                 return
             dnsnames.append(value)
     if not dnsnames:
@@ -45,7 +82,7 @@ def match_hostname(cert, hostname):
                 # XXX according to RFC 2818, the most specific Common Name
                 # must be used.
                 if key == 'commonName':
-                    if _dnsname_to_pat(value).match(hostname):
+                    if _dnsname_match(value, hostname):
                         return
                     dnsnames.append(value)
     if len(dnsnames) > 1:
index c7f93b8..4efff5a 100644 (file)
@@ -74,6 +74,7 @@ class HTTPResponse(io.IOBase):
     """
 
     CONTENT_DECODERS = ['gzip', 'deflate']
+    REDIRECT_STATUSES = [301, 302, 303, 307, 308]
 
     def __init__(self, body='', headers=None, status=0, version=0, reason=None,
                  strict=0, preload_content=True, decode_content=True,
@@ -107,7 +108,7 @@ class HTTPResponse(io.IOBase):
             code and valid location. ``None`` if redirect status and no
             location. ``False`` if not a redirect status code.
         """
-        if self.status in [301, 302, 303, 307]:
+        if self.status in self.REDIRECT_STATUSES:
             return self.headers.get('location')
 
         return False
index 39bceab..266c9ed 100644 (file)
@@ -6,10 +6,11 @@
 
 
 from base64 import b64encode
+from binascii import hexlify, unhexlify
 from collections import namedtuple
-from socket import error as SocketError
 from hashlib import md5, sha1
-from binascii import hexlify, unhexlify
+from socket import error as SocketError, _GLOBAL_DEFAULT_TIMEOUT
+import time
 
 try:
     from select import poll, POLLIN
@@ -32,7 +33,233 @@ except ImportError:
     pass
 
 from .packages import six
-from .exceptions import LocationParseError, SSLError
+from .exceptions import LocationParseError, SSLError, TimeoutStateError
+
+
+_Default = object()
+# The default timeout to use for socket connections. This is the attribute used
+# by httplib to define the default timeout
+
+
+def current_time():
+    """
+    Retrieve the current time, this function is mocked out in unit testing.
+    """
+    return time.time()
+
+
+class Timeout(object):
+    """
+    Utility object for storing timeout values.
+
+    Example usage:
+
+    .. code-block:: python
+
+        timeout = urllib3.util.Timeout(connect=2.0, read=7.0)
+        pool = HTTPConnectionPool('www.google.com', 80, timeout=timeout)
+        pool.request(...) # Etc, etc
+
+    :param connect:
+        The maximum amount of time to wait for a connection attempt to a server
+        to succeed. Omitting the parameter will default the connect timeout to
+        the system default, probably `the global default timeout in socket.py
+        <http://hg.python.org/cpython/file/603b4d593758/Lib/socket.py#l535>`_.
+        None will set an infinite timeout for connection attempts.
+
+    :type connect: integer, float, or None
+
+    :param read:
+        The maximum amount of time to wait between consecutive
+        read operations for a response from the server. Omitting
+        the parameter will default the read timeout to the system
+        default, probably `the global default timeout in socket.py
+        <http://hg.python.org/cpython/file/603b4d593758/Lib/socket.py#l535>`_.
+        None will set an infinite timeout.
+
+    :type read: integer, float, or None
+
+    :param total:
+        The maximum amount of time to wait for an HTTP request to connect and
+        return. This combines the connect and read timeouts into one. In the
+        event that both a connect timeout and a total are specified, or a read
+        timeout and a total are specified, the shorter timeout will be applied.
+
+        Defaults to None.
+
+
+    :type total: integer, float, or None
+
+    .. note::
+
+        Many factors can affect the total amount of time for urllib3 to return
+        an HTTP response. Specifically, Python's DNS resolver does not obey the
+        timeout specified on the socket. Other factors that can affect total
+        request time include high CPU load, high swap, the program running at a
+        low priority level, or other behaviors. The observed running time for
+        urllib3 to return a response may be greater than the value passed to
+        `total`.
+
+        In addition, the read and total timeouts only measure the time between
+        read operations on the socket connecting the client and the server, not
+        the total amount of time for the request to return a complete response.
+        As an example, you may want a request to return within 7 seconds or
+        fail, so you set the ``total`` timeout to 7 seconds. If the server
+        sends one byte to you every 5 seconds, the request will **not** trigger
+        time out. This case is admittedly rare.
+    """
+
+    #: A sentinel object representing the default timeout value
+    DEFAULT_TIMEOUT = _GLOBAL_DEFAULT_TIMEOUT
+
+    def __init__(self, connect=_Default, read=_Default, total=None):
+        self._connect = self._validate_timeout(connect, 'connect')
+        self._read = self._validate_timeout(read, 'read')
+        self.total = self._validate_timeout(total, 'total')
+        self._start_connect = None
+
+    def __str__(self):
+        return '%s(connect=%r, read=%r, total=%r)' % (
+            type(self).__name__, self._connect, self._read, self.total)
+
+
+    @classmethod
+    def _validate_timeout(cls, value, name):
+        """ Check that a timeout attribute is valid
+
+        :param value: The timeout value to validate
+        :param name: The name of the timeout attribute to validate. This is used
+            for clear error messages
+        :return: the value
+        :raises ValueError: if the type is not an integer or a float, or if it
+            is a numeric value less than zero
+        """
+        if value is _Default:
+            return cls.DEFAULT_TIMEOUT
+
+        if value is None or value is cls.DEFAULT_TIMEOUT:
+            return value
+
+        try:
+            float(value)
+        except (TypeError, ValueError):
+            raise ValueError("Timeout value %s was %s, but it must be an "
+                             "int or float." % (name, value))
+
+        try:
+            if value < 0:
+                raise ValueError("Attempted to set %s timeout to %s, but the "
+                                 "timeout cannot be set to a value less "
+                                 "than 0." % (name, value))
+        except TypeError: # Python 3
+            raise ValueError("Timeout value %s was %s, but it must be an "
+                             "int or float." % (name, value))
+
+        return value
+
+    @classmethod
+    def from_float(cls, timeout):
+        """ Create a new Timeout from a legacy timeout value.
+
+        The timeout value used by httplib.py sets the same timeout on the
+        connect(), and recv() socket requests. This creates a :class:`Timeout`
+        object that sets the individual timeouts to the ``timeout`` value passed
+        to this function.
+
+        :param timeout: The legacy timeout value
+        :type timeout: integer, float, sentinel default object, or None
+        :return: a Timeout object
+        :rtype: :class:`Timeout`
+        """
+        return Timeout(read=timeout, connect=timeout)
+
+    def clone(self):
+        """ Create a copy of the timeout object
+
+        Timeout properties are stored per-pool but each request needs a fresh
+        Timeout object to ensure each one has its own start/stop configured.
+
+        :return: a copy of the timeout object
+        :rtype: :class:`Timeout`
+        """
+        # We can't use copy.deepcopy because that will also create a new object
+        # for _GLOBAL_DEFAULT_TIMEOUT, which socket.py uses as a sentinel to
+        # detect the user default.
+        return Timeout(connect=self._connect, read=self._read,
+                       total=self.total)
+
+    def start_connect(self):
+        """ Start the timeout clock, used during a connect() attempt
+
+        :raises urllib3.exceptions.TimeoutStateError: if you attempt
+            to start a timer that has been started already.
+        """
+        if self._start_connect is not None:
+            raise TimeoutStateError("Timeout timer has already been started.")
+        self._start_connect = current_time()
+        return self._start_connect
+
+    def get_connect_duration(self):
+        """ Gets the time elapsed since the call to :meth:`start_connect`.
+
+        :return: the elapsed time
+        :rtype: float
+        :raises urllib3.exceptions.TimeoutStateError: if you attempt
+            to get duration for a timer that hasn't been started.
+        """
+        if self._start_connect is None:
+            raise TimeoutStateError("Can't get connect duration for timer "
+                                    "that has not started.")
+        return current_time() - self._start_connect
+
+    @property
+    def connect_timeout(self):
+        """ Get the value to use when setting a connection timeout.
+
+        This will be a positive float or integer, the value None
+        (never timeout), or the default system timeout.
+
+        :return: the connect timeout
+        :rtype: int, float, :attr:`Timeout.DEFAULT_TIMEOUT` or None
+        """
+        if self.total is None:
+            return self._connect
+
+        if self._connect is None or self._connect is self.DEFAULT_TIMEOUT:
+            return self.total
+
+        return min(self._connect, self.total)
+
+    @property
+    def read_timeout(self):
+        """ Get the value for the read timeout.
+
+        This assumes some time has elapsed in the connection timeout and
+        computes the read timeout appropriately.
+
+        If self.total is set, the read timeout is dependent on the amount of
+        time taken by the connect timeout. If the connection time has not been
+        established, a :exc:`~urllib3.exceptions.TimeoutStateError` will be
+        raised.
+
+        :return: the value to use for the read timeout
+        :rtype: int, float, :attr:`Timeout.DEFAULT_TIMEOUT` or None
+        :raises urllib3.exceptions.TimeoutStateError: If :meth:`start_connect`
+            has not yet been called on this object.
+        """
+        if (self.total is not None and
+            self.total is not self.DEFAULT_TIMEOUT and
+            self._read is not None and
+            self._read is not self.DEFAULT_TIMEOUT):
+            # in case the connect timeout has not yet been established.
+            if self._start_connect is None:
+                return self._read
+            return max(0, min(self.total - self.get_connect_duration(),
+                              self._read))
+        elif self.total is not None and self.total is not self.DEFAULT_TIMEOUT:
+            return max(0, self.total - self.get_connect_duration())
+        else:
+            return self._read
 
 
 class Url(namedtuple('Url', ['scheme', 'auth', 'host', 'port', 'path', 'query', 'fragment'])):