Upgrade urllib3 to 490d3a227fadb626cd54a240b9d0922f849914b4
authorJoe Smith <jsmith@twitter.com>
Wed, 11 Feb 2015 18:22:10 +0000 (10:22 -0800)
committerJoe Smith <jsmith@twitter.com>
Wed, 11 Feb 2015 19:04:25 +0000 (11:04 -0800)
requests/packages/urllib3/_collections.py
requests/packages/urllib3/connectionpool.py
requests/packages/urllib3/poolmanager.py
requests/packages/urllib3/response.py
requests/packages/urllib3/util/ssl_.py

index 784342a4eb5939ef7682c581dad219d3dee67316..92d07a4202f4102d00b0ff2afc4c27a8c85e1365 100644 (file)
@@ -1,7 +1,7 @@
 from collections import Mapping, MutableMapping
 try:
     from threading import RLock
-except ImportError: # Platform-specific: No threads available
+except ImportError:  # Platform-specific: No threads available
     class RLock:
         def __enter__(self):
             pass
@@ -10,7 +10,7 @@ except ImportError: # Platform-specific: No threads available
             pass
 
 
-try: # Python 2.7+
+try:  # Python 2.7+
     from collections import OrderedDict
 except ImportError:
     from .packages.ordered_dict import OrderedDict
@@ -161,7 +161,7 @@ class HTTPHeaderDict(MutableMapping):
     def getlist(self, key):
         """Returns a list of all the values for the named field. Returns an
         empty list if the key doesn't exist."""
-        return self[key].split(', ') if key in self else []
+        return [v for k, v in self._data.get(key.lower(), [])]
 
     def copy(self):
         h = HTTPHeaderDict()
@@ -196,3 +196,11 @@ class HTTPHeaderDict(MutableMapping):
 
     def __repr__(self):
         return '%s(%r)' % (self.__class__.__name__, dict(self.items()))
+
+    def update(self, *args, **kwds):
+        headers = args[0]
+        if isinstance(headers, HTTPHeaderDict):
+            for key in iterkeys(headers._data):
+                self._data.setdefault(key.lower(), []).extend(headers._data[key])
+        else:
+            super(HTTPHeaderDict, self).update(*args, **kwds)
index 8bdf228ffe1fcc308e7f9c3a98cac7bc19b17f3f..ce1784487ddbebfe97277b2f6719e5cc740a914f 100644 (file)
@@ -72,6 +72,21 @@ class ConnectionPool(object):
         return '%s(host=%r, port=%r)' % (type(self).__name__,
                                          self.host, self.port)
 
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
+        # Return False to re-raise any potential exceptions
+        return False
+
+    def close():
+        """
+        Close all pooled connections and disable the pool.
+        """
+        pass
+
+
 # This is taken from http://hg.python.org/cpython/file/7aaba721ebc0/Lib/socket.py#l252
 _blocking_errnos = set([errno.EAGAIN, errno.EWOULDBLOCK])
 
@@ -353,7 +368,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
 
         # Receive the response from the server
         try:
-            try:  # Python 2.7+, use buffering of HTTP responses
+            try:  # Python 2.7, use buffering of HTTP responses
                 httplib_response = conn.getresponse(buffering=True)
             except TypeError:  # Python 2.6 and older
                 httplib_response = conn.getresponse()
@@ -558,6 +573,14 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
                 conn = None
             raise SSLError(e)
 
+        except SSLError:
+            # Treat SSLError separately from BaseSSLError to preserve
+            # traceback.
+            if conn:
+                conn.close()
+                conn = None
+            raise
+
         except (TimeoutError, HTTPException, SocketError, ConnectionError) as e:
             if conn:
                 # Discard the connection for these exceptions. It will be
