Rewrite CaseInsensitiveDict to work correctly/sanely
authorColin Dunklau <colin.dunklau@gmail.com>
Tue, 30 Apr 2013 19:52:27 +0000 (14:52 -0500)
committerColin Dunklau <colin.dunklau@gmail.com>
Tue, 30 Apr 2013 19:52:27 +0000 (14:52 -0500)
Fixes #649 and #1329 by making Session.headers a CaseInsensitiveDict,
and fixing the implementation of CID. Credit for the brilliant idea
to map `lowercased_key -> (cased_key, mapped_value)` goes to
@gazpachoking, thanks a bunch.

Changes from original implementation of CaseInsensitiveDict:

1.  CID is rewritten as a subclass of `collections.MutableMapping`.
2.  CID remembers the case of the last-set key, but `__setitem__`
    and `__delitem__` will handle keys without respect to case.
3.  CID returns the key case as remembered for the `keys`, `items`,
    and `__iter__` methods.
4.  Query operations (`__getitem__` and `__contains__`) are done in
    a case-insensitive manner: `cid['foo']` and `cid['FOO']` will
    return the same value.
5.  The constructor as well as `update` and `__eq__` have undefined
    behavior when given multiple keys that have the same `lower()`.
6.  The new method `lower_items` is like `iteritems`, but keys are
    all lowercased.
7.  CID raises `KeyError` for `__getitem__` as normal dicts do. The
    old implementation returned
6.  The `__repr__` now makes it obvious that it's not a normal dict.

See PR #1333 for the discussions that lead up to this implementation

AUTHORS.rst
requests/structures.py
requests/utils.py
test_requests.py

index 2010cae..2fae296 100644 (file)
@@ -124,3 +124,4 @@ Patches and Suggestions
 - Wilfred Hughes <me@wilfred.me.uk> @dontYetKnow
 - Dmitry Medvinsky <me@dmedvinsky.name>
 - Bryce Boe <bbzbryce@gmail.com> @bboe
+- Colin Dunklau <colin.dunklau@gmail.com> @cdunklau
index 05f5ac1..8d02ea6 100644 (file)
@@ -9,6 +9,7 @@ Data structures that power Requests.
 """
 
 import os
+import collections
 from itertools import islice
 
 
@@ -33,43 +34,79 @@ class IteratorProxy(object):
         return "".join(islice(self.i, None, n))
 
 
-class CaseInsensitiveDict(dict):
-    """Case-insensitive Dictionary
+class CaseInsensitiveDict(collections.MutableMapping):
+    """
+    A case-insensitive ``dict``-like object.
+
+    Implements all methods and operations of
+    ``collections.MutableMapping`` as well as dict's ``copy``. Also
+    provides ``lower_items``.
+
+    All keys are expected to be strings. The structure remembers the
+    case of the last key to be set, and ``iter(instance)``,
+    ``keys()``, ``items()``, ``iterkeys()``, and ``iteritems()``
+    will contain case-sensitive keys. However, querying and contains
+    testing is case insensitive:
+
+        cid = CaseInsensitiveDict()
+        cid['Accept'] = 'application/json'
+        cid['aCCEPT'] == 'application/json'  # True
+        list(cid) == ['Accept']  # True
 
     For example, ``headers['content-encoding']`` will return the
-    value of a ``'Content-Encoding'`` response header."""
+    value of a ``'Content-Encoding'`` response header, regardless
+    of how the header name was originally stored.
 
-    @property
-    def lower_keys(self):
-        if not hasattr(self, '_lower_keys') or not self._lower_keys:
-            self._lower_keys = dict((k.lower(), k) for k in list(self.keys()))
-        return self._lower_keys
+    If the constructor, ``.update``, or equality comparison
+    operations are given keys that have equal ``.lower()``s, the
+    behavior is undefined.
 
