From 2c5b0207f5d37d14c54fcc690e2f2bc7b64618f2 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Mon, 17 Dec 2012 04:08:36 -0500 Subject: [PATCH] new urllib3 --- requests/packages/urllib3/connectionpool.py | 55 ++++++++++++++------- requests/packages/urllib3/exceptions.py | 16 ++++-- requests/packages/urllib3/filepost.py | 23 ++++++--- requests/packages/urllib3/packages/six.py | 53 ++++++++++++-------- requests/packages/urllib3/poolmanager.py | 31 +++++++----- requests/packages/urllib3/request.py | 26 +++++++--- requests/packages/urllib3/response.py | 4 +- requests/packages/urllib3/util.py | 54 ++++++++++++++++++-- 8 files changed, 190 insertions(+), 72 deletions(-) diff --git a/requests/packages/urllib3/connectionpool.py b/requests/packages/urllib3/connectionpool.py index 97da544..af8760d 100644 --- a/requests/packages/urllib3/connectionpool.py +++ b/requests/packages/urllib3/connectionpool.py @@ -6,8 +6,9 @@ import logging import socket +import errno -from socket import timeout as SocketTimeout +from socket import error as SocketError, timeout as SocketTimeout try: # Python 3 from http.client import HTTPConnection, HTTPException @@ -41,7 +42,7 @@ except (ImportError, AttributeError): # Platform-specific: No SSL. from .request import RequestMethods from .response import HTTPResponse -from .util import get_host, is_connection_dropped +from .util import get_host, is_connection_dropped, ssl_wrap_socket from .exceptions import ( ClosedPoolError, EmptyPoolError, @@ -76,6 +77,7 @@ class VerifiedHTTPSConnection(HTTPSConnection): """ cert_reqs = None ca_certs = None + ssl_version = None def set_cert(self, key_file=None, cert_file=None, cert_reqs='CERT_NONE', ca_certs=None): @@ -96,9 +98,12 @@ class VerifiedHTTPSConnection(HTTPSConnection): # Wrap socket using verification with the root certs in # trusted_root_certs - self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file, + self.sock = ssl_wrap_socket(sock, self.key_file, self.cert_file, cert_reqs=self.cert_reqs, - ca_certs=self.ca_certs) + ca_certs=self.ca_certs, + server_hostname=self.host, + ssl_version=self.ssl_version) + if self.ca_certs: match_hostname(self.sock.getpeercert(), self.host) @@ -166,13 +171,13 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods): def __init__(self, host, port=None, strict=False, timeout=None, maxsize=1, block=False, headers=None): - super(HTTPConnectionPool, self).__init__(host, port) + ConnectionPool.__init__(self, host, port) + RequestMethods.__init__(self, headers) self.strict = strict self.timeout = timeout self.pool = self.QueueCls(maxsize) self.block = block - self.headers = headers or {} # Fill the queue up so that doing get() on it will block properly for _ in xrange(maxsize): @@ -189,7 +194,9 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods): self.num_connections += 1 log.info("Starting new HTTP connection (%d): %s" % (self.num_connections, self.host)) - return HTTPConnection(host=self.host, port=self.port) + return HTTPConnection(host=self.host, + port=self.port, + strict=self.strict) def _get_conn(self, timeout=None): """ @@ -449,12 +456,15 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods): # Name mismatch raise SSLError(e) - except HTTPException as e: + except (HTTPException, SocketError) as e: # Connection broken, discard. It will be replaced next _get_conn(). conn = None # This is necessary so we can access e below err = e + if retries == 0: + raise MaxRetryError(self, url, e) + finally: if release_conn: # Put the connection back to be reused. If the connection is @@ -491,11 +501,11 @@ class HTTPSConnectionPool(HTTPConnectionPool): When Python is compiled with the :mod:`ssl` module, then :class:`.VerifiedHTTPSConnection` is used, which *can* verify certificates, - instead of :class:httplib.HTTPSConnection`. + instead of :class:`httplib.HTTPSConnection`. - The ``key_file``, ``cert_file``, ``cert_reqs``, and ``ca_certs`` parameters + The ``key_file``, ``cert_file``, ``cert_reqs``, ``ca_certs``, and ``ssl_version`` are only used if :mod:`ssl` is available and are fed into - :meth:`ssl.wrap_socket` to upgrade the connection socket into an SSL socket. + :meth:`urllib3.util.ssl_wrap_socket` to upgrade the connection socket into an SSL socket. """ scheme = 'https' @@ -504,15 +514,16 @@ class HTTPSConnectionPool(HTTPConnectionPool): strict=False, timeout=None, maxsize=1, block=False, headers=None, key_file=None, cert_file=None, - cert_reqs='CERT_NONE', ca_certs=None): + cert_reqs='CERT_NONE', ca_certs=None, ssl_version=None): - super(HTTPSConnectionPool, self).__init__(host, port, - strict, timeout, maxsize, - block, headers) + HTTPConnectionPool.__init__(self, host, port, + strict, timeout, maxsize, + block, headers) self.key_file = key_file self.cert_file = cert_file self.cert_reqs = cert_reqs self.ca_certs = ca_certs + self.ssl_version = ssl_version def _new_conn(self): """ @@ -527,11 +538,21 @@ class HTTPSConnectionPool(HTTPConnectionPool): raise SSLError("Can't connect to HTTPS URL because the SSL " "module is not available.") - return HTTPSConnection(host=self.host, port=self.port) + return HTTPSConnection(host=self.host, + port=self.port, + strict=self.strict) - connection = VerifiedHTTPSConnection(host=self.host, port=self.port) + connection = VerifiedHTTPSConnection(host=self.host, + port=self.port, + strict=self.strict) connection.set_cert(key_file=self.key_file, cert_file=self.cert_file, cert_reqs=self.cert_reqs, ca_certs=self.ca_certs) + + if self.ssl_version is None: + connection.ssl_version = ssl.PROTOCOL_SSLv23 + else: + connection.ssl_version = self.ssl_version + return connection diff --git a/requests/packages/urllib3/exceptions.py b/requests/packages/urllib3/exceptions.py index 99ebb67..c5eb962 100644 --- a/requests/packages/urllib3/exceptions.py +++ b/requests/packages/urllib3/exceptions.py @@ -18,6 +18,10 @@ class PoolError(HTTPError): self.pool = pool HTTPError.__init__(self, "%s: %s" % (pool, message)) + def __reduce__(self): + # For pickling purposes. + return self.__class__, (None, self.url) + class SSLError(HTTPError): "Raised when SSL certificate fails in an HTTPS connection." @@ -34,10 +38,16 @@ class DecodeError(HTTPError): class MaxRetryError(PoolError): "Raised when the maximum number of retries is exceeded." - def __init__(self, pool, url): + def __init__(self, pool, url, reason=None): + self.reason = reason + message = "Max retries exceeded with url: %s" % url - PoolError.__init__(self, pool, message) + if reason: + message += " (Caused by %s: %s)" % (type(reason), reason) + else: + message += " (Caused by redirect)" + PoolError.__init__(self, pool, message) self.url = url @@ -72,6 +82,6 @@ class LocationParseError(ValueError, HTTPError): def __init__(self, location): message = "Failed to parse: %s" % location - super(LocationParseError, self).__init__(self, message) + HTTPError.__init__(self, message) self.location = location diff --git a/requests/packages/urllib3/filepost.py b/requests/packages/urllib3/filepost.py index e679b93..8d900bd 100644 --- a/requests/packages/urllib3/filepost.py +++ b/requests/packages/urllib3/filepost.py @@ -41,13 +41,16 @@ def iter_fields(fields): def encode_multipart_formdata(fields, boundary=None): """ - Encode a dictionary of ``fields`` using the multipart/form-data mime format. + Encode a dictionary of ``fields`` using the multipart/form-data MIME format. :param fields: - Dictionary of fields or list of (key, value) field tuples. The key is - treated as the field name, and the value as the body of the form-data - bytes. If the value is a tuple of two elements, then the first element - is treated as the filename of the form-data section. + Dictionary of fields or list of (key, value) or (key, value, MIME type) + field tuples. The key is treated as the field name, and the value as + the body of the form-data bytes. If the value is a tuple of two + elements, then the first element is treated as the filename of the + form-data section and a suitable MIME type is guessed based on the + filename. If the value is a tuple of three elements, then the third + element is treated as an explicit MIME type of the form-data section. Field names and filenames must be unicode. @@ -63,16 +66,20 @@ def encode_multipart_formdata(fields, boundary=None): body.write(b('--%s\r\n' % (boundary))) if isinstance(value, tuple): - filename, data = value + if len(value) == 3: + filename, data, content_type = value + else: + filename, data = value + content_type = get_content_type(filename) writer(body).write('Content-Disposition: form-data; name="%s"; ' 'filename="%s"\r\n' % (fieldname, filename)) body.write(b('Content-Type: %s\r\n\r\n' % - (get_content_type(filename)))) + (content_type,))) else: data = value writer(body).write('Content-Disposition: form-data; name="%s"\r\n' % (fieldname)) - body.write(b'Content-Type: text/plain\r\n\r\n') + body.write(b'\r\n') if isinstance(data, int): data = str(data) # Backwards compatibility diff --git a/requests/packages/urllib3/packages/six.py b/requests/packages/urllib3/packages/six.py index a64f6fb..27d8011 100644 --- a/requests/packages/urllib3/packages/six.py +++ b/requests/packages/urllib3/packages/six.py @@ -24,7 +24,7 @@ import sys import types __author__ = "Benjamin Peterson " -__version__ = "1.1.0" +__version__ = "1.2.0" # Revision 41c74fef2ded # True if we are running on Python 3. @@ -45,19 +45,23 @@ else: text_type = unicode binary_type = str - # It's possible to have sizeof(long) != sizeof(Py_ssize_t). - class X(object): - def __len__(self): - return 1 << 31 - try: - len(X()) - except OverflowError: - # 32-bit + if sys.platform.startswith("java"): + # Jython always uses 32 bits. MAXSIZE = int((1 << 31) - 1) else: - # 64-bit - MAXSIZE = int((1 << 63) - 1) - del X + # It's possible to have sizeof(long) != sizeof(Py_ssize_t). + class X(object): + def __len__(self): + return 1 << 31 + try: + len(X()) + except OverflowError: + # 32-bit + MAXSIZE = int((1 << 31) - 1) + else: + # 64-bit + MAXSIZE = int((1 << 63) - 1) + del X def _add_doc(func, doc): @@ -132,6 +136,7 @@ class _MovedItems(types.ModuleType): _moved_attributes = [ MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), + MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), MovedAttribute("map", "itertools", "builtins", "imap", "map"), MovedAttribute("reload_module", "__builtin__", "imp", "reload"), MovedAttribute("reduce", "__builtin__", "functools"), @@ -178,7 +183,7 @@ for attr in _moved_attributes: setattr(_MovedItems, attr.name, attr) del attr -moves = sys.modules["six.moves"] = _MovedItems("moves") +moves = sys.modules[__name__ + ".moves"] = _MovedItems("moves") def add_move(move): @@ -219,12 +224,19 @@ else: _iteritems = "iteritems" +try: + advance_iterator = next +except NameError: + def advance_iterator(it): + return it.next() +next = advance_iterator + + if PY3: def get_unbound_function(unbound): return unbound - - advance_iterator = next + Iterator = object def callable(obj): return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) @@ -232,9 +244,10 @@ else: def get_unbound_function(unbound): return unbound.im_func + class Iterator(object): - def advance_iterator(it): - return it.next() + def next(self): + return type(self).__next__(self) callable = callable _add_doc(get_unbound_function, @@ -249,15 +262,15 @@ get_function_defaults = operator.attrgetter(_func_defaults) def iterkeys(d): """Return an iterator over the keys of a dictionary.""" - return getattr(d, _iterkeys)() + return iter(getattr(d, _iterkeys)()) def itervalues(d): """Return an iterator over the values of a dictionary.""" - return getattr(d, _itervalues)() + return iter(getattr(d, _itervalues)()) def iteritems(d): """Return an iterator over the (key, value) pairs of a dictionary.""" - return getattr(d, _iteritems)() + return iter(getattr(d, _iteritems)()) if PY3: diff --git a/requests/packages/urllib3/poolmanager.py b/requests/packages/urllib3/poolmanager.py index 8f5b54c..a124202 100644 --- a/requests/packages/urllib3/poolmanager.py +++ b/requests/packages/urllib3/poolmanager.py @@ -30,8 +30,12 @@ class PoolManager(RequestMethods): necessary connection pools for you. :param num_pools: - Number of connection pools to cache before discarding the least recently - used pool. + Number of connection pools to cache before discarding the least + recently used pool. + + :param headers: + Headers to include with all requests, unless other headers are given + explicitly. :param \**connection_pool_kw: Additional parameters are used to create fresh @@ -40,15 +44,16 @@ class PoolManager(RequestMethods): Example: :: >>> manager = PoolManager(num_pools=2) - >>> r = manager.urlopen("http://google.com/") - >>> r = manager.urlopen("http://google.com/mail") - >>> r = manager.urlopen("http://yahoo.com/") + >>> r = manager.request('GET', 'http://google.com/') + >>> r = manager.request('GET', 'http://google.com/mail') + >>> r = manager.request('GET', 'http://yahoo.com/') >>> len(manager.pools) 2 """ - def __init__(self, num_pools=10, **connection_pool_kw): + def __init__(self, num_pools=10, headers=None, **connection_pool_kw): + RequestMethods.__init__(self, headers) self.connection_pool_kw = connection_pool_kw self.pools = RecentlyUsedContainer(num_pools, dispose_func=lambda p: p.close()) @@ -113,6 +118,8 @@ class PoolManager(RequestMethods): kw['assert_same_host'] = False kw['redirect'] = False + if 'headers' not in kw: + kw['headers'] = self.headers response = conn.urlopen(method, u.request_uri, **kw) @@ -124,7 +131,7 @@ class PoolManager(RequestMethods): method = 'GET' log.info("Redirecting %s -> %s" % (url, redirect_location)) - kw['retries'] = kw.get('retries', 3) - 1 # Persist retries countdown + kw['retries'] = kw.get('retries', 3) - 1 # Persist retries countdown return self.urlopen(method, redirect_location, **kw) @@ -138,13 +145,11 @@ class ProxyManager(RequestMethods): self.proxy_pool = proxy_pool def _set_proxy_headers(self, headers=None): - headers = headers or {} - - # Same headers are curl passes for --proxy1.0 - headers['Accept'] = '*/*' - headers['Proxy-Connection'] = 'Keep-Alive' + headers_ = {'Accept': '*/*'} + if headers: + headers_.update(headers) - return headers + return headers_ def urlopen(self, method, url, **kw): "Same as HTTP(S)ConnectionPool.urlopen, ``url`` must be absolute." diff --git a/requests/packages/urllib3/request.py b/requests/packages/urllib3/request.py index 569ac96..2b4704e 100644 --- a/requests/packages/urllib3/request.py +++ b/requests/packages/urllib3/request.py @@ -36,12 +36,20 @@ class RequestMethods(object): :meth:`.request` is for making any kind of request, it will look up the appropriate encoding format and use one of the above two methods to make the request. + + Initializer parameters: + + :param headers: + Headers to include with all requests, unless other headers are given + explicitly. """ _encode_url_methods = set(['DELETE', 'GET', 'HEAD', 'OPTIONS']) - _encode_body_methods = set(['PATCH', 'POST', 'PUT', 'TRACE']) + def __init__(self, headers=None): + self.headers = headers or {} + def urlopen(self, method, url, body=None, headers=None, encode_multipart=True, multipart_boundary=None, **kw): # Abstract @@ -97,13 +105,16 @@ class RequestMethods(object): such as with OAuth. Supports an optional ``fields`` parameter of key/value strings AND - key/filetuple. A filetuple is a (filename, data) tuple. For example: :: + key/filetuple. A filetuple is a (filename, data, MIME type) tuple where + the MIME type is optional. For example: :: fields = { 'foo': 'bar', 'fakefile': ('foofile.txt', 'contents of foofile'), 'realfile': ('barfile.txt', open('realfile').read()), - 'nonamefile': ('contents of nonamefile field'), + 'typedfile': ('bazfile.bin', open('bazfile').read(), + 'image/jpeg'), + 'nonamefile': 'contents of nonamefile field', } When uploading a file, providing a filename (the first parameter of the @@ -121,8 +132,11 @@ class RequestMethods(object): body, content_type = (urlencode(fields or {}), 'application/x-www-form-urlencoded') - headers = headers or {} - headers.update({'Content-Type': content_type}) + if headers is None: + headers = self.headers + + headers_ = {'Content-Type': content_type} + headers_.update(headers) - return self.urlopen(method, url, body=body, headers=headers, + return self.urlopen(method, url, body=body, headers=headers_, **urlopen_kw) diff --git a/requests/packages/urllib3/response.py b/requests/packages/urllib3/response.py index 28537d3..833be62 100644 --- a/requests/packages/urllib3/response.py +++ b/requests/packages/urllib3/response.py @@ -130,7 +130,9 @@ class HTTPResponse(object): after having ``.read()`` the file object. (Overridden if ``amt`` is set.) """ - content_encoding = self.headers.get('content-encoding') + # Note: content-encoding value should be case-insensitive, per RFC 2616 + # Section 3.5 + content_encoding = self.headers.get('content-encoding', '').lower() decoder = self.CONTENT_DECODERS.get(content_encoding) if decode_content is None: decode_content = self._decode_content diff --git a/requests/packages/urllib3/util.py b/requests/packages/urllib3/util.py index 8ec990b..8d8654f 100644 --- a/requests/packages/urllib3/util.py +++ b/requests/packages/urllib3/util.py @@ -11,13 +11,24 @@ from socket import error as SocketError try: from select import poll, POLLIN -except ImportError: # `poll` doesn't exist on OSX and other platforms +except ImportError: # `poll` doesn't exist on OSX and other platforms poll = False try: from select import select - except ImportError: # `select` doesn't exist on AppEngine. + except ImportError: # `select` doesn't exist on AppEngine. select = False +try: # Test for SSL features + SSLContext = None + HAS_SNI = False + + from ssl import wrap_socket, CERT_NONE, SSLError, PROTOCOL_SSLv23 + from ssl import SSLContext # Modern SSL? + from ssl import HAS_SNI # Has SNI? +except ImportError: + pass + + from .packages import six from .exceptions import LocationParseError @@ -92,9 +103,9 @@ def parse_url(url): >>> parse_url('http://google.com/mail/') Url(scheme='http', host='google.com', port=None, path='/', ...) - >>> prase_url('google.com:80') + >>> parse_url('google.com:80') Url(scheme=None, host='google.com', port=80, path=None, ...) - >>> prase_url('/foo?bar') + >>> parse_url('/foo?bar') Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...) """ @@ -250,3 +261,38 @@ def is_connection_dropped(conn): if fno == sock.fileno(): # Either data is buffered (bad), or the connection is dropped. return True + + +if SSLContext is not None: # Python 3.2+ + def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=CERT_NONE, + ca_certs=None, server_hostname=None, + ssl_version=PROTOCOL_SSLv23): + """ + All arguments except `server_hostname` have the same meaning as for + :func:`ssl.wrap_socket` + + :param server_hostname: + Hostname of the expected certificate + """ + context = SSLContext(ssl_version) + context.verify_mode = cert_reqs + if ca_certs: + try: + context.load_verify_locations(ca_certs) + except TypeError as e: # Reraise as SSLError + # FIXME: This block needs a test. + raise SSLError(e) + if certfile: + # FIXME: This block needs a test. + context.load_cert_chain(certfile, keyfile) + if HAS_SNI: # Platform-specific: OpenSSL with enabled SNI + return context.wrap_socket(sock, server_hostname=server_hostname) + return context.wrap_socket(sock) + +else: # Python 3.1 and earlier + def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=CERT_NONE, + ca_certs=None, server_hostname=None, + ssl_version=PROTOCOL_SSLv23): + return wrap_socket(sock, keyfile=keyfile, certfile=certfile, + ca_certs=ca_certs, cert_reqs=cert_reqs, + ssl_version=ssl_version) -- 2.34.1