new local merging of keyword arguments for sessions
authorKenneth Reitz <me@kennethreitz.com>
Wed, 19 Oct 2011 05:09:55 +0000 (01:09 -0400)
committerKenneth Reitz <me@kennethreitz.com>
Wed, 19 Oct 2011 05:09:55 +0000 (01:09 -0400)
requests/sessions.py

index 50b09f6..4e7022d 100644 (file)
@@ -16,21 +16,60 @@ from .utils import add_dict_to_cookiejar
 
 
 
+def merge_kwargs(local_kwarg, default_kwarg):
+    """Merges kwarg dictionaries.
+
+    If a local key in the dictionary is set to None, it will be removed.
+    """
+
+    if default_kwarg is None:
+        return local_kwarg
+
+    if local_kwarg is None:
+        return default_kwarg
+
+    # Bypass if not a dictionary (e.g. timeout)
+    if not hasattr(default_kwarg, 'items'):
+        return local_kwarg
+
+
+
+    # Update new values.
+    kwargs = default_kwarg.copy()
+    kwargs.update(local_kwarg)
+
+    # Remove keys that are set to None.
+    for (k,v) in local_kwarg.items():
+        if v is None:
+            del kwargs[k]
+
+    return kwargs
+
+
 class Session(object):
     """A Requests session."""
 
     __attrs__ = ['headers', 'cookies', 'auth', 'timeout', 'proxies', 'hooks']
 
 
-    def __init__(self, **kwargs):
+    def __init__(self,
+        headers=None,
+        cookies=None,
+        auth=None,
+        timeout=None,
+        proxies=None,
+        hooks=None):
+
+        self.headers = headers or {}
+        self.cookies = cookies or {}
+        self.auth = auth
+        self.timeout = timeout
+        self.proxies = proxies or {}
+        self.hooks = hooks or {}
 
         # Set up a CookieJar to be used by default
         self.cookies = cookielib.FileCookieJar()
 
-        # Map args from kwargs to instance-local variables
-        map(lambda k, v: (k in self.__attrs__) and setattr(self, k, v),
-                kwargs.iterkeys(), kwargs.itervalues())
-
         # Map and wrap requests.api methods
         self._map_api_methods()
 
@@ -42,10 +81,8 @@ class Session(object):
         return self
 
     def __exit__(self, *args):
-        # print args
         pass
 
-
     def _map_api_methods(self):
         """Reads each available method from requests.api and decorates
         them with a wrapper, which inserts any instance-local attributes
@@ -54,23 +91,35 @@ class Session(object):
 
         def pass_args(func):
             def wrapper_func(*args, **kwargs):
-                inst_attrs = dict((k, v) for k, v in self.__dict__.iteritems()
-                        if k in self.__attrs__)
-                # Combine instance-local values with kwargs values, with
-                # priority to values in kwargs
-                kwargs = dict(inst_attrs.items() + kwargs.items())
+
+                # Argument collector.
+                _kwargs = {}
 
                 # If a session request has a cookie_dict, inject the
                 # values into the existing CookieJar instead.
                 if isinstance(kwargs.get('cookies', None), dict):
                     kwargs['cookies'] = add_dict_to_cookiejar(
-                        inst_attrs['cookies'], kwargs['cookies']
+                        self.cookies, kwargs['cookies']
                     )
 
-                if kwargs.get('headers', None) and inst_attrs.get('headers', None):
-                    kwargs['headers'].update(inst_attrs['headers'])
+                for attr in self.__attrs__:
+                # for attr in ['headers',]:
+                    s_val = self.__dict__.get(attr)
+                    r_val = kwargs.get(attr)
+
+                    new_attr = merge_kwargs(r_val, s_val)
+
+                    # Skip attributes that were set to None.
+                    if new_attr is not None:
+                        _kwargs[attr] = new_attr
+
+                # Make sure we didn't miss anything.
+                for (k, v) in kwargs.items():
+                    if k not in _kwargs:
+                        _kwargs[k] = v
+
+                return func(*args, **_kwargs)
 
-                return func(*args, **kwargs)
             return wrapper_func
 
         # Map and decorate each function available in requests.api