"""
import os
+import collections
from itertools import islice
return "".join(islice(self.i, None, n))
-class CaseInsensitiveDict(dict):
- """Case-insensitive Dictionary
+class CaseInsensitiveDict(collections.MutableMapping):
+ """
+ A case-insensitive ``dict``-like object.
+
+ Implements all methods and operations of
+ ``collections.MutableMapping`` as well as dict's ``copy``. Also
+ provides ``lower_items``.
+
+ All keys are expected to be strings. The structure remembers the
+ case of the last key to be set, and ``iter(instance)``,
+ ``keys()``, ``items()``, ``iterkeys()``, and ``iteritems()``
+ will contain case-sensitive keys. However, querying and contains
+ testing is case insensitive:
+
+ cid = CaseInsensitiveDict()
+ cid['Accept'] = 'application/json'
+ cid['aCCEPT'] == 'application/json' # True
+ list(cid) == ['Accept'] # True
For example, ``headers['content-encoding']`` will return the
- value of a ``'Content-Encoding'`` response header."""
+ value of a ``'Content-Encoding'`` response header, regardless
+ of how the header name was originally stored.
- @property
- def lower_keys(self):
- if not hasattr(self, '_lower_keys') or not self._lower_keys:
- self._lower_keys = dict((k.lower(), k) for k in list(self.keys()))
- return self._lower_keys
+ If the constructor, ``.update``, or equality comparison
+ operations are given keys that have equal ``.lower()``s, the
+ behavior is undefined.
- def _clear_lower_keys(self):
- if hasattr(self, '_lower_keys'):
- self._lower_keys.clear()
+ """
+ def __init__(self, data=None, **kwargs):
+ self._store = dict()
+ if data is None:
+ data = {}
+ self.update(data, **kwargs)
def __setitem__(self, key, value):
- dict.__setitem__(self, key, value)
- self._clear_lower_keys()
+ # Use the lowercased key for lookups, but store the actual
+ # key alongside the value.
+ self._store[key.lower()] = (key, value)
- def __delitem__(self, key):
- dict.__delitem__(self, self.lower_keys.get(key.lower(), key))
- self._lower_keys.clear()
+ def __getitem__(self, key):
+ return self._store[key.lower()][1]
- def __contains__(self, key):
- return key.lower() in self.lower_keys
+ def __delitem__(self, key):
+ del self._store[key.lower()]
- def __getitem__(self, key):
- # We allow fall-through here, so values default to None
- if key in self:
- return dict.__getitem__(self, self.lower_keys[key.lower()])
+ def __iter__(self):
+ return (casedkey for casedkey, mappedvalue in self._store.values())
- def get(self, key, default=None):
- if key in self:
- return self[key]
+ def __len__(self):
+ return len(self._store)
+
+ def lower_items(self):
+ """Like iteritems(), but with all lowercase keys."""
+ return (
+ (lowerkey, keyval[1])
+ for (lowerkey, keyval)
+ in self._store.items()
+ )
+
+ def __eq__(self, other):
+ if isinstance(other, collections.Mapping):
+ other = CaseInsensitiveDict(other)
else:
- return default
+ return NotImplemented
+ # Compare insensitively
+ return dict(self.lower_items()) == dict(other.lower_items())
+
+ # Copy is required
+ def copy(self):
+ return CaseInsensitiveDict(self._store.values())
+
+ def __repr__(self):
+ return '%s(%r)' % (self.__class__.__name__, dict(self.items()))
class LookupDict(dict):
from requests.auth import HTTPDigestAuth
from requests.compat import str, cookielib
from requests.cookies import cookiejar_from_dict
+from requests.structures import CaseInsensitiveDict
try:
import StringIO
r = s.send(r.prepare())
self.assertEqual(r.status_code, 200)
+ def test_fixes_1329(self):
+ s = requests.Session()
+ s.headers.update({'accept': 'application/json'})
+ r = s.get(httpbin('get'))
+ headers = r.request.headers
+ # ASCII encode because of key comparison changes in py3
+ self.assertEqual(
+ headers['accept'.encode('ascii')],
+ 'application/json'
+ )
+ self.assertEqual(
+ headers['Accept'.encode('ascii')],
+ 'application/json'
+ )
+
+
+class TestCaseInsensitiveDict(unittest.TestCase):
+
+ def test_mapping_init(self):
+ cid = CaseInsensitiveDict({'Foo': 'foo','BAr': 'bar'})
+ self.assertEqual(len(cid), 2)
+ self.assertTrue('foo' in cid)
+ self.assertTrue('bar' in cid)
+
+ def test_iterable_init(self):
+ cid = CaseInsensitiveDict([('Foo', 'foo'), ('BAr', 'bar')])
+ self.assertEqual(len(cid), 2)
+ self.assertTrue('foo' in cid)
+ self.assertTrue('bar' in cid)
+
+ def test_kwargs_init(self):
+ cid = CaseInsensitiveDict(FOO='foo', BAr='bar')
+ self.assertEqual(len(cid), 2)
+ self.assertTrue('foo' in cid)
+ self.assertTrue('bar' in cid)
+
+ def test_docstring_example(self):
+ cid = CaseInsensitiveDict()
+ cid['Accept'] = 'application/json'
+ self.assertEqual(cid['aCCEPT'], 'application/json')
+ self.assertEqual(list(cid), ['Accept'])
+
+ def test_len(self):
+ cid = CaseInsensitiveDict({'a': 'a', 'b': 'b'})
+ cid['A'] = 'a'
+ self.assertEqual(len(cid), 2)
+
+ def test_getitem(self):
+ cid = CaseInsensitiveDict({'Spam': 'blueval'})
+ self.assertEqual(cid['spam'], 'blueval')
+ self.assertEqual(cid['SPAM'], 'blueval')
+
+ def test_fixes_649(self):
+ cid = CaseInsensitiveDict()
+ cid['spam'] = 'oneval'
+ cid['Spam'] = 'twoval'
+ cid['sPAM'] = 'redval'
+ cid['SPAM'] = 'blueval'
+ self.assertEqual(cid['spam'], 'blueval')
+ self.assertEqual(cid['SPAM'], 'blueval')
+ self.assertEqual(list(cid.keys()), ['SPAM'])
+
+ def test_delitem(self):
+ cid = CaseInsensitiveDict()
+ cid['Spam'] = 'someval'
+ del cid['sPam']
+ self.assertFalse('spam' in cid)
+ self.assertEqual(len(cid), 0)
+
+ def test_contains(self):
+ cid = CaseInsensitiveDict()
+ cid['Spam'] = 'someval'
+ self.assertTrue('Spam' in cid)
+ self.assertTrue('spam' in cid)
+ self.assertTrue('SPAM' in cid)
+ self.assertTrue('sPam' in cid)
+ self.assertFalse('notspam' in cid)
+
+ def test_get(self):
+ cid = CaseInsensitiveDict()
+ cid['spam'] = 'oneval'
+ cid['SPAM'] = 'blueval'
+ self.assertEqual(cid.get('spam'), 'blueval')
+ self.assertEqual(cid.get('SPAM'), 'blueval')
+ self.assertEqual(cid.get('sPam'), 'blueval')
+ self.assertEqual(cid.get('notspam', 'default'), 'default')
+
+ def test_update(self):
+ cid = CaseInsensitiveDict()
+ cid['spam'] = 'blueval'
+ cid.update({'sPam': 'notblueval'})
+ self.assertEqual(cid['spam'], 'notblueval')
+ cid = CaseInsensitiveDict({'Foo': 'foo','BAr': 'bar'})
+ cid.update({'fOO': 'anotherfoo', 'bAR': 'anotherbar'})
+ self.assertEqual(len(cid), 2)
+ self.assertEqual(cid['foo'], 'anotherfoo')
+ self.assertEqual(cid['bar'], 'anotherbar')
+
+ def test_update_retains_unchanged(self):
+ cid = CaseInsensitiveDict({'foo': 'foo', 'bar': 'bar'})
+ cid.update({'foo': 'newfoo'})
+ self.assertEquals(cid['bar'], 'bar')
+
+ def test_iter(self):
+ cid = CaseInsensitiveDict({'Spam': 'spam', 'Eggs': 'eggs'})
+ keys = frozenset(['Spam', 'Eggs'])
+ self.assertEqual(frozenset(iter(cid)), keys)
+
+ def test_equality(self):
+ cid = CaseInsensitiveDict({'SPAM': 'blueval', 'Eggs': 'redval'})
+ othercid = CaseInsensitiveDict({'spam': 'blueval', 'eggs': 'redval'})
+ self.assertEqual(cid, othercid)
+ del othercid['spam']
+ self.assertNotEqual(cid, othercid)
+ self.assertEqual(cid, {'spam': 'blueval', 'eggs': 'redval'})
+
+ def test_setdefault(self):
+ cid = CaseInsensitiveDict({'Spam': 'blueval'})
+ self.assertEqual(
+ cid.setdefault('spam', 'notblueval'),
+ 'blueval'
+ )
+ self.assertEqual(
+ cid.setdefault('notspam', 'notblueval'),
+ 'notblueval'
+ )
+
+ def test_lower_items(self):
+ cid = CaseInsensitiveDict({
+ 'Accept': 'application/json',
+ 'user-Agent': 'requests',
+ })
+ keyset = frozenset(lowerkey for lowerkey, v in cid.lower_items())
+ lowerkeyset = frozenset(['accept', 'user-agent'])
+ self.assertEqual(keyset, lowerkeyset)
+
+ def test_preserve_key_case(self):
+ cid = CaseInsensitiveDict({
+ 'Accept': 'application/json',
+ 'user-Agent': 'requests',
+ })
+ keyset = frozenset(['Accept', 'user-Agent'])
+ self.assertEqual(frozenset(i[0] for i in cid.items()), keyset)
+ self.assertEqual(frozenset(cid.keys()), keyset)
+ self.assertEqual(frozenset(cid), keyset)
+
+ def test_preserve_last_key_case(self):
+ cid = CaseInsensitiveDict({
+ 'Accept': 'application/json',
+ 'user-Agent': 'requests',
+ })
+ cid.update({'ACCEPT': 'application/json'})
+ cid['USER-AGENT'] = 'requests'
+ keyset = frozenset(['ACCEPT', 'USER-AGENT'])
+ self.assertEqual(frozenset(i[0] for i in cid.items()), keyset)
+ self.assertEqual(frozenset(cid.keys()), keyset)
+ self.assertEqual(frozenset(cid), keyset)
+
+
if __name__ == '__main__':
unittest.main()