new urllib3
authorKenneth Reitz <me@kennethreitz.com>
Mon, 17 Dec 2012 09:08:36 +0000 (04:08 -0500)
committerKenneth Reitz <me@kennethreitz.com>
Mon, 17 Dec 2012 09:08:36 +0000 (04:08 -0500)
requests/packages/urllib3/connectionpool.py
requests/packages/urllib3/exceptions.py
requests/packages/urllib3/filepost.py
requests/packages/urllib3/packages/six.py
requests/packages/urllib3/poolmanager.py
requests/packages/urllib3/request.py
requests/packages/urllib3/response.py
requests/packages/urllib3/util.py

index 97da544..af8760d 100644 (file)
@@ -6,8 +6,9 @@
 
 import logging
 import socket
+import errno
 
-from socket import timeout as SocketTimeout
+from socket import error as SocketError, timeout as SocketTimeout
 
 try: # Python 3
     from http.client import HTTPConnection, HTTPException
@@ -41,7 +42,7 @@ except (ImportError, AttributeError): # Platform-specific: No SSL.
 
 from .request import RequestMethods
 from .response import HTTPResponse
-from .util import get_host, is_connection_dropped
+from .util import get_host, is_connection_dropped, ssl_wrap_socket
 from .exceptions import (
     ClosedPoolError,
     EmptyPoolError,
@@ -76,6 +77,7 @@ class VerifiedHTTPSConnection(HTTPSConnection):
     """
     cert_reqs = None
     ca_certs = None
+    ssl_version = None
 
     def set_cert(self, key_file=None, cert_file=None,
                  cert_reqs='CERT_NONE', ca_certs=None):
@@ -96,9 +98,12 @@ class VerifiedHTTPSConnection(HTTPSConnection):
 
         # Wrap socket using verification with the root certs in
         # trusted_root_certs
-        self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file,
+        self.sock = ssl_wrap_socket(sock, self.key_file, self.cert_file,
                                     cert_reqs=self.cert_reqs,
-                                    ca_certs=self.ca_certs)
+                                    ca_certs=self.ca_certs,
+                                    server_hostname=self.host,
+                                    ssl_version=self.ssl_version)
+
         if self.ca_certs:
             match_hostname(self.sock.getpeercert(), self.host)
 
@@ -166,13 +171,13 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
 
     def __init__(self, host, port=None, strict=False, timeout=None, maxsize=1,
                  block=False, headers=None):
-        super(HTTPConnectionPool, self).__init__(host, port)
+        ConnectionPool.__init__(self, host, port)
+        RequestMethods.__init__(self, headers)
 
         self.strict = strict
         self.timeout = timeout
         self.pool = self.QueueCls(maxsize)
         self.block = block
-        self.headers = headers or {}
 
         # Fill the queue up so that doing get() on it will block properly
         for _ in xrange(maxsize):
@@ -189,7 +194,9 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
         self.num_connections += 1
         log.info("Starting new HTTP connection (%d): %s" %
                  (self.num_connections, self.host))
-        return HTTPConnection(host=self.host, port=self.port)
+        return HTTPConnection(host=self.host,
+                              port=self.port,
+                              strict=self.strict)
 
     def _get_conn(self, timeout=None):
         """
@@ -449,12 +456,15 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
             # Name mismatch
             raise SSLError(e)
 
-        except HTTPException as e:
+        except (HTTPException, SocketError) as e:
             # Connection broken, discard. It will be replaced next _get_conn().
             conn = None
             # This is necessary so we can access e below
             err = e
 
+            if retries == 0:
+                raise MaxRetryError(self, url, e)
+
         finally:
             if release_conn:
                 # Put the connection back to be reused. If the connection is
@@ -491,11 +501,11 @@ class HTTPSConnectionPool(HTTPConnectionPool):
 
     When Python is compiled with the :mod:`ssl` module, then
     :class:`.VerifiedHTTPSConnection` is used, which *can* verify certificates,
-    instead of :class:httplib.HTTPSConnection`.
+    instead of :class:`httplib.HTTPSConnection`.
 
-    The ``key_file``, ``cert_file``, ``cert_reqs``, and ``ca_certs`` parameters
+    The ``key_file``, ``cert_file``, ``cert_reqs``, ``ca_certs``, and ``ssl_version``
     are only used if :mod:`ssl` is available and are fed into
-    :meth:`ssl.wrap_socket` to upgrade the connection socket into an SSL socket.
+    :meth:`urllib3.util.ssl_wrap_socket` to upgrade the connection socket into an SSL socket.
     """
 
     scheme = 'https'
@@ -504,15 +514,16 @@ class HTTPSConnectionPool(HTTPConnectionPool):
                  strict=False, timeout=None, maxsize=1,
                  block=False, headers=None,
                  key_file=None, cert_file=None,
-                 cert_reqs='CERT_NONE', ca_certs=None):
+                 cert_reqs='CERT_NONE', ca_certs=None, ssl_version=None):
 
-        super(HTTPSConnectionPool, self).__init__(host, port,
-                                                  strict, timeout, maxsize,
-                                                  block, headers)
+        HTTPConnectionPool.__init__(self, host, port,
+                                    strict, timeout, maxsize,
+                                    block, headers)
         self.key_file = key_file
         self.cert_file = cert_file
         self.cert_reqs = cert_reqs
         self.ca_certs = ca_certs
+        self.ssl_version = ssl_version
 
     def _new_conn(self):
         """
@@ -527,11 +538,21 @@ class HTTPSConnectionPool(HTTPConnectionPool):
                 raise SSLError("Can't connect to HTTPS URL because the SSL "
                                "module is not available.")
 
-            return HTTPSConnection(host=self.host, port=self.port)
+            return HTTPSConnection(host=self.host,
+                                   port=self.port,
+                                   strict=self.strict)
 
-        connection = VerifiedHTTPSConnection(host=self.host, port=self.port)
+        connection = VerifiedHTTPSConnection(host=self.host,
+                                             port=self.port,
+                                             strict=self.strict)
         connection.set_cert(key_file=self.key_file, cert_file=self.cert_file,
                             cert_reqs=self.cert_reqs, ca_certs=self.ca_certs)
+
+        if self.ssl_version is None:
+            connection.ssl_version = ssl.PROTOCOL_SSLv23
+        else:
+            connection.ssl_version = self.ssl_version
+
         return connection
 
 
index 99ebb67..c5eb962 100644 (file)
@@ -18,6 +18,10 @@ class PoolError(HTTPError):
         self.pool = pool
         HTTPError.__init__(self, "%s: %s" % (pool, message))
 
+    def __reduce__(self):
+        # For pickling purposes.
+        return self.__class__, (None, self.url)
+
 
 class SSLError(HTTPError):
     "Raised when SSL certificate fails in an HTTPS connection."
@@ -34,10 +38,16 @@ class DecodeError(HTTPError):
 class MaxRetryError(PoolError):
     "Raised when the maximum number of retries is exceeded."
 
-    def __init__(self, pool, url):
+    def __init__(self, pool, url, reason=None):
+        self.reason = reason
+
         message = "Max retries exceeded with url: %s" % url
-        PoolError.__init__(self, pool, message)
+        if reason:
+            message += " (Caused by %s: %s)" % (type(reason), reason)
+        else:
+            message += " (Caused by redirect)"
 
+        PoolError.__init__(self, pool, message)
         self.url = url
 
 
@@ -72,6 +82,6 @@ class LocationParseError(ValueError, HTTPError):
 
     def __init__(self, location):
         message = "Failed to parse: %s" % location
-        super(LocationParseError, self).__init__(self, message)
+        HTTPError.__init__(self, message)
 
         self.location = location
index e679b93..8d900bd 100644 (file)
@@ -41,13 +41,16 @@ def iter_fields(fields):
 
 def encode_multipart_formdata(fields, boundary=None):
     """
-    Encode a dictionary of ``fields`` using the multipart/form-data mime format.
+    Encode a dictionary of ``fields`` using the multipart/form-data MIME format.
 
     :param fields:
-        Dictionary of fields or list of (key, value) field tuples.  The key is
-        treated as the field name, and the value as the body of the form-data
-        bytes. If the value is a tuple of two elements, then the first element
-        is treated as the filename of the form-data section.
+        Dictionary of fields or list of (key, value) or (key, value, MIME type)
+        field tuples.  The key is treated as the field name, and the value as
+        the body of the form-data bytes. If the value is a tuple of two
+        elements, then the first element is treated as the filename of the
+        form-data section and a suitable MIME type is guessed based on the
+        filename. If the value is a tuple of three elements, then the third
+        element is treated as an explicit MIME type of the form-data section.
 
         Field names and filenames must be unicode.
 
@@ -63,16 +66,20 @@ def encode_multipart_formdata(fields, boundary=None):
         body.write(b('--%s\r\n' % (boundary)))
 
         if isinstance(value, tuple):
-            filename, data = value
+            if len(value) == 3:
+                filename, data, content_type = value
+            else:
+                filename, data = value
+                content_type = get_content_type(filename)
             writer(body).write('Content-Disposition: form-data; name="%s"; '
                                'filename="%s"\r\n' % (fieldname, filename))
             body.write(b('Content-Type: %s\r\n\r\n' %
-                       (get_content_type(filename))))
+                       (content_type,)))
         else:
             data = value
             writer(body).write('Content-Disposition: form-data; name="%s"\r\n'
                                % (fieldname))
-            body.write(b'Content-Type: text/plain\r\n\r\n')
+            body.write(b'\r\n')
 
         if isinstance(data, int):
             data = str(data)  # Backwards compatibility
index a64f6fb..27d8011 100644 (file)
@@ -24,7 +24,7 @@ import sys
 import types
 
 __author__ = "Benjamin Peterson <benjamin@python.org>"
-__version__ = "1.1.0"
+__version__ = "1.2.0"  # Revision 41c74fef2ded
 
 
 # True if we are running on Python 3.
@@ -45,19 +45,23 @@ else:
     text_type = unicode
     binary_type = str
 
-    # It's possible to have sizeof(long) != sizeof(Py_ssize_t).
-    class X(object):
-        def __len__(self):
-            return 1 << 31
-    try:
-        len(X())
-    except OverflowError:
-        # 32-bit
+    if sys.platform.startswith("java"):
+        # Jython always uses 32 bits.
         MAXSIZE = int((1 << 31) - 1)
     else:
-        # 64-bit
-        MAXSIZE = int((1 << 63) - 1)
-    del X
+        # It's possible to have sizeof(long) != sizeof(Py_ssize_t).
+        class X(object):
+            def __len__(self):
+                return 1 << 31
+        try:
+            len(X())
+        except OverflowError:
+            # 32-bit
+            MAXSIZE = int((1 << 31) - 1)
+        else:
+            # 64-bit
+            MAXSIZE = int((1 << 63) - 1)
+            del X
 
 
 def _add_doc(func, doc):
@@ -132,6 +136,7 @@ class _MovedItems(types.ModuleType):
 _moved_attributes = [
     MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"),
     MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"),
+    MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"),
     MovedAttribute("map", "itertools", "builtins", "imap", "map"),
     MovedAttribute("reload_module", "__builtin__", "imp", "reload"),
     MovedAttribute("reduce", "__builtin__", "functools"),
@@ -178,7 +183,7 @@ for attr in _moved_attributes:
     setattr(_MovedItems, attr.name, attr)
 del attr
 
-moves = sys.modules["six.moves"] = _MovedItems("moves")
+moves = sys.modules[__name__ + ".moves"] = _MovedItems("moves")
 
 
 def add_move(move):
@@ -219,12 +224,19 @@ else:
     _iteritems = "iteritems"
 
 
+try:
+    advance_iterator = next
+except NameError:
+    def advance_iterator(it):
+        return it.next()
+next = advance_iterator
+
+
 if PY3:
     def get_unbound_function(unbound):
         return unbound
 
-
-    advance_iterator = next
+    Iterator = object
 
     def callable(obj):
         return any("__call__" in klass.__dict__ for klass in type(obj).__mro__)
@@ -232,9 +244,10 @@ else:
     def get_unbound_function(unbound):
         return unbound.im_func
 
+    class Iterator(object):
 
-    def advance_iterator(it):
-        return it.next()
+        def next(self):
+            return type(self).__next__(self)
 
     callable = callable
 _add_doc(get_unbound_function,
@@ -249,15 +262,15 @@ get_function_defaults = operator.attrgetter(_func_defaults)
 
 def iterkeys(d):
     """Return an iterator over the keys of a dictionary."""
-    return getattr(d, _iterkeys)()
+    return iter(getattr(d, _iterkeys)())
 
 def itervalues(d):
     """Return an iterator over the values of a dictionary."""
-    return getattr(d, _itervalues)()
+    return iter(getattr(d, _itervalues)())
 
 def iteritems(d):
     """Return an iterator over the (key, value) pairs of a dictionary."""
-    return getattr(d, _iteritems)()
+    return iter(getattr(d, _iteritems)())
 
 
 if PY3:
index 8f5b54c..a124202 100644 (file)
@@ -30,8 +30,12 @@ class PoolManager(RequestMethods):
     necessary connection pools for you.
 
     :param num_pools:
-        Number of connection pools to cache before discarding the least recently
-        used pool.
+        Number of connection pools to cache before discarding the least
+        recently used pool.
+
+    :param headers:
+        Headers to include with all requests, unless other headers are given
+        explicitly.
 
     :param \**connection_pool_kw:
         Additional parameters are used to create fresh
@@ -40,15 +44,16 @@ class PoolManager(RequestMethods):
     Example: ::
 
         >>> manager = PoolManager(num_pools=2)
-        >>> r = manager.urlopen("http://google.com/")
-        >>> r = manager.urlopen("http://google.com/mail")
-        >>> r = manager.urlopen("http://yahoo.com/")
+        >>> r = manager.request('GET', 'http://google.com/')
+        >>> r = manager.request('GET', 'http://google.com/mail')
+        >>> r = manager.request('GET', 'http://yahoo.com/')
         >>> len(manager.pools)
         2
 
     """
 
-    def __init__(self, num_pools=10, **connection_pool_kw):
+    def __init__(self, num_pools=10, headers=None, **connection_pool_kw):
+        RequestMethods.__init__(self, headers)
         self.connection_pool_kw = connection_pool_kw
         self.pools = RecentlyUsedContainer(num_pools,
                                            dispose_func=lambda p: p.close())
@@ -113,6 +118,8 @@ class PoolManager(RequestMethods):
 
         kw['assert_same_host'] = False
         kw['redirect'] = False
+        if 'headers' not in kw:
+            kw['headers'] = self.headers
 
         response = conn.urlopen(method, u.request_uri, **kw)
 
@@ -124,7 +131,7 @@ class PoolManager(RequestMethods):
             method = 'GET'
 
         log.info("Redirecting %s -> %s" % (url, redirect_location))
-        kw['retries'] = kw.get('retries', 3) - 1 # Persist retries countdown
+        kw['retries'] = kw.get('retries', 3) - 1  # Persist retries countdown
         return self.urlopen(method, redirect_location, **kw)
 
 
@@ -138,13 +145,11 @@ class ProxyManager(RequestMethods):
         self.proxy_pool = proxy_pool
 
     def _set_proxy_headers(self, headers=None):
-        headers = headers or {}
-
-        # Same headers are curl passes for --proxy1.0
-        headers['Accept'] = '*/*'
-        headers['Proxy-Connection'] = 'Keep-Alive'
+        headers_ = {'Accept': '*/*'}
+        if headers:
+            headers_.update(headers)
 
-        return headers
+        return headers_
 
     def urlopen(self, method, url, **kw):
         "Same as HTTP(S)ConnectionPool.urlopen, ``url`` must be absolute."
index 569ac96..2b4704e 100644 (file)
@@ -36,12 +36,20 @@ class RequestMethods(object):
     :meth:`.request` is for making any kind of request, it will look up the
     appropriate encoding format and use one of the above two methods to make
     the request.
+
+    Initializer parameters:
+
+    :param headers:
+        Headers to include with all requests, unless other headers are given
+        explicitly.
     """
 
     _encode_url_methods = set(['DELETE', 'GET', 'HEAD', 'OPTIONS'])
-
     _encode_body_methods = set(['PATCH', 'POST', 'PUT', 'TRACE'])
 
+    def __init__(self, headers=None):
+        self.headers = headers or {}
+
     def urlopen(self, method, url, body=None, headers=None,
                 encode_multipart=True, multipart_boundary=None,
                 **kw): # Abstract
@@ -97,13 +105,16 @@ class RequestMethods(object):
         such as with OAuth.
 
         Supports an optional ``fields`` parameter of key/value strings AND
-        key/filetuple. A filetuple is a (filename, data) tuple. For example: ::
+        key/filetuple. A filetuple is a (filename, data, MIME type) tuple where
+        the MIME type is optional. For example: ::
 
             fields = {
                 'foo': 'bar',
                 'fakefile': ('foofile.txt', 'contents of foofile'),
                 'realfile': ('barfile.txt', open('realfile').read()),
-                'nonamefile': ('contents of nonamefile field'),
+                'typedfile': ('bazfile.bin', open('bazfile').read(),
+                              'image/jpeg'),
+                'nonamefile': 'contents of nonamefile field',
             }
 
         When uploading a file, providing a filename (the first parameter of the
@@ -121,8 +132,11 @@ class RequestMethods(object):
             body, content_type = (urlencode(fields or {}),
                                     'application/x-www-form-urlencoded')
 
-        headers = headers or {}
-        headers.update({'Content-Type': content_type})
+        if headers is None:
+            headers = self.headers
+
+        headers_ = {'Content-Type': content_type}
+        headers_.update(headers)
 
-        return self.urlopen(method, url, body=body, headers=headers,
+        return self.urlopen(method, url, body=body, headers=headers_,
                             **urlopen_kw)
index 28537d3..833be62 100644 (file)
@@ -130,7 +130,9 @@ class HTTPResponse(object):
             after having ``.read()`` the file object. (Overridden if ``amt`` is
             set.)
         """
-        content_encoding = self.headers.get('content-encoding')
+        # Note: content-encoding value should be case-insensitive, per RFC 2616
+        # Section 3.5
+        content_encoding = self.headers.get('content-encoding', '').lower()
         decoder = self.CONTENT_DECODERS.get(content_encoding)
         if decode_content is None:
             decode_content = self._decode_content
index 8ec990b..8d8654f 100644 (file)
@@ -11,13 +11,24 @@ from socket import error as SocketError
 
 try:
     from select import poll, POLLIN
-except ImportError: # `poll` doesn't exist on OSX and other platforms
+except ImportError:  # `poll` doesn't exist on OSX and other platforms
     poll = False
     try:
         from select import select
-    except ImportError: # `select` doesn't exist on AppEngine.
+    except ImportError:  # `select` doesn't exist on AppEngine.
         select = False
 
+try:  # Test for SSL features
+    SSLContext = None
+    HAS_SNI = False
+
+    from ssl import wrap_socket, CERT_NONE, SSLError, PROTOCOL_SSLv23
+    from ssl import SSLContext  # Modern SSL?
+    from ssl import HAS_SNI  # Has SNI?
+except ImportError:
+    pass
+
+
 from .packages import six
 from .exceptions import LocationParseError
 
@@ -92,9 +103,9 @@ def parse_url(url):
 
         >>> parse_url('http://google.com/mail/')
         Url(scheme='http', host='google.com', port=None, path='/', ...)
-        >>> prase_url('google.com:80')
+        >>> parse_url('google.com:80')
         Url(scheme=None, host='google.com', port=80, path=None, ...)
-        >>> prase_url('/foo?bar')
+        >>> parse_url('/foo?bar')
         Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...)
     """
 
@@ -250,3 +261,38 @@ def is_connection_dropped(conn):
         if fno == sock.fileno():
             # Either data is buffered (bad), or the connection is dropped.
             return True
+
+
+if SSLContext is not None:  # Python 3.2+
+    def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=CERT_NONE,
+                        ca_certs=None, server_hostname=None,
+                        ssl_version=PROTOCOL_SSLv23):
+        """
+        All arguments except `server_hostname` have the same meaning as for
+        :func:`ssl.wrap_socket`
+
+        :param server_hostname:
+            Hostname of the expected certificate
+        """
+        context = SSLContext(ssl_version)
+        context.verify_mode = cert_reqs
+        if ca_certs:
+            try:
+                context.load_verify_locations(ca_certs)
+            except TypeError as e:  # Reraise as SSLError
+                # FIXME: This block needs a test.
+                raise SSLError(e)
+        if certfile:
+            # FIXME: This block needs a test.
+            context.load_cert_chain(certfile, keyfile)
+        if HAS_SNI:  # Platform-specific: OpenSSL with enabled SNI
+            return context.wrap_socket(sock, server_hostname=server_hostname)
+        return context.wrap_socket(sock)
+
+else:  # Python 3.1 and earlier
+    def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=CERT_NONE,
+                        ca_certs=None, server_hostname=None,
+                        ssl_version=PROTOCOL_SSLv23):
+        return wrap_socket(sock, keyfile=keyfile, certfile=certfile,
+                           ca_certs=ca_certs, cert_reqs=cert_reqs,
+                           ssl_version=ssl_version)