hooks and things
authorKenneth Reitz <me@kennethreitz.com>
Mon, 17 Dec 2012 09:31:26 +0000 (04:31 -0500)
committerKenneth Reitz <me@kennethreitz.com>
Mon, 17 Dec 2012 09:31:26 +0000 (04:31 -0500)
requests/auth.py
requests/hooks.py
requests/models.py

index 9724c7f..493198b 100644 (file)
@@ -97,7 +97,7 @@ class OAuth1(AuthBase):
                                                    unicode(r.method),
                                                    None,
                                                    r.headers)
-        elif (decoded_body is not None and 
+        elif (decoded_body is not None and
                                   contenttype == CONTENT_TYPE_FORM_URLENCODED):
             # If the Content-Type header is urlencoded and there are no
             # illegal characters in the body, assume that the content actually
@@ -269,127 +269,4 @@ class HTTPDigestAuth(AuthBase):
         if self.last_nonce:
             r.headers['Authorization'] = self.build_digest_header(r.method, r.url)
         r.register_hook('response', self.handle_401)
-        return r
-
-
-def _negotiate_value(r):
-    """Extracts the gssapi authentication token from the appropriate header"""
-
-    authreq = r.headers.get('www-authenticate', None)
-
-    if authreq:
-        rx = re.compile('(?:.*,)*\s*Negotiate\s*([^,]*),?', re.I)
-        mo = rx.search(authreq)
-        if mo:
-            return mo.group(1)
-
-    return None
-
-
-class HTTPKerberosAuth(AuthBase):
-    """Attaches HTTP GSSAPI/Kerberos Authentication to the given Request object."""
-    def __init__(self, require_mutual_auth=True):
-        if k is None:
-            raise Exception("Kerberos libraries unavailable")
-        self.context = None
-        self.require_mutual_auth = require_mutual_auth
-
-    def generate_request_header(self, r):
-        """Generates the gssapi authentication token with kerberos"""
-
-        host = urlparse(r.url).netloc
-        tail, _, head = host.rpartition(':')
-        domain = tail if tail else head
-
-        result, self.context = k.authGSSClientInit("HTTP@%s" % domain)
-
-        if result < 1:
-            raise Exception("authGSSClientInit failed")
-
-        result = k.authGSSClientStep(self.context, _negotiate_value(r))
-
-        if result < 0:
-            raise Exception("authGSSClientStep failed")
-
-        response = k.authGSSClientResponse(self.context)
-
-        return "Negotiate %s" % response
-
-    def authenticate_user(self, r):
-        """Handles user authentication with gssapi/kerberos"""
-
-        auth_header = self.generate_request_header(r)
-        log.debug("authenticate_user(): Authorization header: %s" % auth_header)
-        r.request.headers['Authorization'] = auth_header
-        r.request.send(anyway=True)
-        _r = r.request.response
-        _r.history.append(r)
-        log.debug("authenticate_user(): returning %s" % _r)
-        return _r
-
-    def handle_401(self, r):
-        """Handles 401's, attempts to use gssapi/kerberos authentication"""
-
-        log.debug("handle_401(): Handling: 401")
-        if _negotiate_value(r) is not None:
-            _r = self.authenticate_user(r)
-            log.debug("handle_401(): returning %s" % _r)
-            return _r
-        else:
-            log.debug("handle_401(): Kerberos is not supported")
-            log.debug("handle_401(): returning %s" % r)
-            return r
-
-    def handle_other(self, r):
-        """Handles all responses with the exception of 401s.
-
-        This is necessary so that we can authenticate responses if requested"""
-
-        log.debug("handle_other(): Handling: %d" % r.status_code)
-        self.deregister(r)
-        if self.require_mutual_auth:
-            if _negotiate_value(r) is not None:
-                log.debug("handle_other(): Authenticating the server")
-                _r = self.authenticate_server(r)
-                log.debug("handle_other(): returning %s" % _r)
-                return _r
-            else:
-                log.error("handle_other(): Mutual authentication failed")
-                raise Exception("Mutual authentication failed")
-        else:
-            log.debug("handle_other(): returning %s" % r)
-            return r
-
-    def authenticate_server(self, r):
-        """Uses GSSAPI to authenticate the server"""
-
-        log.debug("authenticate_server(): Authenticate header: %s" % _negotiate_value(r))
-        result = k.authGSSClientStep(self.context, _negotiate_value(r))
-        if  result < 1:
-            raise Exception("authGSSClientStep failed")
-        _r = r.request.response
-        log.debug("authenticate_server(): returning %s" % _r)
-        return _r
-
-    def handle_response(self, r):
-        """Takes the given response and tries kerberos-auth, as needed."""
-
-        if r.status_code == 401:
-            _r = self.handle_401(r)
-            log.debug("handle_response returning %s" % _r)
-            return _r
-        else:
-            _r = self.handle_other(r)
-            log.debug("handle_response returning %s" % _r)
-            return _r
-
-        log.debug("handle_response returning %s" % r)
-        return r
-
-    def deregister(self, r):
-        """Deregisters the response handler"""
-        r.request.deregister_hook('response', self.handle_response)
-
-    def __call__(self, r):
-        r.register_hook('response', self.handle_response)
-        return r
+        return r
\ No newline at end of file
index 75d0b9c..6135033 100644 (file)
@@ -14,7 +14,14 @@ Available hooks:
 """
 
 
-HOOKS = ('response')
+HOOKS = ['response']
+
+def default_hooks():
+    hooks = {}
+    for event in HOOKS:
+        hooks[event] = []
+    return hooks
+
 # TODO: response is the only one
 
 def dispatch_hook(key, hooks, hook_data):
index 316dc40..49c8820 100644 (file)
@@ -9,13 +9,13 @@ This module contains the primary objects that power Requests.
 
 # import os
 # import socket
-import collections
+import collections
 import logging
 
 # from datetime import datetime
 from io import BytesIO
 
-from .hooks import dispatch_hook, HOOKS
+from .hooks import dispatch_hook, default_hooks
 from .structures import CaseInsensitiveDict
 from .status_codes import codes
 
@@ -46,7 +46,7 @@ CONTENT_CHUNK_SIZE = 10 * 1024
 log = logging.getLogger(__name__)
 
 
-class RequestMixin(object):
+class RequestEncodingMixin(object):
 
     @property
     def path_url(self):
@@ -143,7 +143,30 @@ class RequestMixin(object):
         return body, content_type
 
 
-class Request(object):
+class RequestHooksMixin(object):
+    def register_hook(self, event, hook):
+        """Properly register a hook."""
+        print self
+        print event
+        print hook
+        if isinstance(hook, collections.Callable):
+            self.hooks[event].append(hook)
+        elif hasattr(hook, '__iter__'):
+            self.hooks[event].extend(h for h in hook if isinstance(h, collections.Callable))
+
+    def deregister_hook(self, event, hook):
+        """Deregister a previously registered hook.
+        Returns True if the hook existed, False if not.
+        """
+
+        try:
+            self.hooks[event].remove(hook)
+            return True
+        except ValueError:
+            return False
+
+
+class Request(RequestHooksMixin):
     """A user-created :class:`Request <Request>` object."""
     def __init__(self,
         method=None,
@@ -155,12 +178,18 @@ class Request(object):
         auth=None,
         cookies=None,
         timeout=None,
-        allow_redirects=False,
-        proxies=None,
-        hooks=None,
-        prefetch=True,
-        verify=None,
-        cert=None):
+        hooks=None):
+
+        # Default empty dicts for dict params.
+        data = [] if data is None else data
+        files = [] if files is None else files
+        headers = {} if headers is None else headers
+        params = {} if params is None else params
+        hooks = {} if hooks is None else hooks
+
+        self.hooks = default_hooks()
+        for (k, v) in list(hooks.items()):
+            self.register_hook(event=k, hook=v)
 
         self.method = method
         self.url = url
@@ -170,8 +199,8 @@ class Request(object):
         self.params = params
         self.auth = auth
         self.cookies = cookies
-        self.allow_redirects = allow_redirects
-        self.proxies = proxies
+        self.allow_redirects = allow_redirects
+        self.proxies = proxies
         self.hooks = hooks
 
     def __repr__(self):
@@ -191,7 +220,7 @@ class Request(object):
         return p
 
 
-class PreparedRequest(RequestMixin):
+class PreparedRequest(RequestEncodingMixin, RequestHooksMixin):
     """The :class:`PreparedRequest <PreparedRequest>` object."""
 
     def __init__(self):
@@ -203,7 +232,7 @@ class PreparedRequest(RequestMixin):
         self.auth = None
         self.allow_redirects = None
         self.proxies = None
-        self.hooks = None
+        self.hooks = default_hooks()
 
     def __repr__(self):
         return '<PreparedRequest [%s]>' % (self.method)
@@ -313,12 +342,16 @@ class PreparedRequest(RequestMixin):
     def prepare_auth(self, auth):
         """Prepares the given HTTP auth data."""
         if auth:
+            # print auth
             if isinstance(auth, tuple) and len(auth) == 2:
                 # special-case basic HTTP auth
                 auth = HTTPBasicAuth(*auth)
 
             # Allow auth to make its changes.
             r = auth(self)
+            # print r
+            # print r.__dict__
+            # print self.__dict__
 
             # Update self to reflect the auth changes.
             self.__dict__.update(r.__dict__)