Initial import to Tizen
[profile/ivi/python-twisted.git] / twisted / conch / test / test_ssh.py
1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
3
4 """
5 Tests for L{twisted.conch.ssh}.
6 """
7
8 import struct
9
10 try:
11     import Crypto.Cipher.DES3
12 except ImportError:
13     Crypto = None
14
15 try:
16     import pyasn1
17 except ImportError:
18     pyasn1 = None
19
20 from twisted.conch.ssh import common, session, forwarding
21 from twisted.conch import avatar, error
22 from twisted.conch.test.keydata import publicRSA_openssh, privateRSA_openssh
23 from twisted.conch.test.keydata import publicDSA_openssh, privateDSA_openssh
24 from twisted.cred import portal
25 from twisted.cred.error import UnauthorizedLogin
26 from twisted.internet import defer, protocol, reactor
27 from twisted.internet.error import ProcessTerminated
28 from twisted.python import failure, log
29 from twisted.trial import unittest
30
31 from twisted.conch.test.test_recvline import LoopbackRelay
32
33
34
35 class ConchTestRealm(object):
36     """
37     A realm which expects a particular avatarId to log in once and creates a
38     L{ConchTestAvatar} for that request.
39
40     @ivar expectedAvatarID: The only avatarID that this realm will produce an
41         avatar for.
42
43     @ivar avatar: A reference to the avatar after it is requested.
44     """
45     avatar = None
46
47     def __init__(self, expectedAvatarID):
48         self.expectedAvatarID = expectedAvatarID
49
50
51     def requestAvatar(self, avatarID, mind, *interfaces):
52         """
53         Return a new L{ConchTestAvatar} if the avatarID matches the expected one
54         and this is the first avatar request.
55         """
56         if avatarID == self.expectedAvatarID:
57             if self.avatar is not None:
58                 raise UnauthorizedLogin("Only one login allowed")
59             self.avatar = ConchTestAvatar()
60             return interfaces[0], self.avatar, self.avatar.logout
61         raise UnauthorizedLogin(
62             "Only %r may log in, not %r" % (self.expectedAvatarID, avatarID))
63
64
65
66 class ConchTestAvatar(avatar.ConchUser):
67     """
68     An avatar against which various SSH features can be tested.
69
70     @ivar loggedOut: A flag indicating whether the avatar logout method has been
71         called.
72     """
73     loggedOut = False
74
75     def __init__(self):
76         avatar.ConchUser.__init__(self)
77         self.listeners = {}
78         self.globalRequests = {}
79         self.channelLookup.update({'session': session.SSHSession,
80                         'direct-tcpip':forwarding.openConnectForwardingClient})
81         self.subsystemLookup.update({'crazy': CrazySubsystem})
82
83
84     def global_foo(self, data):
85         self.globalRequests['foo'] = data
86         return 1
87
88
89     def global_foo_2(self, data):
90         self.globalRequests['foo_2'] = data
91         return 1, 'data'
92
93
94     def global_tcpip_forward(self, data):
95         host, port = forwarding.unpackGlobal_tcpip_forward(data)
96         try:
97             listener = reactor.listenTCP(
98                 port, forwarding.SSHListenForwardingFactory(
99                     self.conn, (host, port),
100                     forwarding.SSHListenServerForwardingChannel),
101                 interface=host)
102         except:
103             log.err(None, "something went wrong with remote->local forwarding")
104             return 0
105         else:
106             self.listeners[(host, port)] = listener
107             return 1
108
109
110     def global_cancel_tcpip_forward(self, data):
111         host, port = forwarding.unpackGlobal_tcpip_forward(data)
112         listener = self.listeners.get((host, port), None)
113         if not listener:
114             return 0
115         del self.listeners[(host, port)]
116         listener.stopListening()
117         return 1
118
119
120     def logout(self):
121         self.loggedOut = True
122         for listener in self.listeners.values():
123             log.msg('stopListening %s' % listener)
124             listener.stopListening()
125
126
127
128 class ConchSessionForTestAvatar(object):
129     """
130     An ISession adapter for ConchTestAvatar.
131     """
132     def __init__(self, avatar):
133         """
134         Initialize the session and create a reference to it on the avatar for
135         later inspection.
136         """
137         self.avatar = avatar
138         self.avatar._testSession = self
139         self.cmd = None
140         self.proto = None
141         self.ptyReq = False
142         self.eof = 0
143         self.onClose = defer.Deferred()
144
145
146     def getPty(self, term, windowSize, attrs):
147         log.msg('pty req')
148         self._terminalType = term
149         self._windowSize = windowSize
150         self.ptyReq = True
151
152
153     def openShell(self, proto):
154         log.msg('opening shell')
155         self.proto = proto
156         EchoTransport(proto)
157         self.cmd = 'shell'
158
159
160     def execCommand(self, proto, cmd):
161         self.cmd = cmd
162         self.proto = proto
163         f = cmd.split()[0]
164         if f == 'false':
165             t = FalseTransport(proto)
166             # Avoid disconnecting this immediately.  If the channel is closed
167             # before execCommand even returns the caller gets confused.
168             reactor.callLater(0, t.loseConnection)
169         elif f == 'echo':
170             t = EchoTransport(proto)
171             t.write(cmd[5:])
172             t.loseConnection()
173         elif f == 'secho':
174             t = SuperEchoTransport(proto)
175             t.write(cmd[6:])
176             t.loseConnection()
177         elif f == 'eecho':
178             t = ErrEchoTransport(proto)
179             t.write(cmd[6:])
180             t.loseConnection()
181         else:
182             raise error.ConchError('bad exec')
183         self.avatar.conn.transport.expectedLoseConnection = 1
184
185
186     def eofReceived(self):
187         self.eof = 1
188
189
190     def closed(self):
191         log.msg('closed cmd "%s"' % self.cmd)
192         self.remoteWindowLeftAtClose = self.proto.session.remoteWindowLeft
193         self.onClose.callback(None)
194
195 from twisted.python import components
196 components.registerAdapter(ConchSessionForTestAvatar, ConchTestAvatar, session.ISession)
197
198 class CrazySubsystem(protocol.Protocol):
199
200     def __init__(self, *args, **kw):
201         pass
202
203     def connectionMade(self):
204         """
205         good ... good
206         """
207
208
209
210 class FalseTransport:
211     """
212     False transport should act like a /bin/false execution, i.e. just exit with
213     nonzero status, writing nothing to the terminal.
214
215     @ivar proto: The protocol associated with this transport.
216     @ivar closed: A flag tracking whether C{loseConnection} has been called yet.
217     """
218
219     def __init__(self, p):
220         """
221         @type p L{twisted.conch.ssh.session.SSHSessionProcessProtocol} instance
222         """
223         self.proto = p
224         p.makeConnection(self)
225         self.closed = 0
226
227
228     def loseConnection(self):
229         """
230         Disconnect the protocol associated with this transport.
231         """
232         if self.closed:
233             return
234         self.closed = 1
235         self.proto.inConnectionLost()
236         self.proto.outConnectionLost()
237         self.proto.errConnectionLost()
238         self.proto.processEnded(failure.Failure(ProcessTerminated(255, None, None)))
239
240
241
242 class EchoTransport:
243
244     def __init__(self, p):
245         self.proto = p
246         p.makeConnection(self)
247         self.closed = 0
248
249     def write(self, data):
250         log.msg(repr(data))
251         self.proto.outReceived(data)
252         self.proto.outReceived('\r\n')
253         if '\x00' in data: # mimic 'exit' for the shell test
254             self.loseConnection()
255
256     def loseConnection(self):
257         if self.closed: return
258         self.closed = 1
259         self.proto.inConnectionLost()
260         self.proto.outConnectionLost()
261         self.proto.errConnectionLost()
262         self.proto.processEnded(failure.Failure(ProcessTerminated(0, None, None)))
263
264 class ErrEchoTransport:
265
266     def __init__(self, p):
267         self.proto = p
268         p.makeConnection(self)
269         self.closed = 0
270
271     def write(self, data):
272         self.proto.errReceived(data)
273         self.proto.errReceived('\r\n')
274
275     def loseConnection(self):
276         if self.closed: return
277         self.closed = 1
278         self.proto.inConnectionLost()
279         self.proto.outConnectionLost()
280         self.proto.errConnectionLost()
281         self.proto.processEnded(failure.Failure(ProcessTerminated(0, None, None)))
282
283 class SuperEchoTransport:
284
285     def __init__(self, p):
286         self.proto = p
287         p.makeConnection(self)
288         self.closed = 0
289
290     def write(self, data):
291         self.proto.outReceived(data)
292         self.proto.outReceived('\r\n')
293         self.proto.errReceived(data)
294         self.proto.errReceived('\r\n')
295
296     def loseConnection(self):
297         if self.closed: return
298         self.closed = 1
299         self.proto.inConnectionLost()
300         self.proto.outConnectionLost()
301         self.proto.errConnectionLost()
302         self.proto.processEnded(failure.Failure(ProcessTerminated(0, None, None)))
303
304
305 if Crypto is not None and pyasn1 is not None:
306     from twisted.conch import checkers
307     from twisted.conch.ssh import channel, connection, factory, keys
308     from twisted.conch.ssh import transport, userauth
309
310     class UtilityTestCase(unittest.TestCase):
311         def testCounter(self):
312             c = transport._Counter('\x00\x00', 2)
313             for i in xrange(256 * 256):
314                 self.assertEqual(c(), struct.pack('!H', (i + 1) % (2 ** 16)))
315             # It should wrap around, too.
316             for i in xrange(256 * 256):
317                 self.assertEqual(c(), struct.pack('!H', (i + 1) % (2 ** 16)))
318
319
320     class ConchTestPublicKeyChecker(checkers.SSHPublicKeyDatabase):
321         def checkKey(self, credentials):
322             blob = keys.Key.fromString(publicDSA_openssh).blob()
323             if credentials.username == 'testuser' and credentials.blob == blob:
324                 return True
325             return False
326
327
328     class ConchTestPasswordChecker:
329         credentialInterfaces = checkers.IUsernamePassword,
330
331         def requestAvatarId(self, credentials):
332             if credentials.username == 'testuser' and credentials.password == 'testpass':
333                 return defer.succeed(credentials.username)
334             return defer.fail(Exception("Bad credentials"))
335
336
337     class ConchTestSSHChecker(checkers.SSHProtocolChecker):
338
339         def areDone(self, avatarId):
340             if avatarId != 'testuser' or len(self.successfulCredentials[avatarId]) < 2:
341                 return False
342             return True
343
344     class ConchTestServerFactory(factory.SSHFactory):
345         noisy = 0
346
347         services = {
348             'ssh-userauth':userauth.SSHUserAuthServer,
349             'ssh-connection':connection.SSHConnection
350         }
351
352         def buildProtocol(self, addr):
353             proto = ConchTestServer()
354             proto.supportedPublicKeys = self.privateKeys.keys()
355             proto.factory = self
356
357             if hasattr(self, 'expectedLoseConnection'):
358                 proto.expectedLoseConnection = self.expectedLoseConnection
359
360             self.proto = proto
361             return proto
362
363         def getPublicKeys(self):
364             return {
365                 'ssh-rsa': keys.Key.fromString(publicRSA_openssh),
366                 'ssh-dss': keys.Key.fromString(publicDSA_openssh)
367             }
368
369         def getPrivateKeys(self):
370             return {
371                 'ssh-rsa': keys.Key.fromString(privateRSA_openssh),
372                 'ssh-dss': keys.Key.fromString(privateDSA_openssh)
373             }
374
375         def getPrimes(self):
376             return {
377                 2048:[(transport.DH_GENERATOR, transport.DH_PRIME)]
378             }
379
380         def getService(self, trans, name):
381             return factory.SSHFactory.getService(self, trans, name)
382
383     class ConchTestBase:
384
385         done = 0
386
387         def connectionLost(self, reason):
388             if self.done:
389                 return
390             if not hasattr(self,'expectedLoseConnection'):
391                 unittest.fail('unexpectedly lost connection %s\n%s' % (self, reason))
392             self.done = 1
393
394         def receiveError(self, reasonCode, desc):
395             self.expectedLoseConnection = 1
396             # Some versions of OpenSSH (for example, OpenSSH_5.3p1) will
397             # send a DISCONNECT_BY_APPLICATION error before closing the
398             # connection.  Other, older versions (for example,
399             # OpenSSH_5.1p1), won't.  So accept this particular error here,
400             # but no others.
401             if reasonCode != transport.DISCONNECT_BY_APPLICATION:
402                 log.err(
403                     Exception(
404                         'got disconnect for %s: reason %s, desc: %s' % (
405                             self, reasonCode, desc)))
406             self.loseConnection()
407
408         def receiveUnimplemented(self, seqID):
409             unittest.fail('got unimplemented: seqid %s'  % seqID)
410             self.expectedLoseConnection = 1
411             self.loseConnection()
412
413     class ConchTestServer(ConchTestBase, transport.SSHServerTransport):
414
415         def connectionLost(self, reason):
416             ConchTestBase.connectionLost(self, reason)
417             transport.SSHServerTransport.connectionLost(self, reason)
418
419
420     class ConchTestClient(ConchTestBase, transport.SSHClientTransport):
421         """
422         @ivar _channelFactory: A callable which accepts an SSH connection and
423             returns a channel which will be attached to a new channel on that
424             connection.
425         """
426         def __init__(self, channelFactory):
427             self._channelFactory = channelFactory
428
429         def connectionLost(self, reason):
430             ConchTestBase.connectionLost(self, reason)
431             transport.SSHClientTransport.connectionLost(self, reason)
432
433         def verifyHostKey(self, key, fp):
434             keyMatch = key == keys.Key.fromString(publicRSA_openssh).blob()
435             fingerprintMatch = (
436                 fp == '3d:13:5f:cb:c9:79:8a:93:06:27:65:bc:3d:0b:8f:af')
437             if keyMatch and fingerprintMatch:
438                 return defer.succeed(1)
439             return defer.fail(Exception("Key or fingerprint mismatch"))
440
441         def connectionSecure(self):
442             self.requestService(ConchTestClientAuth('testuser',
443                 ConchTestClientConnection(self._channelFactory)))
444
445
446     class ConchTestClientAuth(userauth.SSHUserAuthClient):
447
448         hasTriedNone = 0 # have we tried the 'none' auth yet?
449         canSucceedPublicKey = 0 # can we succed with this yet?
450         canSucceedPassword = 0
451
452         def ssh_USERAUTH_SUCCESS(self, packet):
453             if not self.canSucceedPassword and self.canSucceedPublicKey:
454                 unittest.fail('got USERAUTH_SUCESS before password and publickey')
455             userauth.SSHUserAuthClient.ssh_USERAUTH_SUCCESS(self, packet)
456
457         def getPassword(self):
458             self.canSucceedPassword = 1
459             return defer.succeed('testpass')
460
461         def getPrivateKey(self):
462             self.canSucceedPublicKey = 1
463             return defer.succeed(keys.Key.fromString(privateDSA_openssh))
464
465         def getPublicKey(self):
466             return keys.Key.fromString(publicDSA_openssh)
467
468
469     class ConchTestClientConnection(connection.SSHConnection):
470         """
471         @ivar _completed: A L{Deferred} which will be fired when the number of
472             results collected reaches C{totalResults}.
473         """
474         name = 'ssh-connection'
475         results = 0
476         totalResults = 8
477
478         def __init__(self, channelFactory):
479             connection.SSHConnection.__init__(self)
480             self._channelFactory = channelFactory
481
482         def serviceStarted(self):
483             self.openChannel(self._channelFactory(conn=self))
484
485
486     class SSHTestChannel(channel.SSHChannel):
487
488         def __init__(self, name, opened, *args, **kwargs):
489             self.name = name
490             self._opened = opened
491             self.received = []
492             self.receivedExt = []
493             self.onClose = defer.Deferred()
494             channel.SSHChannel.__init__(self, *args, **kwargs)
495
496
497         def openFailed(self, reason):
498             self._opened.errback(reason)
499
500
501         def channelOpen(self, ignore):
502             self._opened.callback(self)
503
504
505         def dataReceived(self, data):
506             self.received.append(data)
507
508
509         def extReceived(self, dataType, data):
510             if dataType == connection.EXTENDED_DATA_STDERR:
511                 self.receivedExt.append(data)
512             else:
513                 log.msg("Unrecognized extended data: %r" % (dataType,))
514
515
516         def request_exit_status(self, status):
517             [self.status] = struct.unpack('>L', status)
518
519
520         def eofReceived(self):
521             self.eofCalled = True
522
523
524         def closed(self):
525             self.onClose.callback(None)
526
527
528
529 class SSHProtocolTestCase(unittest.TestCase):
530     """
531     Tests for communication between L{SSHServerTransport} and
532     L{SSHClientTransport}.
533     """
534
535     if not Crypto:
536         skip = "can't run w/o PyCrypto"
537
538     if not pyasn1:
539         skip = "Cannot run without PyASN1"
540
541     def _ourServerOurClientTest(self, name='session', **kwargs):
542         """
543         Create a connected SSH client and server protocol pair and return a
544         L{Deferred} which fires with an L{SSHTestChannel} instance connected to
545         a channel on that SSH connection.
546         """
547         result = defer.Deferred()
548         self.realm = ConchTestRealm('testuser')
549         p = portal.Portal(self.realm)
550         sshpc = ConchTestSSHChecker()
551         sshpc.registerChecker(ConchTestPasswordChecker())
552         sshpc.registerChecker(ConchTestPublicKeyChecker())
553         p.registerChecker(sshpc)
554         fac = ConchTestServerFactory()
555         fac.portal = p
556         fac.startFactory()
557         self.server = fac.buildProtocol(None)
558         self.clientTransport = LoopbackRelay(self.server)
559         self.client = ConchTestClient(
560             lambda conn: SSHTestChannel(name, result, conn=conn, **kwargs))
561
562         self.serverTransport = LoopbackRelay(self.client)
563
564         self.server.makeConnection(self.serverTransport)
565         self.client.makeConnection(self.clientTransport)
566         return result
567
568
569     def test_subsystemsAndGlobalRequests(self):
570         """
571         Run the Conch server against the Conch client.  Set up several different
572         channels which exercise different behaviors and wait for them to
573         complete.  Verify that the channels with errors log them.
574         """
575         channel = self._ourServerOurClientTest()
576
577         def cbSubsystem(channel):
578             self.channel = channel
579             return self.assertFailure(
580                 channel.conn.sendRequest(
581                     channel, 'subsystem', common.NS('not-crazy'), 1),
582                 Exception)
583         channel.addCallback(cbSubsystem)
584
585         def cbNotCrazyFailed(ignored):
586             channel = self.channel
587             return channel.conn.sendRequest(
588                 channel, 'subsystem', common.NS('crazy'), 1)
589         channel.addCallback(cbNotCrazyFailed)
590
591         def cbGlobalRequests(ignored):
592             channel = self.channel
593             d1 = channel.conn.sendGlobalRequest('foo', 'bar', 1)
594
595             d2 = channel.conn.sendGlobalRequest('foo-2', 'bar2', 1)
596             d2.addCallback(self.assertEqual, 'data')
597
598             d3 = self.assertFailure(
599                 channel.conn.sendGlobalRequest('bar', 'foo', 1),
600                 Exception)
601
602             return defer.gatherResults([d1, d2, d3])
603         channel.addCallback(cbGlobalRequests)
604
605         def disconnect(ignored):
606             self.assertEqual(
607                 self.realm.avatar.globalRequests,
608                 {"foo": "bar", "foo_2": "bar2"})
609             channel = self.channel
610             channel.conn.transport.expectedLoseConnection = True
611             channel.conn.serviceStopped()
612             channel.loseConnection()
613         channel.addCallback(disconnect)
614
615         return channel
616
617
618     def test_shell(self):
619         """
620         L{SSHChannel.sendRequest} can open a shell with a I{pty-req} request,
621         specifying a terminal type and window size.
622         """
623         channel = self._ourServerOurClientTest()
624
625         data = session.packRequest_pty_req('conch-test-term', (24, 80, 0, 0), '')
626         def cbChannel(channel):
627             self.channel = channel
628             return channel.conn.sendRequest(channel, 'pty-req', data, 1)
629         channel.addCallback(cbChannel)
630
631         def cbPty(ignored):
632             # The server-side object corresponding to our client side channel.
633             session = self.realm.avatar.conn.channels[0].session
634             self.assertIdentical(session.avatar, self.realm.avatar)
635             self.assertEqual(session._terminalType, 'conch-test-term')
636             self.assertEqual(session._windowSize, (24, 80, 0, 0))
637             self.assertTrue(session.ptyReq)
638             channel = self.channel
639             return channel.conn.sendRequest(channel, 'shell', '', 1)
640         channel.addCallback(cbPty)
641
642         def cbShell(ignored):
643             self.channel.write('testing the shell!\x00')
644             self.channel.conn.sendEOF(self.channel)
645             return defer.gatherResults([
646                     self.channel.onClose,
647                     self.realm.avatar._testSession.onClose])
648         channel.addCallback(cbShell)
649
650         def cbExited(ignored):
651             if self.channel.status != 0:
652                 log.msg(
653                     'shell exit status was not 0: %i' % (self.channel.status,))
654             self.assertEqual(
655                 "".join(self.channel.received),
656                 'testing the shell!\x00\r\n')
657             self.assertTrue(self.channel.eofCalled)
658             self.assertTrue(
659                 self.realm.avatar._testSession.eof)
660         channel.addCallback(cbExited)
661         return channel
662
663
664     def test_failedExec(self):
665         """
666         If L{SSHChannel.sendRequest} issues an exec which the server responds to
667         with an error, the L{Deferred} it returns fires its errback.
668         """
669         channel = self._ourServerOurClientTest()
670
671         def cbChannel(channel):
672             self.channel = channel
673             return self.assertFailure(
674                 channel.conn.sendRequest(
675                     channel, 'exec', common.NS('jumboliah'), 1),
676                 Exception)
677         channel.addCallback(cbChannel)
678
679         def cbFailed(ignored):
680             # The server logs this exception when it cannot perform the
681             # requested exec.
682             errors = self.flushLoggedErrors(error.ConchError)
683             self.assertEqual(errors[0].value.args, ('bad exec', None))
684         channel.addCallback(cbFailed)
685         return channel
686
687
688     def test_falseChannel(self):
689         """
690         When the process started by a L{SSHChannel.sendRequest} exec request
691         exits, the exit status is reported to the channel.
692         """
693         channel = self._ourServerOurClientTest()
694
695         def cbChannel(channel):
696             self.channel = channel
697             return channel.conn.sendRequest(
698                 channel, 'exec', common.NS('false'), 1)
699         channel.addCallback(cbChannel)
700
701         def cbExec(ignored):
702             return self.channel.onClose
703         channel.addCallback(cbExec)
704
705         def cbClosed(ignored):
706             # No data is expected
707             self.assertEqual(self.channel.received, [])
708             self.assertNotEquals(self.channel.status, 0)
709         channel.addCallback(cbClosed)
710         return channel
711
712
713     def test_errorChannel(self):
714         """
715         Bytes sent over the extended channel for stderr data are delivered to
716         the channel's C{extReceived} method.
717         """
718         channel = self._ourServerOurClientTest(localWindow=4, localMaxPacket=5)
719
720         def cbChannel(channel):
721             self.channel = channel
722             return channel.conn.sendRequest(
723                 channel, 'exec', common.NS('eecho hello'), 1)
724         channel.addCallback(cbChannel)
725
726         def cbExec(ignored):
727             return defer.gatherResults([
728                     self.channel.onClose,
729                     self.realm.avatar._testSession.onClose])
730         channel.addCallback(cbExec)
731
732         def cbClosed(ignored):
733             self.assertEqual(self.channel.received, [])
734             self.assertEqual("".join(self.channel.receivedExt), "hello\r\n")
735             self.assertEqual(self.channel.status, 0)
736             self.assertTrue(self.channel.eofCalled)
737             self.assertEqual(self.channel.localWindowLeft, 4)
738             self.assertEqual(
739                 self.channel.localWindowLeft,
740                 self.realm.avatar._testSession.remoteWindowLeftAtClose)
741         channel.addCallback(cbClosed)
742         return channel
743
744
745     def test_unknownChannel(self):
746         """
747         When an attempt is made to open an unknown channel type, the L{Deferred}
748         returned by L{SSHChannel.sendRequest} fires its errback.
749         """
750         d = self.assertFailure(
751             self._ourServerOurClientTest('crazy-unknown-channel'), Exception)
752         def cbFailed(ignored):
753             errors = self.flushLoggedErrors(error.ConchError)
754             self.assertEqual(errors[0].value.args, (3, 'unknown channel'))
755             self.assertEqual(len(errors), 1)
756         d.addCallback(cbFailed)
757         return d
758
759
760     def test_maxPacket(self):
761         """
762         An L{SSHChannel} can be configured with a maximum packet size to
763         receive.
764         """
765         # localWindow needs to be at least 11 otherwise the assertion about it
766         # in cbClosed is invalid.
767         channel = self._ourServerOurClientTest(
768             localWindow=11, localMaxPacket=1)
769
770         def cbChannel(channel):
771             self.channel = channel
772             return channel.conn.sendRequest(
773                 channel, 'exec', common.NS('secho hello'), 1)
774         channel.addCallback(cbChannel)
775
776         def cbExec(ignored):
777             return self.channel.onClose
778         channel.addCallback(cbExec)
779
780         def cbClosed(ignored):
781             self.assertEqual(self.channel.status, 0)
782             self.assertEqual("".join(self.channel.received), "hello\r\n")
783             self.assertEqual("".join(self.channel.receivedExt), "hello\r\n")
784             self.assertEqual(self.channel.localWindowLeft, 11)
785             self.assertTrue(self.channel.eofCalled)
786         channel.addCallback(cbClosed)
787         return channel
788
789
790     def test_echo(self):
791         """
792         Normal standard out bytes are sent to the channel's C{dataReceived}
793         method.
794         """
795         channel = self._ourServerOurClientTest(localWindow=4, localMaxPacket=5)
796
797         def cbChannel(channel):
798             self.channel = channel
799             return channel.conn.sendRequest(
800                 channel, 'exec', common.NS('echo hello'), 1)
801         channel.addCallback(cbChannel)
802
803         def cbEcho(ignored):
804             return defer.gatherResults([
805                     self.channel.onClose,
806                     self.realm.avatar._testSession.onClose])
807         channel.addCallback(cbEcho)
808
809         def cbClosed(ignored):
810             self.assertEqual(self.channel.status, 0)
811             self.assertEqual("".join(self.channel.received), "hello\r\n")
812             self.assertEqual(self.channel.localWindowLeft, 4)
813             self.assertTrue(self.channel.eofCalled)
814             self.assertEqual(
815                 self.channel.localWindowLeft,
816                 self.realm.avatar._testSession.remoteWindowLeftAtClose)
817         channel.addCallback(cbClosed)
818         return channel
819
820
821
822 class TestSSHFactory(unittest.TestCase):
823
824     if not Crypto:
825         skip = "can't run w/o PyCrypto"
826
827     if not pyasn1:
828         skip = "Cannot run without PyASN1"
829
830     def makeSSHFactory(self, primes=None):
831         sshFactory = factory.SSHFactory()
832         gpk = lambda: {'ssh-rsa' : keys.Key(None)}
833         sshFactory.getPrimes = lambda: primes
834         sshFactory.getPublicKeys = sshFactory.getPrivateKeys = gpk
835         sshFactory.startFactory()
836         return sshFactory
837
838
839     def test_buildProtocol(self):
840         """
841         By default, buildProtocol() constructs an instance of
842         SSHServerTransport.
843         """
844         factory = self.makeSSHFactory()
845         protocol = factory.buildProtocol(None)
846         self.assertIsInstance(protocol, transport.SSHServerTransport)
847
848
849     def test_buildProtocolRespectsProtocol(self):
850         """
851         buildProtocol() calls 'self.protocol()' to construct a protocol
852         instance.
853         """
854         calls = []
855         def makeProtocol(*args):
856             calls.append(args)
857             return transport.SSHServerTransport()
858         factory = self.makeSSHFactory()
859         factory.protocol = makeProtocol
860         factory.buildProtocol(None)
861         self.assertEqual([()], calls)
862
863
864     def test_multipleFactories(self):
865         f1 = self.makeSSHFactory(primes=None)
866         f2 = self.makeSSHFactory(primes={1:(2,3)})
867         p1 = f1.buildProtocol(None)
868         p2 = f2.buildProtocol(None)
869         self.failIf('diffie-hellman-group-exchange-sha1' in p1.supportedKeyExchanges,
870                 p1.supportedKeyExchanges)
871         self.failUnless('diffie-hellman-group-exchange-sha1' in p2.supportedKeyExchanges,
872                 p2.supportedKeyExchanges)
873
874
875
876 class MPTestCase(unittest.TestCase):
877     """
878     Tests for L{common.getMP}.
879
880     @cvar getMP: a method providing a MP parser.
881     @type getMP: C{callable}
882     """
883     getMP = staticmethod(common.getMP)
884
885     if not Crypto:
886         skip = "can't run w/o PyCrypto"
887
888     if not pyasn1:
889         skip = "Cannot run without PyASN1"
890
891
892     def test_getMP(self):
893         """
894         L{common.getMP} should parse the a multiple precision integer from a
895         string: a 4-byte length followed by length bytes of the integer.
896         """
897         self.assertEqual(
898             self.getMP('\x00\x00\x00\x04\x00\x00\x00\x01'),
899             (1, ''))
900
901
902     def test_getMPBigInteger(self):
903         """
904         L{common.getMP} should be able to parse a big enough integer
905         (that doesn't fit on one byte).
906         """
907         self.assertEqual(
908             self.getMP('\x00\x00\x00\x04\x01\x02\x03\x04'),
909             (16909060, ''))
910
911
912     def test_multipleGetMP(self):
913         """
914         L{common.getMP} has the ability to parse multiple integer in the same
915         string.
916         """
917         self.assertEqual(
918             self.getMP('\x00\x00\x00\x04\x00\x00\x00\x01'
919                        '\x00\x00\x00\x04\x00\x00\x00\x02', 2),
920             (1, 2, ''))
921
922
923     def test_getMPRemainingData(self):
924         """
925         When more data than needed is sent to L{common.getMP}, it should return
926         the remaining data.
927         """
928         self.assertEqual(
929             self.getMP('\x00\x00\x00\x04\x00\x00\x00\x01foo'),
930             (1, 'foo'))
931
932
933     def test_notEnoughData(self):
934         """
935         When the string passed to L{common.getMP} doesn't even make 5 bytes,
936         it should raise a L{struct.error}.
937         """
938         self.assertRaises(struct.error, self.getMP, '\x02\x00')
939
940
941
942 class PyMPTestCase(MPTestCase):
943     """
944     Tests for the python implementation of L{common.getMP}.
945     """
946     getMP = staticmethod(common.getMP_py)
947
948
949
950 class GMPYMPTestCase(MPTestCase):
951     """
952     Tests for the gmpy implementation of L{common.getMP}.
953     """
954     getMP = staticmethod(common._fastgetMP)
955
956
957 class BuiltinPowHackTestCase(unittest.TestCase):
958     """
959     Tests that the builtin pow method is still correct after
960     L{twisted.conch.ssh.common} monkeypatches it to use gmpy.
961     """
962
963     def test_floatBase(self):
964         """
965         pow gives the correct result when passed a base of type float with a
966         non-integer value.
967         """
968         self.assertEqual(6.25, pow(2.5, 2))
969
970     def test_intBase(self):
971         """
972         pow gives the correct result when passed a base of type int.
973         """
974         self.assertEqual(81, pow(3, 4))
975
976     def test_longBase(self):
977         """
978         pow gives the correct result when passed a base of type long.
979         """
980         self.assertEqual(81, pow(3, 4))
981
982     def test_mpzBase(self):
983         """
984         pow gives the correct result when passed a base of type gmpy.mpz.
985         """
986         if gmpy is None:
987             raise unittest.SkipTest('gmpy not available')
988         self.assertEqual(81, pow(gmpy.mpz(3), 4))
989
990
991 try:
992     import gmpy
993 except ImportError:
994     GMPYMPTestCase.skip = "gmpy not available"
995     gmpy = None