From 98114245c686ef7cd62c9932071fd972ad3efedb Mon Sep 17 00:00:00 2001 From: Chase Sterling Date: Thu, 2 May 2013 12:46:59 -0400 Subject: [PATCH] Refactor merge_kwargs for clarity and to fix a few bugs --- requests/sessions.py | 74 ++++++++++++++++++++++------------------------------ requests/utils.py | 4 +-- test_requests.py | 14 ++++++++++ 3 files changed, 47 insertions(+), 45 deletions(-) diff --git a/requests/sessions.py b/requests/sessions.py index a881924..f4aeeee 100644 --- a/requests/sessions.py +++ b/requests/sessions.py @@ -9,14 +9,16 @@ requests (cookies, auth, proxies). """ import os +from collections import Mapping from datetime import datetime from .compat import cookielib, OrderedDict, urljoin, urlparse from .cookies import cookiejar_from_dict, extract_cookies_to_jar, RequestsCookieJar from .models import Request, PreparedRequest from .hooks import default_hooks, dispatch_hook -from .utils import from_key_val_list, default_headers +from .utils import to_key_val_list, default_headers from .exceptions import TooManyRedirects, InvalidSchema +from .structures import CaseInsensitiveDict from .adapters import HTTPAdapter @@ -32,49 +34,35 @@ REDIRECT_STATI = ( DEFAULT_REDIRECT_LIMIT = 30 -def merge_kwargs(local_kwarg, default_kwarg): - """Merges kwarg dictionaries. - - If a local key in the dictionary is set to None, it will be removed. +def merge_setting(request_setting, session_setting, dict_class=OrderedDict): + """ + Determines appropriate setting for a given request, taking into account the + explicit setting on that request, and the setting in the session. If a + setting is a dictionary, they will be merged together using `dict_class` """ - if default_kwarg is None: - return local_kwarg - - if isinstance(local_kwarg, str): - return local_kwarg - - if local_kwarg is None: - return default_kwarg - - # Bypass if not a dictionary (e.g. timeout) - if not hasattr(default_kwarg, 'items'): - return local_kwarg + if session_setting is None: + return request_setting - default_kwarg = from_key_val_list(default_kwarg) - local_kwarg = from_key_val_list(local_kwarg) + if request_setting is None: + return session_setting - # Update new values in a case-insensitive way - def get_original_key(original_keys, new_key): - """ - Finds the key from original_keys that case-insensitive matches new_key. - """ - for original_key in original_keys: - if key.lower() == original_key.lower(): - return original_key - return new_key + # Bypass if not a dictionary (e.g. verify) + if not ( + isinstance(session_setting, Mapping) and + isinstance(request_setting, Mapping) + ): + return request_setting - kwargs = default_kwarg.copy() - original_keys = kwargs.keys() - for key, value in local_kwarg.items(): - kwargs[get_original_key(original_keys, key)] = value + merged_setting = dict_class(to_key_val_list(session_setting)) + merged_setting.update(to_key_val_list(request_setting)) # Remove keys that are set to None. - for (k, v) in local_kwarg.items(): + for (k, v) in request_setting.items(): if v is None: - del kwargs[k] + del merged_setting[k] - return kwargs + return merged_setting class SessionRedirectMixin(object): @@ -311,14 +299,14 @@ class Session(SessionRedirectMixin): verify = os.environ.get('CURL_CA_BUNDLE') # Merge all the kwargs. - params = merge_kwargs(params, self.params) - headers = merge_kwargs(headers, self.headers) - auth = merge_kwargs(auth, self.auth) - proxies = merge_kwargs(proxies, self.proxies) - hooks = merge_kwargs(hooks, self.hooks) - stream = merge_kwargs(stream, self.stream) - verify = merge_kwargs(verify, self.verify) - cert = merge_kwargs(cert, self.cert) + params = merge_setting(params, self.params) + headers = merge_setting(headers, self.headers, dict_class=CaseInsensitiveDict) + auth = merge_setting(auth, self.auth) + proxies = merge_setting(proxies, self.proxies) + hooks = merge_setting(hooks, self.hooks) + stream = merge_setting(stream, self.stream) + verify = merge_setting(verify, self.verify) + cert = merge_setting(cert, self.cert) # Create the Request. req = Request() diff --git a/requests/utils.py b/requests/utils.py index d690559..b21bf8f 100644 --- a/requests/utils.py +++ b/requests/utils.py @@ -11,11 +11,11 @@ that are also useful for external consumption. import cgi import codecs +import collections import os import platform import re import sys -import zlib from netrc import netrc, NetrcParseError from . import __version__ @@ -135,7 +135,7 @@ def to_key_val_list(value): if isinstance(value, (str, bytes, bool, int)): raise ValueError('cannot encode objects that are not 2-tuples') - if isinstance(value, dict): + if isinstance(value, collections.Mapping): value = value.items() return list(value) diff --git a/test_requests.py b/test_requests.py index 60e4498..aabce29 100644 --- a/test_requests.py +++ b/test_requests.py @@ -521,6 +521,20 @@ class RequestsTestCase(unittest.TestCase): self.assertTrue('http://' in s2.adapters) self.assertTrue('https://' in s2.adapters) + def test_header_remove_is_case_insensitive(self): + # From issue #1321 + s = requests.Session() + s.headers['foo'] = 'bar' + r = s.get(httpbin('get'), headers={'FOO': None}) + assert 'foo' not in r.request.headers + + def test_params_are_merged_case_sensitive(self): + s = requests.Session() + s.params['foo'] = 'bar' + r = s.get(httpbin('get'), params={'FOO': 'bar'}) + assert r.json()['args'] == {'foo': 'bar', 'FOO': 'bar'} + + def test_long_authinfo_in_url(self): url = 'http://{0}:{1}@{2}:9000/path?query#frag'.format( 'E8A3BE87-9E3F-4620-8858-95478E385B5B', -- 2.7.4