Facilitate for multiple hooks
authorJohannes Gorset <jgorset@gmail.com>
Sat, 21 Jan 2012 11:05:59 +0000 (12:05 +0100)
committerJohannes Gorset <jgorset@gmail.com>
Sat, 21 Jan 2012 11:14:55 +0000 (12:14 +0100)
requests/hooks.py
test_requests.py

index 37f87d9..dee4b1f 100644 (file)
@@ -31,10 +31,15 @@ def dispatch_hook(key, hooks, hook_data):
     hooks = hooks or dict()
 
     if key in hooks:
-        try:
-            return hooks.get(key).__call__(hook_data) or hook_data
+        hooks = hooks.get(key)
 
-        except Exception:
-            traceback.print_exc()
+        if hasattr(hooks, '__call__'):
+            hooks = [hooks]
+
+        for hook in hooks:
+            try:
+                hook_data = hook(hook_data) or hook_data
+            except Exception:
+                traceback.print_exc()
 
     return hook_data
index 172b1ed..8915425 100755 (executable)
@@ -516,6 +516,65 @@ class RequestsTestSuite(unittest.TestCase):
 
         self.assertEqual(r2.status_code, 200)
 
+    def test_single_hook(self):
+
+        def add_foo_header(args):
+            if not args.get('headers'):
+                args['headers'] = {}
+
+            args['headers'].update({
+                'X-Foo': 'foo'
+            })
+
+            return args
+
+        for service in SERVICES:
+            url = service('headers')
+
+            response = requests.get(
+                url = url,
+                hooks = {
+                    'args': add_foo_header
+                }
+            )
+
+            assert 'foo' in response.content
+
+    def test_multiple_hooks(self):
+
+        def add_foo_header(args):
+            if not args.get('headers'):
+                args['headers'] = {}
+
+            args['headers'].update({
+                'X-Foo': 'foo'
+            })
+
+            return args
+
+        def add_bar_header(args):
+            if not args.get('headers'):
+                args['headers'] = {}
+
+            args['headers'].update({
+                'X-Bar': 'bar'
+            })
+
+            return args
+
+        for service in SERVICES:
+            url = service('headers')
+
+            response = requests.get(
+                url = url,
+                hooks = {
+                    'args': [add_foo_header, add_bar_header]
+                }
+            )
+
+            assert 'foo' in response.content
+            assert 'bar' in response.content
+
     def test_session_persistent_cookies(self):
 
         s = requests.session()