Fix infinite loop on wrong Digest Authentication
authorPierre Chapuis <catwell@archlinux.us>
Thu, 12 Apr 2012 14:33:15 +0000 (16:33 +0200)
committerPierre Chapuis <catwell@archlinux.us>
Thu, 12 Apr 2012 14:33:15 +0000 (16:33 +0200)
requests/auth.py
requests/models.py
tests/test_requests.py

index 4fc338c40120e4caf7290f51395b96b9929d5da5..353180afe90550bee33a9eca7a71f3ad6a444676 100644 (file)
@@ -56,6 +56,8 @@ class HTTPDigestAuth(AuthBase):
     def handle_401(self, r):
         """Takes the given response and tries digest-auth, if needed."""
 
+        r.request.deregister_hook('response', self.handle_401)
+
         s_auth = r.headers.get('www-authenticate', '')
 
         if 'digest' in s_auth.lower():
index e117cb0bdc7755cc6fb3a31a0af903de378748e9..402adb6fb76854512c3430b365018dfe0fa2ec9f 100644 (file)
@@ -408,7 +408,18 @@ class Request(object):
     def register_hook(self, event, hook):
         """Properly register a hook."""
 
-        return self.hooks[event].append(hook)
+        self.hooks[event].append(hook)
+
+    def deregister_hook(self,event,hook):
+        """Deregister a previously registered hook.
+        Returns True if the hook existed, False if not.
+        """
+
+        try:
+            self.hooks[event].remove(hook)
+            return True
+        except ValueError:
+            return False
 
     def send(self, anyway=False, prefetch=False):
         """Sends the request. Returns True of successful, False if not.
index 7ca30ba6ba81aa8c9b0996241af0681667da6518..45497ac1a06dfd9ceac4beb347b692399cd4afa6 100755 (executable)
@@ -272,6 +272,20 @@ class RequestsTestSuite(TestSetup, unittest.TestCase):
             r = get(url, session=s)
             self.assertEqual(r.status_code, 200)
 
+    def test_DIGESTAUTH_WRONG_HTTP_401_GET(self):
+
+        for service in SERVICES:
+
+            auth = HTTPDigestAuth('user', 'wrongpass')
+            url = service('digest-auth', 'auth', 'user', 'pass')
+
+            r = get(url, auth=auth)
+            self.assertEqual(r.status_code, 401)
+
+            s = requests.session(auth=auth)
+            r = get(url, session=s)
+            self.assertEqual(r.status_code, 401)
+
     def test_POSTBIN_GET_POST_FILES(self):
 
         for service in SERVICES: