Imported Upstream version 12.1.0
[contrib/python-twisted.git] / doc / core / examples / dbcred.py
1 #!/usr/bin/env python
2
3 # Copyright (c) Twisted Matrix Laboratories.
4 # See LICENSE for details.
5
6 """
7 Simple example of a db checker: define a L{ICredentialsChecker} implementation
8 that deals with a database backend to authenticate a user.
9 """
10
11 from twisted.cred import error
12 from twisted.cred.credentials import IUsernameHashedPassword, IUsernamePassword
13 from twisted.cred.checkers import ICredentialsChecker
14 from twisted.internet.defer import Deferred
15
16 from zope.interface import implements
17
18
19 class DBCredentialsChecker(object):
20     """
21     This class checks the credentials of incoming connections
22     against a user table in a database.
23     """
24     implements(ICredentialsChecker)
25
26     def __init__(self, runQuery,
27         query="SELECT username, password FROM user WHERE username = %s",
28         customCheckFunc=None, caseSensitivePasswords=True):
29         """
30         @param runQuery: This will be called to get the info from the db.
31             Generally you'd want to create a
32             L{twisted.enterprice.adbapi.ConnectionPool} and pass it's runQuery
33             method here. Otherwise pass a function with the same prototype.
34         @type runQuery: C{callable}
35
36         @type query: query used to authenticate user.
37         @param query: C{str}
38
39         @param customCheckFunc: Use this if the passwords in the db are stored
40             as hashes. We'll just call this, so you can do the checking
41             yourself. It takes the following params:
42             (username, suppliedPass, dbPass) and must return a boolean.
43         @type customCheckFunc: C{callable}
44
45         @param caseSensitivePasswords: If true requires that every letter in
46             C{credentials.password} is exactly the same case as the it's
47             counterpart letter in the database.
48             This is only relevant if C{customCheckFunc} is not used.
49         @type caseSensitivePasswords: C{bool}
50         """
51         self.runQuery = runQuery
52         self.caseSensitivePasswords = caseSensitivePasswords
53         self.customCheckFunc = customCheckFunc
54         # We can't support hashed password credentials if we only have a hash
55         # in the DB
56         if customCheckFunc:
57             self.credentialInterfaces = (IUsernamePassword,)
58         else:
59             self.credentialInterfaces = (
60                 IUsernamePassword, IUsernameHashedPassword,)
61
62         self.sql = query
63
64     def requestAvatarId(self, credentials):
65         """
66         Authenticates the kiosk against the database.
67         """
68         # Check that the credentials instance implements at least one of our
69         # interfaces
70         for interface in self.credentialInterfaces:
71             if interface.providedBy(credentials):
72                 break
73         else:
74             raise error.UnhandledCredentials()
75         # Ask the database for the username and password
76         dbDeferred = self.runQuery(self.sql, (credentials.username,))
77         # Setup our deferred result
78         deferred = Deferred()
79         dbDeferred.addCallbacks(self._cbAuthenticate, self._ebAuthenticate,
80                 callbackArgs=(credentials, deferred),
81                 errbackArgs=(credentials, deferred))
82         return deferred
83
84     def _cbAuthenticate(self, result, credentials, deferred):
85         """
86         Checks to see if authentication was good. Called once the info has
87         been retrieved from the DB.
88         """
89         if len(result) == 0:
90             # Username not found in db
91             deferred.errback(error.UnauthorizedLogin('Username unknown'))
92         else:
93             username, password = result[0]
94             if self.customCheckFunc:
95                 # Let the owner do the checking
96                 if self.customCheckFunc(
97                         username, credentials.password, password):
98                     deferred.callback(credentials.username)
99                 else:
100                     deferred.errback(
101                         error.UnauthorizedLogin('Password mismatch'))
102             else:
103                 # It's up to us or the credentials object to do the checking
104                 # now
105                 if IUsernameHashedPassword.providedBy(credentials):
106                     # Let the hashed password checker do the checking
107                     if credentials.checkPassword(password):
108                         deferred.callback(credentials.username)
109                     else:
110                         deferred.errback(
111                             error.UnauthorizedLogin('Password mismatch'))
112                 elif IUsernamePassword.providedBy(credentials):
113                     # Compare the passwords, deciging whether or not to use
114                     # case sensitivity
115                     if self.caseSensitivePasswords:
116                         passOk = (
117                             password.lower() == credentials.password.lower())
118                     else:
119                         passOk = password == credentials.password
120                     # See if they match
121                     if passOk:
122                         deferred.callback(credentials.username)
123                     else:
124                         deferred.errback(
125                             error.UnauthorizedLogin('Password mismatch'))
126                 else:
127                     # OK, we don't know how to check this
128                     deferred.errback(error.UnhandledCredentials())
129
130     def _ebAuthenticate(self, message, credentials, deferred):
131         """
132         The database lookup failed for some reason.
133         """
134         deferred.errback(error.LoginFailed(message))
135
136
137 def main():
138     """
139     Run a simple echo pb server to test the checker. It defines a custom query
140     for dealing with sqlite special quoting, but otherwise it's a
141     straightforward use of the object.
142
143     You can test it running C{pbechoclient.py}.
144     """
145     import sys
146     from twisted.python import log
147     log.startLogging(sys.stdout)
148     import os
149     if os.path.isfile('testcred'):
150         os.remove('testcred')
151     from twisted.enterprise import adbapi
152     pool = adbapi.ConnectionPool('pysqlite2.dbapi2', 'testcred')
153     # Create the table that will be used
154     query1 = """CREATE TABLE user (
155             username string,
156             password string
157         )"""
158     # Insert a test user
159     query2 = """INSERT INTO user VALUES ('guest', 'guest')"""
160     def cb(res):
161         pool.runQuery(query2)
162     pool.runQuery(query1).addCallback(cb)
163
164     checker = DBCredentialsChecker(pool.runQuery,
165         query="SELECT username, password FROM user WHERE username = ?")
166     from twisted.cred.portal import Portal
167
168     import pbecho
169     from twisted.spread import pb
170     portal = Portal(pbecho.SimpleRealm())
171     portal.registerChecker(checker)
172     reactor.listenTCP(pb.portno, pb.PBServerFactory(portal))
173
174
175 if __name__ == "__main__":
176     from twisted.internet import reactor
177     reactor.callWhenRunning(main)
178     reactor.run()
179