Refactor merge_kwargs for clarity and to fix a few bugs
authorChase Sterling <chase.sterling@gmail.com>
Thu, 2 May 2013 16:46:59 +0000 (12:46 -0400)
committerChase Sterling <chase.sterling@gmail.com>
Tue, 21 May 2013 01:20:51 +0000 (21:20 -0400)
requests/sessions.py
requests/utils.py
test_requests.py

index a881924..f4aeeee 100644 (file)
@@ -9,14 +9,16 @@ requests (cookies, auth, proxies).
 
 """
 import os
+from collections import Mapping
 from datetime import datetime
 
 from .compat import cookielib, OrderedDict, urljoin, urlparse
 from .cookies import cookiejar_from_dict, extract_cookies_to_jar, RequestsCookieJar
 from .models import Request, PreparedRequest
 from .hooks import default_hooks, dispatch_hook
-from .utils import from_key_val_list, default_headers
+from .utils import to_key_val_list, default_headers
 from .exceptions import TooManyRedirects, InvalidSchema
+from .structures import CaseInsensitiveDict
 
 from .adapters import HTTPAdapter
 
@@ -32,49 +34,35 @@ REDIRECT_STATI = (
 DEFAULT_REDIRECT_LIMIT = 30
 
 
-def merge_kwargs(local_kwarg, default_kwarg):
-    """Merges kwarg dictionaries.
-
-    If a local key in the dictionary is set to None, it will be removed.
+def merge_setting(request_setting, session_setting, dict_class=OrderedDict):
+    """
+    Determines appropriate setting for a given request, taking into account the
+    explicit setting on that request, and the setting in the session. If a
+    setting is a dictionary, they will be merged together using `dict_class`
     """
 
-    if default_kwarg is None:
-        return local_kwarg
-
-    if isinstance(local_kwarg, str):
-        return local_kwarg
-
-    if local_kwarg is None:
-        return default_kwarg
-
-    # Bypass if not a dictionary (e.g. timeout)
-    if not hasattr(default_kwarg, 'items'):
-        return local_kwarg
+    if session_setting is None:
+        return request_setting
 
-    default_kwarg = from_key_val_list(default_kwarg)
-    local_kwarg = from_key_val_list(local_kwarg)
+    if request_setting is None:
+        return session_setting
 
-    # Update new values in a case-insensitive way
-    def get_original_key(original_keys, new_key):
-        """
-        Finds the key from original_keys that case-insensitive matches new_key.
-        """
-        for original_key in original_keys:
-            if key.lower() == original_key.lower():
-                return original_key
-        return new_key
+    # Bypass if not a dictionary (e.g. verify)
+    if not (
+            isinstance(session_setting, Mapping) and
+            isinstance(request_setting, Mapping)
+    ):
+        return request_setting
 
-    kwargs = default_kwarg.copy()
-    original_keys = kwargs.keys()
-    for key, value in local_kwarg.items():
-        kwargs[get_original_key(original_keys, key)] = value
+    merged_setting = dict_class(to_key_val_list(session_setting))
+    merged_setting.update(to_key_val_list(request_setting))
 
     # Remove keys that are set to None.
-    for (k, v) in local_kwarg.items():
+    for (k, v) in request_setting.items():
         if v is None:
-            del kwargs[k]
+            del merged_setting[k]
 
-    return kwargs
+    return merged_setting
 
 
 class SessionRedirectMixin(object):
@@ -311,14 +299,14 @@ class Session(SessionRedirectMixin):
                 verify = os.environ.get('CURL_CA_BUNDLE')
 
         # Merge all the kwargs.
-        params = merge_kwargs(params, self.params)
-        headers = merge_kwargs(headers, self.headers)
-        auth = merge_kwargs(auth, self.auth)
-        proxies = merge_kwargs(proxies, self.proxies)
-        hooks = merge_kwargs(hooks, self.hooks)
-        stream = merge_kwargs(stream, self.stream)
-        verify = merge_kwargs(verify, self.verify)
-        cert = merge_kwargs(cert, self.cert)
+        params = merge_setting(params, self.params)
+        headers = merge_setting(headers, self.headers, dict_class=CaseInsensitiveDict)
+        auth = merge_setting(auth, self.auth)
+        proxies = merge_setting(proxies, self.proxies)
+        hooks = merge_setting(hooks, self.hooks)
+        stream = merge_setting(stream, self.stream)
+        verify = merge_setting(verify, self.verify)
+        cert = merge_setting(cert, self.cert)
 
         # Create the Request.
         req = Request()
index d690559..b21bf8f 100644 (file)
@@ -11,11 +11,11 @@ that are also useful for external consumption.
 
 import cgi
 import codecs
+import collections
 import os
 import platform
 import re
 import sys
-import zlib
 from netrc import netrc, NetrcParseError
 
 from . import __version__
@@ -135,7 +135,7 @@ def to_key_val_list(value):
     if isinstance(value, (str, bytes, bool, int)):
         raise ValueError('cannot encode objects that are not 2-tuples')
 
-    if isinstance(value, dict):
+    if isinstance(value, collections.Mapping):
         value = value.items()
 
     return list(value)
index 60e4498..aabce29 100644 (file)
@@ -521,6 +521,20 @@ class RequestsTestCase(unittest.TestCase):
         self.assertTrue('http://' in s2.adapters)
         self.assertTrue('https://' in s2.adapters)
 
+    def test_header_remove_is_case_insensitive(self):
+        # From issue #1321
+        s = requests.Session()
+        s.headers['foo'] = 'bar'
+        r = s.get(httpbin('get'), headers={'FOO': None})
+        assert 'foo' not in r.request.headers
+
+    def test_params_are_merged_case_sensitive(self):
+        s = requests.Session()
+        s.params['foo'] = 'bar'
+        r = s.get(httpbin('get'), params={'FOO': 'bar'})
+        assert r.json()['args'] == {'foo': 'bar', 'FOO': 'bar'}
+
+
     def test_long_authinfo_in_url(self):
         url = 'http://{0}:{1}@{2}:9000/path?query#frag'.format(
             'E8A3BE87-9E3F-4620-8858-95478E385B5B',