Fixed so that safe_mode works for Sessions
authorJonatan Heyman <jonatan@heyman.info>
Fri, 23 Nov 2012 15:48:51 +0000 (16:48 +0100)
committerJonatan Heyman <jonatan@heyman.info>
Fri, 23 Nov 2012 15:48:51 +0000 (16:48 +0100)
AUTHORS.rst
requests/api.py
requests/safe_mode.py
requests/sessions.py
tests/test_requests.py

index 3c074d1be9d209eaef0bec661912eaa0b4f5aa94..e18d5751690dc07367f8b328560f26f0b0d1815a 100644 (file)
@@ -116,3 +116,4 @@ Patches and Suggestions
 - AndrĂ© Graf (dergraf)
 - Stephen Zhuang (everbird)
 - Martijn Pieters
+- Jonatan Heyman
index ded793525af87f80465e3d506bd45396d9e91560..297f4cbf8059bf6d1b8e025c76799b46b16f3e90 100644 (file)
@@ -12,10 +12,8 @@ This module implements the Requests API.
 """
 
 from . import sessions
-from .safe_mode import catch_exceptions_if_in_safe_mode
 
 
-@catch_exceptions_if_in_safe_mode
 def request(method, url, **kwargs):
     """Constructs and sends a :class:`Request <Request>`.
     Returns :class:`Response <Response>` object.
index 0fb8d7052adcb894d940db7874e633ed934dfcdb..18808d749e8eb2097de3f82e52e2fd4999cb6432 100644 (file)
@@ -18,17 +18,17 @@ import socket
 
 
 def catch_exceptions_if_in_safe_mode(function):
-    """New implementation of safe_mode. We catch all exceptions at the API level
+    """New implementation of safe_mode. We catch all exceptions at the Session level
     and then return a blank Response object with the error field filled. This decorator
-    wraps request() in api.py.
+    wraps Session._send_request() in sessions.py.
     """
 
-    def wrapped(method, url, **kwargs):
+    def wrapped(*args, **kwargs):
         # if save_mode, we catch exceptions and fill error field
         if (kwargs.get('config') and kwargs.get('config').get('safe_mode')) or (kwargs.get('session')
                                             and kwargs.get('session').config.get('safe_mode')):
             try:
-                return function(method, url, **kwargs)
+                return function(*args, **kwargs)
             except (RequestException, ConnectionError, HTTPError,
                     socket.timeout, socket.gaierror) as e:
                 r = Response()
@@ -36,5 +36,5 @@ def catch_exceptions_if_in_safe_mode(function):
                 r.raw = HTTPResponse()  # otherwise, tests fail
                 r.status_code = 0  # with this status_code, content returns None
                 return r
-        return function(method, url, **kwargs)
+        return function(*args, **kwargs)
     return wrapped
index 0962d8191d158bbe7af0dda7544742aea15435be..5d67b4d851b10970db0d30f85ca6e6550211a73d 100644 (file)
@@ -17,6 +17,7 @@ from .models import Request
 from .hooks import dispatch_hook
 from .utils import header_expand, from_key_val_list
 from .packages.urllib3.poolmanager import PoolManager
+from .safe_mode import catch_exceptions_if_in_safe_mode
 
 
 def merge_kwargs(local_kwarg, default_kwarg):
@@ -265,7 +266,12 @@ class Session(object):
             return r
 
         # Send the HTTP Request.
-        r.send(prefetch=prefetch)
+        return self._send_request(r, **args)
+
+    @catch_exceptions_if_in_safe_mode
+    def _send_request(self, r, **kwargs):
+        # Send the request.
+        r.send(prefetch=kwargs.get("prefetch"))
 
         # Return the response.
         return r.response
index 6615678ff4518c2a05787b19d4d9af85f1393895..cf326acf49945eb0082c04176f99fbe2fa0547a9 100755 (executable)
@@ -929,6 +929,19 @@ class RequestsTestSuite(TestSetup, TestBaseMixin, unittest.TestCase):
         ds2 = pickle.loads(pickle.dumps(requests.session(prefetch=False)))
         self.assertTrue(ds1.prefetch)
         self.assertFalse(ds2.prefetch)
+    
+    def test_session_connection_error_with_safe_mode(self):
+        config = {"safe_mode":True}
+
+        s = requests.session()
+        r = s.get("http://localhost:1/nope", timeout=0.1, config=config)
+        self.assertFalse(r.ok)
+        self.assertTrue(r.content is None)
+
+        s2 = requests.session(config=config)
+        r2 = s2.get("http://localhost:1/nope", timeout=0.1)
+        self.assertFalse(r2.ok)
+        self.assertTrue(r2.content is None)
 
     def test_connection_error(self):
         try: