From f7596c75dce4e87ab83bdf74e8f120a4b1a5ff03 Mon Sep 17 00:00:00 2001 From: Colin Dunklau Date: Tue, 30 Apr 2013 14:52:27 -0500 Subject: [PATCH] Rewrite CaseInsensitiveDict to work correctly/sanely Fixes #649 and #1329 by making Session.headers a CaseInsensitiveDict, and fixing the implementation of CID. Credit for the brilliant idea to map `lowercased_key -> (cased_key, mapped_value)` goes to @gazpachoking, thanks a bunch. Changes from original implementation of CaseInsensitiveDict: 1. CID is rewritten as a subclass of `collections.MutableMapping`. 2. CID remembers the case of the last-set key, but `__setitem__` and `__delitem__` will handle keys without respect to case. 3. CID returns the key case as remembered for the `keys`, `items`, and `__iter__` methods. 4. Query operations (`__getitem__` and `__contains__`) are done in a case-insensitive manner: `cid['foo']` and `cid['FOO']` will return the same value. 5. The constructor as well as `update` and `__eq__` have undefined behavior when given multiple keys that have the same `lower()`. 6. The new method `lower_items` is like `iteritems`, but keys are all lowercased. 7. CID raises `KeyError` for `__getitem__` as normal dicts do. The old implementation returned 6. The `__repr__` now makes it obvious that it's not a normal dict. See PR #1333 for the discussions that lead up to this implementation --- AUTHORS.rst | 1 + requests/structures.py | 89 +++++++++++++++++++-------- requests/utils.py | 5 +- test_requests.py | 160 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 227 insertions(+), 28 deletions(-) diff --git a/AUTHORS.rst b/AUTHORS.rst index 2010cae..2fae296 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -124,3 +124,4 @@ Patches and Suggestions - Wilfred Hughes @dontYetKnow - Dmitry Medvinsky - Bryce Boe @bboe +- Colin Dunklau @cdunklau diff --git a/requests/structures.py b/requests/structures.py index 05f5ac1..8d02ea6 100644 --- a/requests/structures.py +++ b/requests/structures.py @@ -9,6 +9,7 @@ Data structures that power Requests. """ import os +import collections from itertools import islice @@ -33,43 +34,79 @@ class IteratorProxy(object): return "".join(islice(self.i, None, n)) -class CaseInsensitiveDict(dict): - """Case-insensitive Dictionary +class CaseInsensitiveDict(collections.MutableMapping): + """ + A case-insensitive ``dict``-like object. + + Implements all methods and operations of + ``collections.MutableMapping`` as well as dict's ``copy``. Also + provides ``lower_items``. + + All keys are expected to be strings. The structure remembers the + case of the last key to be set, and ``iter(instance)``, + ``keys()``, ``items()``, ``iterkeys()``, and ``iteritems()`` + will contain case-sensitive keys. However, querying and contains + testing is case insensitive: + + cid = CaseInsensitiveDict() + cid['Accept'] = 'application/json' + cid['aCCEPT'] == 'application/json' # True + list(cid) == ['Accept'] # True For example, ``headers['content-encoding']`` will return the - value of a ``'Content-Encoding'`` response header.""" + value of a ``'Content-Encoding'`` response header, regardless + of how the header name was originally stored. - @property - def lower_keys(self): - if not hasattr(self, '_lower_keys') or not self._lower_keys: - self._lower_keys = dict((k.lower(), k) for k in list(self.keys())) - return self._lower_keys + If the constructor, ``.update``, or equality comparison + operations are given keys that have equal ``.lower()``s, the + behavior is undefined. - def _clear_lower_keys(self): - if hasattr(self, '_lower_keys'): - self._lower_keys.clear() + """ + def __init__(self, data=None, **kwargs): + self._store = dict() + if data is None: + data = {} + self.update(data, **kwargs) def __setitem__(self, key, value): - dict.__setitem__(self, key, value) - self._clear_lower_keys() + # Use the lowercased key for lookups, but store the actual + # key alongside the value. + self._store[key.lower()] = (key, value) - def __delitem__(self, key): - dict.__delitem__(self, self.lower_keys.get(key.lower(), key)) - self._lower_keys.clear() + def __getitem__(self, key): + return self._store[key.lower()][1] - def __contains__(self, key): - return key.lower() in self.lower_keys + def __delitem__(self, key): + del self._store[key.lower()] - def __getitem__(self, key): - # We allow fall-through here, so values default to None - if key in self: - return dict.__getitem__(self, self.lower_keys[key.lower()]) + def __iter__(self): + return (casedkey for casedkey, mappedvalue in self._store.values()) - def get(self, key, default=None): - if key in self: - return self[key] + def __len__(self): + return len(self._store) + + def lower_items(self): + """Like iteritems(), but with all lowercase keys.""" + return ( + (lowerkey, keyval[1]) + for (lowerkey, keyval) + in self._store.items() + ) + + def __eq__(self, other): + if isinstance(other, collections.Mapping): + other = CaseInsensitiveDict(other) else: - return default + return NotImplemented + # Compare insensitively + return dict(self.lower_items()) == dict(other.lower_items()) + + # Copy is required + def copy(self): + return CaseInsensitiveDict(self._store.values()) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, dict(self.items())) class LookupDict(dict): diff --git a/requests/utils.py b/requests/utils.py index 68f3e62..d690559 100644 --- a/requests/utils.py +++ b/requests/utils.py @@ -23,6 +23,7 @@ from . import certs from .compat import parse_http_list as _parse_list_header from .compat import quote, urlparse, bytes, str, OrderedDict, urlunparse from .cookies import RequestsCookieJar, cookiejar_from_dict +from .structures import CaseInsensitiveDict _hush_pyflakes = (RequestsCookieJar,) @@ -449,11 +450,11 @@ def default_user_agent(): def default_headers(): - return { + return CaseInsensitiveDict({ 'User-Agent': default_user_agent(), 'Accept-Encoding': ', '.join(('gzip', 'deflate', 'compress')), 'Accept': '*/*' - } + }) def parse_header_links(value): diff --git a/test_requests.py b/test_requests.py index 9339476..953ecca 100644 --- a/test_requests.py +++ b/test_requests.py @@ -13,6 +13,7 @@ import requests from requests.auth import HTTPDigestAuth from requests.compat import str, cookielib from requests.cookies import cookiejar_from_dict +from requests.structures import CaseInsensitiveDict try: import StringIO @@ -458,6 +459,165 @@ class RequestsTestCase(unittest.TestCase): r = s.send(r.prepare()) self.assertEqual(r.status_code, 200) + def test_fixes_1329(self): + s = requests.Session() + s.headers.update({'accept': 'application/json'}) + r = s.get(httpbin('get')) + headers = r.request.headers + # ASCII encode because of key comparison changes in py3 + self.assertEqual( + headers['accept'.encode('ascii')], + 'application/json' + ) + self.assertEqual( + headers['Accept'.encode('ascii')], + 'application/json' + ) + + +class TestCaseInsensitiveDict(unittest.TestCase): + + def test_mapping_init(self): + cid = CaseInsensitiveDict({'Foo': 'foo','BAr': 'bar'}) + self.assertEqual(len(cid), 2) + self.assertTrue('foo' in cid) + self.assertTrue('bar' in cid) + + def test_iterable_init(self): + cid = CaseInsensitiveDict([('Foo', 'foo'), ('BAr', 'bar')]) + self.assertEqual(len(cid), 2) + self.assertTrue('foo' in cid) + self.assertTrue('bar' in cid) + + def test_kwargs_init(self): + cid = CaseInsensitiveDict(FOO='foo', BAr='bar') + self.assertEqual(len(cid), 2) + self.assertTrue('foo' in cid) + self.assertTrue('bar' in cid) + + def test_docstring_example(self): + cid = CaseInsensitiveDict() + cid['Accept'] = 'application/json' + self.assertEqual(cid['aCCEPT'], 'application/json') + self.assertEqual(list(cid), ['Accept']) + + def test_len(self): + cid = CaseInsensitiveDict({'a': 'a', 'b': 'b'}) + cid['A'] = 'a' + self.assertEqual(len(cid), 2) + + def test_getitem(self): + cid = CaseInsensitiveDict({'Spam': 'blueval'}) + self.assertEqual(cid['spam'], 'blueval') + self.assertEqual(cid['SPAM'], 'blueval') + + def test_fixes_649(self): + cid = CaseInsensitiveDict() + cid['spam'] = 'oneval' + cid['Spam'] = 'twoval' + cid['sPAM'] = 'redval' + cid['SPAM'] = 'blueval' + self.assertEqual(cid['spam'], 'blueval') + self.assertEqual(cid['SPAM'], 'blueval') + self.assertEqual(list(cid.keys()), ['SPAM']) + + def test_delitem(self): + cid = CaseInsensitiveDict() + cid['Spam'] = 'someval' + del cid['sPam'] + self.assertFalse('spam' in cid) + self.assertEqual(len(cid), 0) + + def test_contains(self): + cid = CaseInsensitiveDict() + cid['Spam'] = 'someval' + self.assertTrue('Spam' in cid) + self.assertTrue('spam' in cid) + self.assertTrue('SPAM' in cid) + self.assertTrue('sPam' in cid) + self.assertFalse('notspam' in cid) + + def test_get(self): + cid = CaseInsensitiveDict() + cid['spam'] = 'oneval' + cid['SPAM'] = 'blueval' + self.assertEqual(cid.get('spam'), 'blueval') + self.assertEqual(cid.get('SPAM'), 'blueval') + self.assertEqual(cid.get('sPam'), 'blueval') + self.assertEqual(cid.get('notspam', 'default'), 'default') + + def test_update(self): + cid = CaseInsensitiveDict() + cid['spam'] = 'blueval' + cid.update({'sPam': 'notblueval'}) + self.assertEqual(cid['spam'], 'notblueval') + cid = CaseInsensitiveDict({'Foo': 'foo','BAr': 'bar'}) + cid.update({'fOO': 'anotherfoo', 'bAR': 'anotherbar'}) + self.assertEqual(len(cid), 2) + self.assertEqual(cid['foo'], 'anotherfoo') + self.assertEqual(cid['bar'], 'anotherbar') + + def test_update_retains_unchanged(self): + cid = CaseInsensitiveDict({'foo': 'foo', 'bar': 'bar'}) + cid.update({'foo': 'newfoo'}) + self.assertEquals(cid['bar'], 'bar') + + def test_iter(self): + cid = CaseInsensitiveDict({'Spam': 'spam', 'Eggs': 'eggs'}) + keys = frozenset(['Spam', 'Eggs']) + self.assertEqual(frozenset(iter(cid)), keys) + + def test_equality(self): + cid = CaseInsensitiveDict({'SPAM': 'blueval', 'Eggs': 'redval'}) + othercid = CaseInsensitiveDict({'spam': 'blueval', 'eggs': 'redval'}) + self.assertEqual(cid, othercid) + del othercid['spam'] + self.assertNotEqual(cid, othercid) + self.assertEqual(cid, {'spam': 'blueval', 'eggs': 'redval'}) + + def test_setdefault(self): + cid = CaseInsensitiveDict({'Spam': 'blueval'}) + self.assertEqual( + cid.setdefault('spam', 'notblueval'), + 'blueval' + ) + self.assertEqual( + cid.setdefault('notspam', 'notblueval'), + 'notblueval' + ) + + def test_lower_items(self): + cid = CaseInsensitiveDict({ + 'Accept': 'application/json', + 'user-Agent': 'requests', + }) + keyset = frozenset(lowerkey for lowerkey, v in cid.lower_items()) + lowerkeyset = frozenset(['accept', 'user-agent']) + self.assertEqual(keyset, lowerkeyset) + + def test_preserve_key_case(self): + cid = CaseInsensitiveDict({ + 'Accept': 'application/json', + 'user-Agent': 'requests', + }) + keyset = frozenset(['Accept', 'user-Agent']) + self.assertEqual(frozenset(i[0] for i in cid.items()), keyset) + self.assertEqual(frozenset(cid.keys()), keyset) + self.assertEqual(frozenset(cid), keyset) + + def test_preserve_last_key_case(self): + cid = CaseInsensitiveDict({ + 'Accept': 'application/json', + 'user-Agent': 'requests', + }) + cid.update({'ACCEPT': 'application/json'}) + cid['USER-AGENT'] = 'requests' + keyset = frozenset(['ACCEPT', 'USER-AGENT']) + self.assertEqual(frozenset(i[0] for i in cid.items()), keyset) + self.assertEqual(frozenset(cid.keys()), keyset) + self.assertEqual(frozenset(cid), keyset) + + if __name__ == '__main__': unittest.main() -- 2.7.4