index 515dc96219b187c20272456f1b9f1d5f325d021f..9a701e4561299cb3cbc8133acc1c263582b39003 100644 (file)
@@ -64,6 +64,14 @@ class PoolManager(RequestMethods):
         self.pools = RecentlyUsedContainer(num_pools,
                                            dispose_func=lambda p: p.close())
 
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.clear()
+        # Return False to re-raise any potential exceptions
+        return False
+
     def _new_pool(self, scheme, host, port):
         """
         Create a new :class:`ConnectionPool` based on host, port and scheme.
index e69de95733ad2cdcc2fc5fb019116e09f0ec4a86..709706857951ca09fd68f75a6f5009bb93ac2812 100644 (file)
@@ -2,6 +2,7 @@ import zlib
 import io
 from socket import timeout as SocketTimeout
 
+from .packages import six
 from ._collections import HTTPHeaderDict
 from .exceptions import ProtocolError, DecodeError, ReadTimeoutError
 from .packages.six import string_types as basestring, binary_type
@@ -9,7 +10,6 @@ from .connection import HTTPException, BaseSSLError
 from .util.response import is_fp_closed
 
 
-
 class DeflateDecoder(object):
 
     def __init__(self):
@@ -21,6 +21,9 @@ class DeflateDecoder(object):
         return getattr(self._obj, name)
 
     def decompress(self, data):
+        if not data:
+            return data
+
         if not self._first_try:
             return self._obj.decompress(data)
 
@@ -36,9 +39,23 @@ class DeflateDecoder(object):
                 self._data = None
 
 
+class GzipDecoder(object):
+
+    def __init__(self):
+        self._obj = zlib.decompressobj(16 + zlib.MAX_WBITS)
+
+    def __getattr__(self, name):
+        return getattr(self._obj, name)
+
+    def decompress(self, data):
+        if not data:
+            return data
+        return self._obj.decompress(data)
+
+
 def _get_decoder(mode):
     if mode == 'gzip':
-        return zlib.decompressobj(16 + zlib.MAX_WBITS)
+        return GzipDecoder()
 
     return DeflateDecoder()
 
@@ -202,7 +219,7 @@ class HTTPResponse(io.IOBase):
 
             except BaseSSLError as e:
                 # FIXME: Is there a better way to differentiate between SSLErrors?
-                if not 'read operation timed out' in str(e):  # Defensive:
+                if 'read operation timed out' not in str(e):  # Defensive:
                     # This shouldn't happen but just in case we're missing an edge
                     # case, let's avoid swallowing SSL errors.
                     raise
@@ -270,7 +287,16 @@ class HTTPResponse(io.IOBase):
 
         headers = HTTPHeaderDict()
         for k, v in r.getheaders():
-            headers.add(k, v)
+            if k.lower() != 'set-cookie':
+                headers.add(k, v)
+
+        if six.PY3:  # Python 3:
+            cookies = r.msg.get_all('set-cookie') or tuple()
+        else:  # Python 2:
+            cookies = r.msg.getheaders('set-cookie')
+
+        for cookie in cookies:
+            headers.add('set-cookie', cookie)
 
         # HTTPResponse objects in Python 3 don't have a .strict attribute
         strict = getattr(r, 'strict', 0)
index a788b1b98c63150fe70021d96d2b8e1d45579019..7ad1b30589d52e78585907379f8723960538e817 100644 (file)
@@ -1,5 +1,5 @@
 from binascii import hexlify, unhexlify
-from hashlib import md5, sha1
+from hashlib import md5, sha1, sha256
 
 from ..exceptions import SSLError
 
@@ -96,7 +96,8 @@ def assert_fingerprint(cert, fingerprint):
     # this digest.
     hashfunc_map = {
         16: md5,
-        20: sha1
+        20: sha1,
+        32: sha256,
     }
 
     fingerprint = fingerprint.replace(':', '').lower()
@@ -211,7 +212,9 @@ def create_urllib3_context(ssl_version=None, cert_reqs=ssl.CERT_REQUIRED,
 
     context.verify_mode = cert_reqs
     if getattr(context, 'check_hostname', None) is not None:  # Platform-specific: Python 3.2
-        context.check_hostname = (context.verify_mode == ssl.CERT_REQUIRED)
+        # We do our own verification, including fingerprints and alternative
+        # hostnames. So disable it here
+        context.check_hostname = False
     return context