Initial import to Tizen
[profile/ivi/python-twisted.git] / twisted / conch / test / test_userauth.py
1 # -*- test-case-name: twisted.conch.test.test_userauth -*-
2 # Copyright (c) Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5 """
6 Tests for the implementation of the ssh-userauth service.
7
8 Maintainer: Paul Swartz
9 """
10
11 from zope.interface import implements
12
13 from twisted.cred.checkers import ICredentialsChecker
14 from twisted.cred.credentials import IUsernamePassword, ISSHPrivateKey
15 from twisted.cred.credentials import IPluggableAuthenticationModules
16 from twisted.cred.credentials import IAnonymous
17 from twisted.cred.error import UnauthorizedLogin
18 from twisted.cred.portal import IRealm, Portal
19 from twisted.conch.error import ConchError, ValidPublicKey
20 from twisted.internet import defer, task
21 from twisted.protocols import loopback
22 from twisted.trial import unittest
23
24 try:
25     import Crypto.Cipher.DES3, Crypto.Cipher.XOR
26     import pyasn1
27 except ImportError:
28     keys = None
29
30
31     class transport:
32         class SSHTransportBase:
33             """
34             A stub class so that later class definitions won't die.
35             """
36
37     class userauth:
38         class SSHUserAuthClient:
39             """
40             A stub class so that leter class definitions won't die.
41             """
42 else:
43     from twisted.conch.ssh.common import NS
44     from twisted.conch.checkers import SSHProtocolChecker
45     from twisted.conch.ssh import keys, userauth, transport
46     from twisted.conch.test import keydata
47
48
49
50 class ClientUserAuth(userauth.SSHUserAuthClient):
51     """
52     A mock user auth client.
53     """
54
55
56     def getPublicKey(self):
57         """
58         If this is the first time we've been called, return a blob for
59         the DSA key.  Otherwise, return a blob
60         for the RSA key.
61         """
62         if self.lastPublicKey:
63             return keys.Key.fromString(keydata.publicRSA_openssh)
64         else:
65             return defer.succeed(keys.Key.fromString(keydata.publicDSA_openssh))
66
67
68     def getPrivateKey(self):
69         """
70         Return the private key object for the RSA key.
71         """
72         return defer.succeed(keys.Key.fromString(keydata.privateRSA_openssh))
73
74
75     def getPassword(self, prompt=None):
76         """
77         Return 'foo' as the password.
78         """
79         return defer.succeed('foo')
80
81
82     def getGenericAnswers(self, name, information, answers):
83         """
84         Return 'foo' as the answer to two questions.
85         """
86         return defer.succeed(('foo', 'foo'))
87
88
89
90 class OldClientAuth(userauth.SSHUserAuthClient):
91     """
92     The old SSHUserAuthClient returned a PyCrypto key object from
93     getPrivateKey() and a string from getPublicKey
94     """
95
96
97     def getPrivateKey(self):
98         return defer.succeed(keys.Key.fromString(
99             keydata.privateRSA_openssh).keyObject)
100
101
102     def getPublicKey(self):
103         return keys.Key.fromString(keydata.publicRSA_openssh).blob()
104
105 class ClientAuthWithoutPrivateKey(userauth.SSHUserAuthClient):
106     """
107     This client doesn't have a private key, but it does have a public key.
108     """
109
110
111     def getPrivateKey(self):
112         return
113
114
115     def getPublicKey(self):
116         return keys.Key.fromString(keydata.publicRSA_openssh)
117
118
119
120 class FakeTransport(transport.SSHTransportBase):
121     """
122     L{userauth.SSHUserAuthServer} expects an SSH transport which has a factory
123     attribute which has a portal attribute. Because the portal is important for
124     testing authentication, we need to be able to provide an interesting portal
125     object to the L{SSHUserAuthServer}.
126
127     In addition, we want to be able to capture any packets sent over the
128     transport.
129
130     @ivar packets: a list of 2-tuples: (messageType, data).  Each 2-tuple is
131         a sent packet.
132     @type packets: C{list}
133     @param lostConnecion: True if loseConnection has been called on us.
134     @type lostConnection: C{bool}
135     """
136
137
138     class Service(object):
139         """
140         A mock service, representing the other service offered by the server.
141         """
142         name = 'nancy'
143
144
145         def serviceStarted(self):
146             pass
147
148
149
150     class Factory(object):
151         """
152         A mock factory, representing the factory that spawned this user auth
153         service.
154         """
155
156
157         def getService(self, transport, service):
158             """
159             Return our fake service.
160             """
161             if service == 'none':
162                 return FakeTransport.Service
163
164
165
166     def __init__(self, portal):
167         self.factory = self.Factory()
168         self.factory.portal = portal
169         self.lostConnection = False
170         self.transport = self
171         self.packets = []
172
173
174
175     def sendPacket(self, messageType, message):
176         """
177         Record the packet sent by the service.
178         """
179         self.packets.append((messageType, message))
180
181
182     def isEncrypted(self, direction):
183         """
184         Pretend that this transport encrypts traffic in both directions. The
185         SSHUserAuthServer disables password authentication if the transport
186         isn't encrypted.
187         """
188         return True
189
190
191     def loseConnection(self):
192         self.lostConnection = True
193
194
195
196 class Realm(object):
197     """
198     A mock realm for testing L{userauth.SSHUserAuthServer}.
199
200     This realm is not actually used in the course of testing, so it returns the
201     simplest thing that could possibly work.
202     """
203     implements(IRealm)
204
205
206     def requestAvatar(self, avatarId, mind, *interfaces):
207         return defer.succeed((interfaces[0], None, lambda: None))
208
209
210
211 class PasswordChecker(object):
212     """
213     A very simple username/password checker which authenticates anyone whose
214     password matches their username and rejects all others.
215     """
216     credentialInterfaces = (IUsernamePassword,)
217     implements(ICredentialsChecker)
218
219
220     def requestAvatarId(self, creds):
221         if creds.username == creds.password:
222             return defer.succeed(creds.username)
223         return defer.fail(UnauthorizedLogin("Invalid username/password pair"))
224
225
226
227 class PrivateKeyChecker(object):
228     """
229     A very simple public key checker which authenticates anyone whose
230     public/private keypair is the same keydata.public/privateRSA_openssh.
231     """
232     credentialInterfaces = (ISSHPrivateKey,)
233     implements(ICredentialsChecker)
234
235
236
237     def requestAvatarId(self, creds):
238         if creds.blob == keys.Key.fromString(keydata.publicRSA_openssh).blob():
239             if creds.signature is not None:
240                 obj = keys.Key.fromString(creds.blob)
241                 if obj.verify(creds.signature, creds.sigData):
242                     return creds.username
243             else:
244                 raise ValidPublicKey()
245         raise UnauthorizedLogin()
246
247
248
249 class PAMChecker(object):
250     """
251     A simple PAM checker which asks the user for a password, verifying them
252     if the password is the same as their username.
253     """
254     credentialInterfaces = (IPluggableAuthenticationModules,)
255     implements(ICredentialsChecker)
256
257
258     def requestAvatarId(self, creds):
259         d = creds.pamConversion([('Name: ', 2), ("Password: ", 1)])
260         def check(values):
261             if values == [(creds.username, 0), (creds.username, 0)]:
262                 return creds.username
263             raise UnauthorizedLogin()
264         return d.addCallback(check)
265
266
267
268 class AnonymousChecker(object):
269     """
270     A simple checker which isn't supported by L{SSHUserAuthServer}.
271     """
272     credentialInterfaces = (IAnonymous,)
273     implements(ICredentialsChecker)
274
275
276
277 class SSHUserAuthServerTestCase(unittest.TestCase):
278     """
279     Tests for SSHUserAuthServer.
280     """
281
282
283     if keys is None:
284         skip = "cannot run w/o PyCrypto"
285
286
287     def setUp(self):
288         self.realm = Realm()
289         self.portal = Portal(self.realm)
290         self.portal.registerChecker(PasswordChecker())
291         self.portal.registerChecker(PrivateKeyChecker())
292         self.portal.registerChecker(PAMChecker())
293         self.authServer = userauth.SSHUserAuthServer()
294         self.authServer.transport = FakeTransport(self.portal)
295         self.authServer.serviceStarted()
296         self.authServer.supportedAuthentications.sort() # give a consistent
297                                                         # order
298
299
300     def tearDown(self):
301         self.authServer.serviceStopped()
302         self.authServer = None
303
304
305     def _checkFailed(self, ignored):
306         """
307         Check that the authentication has failed.
308         """
309         self.assertEqual(self.authServer.transport.packets[-1],
310                 (userauth.MSG_USERAUTH_FAILURE,
311                 NS('keyboard-interactive,password,publickey') + '\x00'))
312
313
314     def test_noneAuthentication(self):
315         """
316         A client may request a list of authentication 'method name' values
317         that may continue by using the "none" authentication 'method name'.
318
319         See RFC 4252 Section 5.2.
320         """
321         d = self.authServer.ssh_USERAUTH_REQUEST(NS('foo') + NS('service') +
322                                                  NS('none'))
323         return d.addCallback(self._checkFailed)
324
325
326     def test_successfulPasswordAuthentication(self):
327         """
328         When provided with correct password authentication information, the
329         server should respond by sending a MSG_USERAUTH_SUCCESS message with
330         no other data.
331
332         See RFC 4252, Section 5.1.
333         """
334         packet = NS('foo') + NS('none') + NS('password') + chr(0) + NS('foo')
335         d = self.authServer.ssh_USERAUTH_REQUEST(packet)
336         def check(ignored):
337             self.assertEqual(
338                 self.authServer.transport.packets,
339                 [(userauth.MSG_USERAUTH_SUCCESS, '')])
340         return d.addCallback(check)
341
342
343     def test_failedPasswordAuthentication(self):
344         """
345         When provided with invalid authentication details, the server should
346         respond by sending a MSG_USERAUTH_FAILURE message which states whether
347         the authentication was partially successful, and provides other, open
348         options for authentication.
349
350         See RFC 4252, Section 5.1.
351         """
352         # packet = username, next_service, authentication type, FALSE, password
353         packet = NS('foo') + NS('none') + NS('password') + chr(0) + NS('bar')
354         self.authServer.clock = task.Clock()
355         d = self.authServer.ssh_USERAUTH_REQUEST(packet)
356         self.assertEqual(self.authServer.transport.packets, [])
357         self.authServer.clock.advance(2)
358         return d.addCallback(self._checkFailed)
359
360
361     def test_successfulPrivateKeyAuthentication(self):
362         """
363         Test that private key authentication completes sucessfully,
364         """
365         blob = keys.Key.fromString(keydata.publicRSA_openssh).blob()
366         obj = keys.Key.fromString(keydata.privateRSA_openssh)
367         packet = (NS('foo') + NS('none') + NS('publickey') + '\xff'
368                 + NS(obj.sshType()) + NS(blob))
369         self.authServer.transport.sessionID = 'test'
370         signature = obj.sign(NS('test') + chr(userauth.MSG_USERAUTH_REQUEST)
371                 + packet)
372         packet += NS(signature)
373         d = self.authServer.ssh_USERAUTH_REQUEST(packet)
374         def check(ignored):
375             self.assertEqual(self.authServer.transport.packets,
376                     [(userauth.MSG_USERAUTH_SUCCESS, '')])
377         return d.addCallback(check)
378
379
380     def test_requestRaisesConchError(self):
381         """
382         ssh_USERAUTH_REQUEST should raise a ConchError if tryAuth returns
383         None. Added to catch a bug noticed by pyflakes.
384         """
385         d = defer.Deferred()
386
387         def mockCbFinishedAuth(self, ignored):
388             self.fail('request should have raised ConochError')
389
390         def mockTryAuth(kind, user, data):
391             return None
392
393         def mockEbBadAuth(reason):
394             d.errback(reason.value)
395
396         self.patch(self.authServer, 'tryAuth', mockTryAuth)
397         self.patch(self.authServer, '_cbFinishedAuth', mockCbFinishedAuth)
398         self.patch(self.authServer, '_ebBadAuth', mockEbBadAuth)
399
400         packet = NS('user') + NS('none') + NS('public-key') + NS('data')
401         # If an error other than ConchError is raised, this will trigger an
402         # exception.
403         self.authServer.ssh_USERAUTH_REQUEST(packet)
404         return self.assertFailure(d, ConchError)
405
406
407     def test_verifyValidPrivateKey(self):
408         """
409         Test that verifying a valid private key works.
410         """
411         blob = keys.Key.fromString(keydata.publicRSA_openssh).blob()
412         packet = (NS('foo') + NS('none') + NS('publickey') + '\x00'
413                 + NS('ssh-rsa') + NS(blob))
414         d = self.authServer.ssh_USERAUTH_REQUEST(packet)
415         def check(ignored):
416             self.assertEqual(self.authServer.transport.packets,
417                     [(userauth.MSG_USERAUTH_PK_OK, NS('ssh-rsa') + NS(blob))])
418         return d.addCallback(check)
419
420
421     def test_failedPrivateKeyAuthenticationWithoutSignature(self):
422         """
423         Test that private key authentication fails when the public key
424         is invalid.
425         """
426         blob = keys.Key.fromString(keydata.publicDSA_openssh).blob()
427         packet = (NS('foo') + NS('none') + NS('publickey') + '\x00'
428                 + NS('ssh-dsa') + NS(blob))
429         d = self.authServer.ssh_USERAUTH_REQUEST(packet)
430         return d.addCallback(self._checkFailed)
431
432
433     def test_failedPrivateKeyAuthenticationWithSignature(self):
434         """
435         Test that private key authentication fails when the public key
436         is invalid.
437         """
438         blob = keys.Key.fromString(keydata.publicRSA_openssh).blob()
439         obj = keys.Key.fromString(keydata.privateRSA_openssh)
440         packet = (NS('foo') + NS('none') + NS('publickey') + '\xff'
441                 + NS('ssh-rsa') + NS(blob) + NS(obj.sign(blob)))
442         self.authServer.transport.sessionID = 'test'
443         d = self.authServer.ssh_USERAUTH_REQUEST(packet)
444         return d.addCallback(self._checkFailed)
445
446
447     def test_successfulPAMAuthentication(self):
448         """
449         Test that keyboard-interactive authentication succeeds.
450         """
451         packet = (NS('foo') + NS('none') + NS('keyboard-interactive')
452                 + NS('') + NS(''))
453         response = '\x00\x00\x00\x02' + NS('foo') + NS('foo')
454         d = self.authServer.ssh_USERAUTH_REQUEST(packet)
455         self.authServer.ssh_USERAUTH_INFO_RESPONSE(response)
456         def check(ignored):
457             self.assertEqual(self.authServer.transport.packets,
458                     [(userauth.MSG_USERAUTH_INFO_REQUEST, (NS('') + NS('')
459                         + NS('') + '\x00\x00\x00\x02' + NS('Name: ') + '\x01'
460                         + NS('Password: ') + '\x00')),
461                      (userauth.MSG_USERAUTH_SUCCESS, '')])
462
463         return d.addCallback(check)
464
465
466     def test_failedPAMAuthentication(self):
467         """
468         Test that keyboard-interactive authentication fails.
469         """
470         packet = (NS('foo') + NS('none') + NS('keyboard-interactive')
471                 + NS('') + NS(''))
472         response = '\x00\x00\x00\x02' + NS('bar') + NS('bar')
473         d = self.authServer.ssh_USERAUTH_REQUEST(packet)
474         self.authServer.ssh_USERAUTH_INFO_RESPONSE(response)
475         def check(ignored):
476             self.assertEqual(self.authServer.transport.packets[0],
477                     (userauth.MSG_USERAUTH_INFO_REQUEST, (NS('') + NS('')
478                         + NS('') + '\x00\x00\x00\x02' + NS('Name: ') + '\x01'
479                         + NS('Password: ') + '\x00')))
480         return d.addCallback(check).addCallback(self._checkFailed)
481
482
483     def test_invalid_USERAUTH_INFO_RESPONSE_not_enough_data(self):
484         """
485         If ssh_USERAUTH_INFO_RESPONSE gets an invalid packet,
486         the user authentication should fail.
487         """
488         packet = (NS('foo') + NS('none') + NS('keyboard-interactive')
489                 + NS('') + NS(''))
490         d = self.authServer.ssh_USERAUTH_REQUEST(packet)
491         self.authServer.ssh_USERAUTH_INFO_RESPONSE(NS('\x00\x00\x00\x00' +
492             NS('hi')))
493         return d.addCallback(self._checkFailed)
494
495
496     def test_invalid_USERAUTH_INFO_RESPONSE_too_much_data(self):
497         """
498         If ssh_USERAUTH_INFO_RESPONSE gets too much data, the user
499         authentication should fail.
500         """
501         packet = (NS('foo') + NS('none') + NS('keyboard-interactive')
502                 + NS('') + NS(''))
503         response = '\x00\x00\x00\x02' + NS('foo') + NS('foo') + NS('foo')
504         d = self.authServer.ssh_USERAUTH_REQUEST(packet)
505         self.authServer.ssh_USERAUTH_INFO_RESPONSE(response)
506         return d.addCallback(self._checkFailed)
507
508
509     def test_onlyOnePAMAuthentication(self):
510         """
511         Because it requires an intermediate message, one can't send a second
512         keyboard-interactive request while the first is still pending.
513         """
514         packet = (NS('foo') + NS('none') + NS('keyboard-interactive')
515                 + NS('') + NS(''))
516         self.authServer.ssh_USERAUTH_REQUEST(packet)
517         self.authServer.ssh_USERAUTH_REQUEST(packet)
518         self.assertEqual(self.authServer.transport.packets[-1][0],
519                 transport.MSG_DISCONNECT)
520         self.assertEqual(self.authServer.transport.packets[-1][1][3],
521                 chr(transport.DISCONNECT_PROTOCOL_ERROR))
522
523
524     def test_ignoreUnknownCredInterfaces(self):
525         """
526         L{SSHUserAuthServer} sets up
527         C{SSHUserAuthServer.supportedAuthentications} by checking the portal's
528         credentials interfaces and mapping them to SSH authentication method
529         strings.  If the Portal advertises an interface that
530         L{SSHUserAuthServer} can't map, it should be ignored.  This is a white
531         box test.
532         """
533         server = userauth.SSHUserAuthServer()
534         server.transport = FakeTransport(self.portal)
535         self.portal.registerChecker(AnonymousChecker())
536         server.serviceStarted()
537         server.serviceStopped()
538         server.supportedAuthentications.sort() # give a consistent order
539         self.assertEqual(server.supportedAuthentications,
540                           ['keyboard-interactive', 'password', 'publickey'])
541
542
543     def test_removePasswordIfUnencrypted(self):
544         """
545         Test that the userauth service does not advertise password
546         authentication if the password would be send in cleartext.
547         """
548         self.assertIn('password', self.authServer.supportedAuthentications)
549         # no encryption
550         clearAuthServer = userauth.SSHUserAuthServer()
551         clearAuthServer.transport = FakeTransport(self.portal)
552         clearAuthServer.transport.isEncrypted = lambda x: False
553         clearAuthServer.serviceStarted()
554         clearAuthServer.serviceStopped()
555         self.failIfIn('password', clearAuthServer.supportedAuthentications)
556         # only encrypt incoming (the direction the password is sent)
557         halfAuthServer = userauth.SSHUserAuthServer()
558         halfAuthServer.transport = FakeTransport(self.portal)
559         halfAuthServer.transport.isEncrypted = lambda x: x == 'in'
560         halfAuthServer.serviceStarted()
561         halfAuthServer.serviceStopped()
562         self.assertIn('password', halfAuthServer.supportedAuthentications)
563
564
565     def test_removeKeyboardInteractiveIfUnencrypted(self):
566         """
567         Test that the userauth service does not advertise keyboard-interactive
568         authentication if the password would be send in cleartext.
569         """
570         self.assertIn('keyboard-interactive',
571                 self.authServer.supportedAuthentications)
572         # no encryption
573         clearAuthServer = userauth.SSHUserAuthServer()
574         clearAuthServer.transport = FakeTransport(self.portal)
575         clearAuthServer.transport.isEncrypted = lambda x: False
576         clearAuthServer.serviceStarted()
577         clearAuthServer.serviceStopped()
578         self.failIfIn('keyboard-interactive',
579                 clearAuthServer.supportedAuthentications)
580         # only encrypt incoming (the direction the password is sent)
581         halfAuthServer = userauth.SSHUserAuthServer()
582         halfAuthServer.transport = FakeTransport(self.portal)
583         halfAuthServer.transport.isEncrypted = lambda x: x == 'in'
584         halfAuthServer.serviceStarted()
585         halfAuthServer.serviceStopped()
586         self.assertIn('keyboard-interactive',
587                 halfAuthServer.supportedAuthentications)
588
589
590     def test_unencryptedConnectionWithoutPasswords(self):
591         """
592         If the L{SSHUserAuthServer} is not advertising passwords, then an
593         unencrypted connection should not cause any warnings or exceptions.
594         This is a white box test.
595         """
596         # create a Portal without password authentication
597         portal = Portal(self.realm)
598         portal.registerChecker(PrivateKeyChecker())
599
600         # no encryption
601         clearAuthServer = userauth.SSHUserAuthServer()
602         clearAuthServer.transport = FakeTransport(portal)
603         clearAuthServer.transport.isEncrypted = lambda x: False
604         clearAuthServer.serviceStarted()
605         clearAuthServer.serviceStopped()
606         self.assertEqual(clearAuthServer.supportedAuthentications,
607                           ['publickey'])
608
609         # only encrypt incoming (the direction the password is sent)
610         halfAuthServer = userauth.SSHUserAuthServer()
611         halfAuthServer.transport = FakeTransport(portal)
612         halfAuthServer.transport.isEncrypted = lambda x: x == 'in'
613         halfAuthServer.serviceStarted()
614         halfAuthServer.serviceStopped()
615         self.assertEqual(clearAuthServer.supportedAuthentications,
616                           ['publickey'])
617
618
619     def test_loginTimeout(self):
620         """
621         Test that the login times out.
622         """
623         timeoutAuthServer = userauth.SSHUserAuthServer()
624         timeoutAuthServer.clock = task.Clock()
625         timeoutAuthServer.transport = FakeTransport(self.portal)
626         timeoutAuthServer.serviceStarted()
627         timeoutAuthServer.clock.advance(11 * 60 * 60)
628         timeoutAuthServer.serviceStopped()
629         self.assertEqual(timeoutAuthServer.transport.packets,
630                 [(transport.MSG_DISCONNECT,
631                 '\x00' * 3 +
632                 chr(transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE) +
633                 NS("you took too long") + NS(''))])
634         self.assertTrue(timeoutAuthServer.transport.lostConnection)
635
636
637     def test_cancelLoginTimeout(self):
638         """
639         Test that stopping the service also stops the login timeout.
640         """
641         timeoutAuthServer = userauth.SSHUserAuthServer()
642         timeoutAuthServer.clock = task.Clock()
643         timeoutAuthServer.transport = FakeTransport(self.portal)
644         timeoutAuthServer.serviceStarted()
645         timeoutAuthServer.serviceStopped()
646         timeoutAuthServer.clock.advance(11 * 60 * 60)
647         self.assertEqual(timeoutAuthServer.transport.packets, [])
648         self.assertFalse(timeoutAuthServer.transport.lostConnection)
649
650
651     def test_tooManyAttempts(self):
652         """
653         Test that the server disconnects if the client fails authentication
654         too many times.
655         """
656         packet = NS('foo') + NS('none') + NS('password') + chr(0) + NS('bar')
657         self.authServer.clock = task.Clock()
658         for i in range(21):
659             d = self.authServer.ssh_USERAUTH_REQUEST(packet)
660             self.authServer.clock.advance(2)
661         def check(ignored):
662             self.assertEqual(self.authServer.transport.packets[-1],
663                 (transport.MSG_DISCONNECT,
664                 '\x00' * 3 +
665                 chr(transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE) +
666                 NS("too many bad auths") + NS('')))
667         return d.addCallback(check)
668
669
670     def test_failIfUnknownService(self):
671         """
672         If the user requests a service that we don't support, the
673         authentication should fail.
674         """
675         packet = NS('foo') + NS('') + NS('password') + chr(0) + NS('foo')
676         self.authServer.clock = task.Clock()
677         d = self.authServer.ssh_USERAUTH_REQUEST(packet)
678         return d.addCallback(self._checkFailed)
679
680
681     def test__pamConvErrors(self):
682         """
683         _pamConv should fail if it gets a message that's not 1 or 2.
684         """
685         def secondTest(ignored):
686             d2 = self.authServer._pamConv([('', 90)])
687             return self.assertFailure(d2, ConchError)
688
689         d = self.authServer._pamConv([('', 3)])
690         return self.assertFailure(d, ConchError).addCallback(secondTest)
691
692
693     def test_tryAuthEdgeCases(self):
694         """
695         tryAuth() has two edge cases that are difficult to reach.
696
697         1) an authentication method auth_* returns None instead of a Deferred.
698         2) an authentication type that is defined does not have a matching
699            auth_* method.
700
701         Both these cases should return a Deferred which fails with a
702         ConchError.
703         """
704         def mockAuth(packet):
705             return None
706
707         self.patch(self.authServer, 'auth_publickey', mockAuth) # first case
708         self.patch(self.authServer, 'auth_password', None) # second case
709
710         def secondTest(ignored):
711             d2 = self.authServer.tryAuth('password', None, None)
712             return self.assertFailure(d2, ConchError)
713
714         d1 = self.authServer.tryAuth('publickey', None, None)
715         return self.assertFailure(d1, ConchError).addCallback(secondTest)
716
717
718
719
720 class SSHUserAuthClientTestCase(unittest.TestCase):
721     """
722     Tests for SSHUserAuthClient.
723     """
724
725
726     if keys is None:
727         skip = "cannot run w/o PyCrypto"
728
729
730     def setUp(self):
731         self.authClient = ClientUserAuth('foo', FakeTransport.Service())
732         self.authClient.transport = FakeTransport(None)
733         self.authClient.transport.sessionID = 'test'
734         self.authClient.serviceStarted()
735
736
737     def tearDown(self):
738         self.authClient.serviceStopped()
739         self.authClient = None
740
741
742     def test_init(self):
743         """
744         Test that client is initialized properly.
745         """
746         self.assertEqual(self.authClient.user, 'foo')
747         self.assertEqual(self.authClient.instance.name, 'nancy')
748         self.assertEqual(self.authClient.transport.packets,
749                 [(userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy')
750                     + NS('none'))])
751
752
753     def test_USERAUTH_SUCCESS(self):
754         """
755         Test that the client succeeds properly.
756         """
757         instance = [None]
758         def stubSetService(service):
759             instance[0] = service
760         self.authClient.transport.setService = stubSetService
761         self.authClient.ssh_USERAUTH_SUCCESS('')
762         self.assertEqual(instance[0], self.authClient.instance)
763
764
765     def test_publickey(self):
766         """
767         Test that the client can authenticate with a public key.
768         """
769         self.authClient.ssh_USERAUTH_FAILURE(NS('publickey') + '\x00')
770         self.assertEqual(self.authClient.transport.packets[-1],
771                 (userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy')
772                     + NS('publickey') + '\x00' + NS('ssh-dss')
773                     + NS(keys.Key.fromString(
774                         keydata.publicDSA_openssh).blob())))
775        # that key isn't good
776         self.authClient.ssh_USERAUTH_FAILURE(NS('publickey') + '\x00')
777         blob = NS(keys.Key.fromString(keydata.publicRSA_openssh).blob())
778         self.assertEqual(self.authClient.transport.packets[-1],
779                 (userauth.MSG_USERAUTH_REQUEST, (NS('foo') + NS('nancy')
780                     + NS('publickey') + '\x00'+ NS('ssh-rsa') + blob)))
781         self.authClient.ssh_USERAUTH_PK_OK(NS('ssh-rsa')
782             + NS(keys.Key.fromString(keydata.publicRSA_openssh).blob()))
783         sigData = (NS(self.authClient.transport.sessionID)
784                 + chr(userauth.MSG_USERAUTH_REQUEST) + NS('foo')
785                 + NS('nancy') + NS('publickey') + '\x01' + NS('ssh-rsa')
786                 + blob)
787         obj = keys.Key.fromString(keydata.privateRSA_openssh)
788         self.assertEqual(self.authClient.transport.packets[-1],
789                 (userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy')
790                     + NS('publickey') + '\x01' + NS('ssh-rsa') + blob
791                     + NS(obj.sign(sigData))))
792
793
794     def test_publickey_without_privatekey(self):
795         """
796         If the SSHUserAuthClient doesn't return anything from signData,
797         the client should start the authentication over again by requesting
798         'none' authentication.
799         """
800         authClient = ClientAuthWithoutPrivateKey('foo',
801                                                  FakeTransport.Service())
802
803         authClient.transport = FakeTransport(None)
804         authClient.transport.sessionID = 'test'
805         authClient.serviceStarted()
806         authClient.tryAuth('publickey')
807         authClient.transport.packets = []
808         self.assertIdentical(authClient.ssh_USERAUTH_PK_OK(''), None)
809         self.assertEqual(authClient.transport.packets, [
810                 (userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy') +
811                  NS('none'))])
812
813
814     def test_old_publickey_getPublicKey(self):
815         """
816         Old SSHUserAuthClients returned strings of public key blobs from
817         getPublicKey().  Test that a Deprecation warning is raised but the key is
818         verified correctly.
819         """
820         oldAuth = OldClientAuth('foo', FakeTransport.Service())
821         oldAuth.transport = FakeTransport(None)
822         oldAuth.transport.sessionID = 'test'
823         oldAuth.serviceStarted()
824         oldAuth.transport.packets = []
825         self.assertWarns(DeprecationWarning, "Returning a string from "
826                          "SSHUserAuthClient.getPublicKey() is deprecated since "
827                          "Twisted 9.0.  Return a keys.Key() instead.",
828                          userauth.__file__, oldAuth.tryAuth, 'publickey')
829         self.assertEqual(oldAuth.transport.packets, [
830                 (userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy') +
831                  NS('publickey') + '\x00' + NS('ssh-rsa') +
832                  NS(keys.Key.fromString(keydata.publicRSA_openssh).blob()))])
833
834
835     def test_old_publickey_getPrivateKey(self):
836         """
837         Old SSHUserAuthClients returned a PyCrypto key object from
838         getPrivateKey().  Test that _cbSignData signs the data warns the
839         user about the deprecation, but signs the data correctly.
840         """
841         oldAuth = OldClientAuth('foo', FakeTransport.Service())
842         d = self.assertWarns(DeprecationWarning, "Returning a PyCrypto key "
843                              "object from SSHUserAuthClient.getPrivateKey() is "
844                              "deprecated since Twisted 9.0.  "
845                              "Return a keys.Key() instead.", userauth.__file__,
846                              oldAuth.signData, None, 'data')
847         def _checkSignedData(sig):
848             self.assertEqual(sig,
849                 keys.Key.fromString(keydata.privateRSA_openssh).sign(
850                     'data'))
851         d.addCallback(_checkSignedData)
852         return d
853
854
855     def test_no_publickey(self):
856         """
857         If there's no public key, auth_publickey should return a Deferred
858         called back with a False value.
859         """
860         self.authClient.getPublicKey = lambda x: None
861         d = self.authClient.tryAuth('publickey')
862         def check(result):
863             self.assertFalse(result)
864         return d.addCallback(check)
865
866     def test_password(self):
867         """
868         Test that the client can authentication with a password.  This
869         includes changing the password.
870         """
871         self.authClient.ssh_USERAUTH_FAILURE(NS('password') + '\x00')
872         self.assertEqual(self.authClient.transport.packets[-1],
873                 (userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy')
874                     + NS('password') + '\x00' + NS('foo')))
875         self.authClient.ssh_USERAUTH_PK_OK(NS('') + NS(''))
876         self.assertEqual(self.authClient.transport.packets[-1],
877                 (userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy')
878                     + NS('password') + '\xff' + NS('foo') * 2))
879
880
881     def test_no_password(self):
882         """
883         If getPassword returns None, tryAuth should return False.
884         """
885         self.authClient.getPassword = lambda: None
886         self.assertFalse(self.authClient.tryAuth('password'))
887
888
889     def test_keyboardInteractive(self):
890         """
891         Test that the client can authenticate using keyboard-interactive
892         authentication.
893         """
894         self.authClient.ssh_USERAUTH_FAILURE(NS('keyboard-interactive')
895                + '\x00')
896         self.assertEqual(self.authClient.transport.packets[-1],
897                 (userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy')
898                     + NS('keyboard-interactive') + NS('')*2))
899         self.authClient.ssh_USERAUTH_PK_OK(NS('')*3 + '\x00\x00\x00\x02'
900                 + NS('Name: ') + '\xff' + NS('Password: ') + '\x00')
901         self.assertEqual(self.authClient.transport.packets[-1],
902                 (userauth.MSG_USERAUTH_INFO_RESPONSE, '\x00\x00\x00\x02'
903                     + NS('foo')*2))
904
905
906     def test_USERAUTH_PK_OK_unknown_method(self):
907         """
908         If C{SSHUserAuthClient} gets a MSG_USERAUTH_PK_OK packet when it's not
909         expecting it, it should fail the current authentication and move on to
910         the next type.
911         """
912         self.authClient.lastAuth = 'unknown'
913         self.authClient.transport.packets = []
914         self.authClient.ssh_USERAUTH_PK_OK('')
915         self.assertEqual(self.authClient.transport.packets,
916                           [(userauth.MSG_USERAUTH_REQUEST, NS('foo') +
917                             NS('nancy') + NS('none'))])
918
919
920     def test_USERAUTH_FAILURE_sorting(self):
921         """
922         ssh_USERAUTH_FAILURE should sort the methods by their position
923         in SSHUserAuthClient.preferredOrder.  Methods that are not in
924         preferredOrder should be sorted at the end of that list.
925         """
926         def auth_firstmethod():
927             self.authClient.transport.sendPacket(255, 'here is data')
928         def auth_anothermethod():
929             self.authClient.transport.sendPacket(254, 'other data')
930             return True
931         self.authClient.auth_firstmethod = auth_firstmethod
932         self.authClient.auth_anothermethod = auth_anothermethod
933
934         # although they shouldn't get called, method callbacks auth_* MUST
935         # exist in order for the test to work properly.
936         self.authClient.ssh_USERAUTH_FAILURE(NS('anothermethod,password') +
937                                              '\x00')
938         # should send password packet
939         self.assertEqual(self.authClient.transport.packets[-1],
940                 (userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy')
941                     + NS('password') + '\x00' + NS('foo')))
942         self.authClient.ssh_USERAUTH_FAILURE(
943             NS('firstmethod,anothermethod,password') + '\xff')
944         self.assertEqual(self.authClient.transport.packets[-2:],
945                           [(255, 'here is data'), (254, 'other data')])
946
947
948     def test_disconnectIfNoMoreAuthentication(self):
949         """
950         If there are no more available user authentication messages,
951         the SSHUserAuthClient should disconnect with code
952         DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE.
953         """
954         self.authClient.ssh_USERAUTH_FAILURE(NS('password') + '\x00')
955         self.authClient.ssh_USERAUTH_FAILURE(NS('password') + '\xff')
956         self.assertEqual(self.authClient.transport.packets[-1],
957                           (transport.MSG_DISCONNECT, '\x00\x00\x00\x0e' +
958                            NS('no more authentication methods available') +
959                            '\x00\x00\x00\x00'))
960
961
962     def test_ebAuth(self):
963         """
964         _ebAuth (the generic authentication error handler) should send
965         a request for the 'none' authentication method.
966         """
967         self.authClient.transport.packets = []
968         self.authClient._ebAuth(None)
969         self.assertEqual(self.authClient.transport.packets,
970                 [(userauth.MSG_USERAUTH_REQUEST, NS('foo') + NS('nancy')
971                     + NS('none'))])
972
973
974     def test_defaults(self):
975         """
976         getPublicKey() should return None.  getPrivateKey() should return a
977         failed Deferred.  getPassword() should return a failed Deferred.
978         getGenericAnswers() should return a failed Deferred.
979         """
980         authClient = userauth.SSHUserAuthClient('foo', FakeTransport.Service())
981         self.assertIdentical(authClient.getPublicKey(), None)
982         def check(result):
983             result.trap(NotImplementedError)
984             d = authClient.getPassword()
985             return d.addCallback(self.fail).addErrback(check2)
986         def check2(result):
987             result.trap(NotImplementedError)
988             d = authClient.getGenericAnswers(None, None, None)
989             return d.addCallback(self.fail).addErrback(check3)
990         def check3(result):
991             result.trap(NotImplementedError)
992         d = authClient.getPrivateKey()
993         return d.addCallback(self.fail).addErrback(check)
994
995
996
997 class LoopbackTestCase(unittest.TestCase):
998
999
1000     if keys is None:
1001         skip = "cannot run w/o PyCrypto or PyASN1"
1002
1003
1004     class Factory:
1005         class Service:
1006             name = 'TestService'
1007
1008
1009             def serviceStarted(self):
1010                 self.transport.loseConnection()
1011
1012
1013             def serviceStopped(self):
1014                 pass
1015
1016
1017         def getService(self, avatar, name):
1018             return self.Service
1019
1020
1021     def test_loopback(self):
1022         """
1023         Test that the userauth server and client play nicely with each other.
1024         """
1025         server = userauth.SSHUserAuthServer()
1026         client = ClientUserAuth('foo', self.Factory.Service())
1027
1028         # set up transports
1029         server.transport = transport.SSHTransportBase()
1030         server.transport.service = server
1031         server.transport.isEncrypted = lambda x: True
1032         client.transport = transport.SSHTransportBase()
1033         client.transport.service = client
1034         server.transport.sessionID = client.transport.sessionID = ''
1035         # don't send key exchange packet
1036         server.transport.sendKexInit = client.transport.sendKexInit = \
1037                 lambda: None
1038
1039         # set up server authentication
1040         server.transport.factory = self.Factory()
1041         server.passwordDelay = 0 # remove bad password delay
1042         realm = Realm()
1043         portal = Portal(realm)
1044         checker = SSHProtocolChecker()
1045         checker.registerChecker(PasswordChecker())
1046         checker.registerChecker(PrivateKeyChecker())
1047         checker.registerChecker(PAMChecker())
1048         checker.areDone = lambda aId: (
1049             len(checker.successfulCredentials[aId]) == 3)
1050         portal.registerChecker(checker)
1051         server.transport.factory.portal = portal
1052
1053         d = loopback.loopbackAsync(server.transport, client.transport)
1054         server.transport.transport.logPrefix = lambda: '_ServerLoopback'
1055         client.transport.transport.logPrefix = lambda: '_ClientLoopback'
1056
1057         server.serviceStarted()
1058         client.serviceStarted()
1059
1060         def check(ignored):
1061             self.assertEqual(server.transport.service.name, 'TestService')
1062         return d.addCallback(check)