1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
5 Tests for L{twisted.cred}, now with 30% more starch.
10 from zope.interface import implements, Interface
12 from twisted.trial import unittest
13 from twisted.cred import portal, checkers, credentials, error
14 from twisted.python import components
15 from twisted.internet import defer
16 from twisted.internet.defer import deferredGenerator as dG, waitForDeferred as wFD
19 from crypt import crypt
24 from twisted.cred.pamauth import callIntoPAM
28 from twisted.cred import pamauth
31 class ITestable(Interface):
35 def __init__(self, name):
38 self.loggedOut = False
41 assert not self.loggedIn
47 class Testable(components.Adapter):
50 # components.Interface(TestAvatar).adaptWith(Testable, ITestable)
52 components.registerAdapter(Testable, TestAvatar, ITestable)
54 class IDerivedCredentials(credentials.IUsernamePassword):
57 class DerivedCredentials(object):
58 implements(IDerivedCredentials, ITestable)
60 def __init__(self, username, password):
61 self.username = username
62 self.password = password
64 def checkPassword(self, password):
65 return password == self.password
69 implements(portal.IRealm)
73 def requestAvatar(self, avatarId, mind, *interfaces):
74 if self.avatars.has_key(avatarId):
75 avatar = self.avatars[avatarId]
77 avatar = TestAvatar(avatarId)
78 self.avatars[avatarId] = avatar
80 return (interfaces[0], interfaces[0](avatar),
83 class NewCredTest(unittest.TestCase):
85 r = self.realm = TestRealm()
86 p = self.portal = portal.Portal(r)
87 up = self.checker = checkers.InMemoryUsernamePasswordDatabaseDontUse()
88 up.addUser("bob", "hello")
91 def testListCheckers(self):
92 expected = [credentials.IUsernamePassword, credentials.IUsernameHashedPassword]
93 got = self.portal.listCredentialsInterfaces()
96 self.assertEqual(got, expected)
98 def testBasicLogin(self):
100 self.portal.login(credentials.UsernamePassword("bob", "hello"),
101 self, ITestable).addCallback(
102 l.append).addErrback(f.append)
105 # print l[0].getBriefTraceback()
106 iface, impl, logout = l[0]
108 self.assertEqual(iface, ITestable)
109 self.failUnless(iface.providedBy(impl),
110 "%s does not implement %s" % (impl, iface))
112 self.failUnless(impl.original.loggedIn)
113 self.failUnless(not impl.original.loggedOut)
115 self.failUnless(impl.original.loggedOut)
117 def test_derivedInterface(self):
119 Login with credentials implementing an interface inheriting from an
120 interface registered with a checker (but not itself registered).
124 self.portal.login(DerivedCredentials("bob", "hello"), self, ITestable
125 ).addCallback(l.append
126 ).addErrback(f.append)
129 iface, impl, logout = l[0]
131 self.assertEqual(iface, ITestable)
132 self.failUnless(iface.providedBy(impl),
133 "%s does not implement %s" % (impl, iface))
135 self.failUnless(impl.original.loggedIn)
136 self.failUnless(not impl.original.loggedOut)
138 self.failUnless(impl.original.loggedOut)
140 def testFailedLogin(self):
142 self.portal.login(credentials.UsernamePassword("bob", "h3llo"),
143 self, ITestable).addErrback(
144 lambda x: x.trap(error.UnauthorizedLogin)).addCallback(l.append)
146 self.assertEqual(error.UnauthorizedLogin, l[0])
148 def testFailedLoginName(self):
150 self.portal.login(credentials.UsernamePassword("jay", "hello"),
151 self, ITestable).addErrback(
152 lambda x: x.trap(error.UnauthorizedLogin)).addCallback(l.append)
154 self.assertEqual(error.UnauthorizedLogin, l[0])
157 class CramMD5CredentialsTestCase(unittest.TestCase):
158 def testIdempotentChallenge(self):
159 c = credentials.CramMD5Credentials()
160 chal = c.getChallenge()
161 self.assertEqual(chal, c.getChallenge())
163 def testCheckPassword(self):
164 c = credentials.CramMD5Credentials()
165 chal = c.getChallenge()
166 c.response = hmac.HMAC('secret', chal).hexdigest()
167 self.failUnless(c.checkPassword('secret'))
169 def testWrongPassword(self):
170 c = credentials.CramMD5Credentials()
171 self.failIf(c.checkPassword('secret'))
173 class OnDiskDatabaseTestCase(unittest.TestCase):
181 def testUserLookup(self):
182 dbfile = self.mktemp()
183 db = checkers.FilePasswordDB(dbfile)
184 f = file(dbfile, 'w')
185 for (u, p) in self.users:
186 f.write('%s:%s\n' % (u, p))
189 for (u, p) in self.users:
190 self.failUnlessRaises(KeyError, db.getUser, u.upper())
191 self.assertEqual(db.getUser(u), (u, p))
193 def testCaseInSensitivity(self):
194 dbfile = self.mktemp()
195 db = checkers.FilePasswordDB(dbfile, caseSensitive=0)
196 f = file(dbfile, 'w')
197 for (u, p) in self.users:
198 f.write('%s:%s\n' % (u, p))
201 for (u, p) in self.users:
202 self.assertEqual(db.getUser(u.upper()), (u, p))
204 def testRequestAvatarId(self):
205 dbfile = self.mktemp()
206 db = checkers.FilePasswordDB(dbfile, caseSensitive=0)
207 f = file(dbfile, 'w')
208 for (u, p) in self.users:
209 f.write('%s:%s\n' % (u, p))
211 creds = [credentials.UsernamePassword(u, p) for u, p in self.users]
212 d = defer.gatherResults(
213 [defer.maybeDeferred(db.requestAvatarId, c) for c in creds])
214 d.addCallback(self.assertEqual, [u for u, p in self.users])
217 def testRequestAvatarId_hashed(self):
218 dbfile = self.mktemp()
219 db = checkers.FilePasswordDB(dbfile, caseSensitive=0)
220 f = file(dbfile, 'w')
221 for (u, p) in self.users:
222 f.write('%s:%s\n' % (u, p))
224 creds = [credentials.UsernameHashedPassword(u, p) for u, p in self.users]
225 d = defer.gatherResults(
226 [defer.maybeDeferred(db.requestAvatarId, c) for c in creds])
227 d.addCallback(self.assertEqual, [u for u, p in self.users])
232 class HashedPasswordOnDiskDatabaseTestCase(unittest.TestCase):
240 def hash(self, u, p, s):
244 dbfile = self.mktemp()
245 self.db = checkers.FilePasswordDB(dbfile, hash=self.hash)
246 f = file(dbfile, 'w')
247 for (u, p) in self.users:
248 f.write('%s:%s\n' % (u, crypt(p, u[:2])))
251 self.port = portal.Portal(r)
252 self.port.registerChecker(self.db)
254 def testGoodCredentials(self):
255 goodCreds = [credentials.UsernamePassword(u, p) for u, p in self.users]
256 d = defer.gatherResults([self.db.requestAvatarId(c) for c in goodCreds])
257 d.addCallback(self.assertEqual, [u for u, p in self.users])
260 def testGoodCredentials_login(self):
261 goodCreds = [credentials.UsernamePassword(u, p) for u, p in self.users]
262 d = defer.gatherResults([self.port.login(c, None, ITestable)
264 d.addCallback(lambda x: [a.original.name for i, a, l in x])
265 d.addCallback(self.assertEqual, [u for u, p in self.users])
268 def testBadCredentials(self):
269 badCreds = [credentials.UsernamePassword(u, 'wrong password')
270 for u, p in self.users]
271 d = defer.DeferredList([self.port.login(c, None, ITestable)
272 for c in badCreds], consumeErrors=True)
273 d.addCallback(self._assertFailures, error.UnauthorizedLogin)
276 def testHashedCredentials(self):
277 hashedCreds = [credentials.UsernameHashedPassword(u, crypt(p, u[:2]))
278 for u, p in self.users]
279 d = defer.DeferredList([self.port.login(c, None, ITestable)
280 for c in hashedCreds], consumeErrors=True)
281 d.addCallback(self._assertFailures, error.UnhandledCredentials)
284 def _assertFailures(self, failures, *expectedFailures):
285 for flag, failure in failures:
286 self.assertEqual(flag, defer.FAILURE)
287 failure.trap(*expectedFailures)
291 skip = "crypt module not available"
293 class PluggableAuthenticationModulesTest(unittest.TestCase):
297 Replace L{pamauth.callIntoPAM} with a dummy implementation with
298 easily-controlled behavior.
300 self._oldCallIntoPAM = pamauth.callIntoPAM
301 pamauth.callIntoPAM = self.callIntoPAM
306 Restore the original value of L{pamauth.callIntoPAM}.
308 pamauth.callIntoPAM = self._oldCallIntoPAM
311 def callIntoPAM(self, service, user, conv):
312 if service != 'Twisted':
313 raise error.UnauthorizedLogin('bad service: %s' % service)
314 if user != 'testuser':
315 raise error.UnauthorizedLogin('bad username: %s' % user)
318 (2, "Message w/ Input"),
319 (3, "Message w/o Input"),
321 replies = conv(questions)
327 raise error.UnauthorizedLogin('bad conversion: %s' % repr(replies))
330 def _makeConv(self, d):
332 return defer.succeed([(d[t], 0) for t, q in questions])
335 def testRequestAvatarId(self):
336 db = checkers.PluggableAuthenticationModulesChecker()
337 conv = self._makeConv({1:'password', 2:'entry', 3:''})
338 creds = credentials.PluggableAuthenticationModules('testuser',
340 d = db.requestAvatarId(creds)
341 d.addCallback(self.assertEqual, 'testuser')
344 def testBadCredentials(self):
345 db = checkers.PluggableAuthenticationModulesChecker()
346 conv = self._makeConv({1:'', 2:'', 3:''})
347 creds = credentials.PluggableAuthenticationModules('testuser',
349 d = db.requestAvatarId(creds)
350 self.assertFailure(d, error.UnauthorizedLogin)
353 def testBadUsername(self):
354 db = checkers.PluggableAuthenticationModulesChecker()
355 conv = self._makeConv({1:'password', 2:'entry', 3:''})
356 creds = credentials.PluggableAuthenticationModules('baduser',
358 d = db.requestAvatarId(creds)
359 self.assertFailure(d, error.UnauthorizedLogin)
363 skip = "Can't run without PyPAM"
366 def testPositive(self):
367 for chk in self.getCheckers():
368 for (cred, avatarId) in self.getGoodCredentials():
369 r = wFD(chk.requestAvatarId(cred))
371 self.assertEqual(r.getResult(), avatarId)
372 testPositive = dG(testPositive)
374 def testNegative(self):
375 for chk in self.getCheckers():
376 for cred in self.getBadCredentials():
377 r = wFD(chk.requestAvatarId(cred))
379 self.assertRaises(error.UnauthorizedLogin, r.getResult)
380 testNegative = dG(testNegative)
382 class HashlessFilePasswordDBMixin:
383 credClass = credentials.UsernamePassword
385 networkHash = staticmethod(lambda x: x)
387 _validCredentials = [
388 ('user1', 'password1'),
389 ('user2', 'password2'),
390 ('user3', 'password3')]
392 def getGoodCredentials(self):
393 for u, p in self._validCredentials:
394 yield self.credClass(u, self.networkHash(p)), u
396 def getBadCredentials(self):
397 for u, p in [('user1', 'password3'),
398 ('user2', 'password1'),
400 yield self.credClass(u, self.networkHash(p))
402 def getCheckers(self):
403 diskHash = self.diskHash or (lambda x: x)
404 hashCheck = self.diskHash and (lambda username, password, stored: self.diskHash(password))
406 for cache in True, False:
409 for u, p in self._validCredentials:
410 fObj.write('%s:%s\n' % (u, diskHash(p)))
412 yield checkers.FilePasswordDB(fn, cache=cache, hash=hashCheck)
416 for u, p in self._validCredentials:
417 fObj.write('%s dingle dongle %s\n' % (diskHash(p), u))
419 yield checkers.FilePasswordDB(fn, ' ', 3, 0, cache=cache, hash=hashCheck)
423 for u, p in self._validCredentials:
424 fObj.write('zip,zap,%s,zup,%s\n' % (u.title(), diskHash(p)))
426 yield checkers.FilePasswordDB(fn, ',', 2, 4, False, cache=cache, hash=hashCheck)
428 class LocallyHashedFilePasswordDBMixin(HashlessFilePasswordDBMixin):
429 diskHash = staticmethod(lambda x: x.encode('hex'))
431 class NetworkHashedFilePasswordDBMixin(HashlessFilePasswordDBMixin):
432 networkHash = staticmethod(lambda x: x.encode('hex'))
433 class credClass(credentials.UsernameHashedPassword):
434 def checkPassword(self, password):
435 return self.hashed.decode('hex') == password
437 class HashlessFilePasswordDBCheckerTestCase(HashlessFilePasswordDBMixin, CheckersMixin, unittest.TestCase):
440 class LocallyHashedFilePasswordDBCheckerTestCase(LocallyHashedFilePasswordDBMixin, CheckersMixin, unittest.TestCase):
443 class NetworkHashedFilePasswordDBCheckerTestCase(NetworkHashedFilePasswordDBMixin, CheckersMixin, unittest.TestCase):