HTTP Basic recursion. Fixes #31
authorKenneth Reitz <me@kennethreitz.com>
Sat, 14 May 2011 18:02:36 +0000 (14:02 -0400)
committerKenneth Reitz <me@kennethreitz.com>
Sat, 14 May 2011 18:02:36 +0000 (14:02 -0400)
requests/core.py
test_requests.py

index a90cdd7d8736bfccf8ff4719ef53fd3fc4e82076..7a92c5fad269df9bd6a549a89f86678fa5847c50 100644 (file)
@@ -55,6 +55,28 @@ class _Request(urllib2.Request):
         return urllib2.Request.get_method(self)
 
 
+class _HTTPBasicAuthHandler(urllib2.HTTPBasicAuthHandler):
+    # from mercurial
+
+    def __init__(self, *args, **kwargs):
+        urllib2.HTTPBasicAuthHandler.__init__(self, *args, **kwargs)
+        self.retried_req = None
+
+    def reset_retry_count(self):
+        # Python 2.6.5 will call this on 401 or 407 errors and thus loop
+        # forever. We disable reset_retry_count completely and reset in
+        # http_error_auth_reqed instead.
+        pass
+
+    def http_error_auth_reqed(self, auth_header, host, req, headers):
+        # Reset the retry counter once for each request.
+        if req is not self.retried_req:
+            self.retried_req = req
+            self.retried = 0
+        return urllib2.HTTPBasicAuthHandler.http_error_auth_reqed(
+                        self, auth_header, host, req, headers)
+
+
 class Request(object):
     """The :class:`Request` object. It carries out all functionality of
     Requests. Recommended interface is with the Requests functions.
@@ -153,10 +175,17 @@ class Request(object):
 
     def _build_response(self, resp):
         """Build internal Response object from given response."""
+        if isinstance(resp, HTTPError):
+            # print resp.__dict__
+            pass
 
         self.response.status_code = getattr(resp, 'code', None)
-        self.response.headers = getattr(resp.info(), 'dict', None)
-        self.response.content = resp.read()
+
+        try:
+            self.response.headers = getattr(resp.info(), 'dict', None)
+            self.response.content = resp.read()
+        except AttributeError, why:
+            pass
 
         if self.response.headers.get('content-encoding', None) == 'gzip':
             try:
@@ -431,7 +460,7 @@ class AuthObject(object):
     """
 
     _handlers = {
-        'basic': urllib2.HTTPBasicAuthHandler,
+        'basic': _HTTPBasicAuthHandler,
         'digest': urllib2.HTTPDigestAuthHandler,
         'proxy_basic': urllib2.ProxyBasicAuthHandler,
         'proxy_digest': urllib2.ProxyDigestAuthHandler
index ecd017728f21aecdb36f1bd92249285ac015d859..3bc2e092a9d2d58fde11b67ccf766759d2250940 100755 (executable)
@@ -149,6 +149,13 @@ class RequestsTestSuite(unittest.TestCase):
         requests.get('http://google.com', params={'foo': u'foo'})
         requests.get('http://google.com/ΓΈ', params={'foo': u'foo'})
 
+    def test_httpauth_recursion(self):
+        conv_auth = ('requeststest', 'bad_password')
+
+        r = requests.get('https://convore.com/api/account/verify.json', auth=conv_auth)
+        self.assertEquals(r.status_code, 401)
+        print r.__dict__
+
 
 if __name__ == '__main__':
     unittest.main()