Imported Upstream version 12.1.0
[contrib/python-twisted.git] / twisted / conch / test / test_transport.py
1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
3
4 """
5 Tests for ssh/transport.py and the classes therein.
6 """
7
8 try:
9     import pyasn1
10 except ImportError:
11     pyasn1 = None
12
13 try:
14     import Crypto.Cipher.DES3
15 except ImportError:
16     Crypto = None
17
18 if pyasn1 is not None and Crypto is not None:
19     dependencySkip = None
20     from twisted.conch.ssh import transport, keys, factory
21     from twisted.conch.test import keydata
22 else:
23     if pyasn1 is None:
24         dependencySkip = "Cannot run without PyASN1"
25     elif Crypto is None:
26         dependencySkip = "can't run w/o PyCrypto"
27
28     class transport: # fictional modules to make classes work
29         class SSHTransportBase: pass
30         class SSHServerTransport: pass
31         class SSHClientTransport: pass
32     class factory:
33         class SSHFactory:
34             pass
35
36 from twisted.trial import unittest
37 from twisted.internet import defer
38 from twisted.protocols import loopback
39 from twisted.python import randbytes
40 from twisted.python.reflect import qual
41 from twisted.python.hashlib import md5, sha1
42 from twisted.conch.ssh import service, common
43 from twisted.test import proto_helpers
44
45 from twisted.conch.error import ConchError
46
47
48
49 class MockTransportBase(transport.SSHTransportBase):
50     """
51     A base class for the client and server protocols.  Stores the messages
52     it receieves instead of ignoring them.
53
54     @ivar errors: a list of tuples: (reasonCode, description)
55     @ivar unimplementeds: a list of integers: sequence number
56     @ivar debugs: a list of tuples: (alwaysDisplay, message, lang)
57     @ivar ignoreds: a list of strings: ignored data
58     """
59
60     def connectionMade(self):
61         """
62         Set up instance variables.
63         """
64         transport.SSHTransportBase.connectionMade(self)
65         self.errors = []
66         self.unimplementeds = []
67         self.debugs = []
68         self.ignoreds = []
69         self.gotUnsupportedVersion = None
70
71
72     def _unsupportedVersionReceived(self, remoteVersion):
73         """
74         Intercept unsupported version call.
75
76         @type remoteVersion: C{str}
77         """
78         self.gotUnsupportedVersion = remoteVersion
79         return transport.SSHTransportBase._unsupportedVersionReceived(self, remoteVersion)
80
81
82     def receiveError(self, reasonCode, description):
83         """
84         Store any errors received.
85
86         @type reasonCode: C{int}
87         @type description: C{str}
88         """
89         self.errors.append((reasonCode, description))
90
91
92     def receiveUnimplemented(self, seqnum):
93         """
94         Store any unimplemented packet messages.
95
96         @type seqnum: C{int}
97         """
98         self.unimplementeds.append(seqnum)
99
100
101     def receiveDebug(self, alwaysDisplay, message, lang):
102         """
103         Store any debug messages.
104
105         @type alwaysDisplay: C{bool}
106         @type message: C{str}
107         @type lang: C{str}
108         """
109         self.debugs.append((alwaysDisplay, message, lang))
110
111
112     def ssh_IGNORE(self, packet):
113         """
114         Store any ignored data.
115
116         @type packet: C{str}
117         """
118         self.ignoreds.append(packet)
119
120
121 class MockCipher(object):
122     """
123     A mocked-up version of twisted.conch.ssh.transport.SSHCiphers.
124     """
125     outCipType = 'test'
126     encBlockSize = 6
127     inCipType = 'test'
128     decBlockSize = 6
129     inMACType = 'test'
130     outMACType = 'test'
131     verifyDigestSize = 1
132     usedEncrypt = False
133     usedDecrypt = False
134     outMAC = (None, '', '', 1)
135     inMAC = (None, '', '', 1)
136     keys = ()
137
138
139     def encrypt(self, x):
140         """
141         Called to encrypt the packet.  Simply record that encryption was used
142         and return the data unchanged.
143         """
144         self.usedEncrypt = True
145         if (len(x) % self.encBlockSize) != 0:
146             raise RuntimeError("length %i modulo blocksize %i is not 0: %i" %
147                     (len(x), self.encBlockSize, len(x) % self.encBlockSize))
148         return x
149
150
151     def decrypt(self, x):
152         """
153         Called to decrypt the packet.  Simply record that decryption was used
154         and return the data unchanged.
155         """
156         self.usedDecrypt = True
157         if (len(x) % self.encBlockSize) != 0:
158             raise RuntimeError("length %i modulo blocksize %i is not 0: %i" %
159                     (len(x), self.decBlockSize, len(x) % self.decBlockSize))
160         return x
161
162
163     def makeMAC(self, outgoingPacketSequence, payload):
164         """
165         Make a Message Authentication Code by sending the character value of
166         the outgoing packet.
167         """
168         return chr(outgoingPacketSequence)
169
170
171     def verify(self, incomingPacketSequence, packet, macData):
172         """
173         Verify the Message Authentication Code by checking that the packet
174         sequence number is the same.
175         """
176         return chr(incomingPacketSequence) == macData
177
178
179     def setKeys(self, ivOut, keyOut, ivIn, keyIn, macIn, macOut):
180         """
181         Record the keys.
182         """
183         self.keys = (ivOut, keyOut, ivIn, keyIn, macIn, macOut)
184
185
186
187 class MockCompression:
188     """
189     A mocked-up compression, based on the zlib interface.  Instead of
190     compressing, it reverses the data and adds a 0x66 byte to the end.
191     """
192
193
194     def compress(self, payload):
195         return payload[::-1] # reversed
196
197
198     def decompress(self, payload):
199         return payload[:-1][::-1]
200
201
202     def flush(self, kind):
203         return '\x66'
204
205
206
207 class MockService(service.SSHService):
208     """
209     A mocked-up service, based on twisted.conch.ssh.service.SSHService.
210
211     @ivar started: True if this service has been started.
212     @ivar stopped: True if this service has been stopped.
213     """
214     name = "MockService"
215     started = False
216     stopped = False
217     protocolMessages = {0xff: "MSG_TEST", 71: "MSG_fiction"}
218
219
220     def logPrefix(self):
221         return "MockService"
222
223
224     def serviceStarted(self):
225         """
226         Record that the service was started.
227         """
228         self.started = True
229
230
231     def serviceStopped(self):
232         """
233         Record that the service was stopped.
234         """
235         self.stopped = True
236
237
238     def ssh_TEST(self, packet):
239         """
240         A message that this service responds to.
241         """
242         self.transport.sendPacket(0xff, packet)
243
244
245 class MockFactory(factory.SSHFactory):
246     """
247     A mocked-up factory based on twisted.conch.ssh.factory.SSHFactory.
248     """
249     services = {
250         'ssh-userauth': MockService}
251
252
253     def getPublicKeys(self):
254         """
255         Return the public keys that authenticate this server.
256         """
257         return {
258             'ssh-rsa': keys.Key.fromString(keydata.publicRSA_openssh),
259             'ssh-dsa': keys.Key.fromString(keydata.publicDSA_openssh)}
260
261
262     def getPrivateKeys(self):
263         """
264         Return the private keys that authenticate this server.
265         """
266         return {
267             'ssh-rsa': keys.Key.fromString(keydata.privateRSA_openssh),
268             'ssh-dsa': keys.Key.fromString(keydata.privateDSA_openssh)}
269
270
271     def getPrimes(self):
272         """
273         Return the Diffie-Hellman primes that can be used for the
274         diffie-hellman-group-exchange-sha1 key exchange.
275         """
276         return {
277             1024: ((2, transport.DH_PRIME),),
278             2048: ((3, transport.DH_PRIME),),
279             4096: ((5, 7),)}
280
281
282
283 class MockOldFactoryPublicKeys(MockFactory):
284     """
285     The old SSHFactory returned mappings from key names to strings from
286     getPublicKeys().  We return those here for testing.
287     """
288
289
290     def getPublicKeys(self):
291         """
292         We used to map key types to public key blobs as strings.
293         """
294         keys = MockFactory.getPublicKeys(self)
295         for name, key in keys.items()[:]:
296             keys[name] = key.blob()
297         return keys
298
299
300
301 class MockOldFactoryPrivateKeys(MockFactory):
302     """
303     The old SSHFactory returned mappings from key names to PyCrypto key
304     objects from getPrivateKeys().  We return those here for testing.
305     """
306
307
308     def getPrivateKeys(self):
309         """
310         We used to map key types to PyCrypto key objects.
311         """
312         keys = MockFactory.getPrivateKeys(self)
313         for name, key  in keys.items()[:]:
314             keys[name] = key.keyObject
315         return keys
316
317
318
319 class TransportTestCase(unittest.TestCase):
320     """
321     Base class for transport test cases.
322     """
323     klass = None
324
325     if Crypto is None:
326         skip = "cannot run w/o PyCrypto"
327
328     if pyasn1 is None:
329         skip = "Cannot run without PyASN1"
330
331
332     def setUp(self):
333         self.transport = proto_helpers.StringTransport()
334         self.proto = self.klass()
335         self.packets = []
336         def secureRandom(len):
337             """
338             Return a consistent entropy value
339             """
340             return '\x99' * len
341         self.oldSecureRandom = randbytes.secureRandom
342         randbytes.secureRandom = secureRandom
343         def stubSendPacket(messageType, payload):
344             self.packets.append((messageType, payload))
345         self.proto.makeConnection(self.transport)
346         # we just let the kex packet go into the transport
347         self.proto.sendPacket = stubSendPacket
348
349
350     def finishKeyExchange(self, proto):
351         """
352         Deliver enough additional messages to C{proto} so that the key exchange
353         which is started in L{SSHTransportBase.connectionMade} completes and
354         non-key exchange messages can be sent and received.
355         """
356         proto.dataReceived("SSH-2.0-BogoClient-1.2i\r\n")
357         proto.dispatchMessage(
358             transport.MSG_KEXINIT, self._A_KEXINIT_MESSAGE)
359         proto._keySetup("foo", "bar")
360         # SSHTransportBase can't handle MSG_NEWKEYS, or it would be the right
361         # thing to deliver next.  _newKeys won't work either, because
362         # sendKexInit (probably) hasn't been called.  sendKexInit is responsible
363         # for setting up certain state _newKeys relies on.  So, just change the
364         # key exchange state to what it would be when key exchange is finished.
365         proto._keyExchangeState = proto._KEY_EXCHANGE_NONE
366
367
368     def tearDown(self):
369         randbytes.secureRandom = self.oldSecureRandom
370         self.oldSecureRandom = None
371
372
373     def simulateKeyExchange(self, sharedSecret, exchangeHash):
374         """
375         Finish a key exchange by calling C{_keySetup} with the given arguments.
376         Also do extra whitebox stuff to satisfy that method's assumption that
377         some kind of key exchange has actually taken place.
378         """
379         self.proto._keyExchangeState = self.proto._KEY_EXCHANGE_REQUESTED
380         self.proto._blockedByKeyExchange = []
381         self.proto._keySetup(sharedSecret, exchangeHash)
382
383
384
385 class BaseSSHTransportTestCase(TransportTestCase):
386     """
387     Test TransportBase.  It implements the non-server/client specific
388     parts of the SSH transport protocol.
389     """
390
391     klass = MockTransportBase
392
393     _A_KEXINIT_MESSAGE = (
394         "\xAA" * 16 +
395         common.NS('diffie-hellman-group1-sha1') +
396         common.NS('ssh-rsa') +
397         common.NS('aes256-ctr') +
398         common.NS('aes256-ctr') +
399         common.NS('hmac-sha1') +
400         common.NS('hmac-sha1') +
401         common.NS('none') +
402         common.NS('none') +
403         common.NS('') +
404         common.NS('') +
405         '\x00' + '\x00\x00\x00\x00')
406
407     def test_sendVersion(self):
408         """
409         Test that the first thing sent over the connection is the version
410         string.
411         """
412         # the other setup was done in the setup method
413         self.assertEqual(self.transport.value().split('\r\n', 1)[0],
414                           "SSH-2.0-Twisted")
415
416
417     def test_sendPacketPlain(self):
418         """
419         Test that plain (unencrypted, uncompressed) packets are sent
420         correctly.  The format is::
421             uint32 length (including type and padding length)
422             byte padding length
423             byte type
424             bytes[length-padding length-2] data
425             bytes[padding length] padding
426         """
427         proto = MockTransportBase()
428         proto.makeConnection(self.transport)
429         self.finishKeyExchange(proto)
430         self.transport.clear()
431         message = ord('A')
432         payload = 'BCDEFG'
433         proto.sendPacket(message, payload)
434         value = self.transport.value()
435         self.assertEqual(value, '\x00\x00\x00\x0c\x04ABCDEFG\x99\x99\x99\x99')
436
437
438     def test_sendPacketEncrypted(self):
439         """
440         Test that packets sent while encryption is enabled are sent
441         correctly.  The whole packet should be encrypted.
442         """
443         proto = MockTransportBase()
444         proto.makeConnection(self.transport)
445         self.finishKeyExchange(proto)
446         proto.currentEncryptions = testCipher = MockCipher()
447         message = ord('A')
448         payload = 'BC'
449         self.transport.clear()
450         proto.sendPacket(message, payload)
451         self.assertTrue(testCipher.usedEncrypt)
452         value = self.transport.value()
453         self.assertEqual(
454             value,
455             # Four byte length prefix
456             '\x00\x00\x00\x08'
457             # One byte padding length
458             '\x04'
459             # The actual application data
460             'ABC'
461             # "Random" padding - see the secureRandom monkeypatch in setUp
462             '\x99\x99\x99\x99'
463             # The MAC
464             '\x02')
465
466
467     def test_sendPacketCompressed(self):
468         """
469         Test that packets sent while compression is enabled are sent
470         correctly.  The packet type and data should be encrypted.
471         """
472         proto = MockTransportBase()
473         proto.makeConnection(self.transport)
474         self.finishKeyExchange(proto)
475         proto.outgoingCompression = MockCompression()
476         self.transport.clear()
477         proto.sendPacket(ord('A'), 'B')
478         value = self.transport.value()
479         self.assertEqual(
480             value,
481             '\x00\x00\x00\x0c\x08BA\x66\x99\x99\x99\x99\x99\x99\x99\x99')
482
483
484     def test_sendPacketBoth(self):
485         """
486         Test that packets sent while compression and encryption are
487         enabled are sent correctly.  The packet type and data should be
488         compressed and then the whole packet should be encrypted.
489         """
490         proto = MockTransportBase()
491         proto.makeConnection(self.transport)
492         self.finishKeyExchange(proto)
493         proto.currentEncryptions = testCipher = MockCipher()
494         proto.outgoingCompression = MockCompression()
495         message = ord('A')
496         payload = 'BC'
497         self.transport.clear()
498         proto.sendPacket(message, payload)
499         self.assertTrue(testCipher.usedEncrypt)
500         value = self.transport.value()
501         self.assertEqual(
502             value,
503             # Four byte length prefix
504             '\x00\x00\x00\x0e'
505             # One byte padding length
506             '\x09'
507             # Compressed application data
508             'CBA\x66'
509             # "Random" padding - see the secureRandom monkeypatch in setUp
510             '\x99\x99\x99\x99\x99\x99\x99\x99\x99'
511             # The MAC
512             '\x02')
513
514
515     def test_getPacketPlain(self):
516         """
517         Test that packets are retrieved correctly out of the buffer when
518         no encryption is enabled.
519         """
520         proto = MockTransportBase()
521         proto.makeConnection(self.transport)
522         self.finishKeyExchange(proto)
523         self.transport.clear()
524         proto.sendPacket(ord('A'), 'BC')
525         proto.buf = self.transport.value() + 'extra'
526         self.assertEqual(proto.getPacket(), 'ABC')
527         self.assertEqual(proto.buf, 'extra')
528
529
530     def test_getPacketEncrypted(self):
531         """
532         Test that encrypted packets are retrieved correctly.
533         See test_sendPacketEncrypted.
534         """
535         proto = MockTransportBase()
536         proto.sendKexInit = lambda: None # don't send packets
537         proto.makeConnection(self.transport)
538         self.transport.clear()
539         proto.currentEncryptions = testCipher = MockCipher()
540         proto.sendPacket(ord('A'), 'BCD')
541         value = self.transport.value()
542         proto.buf = value[:MockCipher.decBlockSize]
543         self.assertEqual(proto.getPacket(), None)
544         self.assertTrue(testCipher.usedDecrypt)
545         self.assertEqual(proto.first, '\x00\x00\x00\x0e\x09A')
546         proto.buf += value[MockCipher.decBlockSize:]
547         self.assertEqual(proto.getPacket(), 'ABCD')
548         self.assertEqual(proto.buf, '')
549
550
551     def test_getPacketCompressed(self):
552         """
553         Test that compressed packets are retrieved correctly.  See
554         test_sendPacketCompressed.
555         """
556         proto = MockTransportBase()
557         proto.makeConnection(self.transport)
558         self.finishKeyExchange(proto)
559         self.transport.clear()
560         proto.outgoingCompression = MockCompression()
561         proto.incomingCompression = proto.outgoingCompression
562         proto.sendPacket(ord('A'), 'BCD')
563         proto.buf = self.transport.value()
564         self.assertEqual(proto.getPacket(), 'ABCD')
565
566
567     def test_getPacketBoth(self):
568         """
569         Test that compressed and encrypted packets are retrieved correctly.
570         See test_sendPacketBoth.
571         """
572         proto = MockTransportBase()
573         proto.sendKexInit = lambda: None
574         proto.makeConnection(self.transport)
575         self.transport.clear()
576         proto.currentEncryptions = testCipher = MockCipher()
577         proto.outgoingCompression = MockCompression()
578         proto.incomingCompression = proto.outgoingCompression
579         proto.sendPacket(ord('A'), 'BCDEFG')
580         proto.buf = self.transport.value()
581         self.assertEqual(proto.getPacket(), 'ABCDEFG')
582
583
584     def test_ciphersAreValid(self):
585         """
586         Test that all the supportedCiphers are valid.
587         """
588         ciphers = transport.SSHCiphers('A', 'B', 'C', 'D')
589         iv = key = '\x00' * 16
590         for cipName in self.proto.supportedCiphers:
591             self.assertTrue(ciphers._getCipher(cipName, iv, key))
592
593
594     def test_sendKexInit(self):
595         """
596         Test that the KEXINIT (key exchange initiation) message is sent
597         correctly.  Payload::
598             bytes[16] cookie
599             string key exchange algorithms
600             string public key algorithms
601             string outgoing ciphers
602             string incoming ciphers
603             string outgoing MACs
604             string incoming MACs
605             string outgoing compressions
606             string incoming compressions
607             bool first packet follows
608             uint32 0
609         """
610         value = self.transport.value().split('\r\n', 1)[1]
611         self.proto.buf = value
612         packet = self.proto.getPacket()
613         self.assertEqual(packet[0], chr(transport.MSG_KEXINIT))
614         self.assertEqual(packet[1:17], '\x99' * 16)
615         (kex, pubkeys, ciphers1, ciphers2, macs1, macs2, compressions1,
616          compressions2, languages1, languages2,
617          buf) = common.getNS(packet[17:], 10)
618
619         self.assertEqual(kex, ','.join(self.proto.supportedKeyExchanges))
620         self.assertEqual(pubkeys, ','.join(self.proto.supportedPublicKeys))
621         self.assertEqual(ciphers1, ','.join(self.proto.supportedCiphers))
622         self.assertEqual(ciphers2, ','.join(self.proto.supportedCiphers))
623         self.assertEqual(macs1, ','.join(self.proto.supportedMACs))
624         self.assertEqual(macs2, ','.join(self.proto.supportedMACs))
625         self.assertEqual(compressions1,
626                           ','.join(self.proto.supportedCompressions))
627         self.assertEqual(compressions2,
628                           ','.join(self.proto.supportedCompressions))
629         self.assertEqual(languages1, ','.join(self.proto.supportedLanguages))
630         self.assertEqual(languages2, ','.join(self.proto.supportedLanguages))
631         self.assertEqual(buf, '\x00' * 5)
632
633
634     def test_receiveKEXINITReply(self):
635         """
636         Immediately after connecting, the transport expects a KEXINIT message
637         and does not reply to it.
638         """
639         self.transport.clear()
640         self.proto.dispatchMessage(
641             transport.MSG_KEXINIT, self._A_KEXINIT_MESSAGE)
642         self.assertEqual(self.packets, [])
643
644
645     def test_sendKEXINITReply(self):
646         """
647         When a KEXINIT message is received which is not a reply to an earlier
648         KEXINIT message which was sent, a KEXINIT reply is sent.
649         """
650         self.finishKeyExchange(self.proto)
651         del self.packets[:]
652
653         self.proto.dispatchMessage(
654             transport.MSG_KEXINIT, self._A_KEXINIT_MESSAGE)
655         self.assertEqual(len(self.packets), 1)
656         self.assertEqual(self.packets[0][0], transport.MSG_KEXINIT)
657
658
659     def test_sendKexInitTwiceFails(self):
660         """
661         A new key exchange cannot be started while a key exchange is already in
662         progress.  If an attempt is made to send a I{KEXINIT} message using
663         L{SSHTransportBase.sendKexInit} while a key exchange is in progress
664         causes that method to raise a L{RuntimeError}.
665         """
666         self.assertRaises(RuntimeError, self.proto.sendKexInit)
667
668
669     def test_sendKexInitBlocksOthers(self):
670         """
671         After L{SSHTransportBase.sendKexInit} has been called, messages types
672         other than the following are queued and not sent until after I{NEWKEYS}
673         is sent by L{SSHTransportBase._keySetup}.
674
675         RFC 4253, section 7.1.
676         """
677         # sendKexInit is called by connectionMade, which is called in setUp.  So
678         # we're in the state already.
679         disallowedMessageTypes = [
680             transport.MSG_SERVICE_REQUEST,
681             transport.MSG_KEXINIT,
682             ]
683
684         # Drop all the bytes sent by setUp, they're not relevant to this test.
685         self.transport.clear()
686
687         # Get rid of the sendPacket monkey patch, we are testing the behavior of
688         # sendPacket.
689         del self.proto.sendPacket
690
691         for messageType in disallowedMessageTypes:
692             self.proto.sendPacket(messageType, 'foo')
693             self.assertEqual(self.transport.value(), "")
694
695         self.finishKeyExchange(self.proto)
696         # Make the bytes written to the transport cleartext so it's easier to
697         # make an assertion about them.
698         self.proto.nextEncryptions = MockCipher()
699
700         # Pseudo-deliver the peer's NEWKEYS message, which should flush the
701         # messages which were queued above.
702         self.proto._newKeys()
703         self.assertEqual(self.transport.value().count("foo"), 2)
704
705
706     def test_sendDebug(self):
707         """
708         Test that debug messages are sent correctly.  Payload::
709             bool always display
710             string debug message
711             string language
712         """
713         self.proto.sendDebug("test", True, 'en')
714         self.assertEqual(
715             self.packets,
716             [(transport.MSG_DEBUG,
717               "\x01\x00\x00\x00\x04test\x00\x00\x00\x02en")])
718
719
720     def test_receiveDebug(self):
721         """
722         Test that debug messages are received correctly.  See test_sendDebug.
723         """
724         self.proto.dispatchMessage(
725             transport.MSG_DEBUG,
726             '\x01\x00\x00\x00\x04test\x00\x00\x00\x02en')
727         self.assertEqual(self.proto.debugs, [(True, 'test', 'en')])
728
729
730     def test_sendIgnore(self):
731         """
732         Test that ignored messages are sent correctly.  Payload::
733             string ignored data
734         """
735         self.proto.sendIgnore("test")
736         self.assertEqual(
737             self.packets, [(transport.MSG_IGNORE,
738                             '\x00\x00\x00\x04test')])
739
740
741     def test_receiveIgnore(self):
742         """
743         Test that ignored messages are received correctly.  See
744         test_sendIgnore.
745         """
746         self.proto.dispatchMessage(transport.MSG_IGNORE, 'test')
747         self.assertEqual(self.proto.ignoreds, ['test'])
748
749
750     def test_sendUnimplemented(self):
751         """
752         Test that unimplemented messages are sent correctly.  Payload::
753             uint32 sequence number
754         """
755         self.proto.sendUnimplemented()
756         self.assertEqual(
757             self.packets, [(transport.MSG_UNIMPLEMENTED,
758                             '\x00\x00\x00\x00')])
759
760
761     def test_receiveUnimplemented(self):
762         """
763         Test that unimplemented messages are received correctly.  See
764         test_sendUnimplemented.
765         """
766         self.proto.dispatchMessage(transport.MSG_UNIMPLEMENTED,
767                                    '\x00\x00\x00\xff')
768         self.assertEqual(self.proto.unimplementeds, [255])
769
770
771     def test_sendDisconnect(self):
772         """
773         Test that disconnection messages are sent correctly.  Payload::
774             uint32 reason code
775             string reason description
776             string language
777         """
778         disconnected = [False]
779         def stubLoseConnection():
780             disconnected[0] = True
781         self.transport.loseConnection = stubLoseConnection
782         self.proto.sendDisconnect(0xff, "test")
783         self.assertEqual(
784             self.packets,
785             [(transport.MSG_DISCONNECT,
786               "\x00\x00\x00\xff\x00\x00\x00\x04test\x00\x00\x00\x00")])
787         self.assertTrue(disconnected[0])
788
789
790     def test_receiveDisconnect(self):
791         """
792         Test that disconnection messages are received correctly.  See
793         test_sendDisconnect.
794         """
795         disconnected = [False]
796         def stubLoseConnection():
797             disconnected[0] = True
798         self.transport.loseConnection = stubLoseConnection
799         self.proto.dispatchMessage(transport.MSG_DISCONNECT,
800                                    '\x00\x00\x00\xff\x00\x00\x00\x04test')
801         self.assertEqual(self.proto.errors, [(255, 'test')])
802         self.assertTrue(disconnected[0])
803
804
805     def test_dataReceived(self):
806         """
807         Test that dataReceived parses packets and dispatches them to
808         ssh_* methods.
809         """
810         kexInit = [False]
811         def stubKEXINIT(packet):
812             kexInit[0] = True
813         self.proto.ssh_KEXINIT = stubKEXINIT
814         self.proto.dataReceived(self.transport.value())
815         self.assertTrue(self.proto.gotVersion)
816         self.assertEqual(self.proto.ourVersionString,
817                           self.proto.otherVersionString)
818         self.assertTrue(kexInit[0])
819
820
821     def test_service(self):
822         """
823         Test that the transport can set the running service and dispatches
824         packets to the service's packetReceived method.
825         """
826         service = MockService()
827         self.proto.setService(service)
828         self.assertEqual(self.proto.service, service)
829         self.assertTrue(service.started)
830         self.proto.dispatchMessage(0xff, "test")
831         self.assertEqual(self.packets, [(0xff, "test")])
832
833         service2 = MockService()
834         self.proto.setService(service2)
835         self.assertTrue(service2.started)
836         self.assertTrue(service.stopped)
837
838         self.proto.connectionLost(None)
839         self.assertTrue(service2.stopped)
840
841
842     def test_avatar(self):
843         """
844         Test that the transport notifies the avatar of disconnections.
845         """
846         disconnected = [False]
847         def logout():
848             disconnected[0] = True
849         self.proto.logoutFunction = logout
850         self.proto.avatar = True
851
852         self.proto.connectionLost(None)
853         self.assertTrue(disconnected[0])
854
855
856     def test_isEncrypted(self):
857         """
858         Test that the transport accurately reflects its encrypted status.
859         """
860         self.assertFalse(self.proto.isEncrypted('in'))
861         self.assertFalse(self.proto.isEncrypted('out'))
862         self.assertFalse(self.proto.isEncrypted('both'))
863         self.proto.currentEncryptions = MockCipher()
864         self.assertTrue(self.proto.isEncrypted('in'))
865         self.assertTrue(self.proto.isEncrypted('out'))
866         self.assertTrue(self.proto.isEncrypted('both'))
867         self.proto.currentEncryptions = transport.SSHCiphers('none', 'none',
868                                                              'none', 'none')
869         self.assertFalse(self.proto.isEncrypted('in'))
870         self.assertFalse(self.proto.isEncrypted('out'))
871         self.assertFalse(self.proto.isEncrypted('both'))
872
873         self.assertRaises(TypeError, self.proto.isEncrypted, 'bad')
874
875
876     def test_isVerified(self):
877         """
878         Test that the transport accurately reflects its verified status.
879         """
880         self.assertFalse(self.proto.isVerified('in'))
881         self.assertFalse(self.proto.isVerified('out'))
882         self.assertFalse(self.proto.isVerified('both'))
883         self.proto.currentEncryptions = MockCipher()
884         self.assertTrue(self.proto.isVerified('in'))
885         self.assertTrue(self.proto.isVerified('out'))
886         self.assertTrue(self.proto.isVerified('both'))
887         self.proto.currentEncryptions = transport.SSHCiphers('none', 'none',
888                                                              'none', 'none')
889         self.assertFalse(self.proto.isVerified('in'))
890         self.assertFalse(self.proto.isVerified('out'))
891         self.assertFalse(self.proto.isVerified('both'))
892
893         self.assertRaises(TypeError, self.proto.isVerified, 'bad')
894
895
896     def test_loseConnection(self):
897         """
898         Test that loseConnection sends a disconnect message and closes the
899         connection.
900         """
901         disconnected = [False]
902         def stubLoseConnection():
903             disconnected[0] = True
904         self.transport.loseConnection = stubLoseConnection
905         self.proto.loseConnection()
906         self.assertEqual(self.packets[0][0], transport.MSG_DISCONNECT)
907         self.assertEqual(self.packets[0][1][3],
908                           chr(transport.DISCONNECT_CONNECTION_LOST))
909
910
911     def test_badVersion(self):
912         """
913         Test that the transport disconnects when it receives a bad version.
914         """
915         def testBad(version):
916             self.packets = []
917             self.proto.gotVersion = False
918             disconnected = [False]
919             def stubLoseConnection():
920                 disconnected[0] = True
921             self.transport.loseConnection = stubLoseConnection
922             for c in version + '\r\n':
923                 self.proto.dataReceived(c)
924             self.assertTrue(disconnected[0])
925             self.assertEqual(self.packets[0][0], transport.MSG_DISCONNECT)
926             self.assertEqual(
927                 self.packets[0][1][3],
928                 chr(transport.DISCONNECT_PROTOCOL_VERSION_NOT_SUPPORTED))
929         testBad('SSH-1.5-OpenSSH')
930         testBad('SSH-3.0-Twisted')
931         testBad('GET / HTTP/1.1')
932
933
934     def test_dataBeforeVersion(self):
935         """
936         Test that the transport ignores data sent before the version string.
937         """
938         proto = MockTransportBase()
939         proto.makeConnection(proto_helpers.StringTransport())
940         data = ("""here's some stuff beforehand
941 here's some other stuff
942 """ + proto.ourVersionString + "\r\n")
943         [proto.dataReceived(c) for c in data]
944         self.assertTrue(proto.gotVersion)
945         self.assertEqual(proto.otherVersionString, proto.ourVersionString)
946
947
948     def test_compatabilityVersion(self):
949         """
950         Test that the transport treats the compatbility version (1.99)
951         as equivalent to version 2.0.
952         """
953         proto = MockTransportBase()
954         proto.makeConnection(proto_helpers.StringTransport())
955         proto.dataReceived("SSH-1.99-OpenSSH\n")
956         self.assertTrue(proto.gotVersion)
957         self.assertEqual(proto.otherVersionString, "SSH-1.99-OpenSSH")
958
959
960     def test_supportedVersionsAreAllowed(self):
961         """
962         If an unusual SSH version is received and is included in
963         C{supportedVersions}, an unsupported version error is not emitted.
964         """
965         proto = MockTransportBase()
966         proto.supportedVersions = ("9.99", )
967         proto.makeConnection(proto_helpers.StringTransport())
968         proto.dataReceived("SSH-9.99-OpenSSH\n")
969         self.assertFalse(proto.gotUnsupportedVersion)
970
971
972     def test_unsupportedVersionsCallUnsupportedVersionReceived(self):
973         """
974         If an unusual SSH version is received and is not included in
975         C{supportedVersions}, an unsupported version error is emitted.
976         """
977         proto = MockTransportBase()
978         proto.supportedVersions = ("2.0", )
979         proto.makeConnection(proto_helpers.StringTransport())
980         proto.dataReceived("SSH-9.99-OpenSSH\n")
981         self.assertEqual("9.99", proto.gotUnsupportedVersion)
982
983
984     def test_badPackets(self):
985         """
986         Test that the transport disconnects with an error when it receives
987         bad packets.
988         """
989         def testBad(packet, error=transport.DISCONNECT_PROTOCOL_ERROR):
990             self.packets = []
991             self.proto.buf = packet
992             self.assertEqual(self.proto.getPacket(), None)
993             self.assertEqual(len(self.packets), 1)
994             self.assertEqual(self.packets[0][0], transport.MSG_DISCONNECT)
995             self.assertEqual(self.packets[0][1][3], chr(error))
996
997         testBad('\xff' * 8) # big packet
998         testBad('\x00\x00\x00\x05\x00BCDE') # length not modulo blocksize
999         oldEncryptions = self.proto.currentEncryptions
1000         self.proto.currentEncryptions = MockCipher()
1001         testBad('\x00\x00\x00\x08\x06AB123456', # bad MAC
1002                 transport.DISCONNECT_MAC_ERROR)
1003         self.proto.currentEncryptions.decrypt = lambda x: x[:-1]
1004         testBad('\x00\x00\x00\x08\x06BCDEFGHIJK') # bad decryption
1005         self.proto.currentEncryptions = oldEncryptions
1006         self.proto.incomingCompression = MockCompression()
1007         def stubDecompress(payload):
1008             raise Exception('bad compression')
1009         self.proto.incomingCompression.decompress = stubDecompress
1010         testBad('\x00\x00\x00\x04\x00BCDE', # bad decompression
1011                 transport.DISCONNECT_COMPRESSION_ERROR)
1012         self.flushLoggedErrors()
1013
1014
1015     def test_unimplementedPackets(self):
1016         """
1017         Test that unimplemented packet types cause MSG_UNIMPLEMENTED packets
1018         to be sent.
1019         """
1020         seqnum = self.proto.incomingPacketSequence
1021         def checkUnimplemented(seqnum=seqnum):
1022             self.assertEqual(self.packets[0][0],
1023                               transport.MSG_UNIMPLEMENTED)
1024             self.assertEqual(self.packets[0][1][3], chr(seqnum))
1025             self.proto.packets = []
1026             seqnum += 1
1027
1028         self.proto.dispatchMessage(40, '')
1029         checkUnimplemented()
1030         transport.messages[41] = 'MSG_fiction'
1031         self.proto.dispatchMessage(41, '')
1032         checkUnimplemented()
1033         self.proto.dispatchMessage(60, '')
1034         checkUnimplemented()
1035         self.proto.setService(MockService())
1036         self.proto.dispatchMessage(70, '')
1037         checkUnimplemented()
1038         self.proto.dispatchMessage(71, '')
1039         checkUnimplemented()
1040
1041
1042     def test_getKey(self):
1043         """
1044         Test that _getKey generates the correct keys.
1045         """
1046         self.proto.sessionID = 'EF'
1047
1048         k1 = sha1('AB' + 'CD' + 'K' + self.proto.sessionID).digest()
1049         k2 = sha1('ABCD' + k1).digest()
1050         self.assertEqual(self.proto._getKey('K', 'AB', 'CD'), k1 + k2)
1051
1052
1053     def test_multipleClasses(self):
1054         """
1055         Test that multiple instances have distinct states.
1056         """
1057         proto = self.proto
1058         proto.dataReceived(self.transport.value())
1059         proto.currentEncryptions = MockCipher()
1060         proto.outgoingCompression = MockCompression()
1061         proto.incomingCompression = MockCompression()
1062         proto.setService(MockService())
1063         proto2 = MockTransportBase()
1064         proto2.makeConnection(proto_helpers.StringTransport())
1065         proto2.sendIgnore('')
1066         self.failIfEquals(proto.gotVersion, proto2.gotVersion)
1067         self.failIfEquals(proto.transport, proto2.transport)
1068         self.failIfEquals(proto.outgoingPacketSequence,
1069                           proto2.outgoingPacketSequence)
1070         self.failIfEquals(proto.incomingPacketSequence,
1071                           proto2.incomingPacketSequence)
1072         self.failIfEquals(proto.currentEncryptions,
1073                           proto2.currentEncryptions)
1074         self.failIfEquals(proto.service, proto2.service)
1075
1076
1077
1078 class ServerAndClientSSHTransportBaseCase:
1079     """
1080     Tests that need to be run on both the server and the client.
1081     """
1082
1083
1084     def checkDisconnected(self, kind=None):
1085         """
1086         Helper function to check if the transport disconnected.
1087         """
1088         if kind is None:
1089             kind = transport.DISCONNECT_PROTOCOL_ERROR
1090         self.assertEqual(self.packets[-1][0], transport.MSG_DISCONNECT)
1091         self.assertEqual(self.packets[-1][1][3], chr(kind))
1092
1093
1094     def connectModifiedProtocol(self, protoModification,
1095             kind=None):
1096         """
1097         Helper function to connect a modified protocol to the test protocol
1098         and test for disconnection.
1099         """
1100         if kind is None:
1101             kind = transport.DISCONNECT_KEY_EXCHANGE_FAILED
1102         proto2 = self.klass()
1103         protoModification(proto2)
1104         proto2.makeConnection(proto_helpers.StringTransport())
1105         self.proto.dataReceived(proto2.transport.value())
1106         if kind:
1107             self.checkDisconnected(kind)
1108         return proto2
1109
1110
1111     def test_disconnectIfCantMatchKex(self):
1112         """
1113         Test that the transport disconnects if it can't match the key
1114         exchange
1115         """
1116         def blankKeyExchanges(proto2):
1117             proto2.supportedKeyExchanges = []
1118         self.connectModifiedProtocol(blankKeyExchanges)
1119
1120
1121     def test_disconnectIfCantMatchKeyAlg(self):
1122         """
1123         Like test_disconnectIfCantMatchKex, but for the key algorithm.
1124         """
1125         def blankPublicKeys(proto2):
1126             proto2.supportedPublicKeys = []
1127         self.connectModifiedProtocol(blankPublicKeys)
1128
1129
1130     def test_disconnectIfCantMatchCompression(self):
1131         """
1132         Like test_disconnectIfCantMatchKex, but for the compression.
1133         """
1134         def blankCompressions(proto2):
1135             proto2.supportedCompressions = []
1136         self.connectModifiedProtocol(blankCompressions)
1137
1138
1139     def test_disconnectIfCantMatchCipher(self):
1140         """
1141         Like test_disconnectIfCantMatchKex, but for the encryption.
1142         """
1143         def blankCiphers(proto2):
1144             proto2.supportedCiphers = []
1145         self.connectModifiedProtocol(blankCiphers)
1146
1147
1148     def test_disconnectIfCantMatchMAC(self):
1149         """
1150         Like test_disconnectIfCantMatchKex, but for the MAC.
1151         """
1152         def blankMACs(proto2):
1153             proto2.supportedMACs = []
1154         self.connectModifiedProtocol(blankMACs)
1155
1156
1157
1158 class ServerSSHTransportTestCase(ServerAndClientSSHTransportBaseCase,
1159         TransportTestCase):
1160     """
1161     Tests for the SSHServerTransport.
1162     """
1163
1164     klass = transport.SSHServerTransport
1165
1166
1167     def setUp(self):
1168         TransportTestCase.setUp(self)
1169         self.proto.factory = MockFactory()
1170         self.proto.factory.startFactory()
1171
1172
1173     def tearDown(self):
1174         TransportTestCase.tearDown(self)
1175         self.proto.factory.stopFactory()
1176         del self.proto.factory
1177
1178
1179     def test_KEXINIT(self):
1180         """
1181         Test that receiving a KEXINIT packet sets up the correct values on the
1182         server.
1183         """
1184         self.proto.dataReceived( 'SSH-2.0-Twisted\r\n\x00\x00\x01\xd4\t\x14'
1185                 '\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99'
1186                 '\x99\x00\x00\x00=diffie-hellman-group1-sha1,diffie-hellman-g'
1187                 'roup-exchange-sha1\x00\x00\x00\x0fssh-dss,ssh-rsa\x00\x00\x00'
1188                 '\x85aes128-ctr,aes128-cbc,aes192-ctr,aes192-cbc,aes256-ctr,ae'
1189                 's256-cbc,cast128-ctr,cast128-cbc,blowfish-ctr,blowfish-cbc,3d'
1190                 'es-ctr,3des-cbc\x00\x00\x00\x85aes128-ctr,aes128-cbc,aes192-c'
1191                 'tr,aes192-cbc,aes256-ctr,aes256-cbc,cast128-ctr,cast128-cbc,b'
1192                 'lowfish-ctr,blowfish-cbc,3des-ctr,3des-cbc\x00\x00\x00\x12hma'
1193                 'c-md5,hmac-sha1\x00\x00\x00\x12hmac-md5,hmac-sha1\x00\x00\x00'
1194                 '\tnone,zlib\x00\x00\x00\tnone,zlib\x00\x00\x00\x00\x00\x00'
1195                 '\x00\x00\x00\x00\x00\x00\x00\x99\x99\x99\x99\x99\x99\x99\x99'
1196                 '\x99')
1197         self.assertEqual(self.proto.kexAlg,
1198                           'diffie-hellman-group1-sha1')
1199         self.assertEqual(self.proto.keyAlg,
1200                           'ssh-dss')
1201         self.assertEqual(self.proto.outgoingCompressionType,
1202                           'none')
1203         self.assertEqual(self.proto.incomingCompressionType,
1204                           'none')
1205         ne = self.proto.nextEncryptions
1206         self.assertEqual(ne.outCipType, 'aes128-ctr')
1207         self.assertEqual(ne.inCipType, 'aes128-ctr')
1208         self.assertEqual(ne.outMACType, 'hmac-md5')
1209         self.assertEqual(ne.inMACType, 'hmac-md5')
1210
1211
1212     def test_ignoreGuessPacketKex(self):
1213         """
1214         The client is allowed to send a guessed key exchange packet
1215         after it sends the KEXINIT packet.  However, if the key exchanges
1216         do not match, that guess packet must be ignored.  This tests that
1217         the packet is ignored in the case of the key exchange method not
1218         matching.
1219         """
1220         kexInitPacket = '\x00' * 16 + (
1221             ''.join([common.NS(x) for x in
1222                      [','.join(y) for y in
1223                       [self.proto.supportedKeyExchanges[::-1],
1224                        self.proto.supportedPublicKeys,
1225                        self.proto.supportedCiphers,
1226                        self.proto.supportedCiphers,
1227                        self.proto.supportedMACs,
1228                        self.proto.supportedMACs,
1229                        self.proto.supportedCompressions,
1230                        self.proto.supportedCompressions,
1231                        self.proto.supportedLanguages,
1232                        self.proto.supportedLanguages]]])) + (
1233             '\xff\x00\x00\x00\x00')
1234         self.proto.ssh_KEXINIT(kexInitPacket)
1235         self.assertTrue(self.proto.ignoreNextPacket)
1236         self.proto.ssh_DEBUG("\x01\x00\x00\x00\x04test\x00\x00\x00\x00")
1237         self.assertTrue(self.proto.ignoreNextPacket)
1238
1239
1240         self.proto.ssh_KEX_DH_GEX_REQUEST_OLD('\x00\x00\x08\x00')
1241         self.assertFalse(self.proto.ignoreNextPacket)
1242         self.assertEqual(self.packets, [])
1243         self.proto.ignoreNextPacket = True
1244
1245         self.proto.ssh_KEX_DH_GEX_REQUEST('\x00\x00\x08\x00' * 3)
1246         self.assertFalse(self.proto.ignoreNextPacket)
1247         self.assertEqual(self.packets, [])
1248
1249
1250     def test_ignoreGuessPacketKey(self):
1251         """
1252         Like test_ignoreGuessPacketKex, but for an incorrectly guessed
1253         public key format.
1254         """
1255         kexInitPacket = '\x00' * 16 + (
1256             ''.join([common.NS(x) for x in
1257                      [','.join(y) for y in
1258                       [self.proto.supportedKeyExchanges,
1259                        self.proto.supportedPublicKeys[::-1],
1260                        self.proto.supportedCiphers,
1261                        self.proto.supportedCiphers,
1262                        self.proto.supportedMACs,
1263                        self.proto.supportedMACs,
1264                        self.proto.supportedCompressions,
1265                        self.proto.supportedCompressions,
1266                        self.proto.supportedLanguages,
1267                        self.proto.supportedLanguages]]])) + (
1268             '\xff\x00\x00\x00\x00')
1269         self.proto.ssh_KEXINIT(kexInitPacket)
1270         self.assertTrue(self.proto.ignoreNextPacket)
1271         self.proto.ssh_DEBUG("\x01\x00\x00\x00\x04test\x00\x00\x00\x00")
1272         self.assertTrue(self.proto.ignoreNextPacket)
1273
1274         self.proto.ssh_KEX_DH_GEX_REQUEST_OLD('\x00\x00\x08\x00')
1275         self.assertFalse(self.proto.ignoreNextPacket)
1276         self.assertEqual(self.packets, [])
1277         self.proto.ignoreNextPacket = True
1278
1279         self.proto.ssh_KEX_DH_GEX_REQUEST('\x00\x00\x08\x00' * 3)
1280         self.assertFalse(self.proto.ignoreNextPacket)
1281         self.assertEqual(self.packets, [])
1282
1283
1284     def test_KEXDH_INIT(self):
1285         """
1286         Test that the KEXDH_INIT packet causes the server to send a
1287         KEXDH_REPLY with the server's public key and a signature.
1288         """
1289         self.proto.supportedKeyExchanges = ['diffie-hellman-group1-sha1']
1290         self.proto.supportedPublicKeys = ['ssh-rsa']
1291         self.proto.dataReceived(self.transport.value())
1292         e = pow(transport.DH_GENERATOR, 5000,
1293                 transport.DH_PRIME)
1294
1295         self.proto.ssh_KEX_DH_GEX_REQUEST_OLD(common.MP(e))
1296         y = common.getMP('\x00\x00\x00\x40' + '\x99' * 64)[0]
1297         f = common._MPpow(transport.DH_GENERATOR, y, transport.DH_PRIME)
1298         sharedSecret = common._MPpow(e, y, transport.DH_PRIME)
1299
1300         h = sha1()
1301         h.update(common.NS(self.proto.ourVersionString) * 2)
1302         h.update(common.NS(self.proto.ourKexInitPayload) * 2)
1303         h.update(common.NS(self.proto.factory.publicKeys['ssh-rsa'].blob()))
1304         h.update(common.MP(e))
1305         h.update(f)
1306         h.update(sharedSecret)
1307         exchangeHash = h.digest()
1308
1309         signature = self.proto.factory.privateKeys['ssh-rsa'].sign(
1310                 exchangeHash)
1311
1312         self.assertEqual(
1313             self.packets,
1314             [(transport.MSG_KEXDH_REPLY,
1315               common.NS(self.proto.factory.publicKeys['ssh-rsa'].blob())
1316               + f + common.NS(signature)),
1317              (transport.MSG_NEWKEYS, '')])
1318
1319
1320     def test_KEX_DH_GEX_REQUEST_OLD(self):
1321         """
1322         Test that the KEX_DH_GEX_REQUEST_OLD message causes the server
1323         to reply with a KEX_DH_GEX_GROUP message with the correct
1324         Diffie-Hellman group.
1325         """
1326         self.proto.supportedKeyExchanges = [
1327                 'diffie-hellman-group-exchange-sha1']
1328         self.proto.supportedPublicKeys = ['ssh-rsa']
1329         self.proto.dataReceived(self.transport.value())
1330         self.proto.ssh_KEX_DH_GEX_REQUEST_OLD('\x00\x00\x04\x00')
1331         self.assertEqual(
1332             self.packets,
1333             [(transport.MSG_KEX_DH_GEX_GROUP,
1334               common.MP(transport.DH_PRIME) + '\x00\x00\x00\x01\x02')])
1335         self.assertEqual(self.proto.g, 2)
1336         self.assertEqual(self.proto.p, transport.DH_PRIME)
1337
1338
1339     def test_KEX_DH_GEX_REQUEST_OLD_badKexAlg(self):
1340         """
1341         Test that if the server recieves a KEX_DH_GEX_REQUEST_OLD message
1342         and the key exchange algorithm is not 'diffie-hellman-group1-sha1' or
1343         'diffie-hellman-group-exchange-sha1', we raise a ConchError.
1344         """
1345         self.proto.kexAlg = None
1346         self.assertRaises(ConchError, self.proto.ssh_KEX_DH_GEX_REQUEST_OLD,
1347                 None)
1348
1349
1350     def test_KEX_DH_GEX_REQUEST(self):
1351         """
1352         Test that the KEX_DH_GEX_REQUEST message causes the server to reply
1353         with a KEX_DH_GEX_GROUP message with the correct Diffie-Hellman
1354         group.
1355         """
1356         self.proto.supportedKeyExchanges = [
1357             'diffie-hellman-group-exchange-sha1']
1358         self.proto.supportedPublicKeys = ['ssh-rsa']
1359         self.proto.dataReceived(self.transport.value())
1360         self.proto.ssh_KEX_DH_GEX_REQUEST('\x00\x00\x04\x00\x00\x00\x08\x00' +
1361                                           '\x00\x00\x0c\x00')
1362         self.assertEqual(
1363             self.packets,
1364             [(transport.MSG_KEX_DH_GEX_GROUP,
1365               common.MP(transport.DH_PRIME) + '\x00\x00\x00\x01\x03')])
1366         self.assertEqual(self.proto.g, 3)
1367         self.assertEqual(self.proto.p, transport.DH_PRIME)
1368
1369
1370     def test_KEX_DH_GEX_INIT_after_REQUEST(self):
1371         """
1372         Test that the KEX_DH_GEX_INIT message after the client sends
1373         KEX_DH_GEX_REQUEST causes the server to send a KEX_DH_GEX_INIT message
1374         with a public key and signature.
1375         """
1376         self.test_KEX_DH_GEX_REQUEST()
1377         e = pow(self.proto.g, 3, self.proto.p)
1378         y = common.getMP('\x00\x00\x00\x80' + '\x99' * 128)[0]
1379         f = common._MPpow(self.proto.g, y, self.proto.p)
1380         sharedSecret = common._MPpow(e, y, self.proto.p)
1381         h = sha1()
1382         h.update(common.NS(self.proto.ourVersionString) * 2)
1383         h.update(common.NS(self.proto.ourKexInitPayload) * 2)
1384         h.update(common.NS(self.proto.factory.publicKeys['ssh-rsa'].blob()))
1385         h.update('\x00\x00\x04\x00\x00\x00\x08\x00\x00\x00\x0c\x00')
1386         h.update(common.MP(self.proto.p))
1387         h.update(common.MP(self.proto.g))
1388         h.update(common.MP(e))
1389         h.update(f)
1390         h.update(sharedSecret)
1391         exchangeHash = h.digest()
1392         self.proto.ssh_KEX_DH_GEX_INIT(common.MP(e))
1393         self.assertEqual(
1394             self.packets[1],
1395             (transport.MSG_KEX_DH_GEX_REPLY,
1396              common.NS(self.proto.factory.publicKeys['ssh-rsa'].blob()) +
1397              f + common.NS(self.proto.factory.privateKeys['ssh-rsa'].sign(
1398                         exchangeHash))))
1399
1400
1401     def test_KEX_DH_GEX_INIT_after_REQUEST_OLD(self):
1402         """
1403         Test that the KEX_DH_GEX_INIT message after the client sends
1404         KEX_DH_GEX_REQUEST_OLD causes the server to sent a KEX_DH_GEX_INIT
1405         message with a public key and signature.
1406         """
1407         self.test_KEX_DH_GEX_REQUEST_OLD()
1408         e = pow(self.proto.g, 3, self.proto.p)
1409         y = common.getMP('\x00\x00\x00\x80' + '\x99' * 128)[0]
1410         f = common._MPpow(self.proto.g, y, self.proto.p)
1411         sharedSecret = common._MPpow(e, y, self.proto.p)
1412         h = sha1()
1413         h.update(common.NS(self.proto.ourVersionString) * 2)
1414         h.update(common.NS(self.proto.ourKexInitPayload) * 2)
1415         h.update(common.NS(self.proto.factory.publicKeys['ssh-rsa'].blob()))
1416         h.update('\x00\x00\x04\x00')
1417         h.update(common.MP(self.proto.p))
1418         h.update(common.MP(self.proto.g))
1419         h.update(common.MP(e))
1420         h.update(f)
1421         h.update(sharedSecret)
1422         exchangeHash = h.digest()
1423         self.proto.ssh_KEX_DH_GEX_INIT(common.MP(e))
1424         self.assertEqual(
1425             self.packets[1:],
1426             [(transport.MSG_KEX_DH_GEX_REPLY,
1427               common.NS(self.proto.factory.publicKeys['ssh-rsa'].blob()) +
1428               f + common.NS(self.proto.factory.privateKeys['ssh-rsa'].sign(
1429                             exchangeHash))),
1430              (transport.MSG_NEWKEYS, '')])
1431
1432
1433     def test_keySetup(self):
1434         """
1435         Test that _keySetup sets up the next encryption keys.
1436         """
1437         self.proto.nextEncryptions = MockCipher()
1438         self.simulateKeyExchange('AB', 'CD')
1439         self.assertEqual(self.proto.sessionID, 'CD')
1440         self.simulateKeyExchange('AB', 'EF')
1441         self.assertEqual(self.proto.sessionID, 'CD')
1442         self.assertEqual(self.packets[-1], (transport.MSG_NEWKEYS, ''))
1443         newKeys = [self.proto._getKey(c, 'AB', 'EF') for c in 'ABCDEF']
1444         self.assertEqual(
1445             self.proto.nextEncryptions.keys,
1446             (newKeys[1], newKeys[3], newKeys[0], newKeys[2], newKeys[5],
1447              newKeys[4]))
1448
1449
1450     def test_NEWKEYS(self):
1451         """
1452         Test that NEWKEYS transitions the keys in nextEncryptions to
1453         currentEncryptions.
1454         """
1455         self.test_KEXINIT()
1456
1457         self.proto.nextEncryptions = transport.SSHCiphers('none', 'none',
1458                                                           'none', 'none')
1459         self.proto.ssh_NEWKEYS('')
1460         self.assertIdentical(self.proto.currentEncryptions,
1461                              self.proto.nextEncryptions)
1462         self.assertIdentical(self.proto.outgoingCompression, None)
1463         self.assertIdentical(self.proto.incomingCompression, None)
1464         self.proto.outgoingCompressionType = 'zlib'
1465         self.simulateKeyExchange('AB', 'CD')
1466         self.proto.ssh_NEWKEYS('')
1467         self.failIfIdentical(self.proto.outgoingCompression, None)
1468         self.proto.incomingCompressionType = 'zlib'
1469         self.simulateKeyExchange('AB', 'EF')
1470         self.proto.ssh_NEWKEYS('')
1471         self.failIfIdentical(self.proto.incomingCompression, None)
1472
1473
1474     def test_SERVICE_REQUEST(self):
1475         """
1476         Test that the SERVICE_REQUEST message requests and starts a
1477         service.
1478         """
1479         self.proto.ssh_SERVICE_REQUEST(common.NS('ssh-userauth'))
1480         self.assertEqual(self.packets, [(transport.MSG_SERVICE_ACCEPT,
1481                                           common.NS('ssh-userauth'))])
1482         self.assertEqual(self.proto.service.name, 'MockService')
1483
1484
1485     def test_disconnectNEWKEYSData(self):
1486         """
1487         Test that NEWKEYS disconnects if it receives data.
1488         """
1489         self.proto.ssh_NEWKEYS("bad packet")
1490         self.checkDisconnected()
1491
1492
1493     def test_disconnectSERVICE_REQUESTBadService(self):
1494         """
1495         Test that SERVICE_REQUESTS disconnects if an unknown service is
1496         requested.
1497         """
1498         self.proto.ssh_SERVICE_REQUEST(common.NS('no service'))
1499         self.checkDisconnected(transport.DISCONNECT_SERVICE_NOT_AVAILABLE)
1500
1501
1502
1503 class ClientSSHTransportTestCase(ServerAndClientSSHTransportBaseCase,
1504         TransportTestCase):
1505     """
1506     Tests for SSHClientTransport.
1507     """
1508
1509     klass = transport.SSHClientTransport
1510
1511
1512     def test_KEXINIT(self):
1513         """
1514         Test that receiving a KEXINIT packet sets up the correct values on the
1515         client.  The way algorithms are picks is that the first item in the
1516         client's list that is also in the server's list is chosen.
1517         """
1518         self.proto.dataReceived( 'SSH-2.0-Twisted\r\n\x00\x00\x01\xd4\t\x14'
1519                 '\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99'
1520                 '\x99\x00\x00\x00=diffie-hellman-group1-sha1,diffie-hellman-g'
1521                 'roup-exchange-sha1\x00\x00\x00\x0fssh-dss,ssh-rsa\x00\x00\x00'
1522                 '\x85aes128-ctr,aes128-cbc,aes192-ctr,aes192-cbc,aes256-ctr,ae'
1523                 's256-cbc,cast128-ctr,cast128-cbc,blowfish-ctr,blowfish-cbc,3d'
1524                 'es-ctr,3des-cbc\x00\x00\x00\x85aes128-ctr,aes128-cbc,aes192-c'
1525                 'tr,aes192-cbc,aes256-ctr,aes256-cbc,cast128-ctr,cast128-cbc,b'
1526                 'lowfish-ctr,blowfish-cbc,3des-ctr,3des-cbc\x00\x00\x00\x12hma'
1527                 'c-md5,hmac-sha1\x00\x00\x00\x12hmac-md5,hmac-sha1\x00\x00\x00'
1528                 '\tzlib,none\x00\x00\x00\tzlib,none\x00\x00\x00\x00\x00\x00'
1529                 '\x00\x00\x00\x00\x00\x00\x00\x99\x99\x99\x99\x99\x99\x99\x99'
1530                 '\x99')
1531         self.assertEqual(self.proto.kexAlg,
1532                           'diffie-hellman-group-exchange-sha1')
1533         self.assertEqual(self.proto.keyAlg,
1534                           'ssh-rsa')
1535         self.assertEqual(self.proto.outgoingCompressionType,
1536                           'none')
1537         self.assertEqual(self.proto.incomingCompressionType,
1538                           'none')
1539         ne = self.proto.nextEncryptions
1540         self.assertEqual(ne.outCipType, 'aes256-ctr')
1541         self.assertEqual(ne.inCipType, 'aes256-ctr')
1542         self.assertEqual(ne.outMACType, 'hmac-sha1')
1543         self.assertEqual(ne.inMACType, 'hmac-sha1')
1544
1545
1546     def verifyHostKey(self, pubKey, fingerprint):
1547         """
1548         Mock version of SSHClientTransport.verifyHostKey.
1549         """
1550         self.calledVerifyHostKey = True
1551         self.assertEqual(pubKey, self.blob)
1552         self.assertEqual(fingerprint.replace(':', ''),
1553                           md5(pubKey).hexdigest())
1554         return defer.succeed(True)
1555
1556
1557     def setUp(self):
1558         TransportTestCase.setUp(self)
1559         self.blob = keys.Key.fromString(keydata.publicRSA_openssh).blob()
1560         self.privObj = keys.Key.fromString(keydata.privateRSA_openssh)
1561         self.calledVerifyHostKey = False
1562         self.proto.verifyHostKey = self.verifyHostKey
1563
1564
1565     def test_notImplementedClientMethods(self):
1566         """
1567         verifyHostKey() should return a Deferred which fails with a
1568         NotImplementedError exception.  connectionSecure() should raise
1569         NotImplementedError().
1570         """
1571         self.assertRaises(NotImplementedError, self.klass().connectionSecure)
1572         def _checkRaises(f):
1573             f.trap(NotImplementedError)
1574         d = self.klass().verifyHostKey(None, None)
1575         return d.addCallback(self.fail).addErrback(_checkRaises)
1576
1577
1578     def test_KEXINIT_groupexchange(self):
1579         """
1580         Test that a KEXINIT packet with a group-exchange key exchange results
1581         in a KEX_DH_GEX_REQUEST_OLD message..
1582         """
1583         self.proto.supportedKeyExchanges = [
1584             'diffie-hellman-group-exchange-sha1']
1585         self.proto.dataReceived(self.transport.value())
1586         self.assertEqual(self.packets, [(transport.MSG_KEX_DH_GEX_REQUEST_OLD,
1587                                           '\x00\x00\x08\x00')])
1588
1589
1590     def test_KEXINIT_group1(self):
1591         """
1592         Like test_KEXINIT_groupexchange, but for the group-1 key exchange.
1593         """
1594         self.proto.supportedKeyExchanges = ['diffie-hellman-group1-sha1']
1595         self.proto.dataReceived(self.transport.value())
1596         self.assertEqual(common.MP(self.proto.x)[5:], '\x99' * 64)
1597         self.assertEqual(self.packets,
1598                           [(transport.MSG_KEXDH_INIT, self.proto.e)])
1599
1600
1601     def test_KEXINIT_badKexAlg(self):
1602         """
1603         Test that the client raises a ConchError if it receives a
1604         KEXINIT message bug doesn't have a key exchange algorithm that we
1605         understand.
1606         """
1607         self.proto.supportedKeyExchanges = ['diffie-hellman-group2-sha1']
1608         data = self.transport.value().replace('group1', 'group2')
1609         self.assertRaises(ConchError, self.proto.dataReceived, data)
1610
1611
1612     def test_KEXDH_REPLY(self):
1613         """
1614         Test that the KEXDH_REPLY message verifies the server.
1615         """
1616         self.test_KEXINIT_group1()
1617
1618         sharedSecret = common._MPpow(transport.DH_GENERATOR,
1619                                      self.proto.x, transport.DH_PRIME)
1620         h = sha1()
1621         h.update(common.NS(self.proto.ourVersionString) * 2)
1622         h.update(common.NS(self.proto.ourKexInitPayload) * 2)
1623         h.update(common.NS(self.blob))
1624         h.update(self.proto.e)
1625         h.update('\x00\x00\x00\x01\x02') # f
1626         h.update(sharedSecret)
1627         exchangeHash = h.digest()
1628
1629         def _cbTestKEXDH_REPLY(value):
1630             self.assertIdentical(value, None)
1631             self.assertEqual(self.calledVerifyHostKey, True)
1632             self.assertEqual(self.proto.sessionID, exchangeHash)
1633
1634         signature = self.privObj.sign(exchangeHash)
1635
1636         d = self.proto.ssh_KEX_DH_GEX_GROUP(
1637             (common.NS(self.blob) + '\x00\x00\x00\x01\x02' +
1638              common.NS(signature)))
1639         d.addCallback(_cbTestKEXDH_REPLY)
1640
1641         return d
1642
1643
1644     def test_KEX_DH_GEX_GROUP(self):
1645         """
1646         Test that the KEX_DH_GEX_GROUP message results in a
1647         KEX_DH_GEX_INIT message with the client's Diffie-Hellman public key.
1648         """
1649         self.test_KEXINIT_groupexchange()
1650         self.proto.ssh_KEX_DH_GEX_GROUP(
1651             '\x00\x00\x00\x01\x0f\x00\x00\x00\x01\x02')
1652         self.assertEqual(self.proto.p, 15)
1653         self.assertEqual(self.proto.g, 2)
1654         self.assertEqual(common.MP(self.proto.x)[5:], '\x99' * 40)
1655         self.assertEqual(self.proto.e,
1656                           common.MP(pow(2, self.proto.x, 15)))
1657         self.assertEqual(self.packets[1:], [(transport.MSG_KEX_DH_GEX_INIT,
1658                                               self.proto.e)])
1659
1660
1661     def test_KEX_DH_GEX_REPLY(self):
1662         """
1663         Test that the KEX_DH_GEX_REPLY message results in a verified
1664         server.
1665         """
1666
1667         self.test_KEX_DH_GEX_GROUP()
1668         sharedSecret = common._MPpow(3, self.proto.x, self.proto.p)
1669         h = sha1()
1670         h.update(common.NS(self.proto.ourVersionString) * 2)
1671         h.update(common.NS(self.proto.ourKexInitPayload) * 2)
1672         h.update(common.NS(self.blob))
1673         h.update('\x00\x00\x08\x00\x00\x00\x00\x01\x0f\x00\x00\x00\x01\x02')
1674         h.update(self.proto.e)
1675         h.update('\x00\x00\x00\x01\x03') # f
1676         h.update(sharedSecret)
1677         exchangeHash = h.digest()
1678
1679         def _cbTestKEX_DH_GEX_REPLY(value):
1680             self.assertIdentical(value, None)
1681             self.assertEqual(self.calledVerifyHostKey, True)
1682             self.assertEqual(self.proto.sessionID, exchangeHash)
1683
1684         signature = self.privObj.sign(exchangeHash)
1685
1686         d = self.proto.ssh_KEX_DH_GEX_REPLY(
1687             common.NS(self.blob) +
1688             '\x00\x00\x00\x01\x03' +
1689             common.NS(signature))
1690         d.addCallback(_cbTestKEX_DH_GEX_REPLY)
1691         return d
1692
1693
1694     def test_keySetup(self):
1695         """
1696         Test that _keySetup sets up the next encryption keys.
1697         """
1698         self.proto.nextEncryptions = MockCipher()
1699         self.simulateKeyExchange('AB', 'CD')
1700         self.assertEqual(self.proto.sessionID, 'CD')
1701         self.simulateKeyExchange('AB', 'EF')
1702         self.assertEqual(self.proto.sessionID, 'CD')
1703         self.assertEqual(self.packets[-1], (transport.MSG_NEWKEYS, ''))
1704         newKeys = [self.proto._getKey(c, 'AB', 'EF') for c in 'ABCDEF']
1705         self.assertEqual(self.proto.nextEncryptions.keys,
1706                           (newKeys[0], newKeys[2], newKeys[1], newKeys[3],
1707                            newKeys[4], newKeys[5]))
1708
1709
1710     def test_NEWKEYS(self):
1711         """
1712         Test that NEWKEYS transitions the keys from nextEncryptions to
1713         currentEncryptions.
1714         """
1715         self.test_KEXINIT()
1716         secure = [False]
1717         def stubConnectionSecure():
1718             secure[0] = True
1719         self.proto.connectionSecure = stubConnectionSecure
1720
1721         self.proto.nextEncryptions = transport.SSHCiphers(
1722             'none', 'none', 'none', 'none')
1723         self.simulateKeyExchange('AB', 'CD')
1724         self.assertNotIdentical(
1725             self.proto.currentEncryptions, self.proto.nextEncryptions)
1726
1727         self.proto.nextEncryptions = MockCipher()
1728         self.proto.ssh_NEWKEYS('')
1729         self.assertIdentical(self.proto.outgoingCompression, None)
1730         self.assertIdentical(self.proto.incomingCompression, None)
1731         self.assertIdentical(self.proto.currentEncryptions,
1732                              self.proto.nextEncryptions)
1733         self.assertTrue(secure[0])
1734         self.proto.outgoingCompressionType = 'zlib'
1735         self.simulateKeyExchange('AB', 'GH')
1736         self.proto.ssh_NEWKEYS('')
1737         self.failIfIdentical(self.proto.outgoingCompression, None)
1738         self.proto.incomingCompressionType = 'zlib'
1739         self.simulateKeyExchange('AB', 'IJ')
1740         self.proto.ssh_NEWKEYS('')
1741         self.failIfIdentical(self.proto.incomingCompression, None)
1742
1743
1744     def test_SERVICE_ACCEPT(self):
1745         """
1746         Test that the SERVICE_ACCEPT packet starts the requested service.
1747         """
1748         self.proto.instance = MockService()
1749         self.proto.ssh_SERVICE_ACCEPT('\x00\x00\x00\x0bMockService')
1750         self.assertTrue(self.proto.instance.started)
1751
1752
1753     def test_requestService(self):
1754         """
1755         Test that requesting a service sends a SERVICE_REQUEST packet.
1756         """
1757         self.proto.requestService(MockService())
1758         self.assertEqual(self.packets, [(transport.MSG_SERVICE_REQUEST,
1759                                           '\x00\x00\x00\x0bMockService')])
1760
1761
1762     def test_disconnectKEXDH_REPLYBadSignature(self):
1763         """
1764         Test that KEXDH_REPLY disconnects if the signature is bad.
1765         """
1766         self.test_KEXDH_REPLY()
1767         self.proto._continueKEXDH_REPLY(None, self.blob, 3, "bad signature")
1768         self.checkDisconnected(transport.DISCONNECT_KEY_EXCHANGE_FAILED)
1769
1770
1771     def test_disconnectGEX_REPLYBadSignature(self):
1772         """
1773         Like test_disconnectKEXDH_REPLYBadSignature, but for DH_GEX_REPLY.
1774         """
1775         self.test_KEX_DH_GEX_REPLY()
1776         self.proto._continueGEX_REPLY(None, self.blob, 3, "bad signature")
1777         self.checkDisconnected(transport.DISCONNECT_KEY_EXCHANGE_FAILED)
1778
1779
1780     def test_disconnectNEWKEYSData(self):
1781         """
1782         Test that NEWKEYS disconnects if it receives data.
1783         """
1784         self.proto.ssh_NEWKEYS("bad packet")
1785         self.checkDisconnected()
1786
1787
1788     def test_disconnectSERVICE_ACCEPT(self):
1789         """
1790         Test that SERVICE_ACCEPT disconnects if the accepted protocol is
1791         differet from the asked-for protocol.
1792         """
1793         self.proto.instance = MockService()
1794         self.proto.ssh_SERVICE_ACCEPT('\x00\x00\x00\x03bad')
1795         self.checkDisconnected()
1796
1797
1798
1799 class SSHCiphersTestCase(unittest.TestCase):
1800     """
1801     Tests for the SSHCiphers helper class.
1802     """
1803     if Crypto is None:
1804         skip = "cannot run w/o PyCrypto"
1805
1806     if pyasn1 is None:
1807         skip = "Cannot run without PyASN1"
1808
1809
1810     def test_init(self):
1811         """
1812         Test that the initializer sets up the SSHCiphers object.
1813         """
1814         ciphers = transport.SSHCiphers('A', 'B', 'C', 'D')
1815         self.assertEqual(ciphers.outCipType, 'A')
1816         self.assertEqual(ciphers.inCipType, 'B')
1817         self.assertEqual(ciphers.outMACType, 'C')
1818         self.assertEqual(ciphers.inMACType, 'D')
1819
1820
1821     def test_getCipher(self):
1822         """
1823         Test that the _getCipher method returns the correct cipher.
1824         """
1825         ciphers = transport.SSHCiphers('A', 'B', 'C', 'D')
1826         iv = key = '\x00' * 16
1827         for cipName, (modName, keySize, counter) in ciphers.cipherMap.items():
1828             cip = ciphers._getCipher(cipName, iv, key)
1829             if cipName == 'none':
1830                 self.assertIsInstance(cip, transport._DummyCipher)
1831             else:
1832                 self.assertTrue(str(cip).startswith('<' + modName))
1833
1834
1835     def test_getMAC(self):
1836         """
1837         Test that the _getMAC method returns the correct MAC.
1838         """
1839         ciphers = transport.SSHCiphers('A', 'B', 'C', 'D')
1840         key = '\x00' * 64
1841         for macName, mac in ciphers.macMap.items():
1842             mod = ciphers._getMAC(macName, key)
1843             if macName == 'none':
1844                 self.assertIdentical(mac, None)
1845             else:
1846                 self.assertEqual(mod[0], mac)
1847                 self.assertEqual(mod[1],
1848                                   Crypto.Cipher.XOR.new('\x36').encrypt(key))
1849                 self.assertEqual(mod[2],
1850                                   Crypto.Cipher.XOR.new('\x5c').encrypt(key))
1851                 self.assertEqual(mod[3], len(mod[0]().digest()))
1852
1853
1854     def test_setKeysCiphers(self):
1855         """
1856         Test that setKeys sets up the ciphers.
1857         """
1858         key = '\x00' * 64
1859         cipherItems = transport.SSHCiphers.cipherMap.items()
1860         for cipName, (modName, keySize, counter) in cipherItems:
1861             encCipher = transport.SSHCiphers(cipName, 'none', 'none', 'none')
1862             decCipher = transport.SSHCiphers('none', cipName, 'none', 'none')
1863             cip = encCipher._getCipher(cipName, key, key)
1864             bs = cip.block_size
1865             encCipher.setKeys(key, key, '', '', '', '')
1866             decCipher.setKeys('', '', key, key, '', '')
1867             self.assertEqual(encCipher.encBlockSize, bs)
1868             self.assertEqual(decCipher.decBlockSize, bs)
1869             enc = cip.encrypt(key[:bs])
1870             enc2 = cip.encrypt(key[:bs])
1871             if counter:
1872                 self.failIfEquals(enc, enc2)
1873             self.assertEqual(encCipher.encrypt(key[:bs]), enc)
1874             self.assertEqual(encCipher.encrypt(key[:bs]), enc2)
1875             self.assertEqual(decCipher.decrypt(enc), key[:bs])
1876             self.assertEqual(decCipher.decrypt(enc2), key[:bs])
1877
1878
1879     def test_setKeysMACs(self):
1880         """
1881         Test that setKeys sets up the MACs.
1882         """
1883         key = '\x00' * 64
1884         for macName, mod in transport.SSHCiphers.macMap.items():
1885             outMac = transport.SSHCiphers('none', 'none', macName, 'none')
1886             inMac = transport.SSHCiphers('none', 'none', 'none', macName)
1887             outMac.setKeys('', '', '', '', key, '')
1888             inMac.setKeys('', '', '', '', '', key)
1889             if mod:
1890                 ds = mod().digest_size
1891             else:
1892                 ds = 0
1893             self.assertEqual(inMac.verifyDigestSize, ds)
1894             if mod:
1895                 mod, i, o, ds = outMac._getMAC(macName, key)
1896             seqid = 0
1897             data = key
1898             packet = '\x00' * 4 + key
1899             if mod:
1900                 mac = mod(o + mod(i + packet).digest()).digest()
1901             else:
1902                 mac = ''
1903             self.assertEqual(outMac.makeMAC(seqid, data), mac)
1904             self.assertTrue(inMac.verify(seqid, data, mac))
1905
1906
1907
1908 class CounterTestCase(unittest.TestCase):
1909     """
1910     Tests for the _Counter helper class.
1911     """
1912     if Crypto is None:
1913         skip = "cannot run w/o PyCrypto"
1914
1915     if pyasn1 is None:
1916         skip = "Cannot run without PyASN1"
1917
1918
1919     def test_init(self):
1920         """
1921         Test that the counter is initialized correctly.
1922         """
1923         counter = transport._Counter('\x00' * 8 + '\xff' * 8, 8)
1924         self.assertEqual(counter.blockSize, 8)
1925         self.assertEqual(counter.count.tostring(), '\x00' * 8)
1926
1927
1928     def test_count(self):
1929         """
1930         Test that the counter counts incrementally and wraps at the top.
1931         """
1932         counter = transport._Counter('\x00', 1)
1933         self.assertEqual(counter(), '\x01')
1934         self.assertEqual(counter(), '\x02')
1935         [counter() for i in range(252)]
1936         self.assertEqual(counter(), '\xff')
1937         self.assertEqual(counter(), '\x00')
1938
1939
1940
1941 class TransportLoopbackTestCase(unittest.TestCase):
1942     """
1943     Test the server transport and client transport against each other,
1944     """
1945     if Crypto is None:
1946         skip = "cannot run w/o PyCrypto"
1947
1948     if pyasn1 is None:
1949         skip = "Cannot run without PyASN1"
1950
1951
1952     def _runClientServer(self, mod):
1953         """
1954         Run an async client and server, modifying each using the mod function
1955         provided.  Returns a Deferred called back when both Protocols have
1956         disconnected.
1957
1958         @type mod: C{func}
1959         @rtype: C{defer.Deferred}
1960         """
1961         factory = MockFactory()
1962         server = transport.SSHServerTransport()
1963         server.factory = factory
1964         factory.startFactory()
1965         server.errors = []
1966         server.receiveError = lambda code, desc: server.errors.append((
1967                 code, desc))
1968         client = transport.SSHClientTransport()
1969         client.verifyHostKey = lambda x, y: defer.succeed(None)
1970         client.errors = []
1971         client.receiveError = lambda code, desc: client.errors.append((
1972                 code, desc))
1973         client.connectionSecure = lambda: client.loseConnection()
1974         server = mod(server)
1975         client = mod(client)
1976         def check(ignored, server, client):
1977             name = repr([server.supportedCiphers[0],
1978                          server.supportedMACs[0],
1979                          server.supportedKeyExchanges[0],
1980                          server.supportedCompressions[0]])
1981             self.assertEqual(client.errors, [])
1982             self.assertEqual(server.errors, [(
1983                         transport.DISCONNECT_CONNECTION_LOST,
1984                         "user closed connection")])
1985             if server.supportedCiphers[0] == 'none':
1986                 self.assertFalse(server.isEncrypted(), name)
1987                 self.assertFalse(client.isEncrypted(), name)
1988             else:
1989                 self.assertTrue(server.isEncrypted(), name)
1990                 self.assertTrue(client.isEncrypted(), name)
1991             if server.supportedMACs[0] == 'none':
1992                 self.assertFalse(server.isVerified(), name)
1993                 self.assertFalse(client.isVerified(), name)
1994             else:
1995                 self.assertTrue(server.isVerified(), name)
1996                 self.assertTrue(client.isVerified(), name)
1997
1998         d = loopback.loopbackAsync(server, client)
1999         d.addCallback(check, server, client)
2000         return d
2001
2002
2003     def test_ciphers(self):
2004         """
2005         Test that the client and server play nicely together, in all
2006         the various combinations of ciphers.
2007         """
2008         deferreds = []
2009         for cipher in transport.SSHTransportBase.supportedCiphers + ['none']:
2010             def setCipher(proto):
2011                 proto.supportedCiphers = [cipher]
2012                 return proto
2013             deferreds.append(self._runClientServer(setCipher))
2014         return defer.DeferredList(deferreds, fireOnOneErrback=True)
2015
2016
2017     def test_macs(self):
2018         """
2019         Like test_ciphers, but for the various MACs.
2020         """
2021         deferreds = []
2022         for mac in transport.SSHTransportBase.supportedMACs + ['none']:
2023             def setMAC(proto):
2024                 proto.supportedMACs = [mac]
2025                 return proto
2026             deferreds.append(self._runClientServer(setMAC))
2027         return defer.DeferredList(deferreds, fireOnOneErrback=True)
2028
2029
2030     def test_keyexchanges(self):
2031         """
2032         Like test_ciphers, but for the various key exchanges.
2033         """
2034         deferreds = []
2035         for kex in transport.SSHTransportBase.supportedKeyExchanges:
2036             def setKeyExchange(proto):
2037                 proto.supportedKeyExchanges = [kex]
2038                 return proto
2039             deferreds.append(self._runClientServer(setKeyExchange))
2040         return defer.DeferredList(deferreds, fireOnOneErrback=True)
2041
2042
2043     def test_compressions(self):
2044         """
2045         Like test_ciphers, but for the various compressions.
2046         """
2047         deferreds = []
2048         for compression in transport.SSHTransportBase.supportedCompressions:
2049             def setCompression(proto):
2050                 proto.supportedCompressions = [compression]
2051                 return proto
2052             deferreds.append(self._runClientServer(setCompression))
2053         return defer.DeferredList(deferreds, fireOnOneErrback=True)
2054
2055
2056 class RandomNumberTestCase(unittest.TestCase):
2057     """
2058     Tests for the random number generator L{_getRandomNumber} and private
2059     key generator L{_generateX}.
2060     """
2061     skip = dependencySkip
2062
2063     def test_usesSuppliedRandomFunction(self):
2064         """
2065         L{_getRandomNumber} returns an integer constructed directly from the
2066         bytes returned by the random byte generator passed to it.
2067         """
2068         def random(bytes):
2069             # The number of bytes requested will be the value of each byte
2070             # we return.
2071             return chr(bytes) * bytes
2072         self.assertEqual(
2073             transport._getRandomNumber(random, 32),
2074             4 << 24 | 4 << 16 | 4 << 8 | 4)
2075
2076
2077     def test_rejectsNonByteMultiples(self):
2078         """
2079         L{_getRandomNumber} raises L{ValueError} if the number of bits
2080         passed to L{_getRandomNumber} is not a multiple of 8.
2081         """
2082         self.assertRaises(
2083             ValueError,
2084             transport._getRandomNumber, None, 9)
2085
2086
2087     def test_excludesSmall(self):
2088         """
2089         If the random byte generator passed to L{_generateX} produces bytes
2090         which would result in 0 or 1 being returned, these bytes are
2091         discarded and another attempt is made to produce a larger value.
2092         """
2093         results = [chr(0), chr(1), chr(127)]
2094         def random(bytes):
2095             return results.pop(0) * bytes
2096         self.assertEqual(
2097             transport._generateX(random, 8),
2098             127)
2099
2100
2101     def test_excludesLarge(self):
2102         """
2103         If the random byte generator passed to L{_generateX} produces bytes
2104         which would result in C{(2 ** bits) - 1} being returned, these bytes
2105         are discarded and another attempt is made to produce a smaller
2106         value.
2107         """
2108         results = [chr(255), chr(64)]
2109         def random(bytes):
2110             return results.pop(0) * bytes
2111         self.assertEqual(
2112             transport._generateX(random, 8),
2113             64)
2114
2115
2116
2117 class OldFactoryTestCase(unittest.TestCase):
2118     """
2119     The old C{SSHFactory.getPublicKeys}() returned mappings of key names to
2120     strings of key blobs and mappings of key names to PyCrypto key objects from
2121     C{SSHFactory.getPrivateKeys}() (they could also be specified with the
2122     C{publicKeys} and C{privateKeys} attributes).  This is no longer supported
2123     by the C{SSHServerTransport}, so we warn the user if they create an old
2124     factory.
2125     """
2126
2127     if Crypto is None:
2128         skip = "cannot run w/o PyCrypto"
2129
2130     if pyasn1 is None:
2131         skip = "Cannot run without PyASN1"
2132
2133
2134     def test_getPublicKeysWarning(self):
2135         """
2136         If the return value of C{getPublicKeys}() isn't a mapping from key
2137         names to C{Key} objects, then warn the user and convert the mapping.
2138         """
2139         sshFactory = MockOldFactoryPublicKeys()
2140         self.assertWarns(DeprecationWarning,
2141                 "Returning a mapping from strings to strings from"
2142                 " getPublicKeys()/publicKeys (in %s) is deprecated.  Return "
2143                 "a mapping from strings to Key objects instead." %
2144                 (qual(MockOldFactoryPublicKeys),),
2145                 factory.__file__, sshFactory.startFactory)
2146         self.assertEqual(sshFactory.publicKeys, MockFactory().getPublicKeys())
2147
2148
2149     def test_getPrivateKeysWarning(self):
2150         """
2151         If the return value of C{getPrivateKeys}() isn't a mapping from key
2152         names to C{Key} objects, then warn the user and convert the mapping.
2153         """
2154         sshFactory = MockOldFactoryPrivateKeys()
2155         self.assertWarns(DeprecationWarning,
2156                 "Returning a mapping from strings to PyCrypto key objects from"
2157                 " getPrivateKeys()/privateKeys (in %s) is deprecated.  Return"
2158                 " a mapping from strings to Key objects instead." %
2159                 (qual(MockOldFactoryPrivateKeys),),
2160                 factory.__file__, sshFactory.startFactory)
2161         self.assertEqual(sshFactory.privateKeys,
2162                           MockFactory().getPrivateKeys())
2163
2164
2165     def test_publicKeysWarning(self):
2166         """
2167         If the value of the C{publicKeys} attribute isn't a mapping from key
2168         names to C{Key} objects, then warn the user and convert the mapping.
2169         """
2170         sshFactory = MockOldFactoryPublicKeys()
2171         sshFactory.publicKeys = sshFactory.getPublicKeys()
2172         self.assertWarns(DeprecationWarning,
2173                 "Returning a mapping from strings to strings from"
2174                 " getPublicKeys()/publicKeys (in %s) is deprecated.  Return "
2175                 "a mapping from strings to Key objects instead." %
2176                 (qual(MockOldFactoryPublicKeys),),
2177                 factory.__file__, sshFactory.startFactory)
2178         self.assertEqual(sshFactory.publicKeys, MockFactory().getPublicKeys())
2179
2180
2181     def test_privateKeysWarning(self):
2182         """
2183         If the return value of C{privateKeys} attribute isn't a mapping from
2184         key names to C{Key} objects, then warn the user and convert the
2185         mapping.
2186         """
2187         sshFactory = MockOldFactoryPrivateKeys()
2188         sshFactory.privateKeys = sshFactory.getPrivateKeys()
2189         self.assertWarns(DeprecationWarning,
2190                 "Returning a mapping from strings to PyCrypto key objects from"
2191                 " getPrivateKeys()/privateKeys (in %s) is deprecated.  Return"
2192                 " a mapping from strings to Key objects instead." %
2193                 (qual(MockOldFactoryPrivateKeys),),
2194                 factory.__file__, sshFactory.startFactory)
2195         self.assertEqual(sshFactory.privateKeys,
2196                           MockFactory().getPrivateKeys())