-    def _clear_lower_keys(self):
-        if hasattr(self, '_lower_keys'):
-            self._lower_keys.clear()
+    """
+    def __init__(self, data=None, **kwargs):
+        self._store = dict()
+        if data is None:
+            data = {}
+        self.update(data, **kwargs)
 
     def __setitem__(self, key, value):
-        dict.__setitem__(self, key, value)
-        self._clear_lower_keys()
+        # Use the lowercased key for lookups, but store the actual
+        # key alongside the value.
+        self._store[key.lower()] = (key, value)
 
-    def __delitem__(self, key):
-        dict.__delitem__(self, self.lower_keys.get(key.lower(), key))
-        self._lower_keys.clear()
+    def __getitem__(self, key):
+        return self._store[key.lower()][1]
 
-    def __contains__(self, key):
-        return key.lower() in self.lower_keys
+    def __delitem__(self, key):
+        del self._store[key.lower()]
 
-    def __getitem__(self, key):
-        # We allow fall-through here, so values default to None
-        if key in self:
-            return dict.__getitem__(self, self.lower_keys[key.lower()])
+    def __iter__(self):
+        return (casedkey for casedkey, mappedvalue in self._store.values())
 
-    def get(self, key, default=None):
-        if key in self:
-            return self[key]
+    def __len__(self):
+        return len(self._store)
+
+    def lower_items(self):
+        """Like iteritems(), but with all lowercase keys."""
+        return (
+            (lowerkey, keyval[1])
+            for (lowerkey, keyval)
+            in self._store.items()
+        )
+
+    def __eq__(self, other):
+        if isinstance(other, collections.Mapping):
+            other = CaseInsensitiveDict(other)
         else:
-            return default
+            return NotImplemented
+        # Compare insensitively
+        return dict(self.lower_items()) == dict(other.lower_items())
+
+    # Copy is required
+    def copy(self):
+         return CaseInsensitiveDict(self._store.values())
+
+    def __repr__(self):
+        return '%s(%r)' % (self.__class__.__name__, dict(self.items()))
 
 
 class LookupDict(dict):
index 68f3e62..d690559 100644 (file)
@@ -23,6 +23,7 @@ from . import certs
 from .compat import parse_http_list as _parse_list_header
 from .compat import quote, urlparse, bytes, str, OrderedDict, urlunparse
 from .cookies import RequestsCookieJar, cookiejar_from_dict
+from .structures import CaseInsensitiveDict
 
 _hush_pyflakes = (RequestsCookieJar,)
 
@@ -449,11 +450,11 @@ def default_user_agent():
 
 
 def default_headers():
-    return {
+    return CaseInsensitiveDict({
         'User-Agent': default_user_agent(),
         'Accept-Encoding': ', '.join(('gzip', 'deflate', 'compress')),
         'Accept': '*/*'
-    }
+    })
 
 
 def parse_header_links(value):
index 9339476..953ecca 100644 (file)
@@ -13,6 +13,7 @@ import requests
 from requests.auth import HTTPDigestAuth
 from requests.compat import str, cookielib
 from requests.cookies import cookiejar_from_dict
+from requests.structures import CaseInsensitiveDict
 
 try:
     import StringIO
@@ -458,6 +459,165 @@ class RequestsTestCase(unittest.TestCase):
         r = s.send(r.prepare())
         self.assertEqual(r.status_code, 200)
 
+    def test_fixes_1329(self):
+        s = requests.Session()
+        s.headers.update({'accept': 'application/json'})
+        r = s.get(httpbin('get'))
+        headers = r.request.headers
+        # ASCII encode because of key comparison changes in py3
+        self.assertEqual(
+            headers['accept'.encode('ascii')],
+            'application/json'
+        )
+        self.assertEqual(
+            headers['Accept'.encode('ascii')],
+            'application/json'
+        )
+
+
+class TestCaseInsensitiveDict(unittest.TestCase):
+
+    def test_mapping_init(self):
+        cid = CaseInsensitiveDict({'Foo': 'foo','BAr': 'bar'})
+        self.assertEqual(len(cid), 2)
+        self.assertTrue('foo' in cid)
+        self.assertTrue('bar' in cid)
+
+    def test_iterable_init(self):
+        cid = CaseInsensitiveDict([('Foo', 'foo'), ('BAr', 'bar')])
+        self.assertEqual(len(cid), 2)
+        self.assertTrue('foo' in cid)
+        self.assertTrue('bar' in cid)
+
+    def test_kwargs_init(self):
+        cid = CaseInsensitiveDict(FOO='foo', BAr='bar')
+        self.assertEqual(len(cid), 2)
+        self.assertTrue('foo' in cid)
+        self.assertTrue('bar' in cid)
+
+    def test_docstring_example(self):
+        cid = CaseInsensitiveDict()
+        cid['Accept'] = 'application/json'
+        self.assertEqual(cid['aCCEPT'], 'application/json')
+        self.assertEqual(list(cid), ['Accept'])
+
+    def test_len(self):
+        cid = CaseInsensitiveDict({'a': 'a', 'b': 'b'})
+        cid['A'] = 'a'
+        self.assertEqual(len(cid), 2)
+
+    def test_getitem(self):
+        cid = CaseInsensitiveDict({'Spam': 'blueval'})
+        self.assertEqual(cid['spam'], 'blueval')
+        self.assertEqual(cid['SPAM'], 'blueval')
+
+    def test_fixes_649(self):
+        cid = CaseInsensitiveDict()
+        cid['spam'] = 'oneval'
+        cid['Spam'] = 'twoval'
+        cid['sPAM'] = 'redval'
+        cid['SPAM'] = 'blueval'
+        self.assertEqual(cid['spam'], 'blueval')
+        self.assertEqual(cid['SPAM'], 'blueval')
+        self.assertEqual(list(cid.keys()), ['SPAM'])
+
+    def test_delitem(self):
+        cid = CaseInsensitiveDict()
+        cid['Spam'] = 'someval'
+        del cid['sPam']
+        self.assertFalse('spam' in cid)
+        self.assertEqual(len(cid), 0)
+
+    def test_contains(self):
+        cid = CaseInsensitiveDict()
+        cid['Spam'] = 'someval'
+        self.assertTrue('Spam' in cid)
+        self.assertTrue('spam' in cid)
+        self.assertTrue('SPAM' in cid)
+        self.assertTrue('sPam' in cid)
+        self.assertFalse('notspam' in cid)
+
+    def test_get(self):
+        cid = CaseInsensitiveDict()
+        cid['spam'] = 'oneval'
+        cid['SPAM'] = 'blueval'
+        self.assertEqual(cid.get('spam'), 'blueval')
+        self.assertEqual(cid.get('SPAM'), 'blueval')
+        self.assertEqual(cid.get('sPam'), 'blueval')
+        self.assertEqual(cid.get('notspam', 'default'), 'default')
+
+    def test_update(self):
+        cid = CaseInsensitiveDict()
+        cid['spam'] = 'blueval'
+        cid.update({'sPam': 'notblueval'})
+        self.assertEqual(cid['spam'], 'notblueval')
+        cid = CaseInsensitiveDict({'Foo': 'foo','BAr': 'bar'})
+        cid.update({'fOO': 'anotherfoo', 'bAR': 'anotherbar'})
+        self.assertEqual(len(cid), 2)
+        self.assertEqual(cid['foo'], 'anotherfoo')
+        self.assertEqual(cid['bar'], 'anotherbar')
+
+    def test_update_retains_unchanged(self):
+        cid = CaseInsensitiveDict({'foo': 'foo', 'bar': 'bar'})
+        cid.update({'foo': 'newfoo'})
+        self.assertEquals(cid['bar'], 'bar')
+
+    def test_iter(self):
+        cid = CaseInsensitiveDict({'Spam': 'spam', 'Eggs': 'eggs'})
+        keys = frozenset(['Spam', 'Eggs'])
+        self.assertEqual(frozenset(iter(cid)), keys)
+
+    def test_equality(self):
+        cid = CaseInsensitiveDict({'SPAM': 'blueval', 'Eggs': 'redval'})
+        othercid = CaseInsensitiveDict({'spam': 'blueval', 'eggs': 'redval'})
+        self.assertEqual(cid, othercid)
+        del othercid['spam']
+        self.assertNotEqual(cid, othercid)
+        self.assertEqual(cid, {'spam': 'blueval', 'eggs': 'redval'})
+
+    def test_setdefault(self):
+        cid = CaseInsensitiveDict({'Spam': 'blueval'})
+        self.assertEqual(
+            cid.setdefault('spam', 'notblueval'),
+            'blueval'
+        )
+        self.assertEqual(
+            cid.setdefault('notspam', 'notblueval'),
+            'notblueval'
+        )
+
+    def test_lower_items(self):
+        cid = CaseInsensitiveDict({
+            'Accept': 'application/json',
+            'user-Agent': 'requests',
+        })
+        keyset = frozenset(lowerkey for lowerkey, v in cid.lower_items())
+        lowerkeyset = frozenset(['accept', 'user-agent'])
+        self.assertEqual(keyset, lowerkeyset)
+
+    def test_preserve_key_case(self):
+        cid = CaseInsensitiveDict({
+            'Accept': 'application/json',
+            'user-Agent': 'requests',
+        })
+        keyset = frozenset(['Accept', 'user-Agent'])
+        self.assertEqual(frozenset(i[0] for i in cid.items()), keyset)
+        self.assertEqual(frozenset(cid.keys()), keyset)
+        self.assertEqual(frozenset(cid), keyset)
+
+    def test_preserve_last_key_case(self):
+        cid = CaseInsensitiveDict({
+            'Accept': 'application/json',
+            'user-Agent': 'requests',
+        })
+        cid.update({'ACCEPT': 'application/json'})
+        cid['USER-AGENT'] = 'requests'
+        keyset = frozenset(['ACCEPT', 'USER-AGENT'])
+        self.assertEqual(frozenset(i[0] for i in cid.items()), keyset)
+        self.assertEqual(frozenset(cid.keys()), keyset)
+        self.assertEqual(frozenset(cid), keyset)
+
+
 
 if __name__ == '__main__':
     unittest.main()