Imported Upstream version 12.1.0
[contrib/python-twisted.git] / twisted / conch / ssh / connection.py
1 # -*- test-case-name: twisted.conch.test.test_connection -*-
2 # Copyright (c) Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5 """
6 This module contains the implementation of the ssh-connection service, which
7 allows access to the shell and port-forwarding.
8
9 Maintainer: Paul Swartz
10 """
11
12 import struct
13
14 from twisted.conch.ssh import service, common
15 from twisted.conch import error
16 from twisted.internet import defer
17 from twisted.python import log
18
19 class SSHConnection(service.SSHService):
20     """
21     An implementation of the 'ssh-connection' service.  It is used to
22     multiplex multiple channels over the single SSH connection.
23
24     @ivar localChannelID: the next number to use as a local channel ID.
25     @type localChannelID: C{int}
26     @ivar channels: a C{dict} mapping a local channel ID to C{SSHChannel}
27         subclasses.
28     @type channels: C{dict}
29     @ivar localToRemoteChannel: a C{dict} mapping a local channel ID to a
30         remote channel ID.
31     @type localToRemoteChannel: C{dict}
32     @ivar channelsToRemoteChannel: a C{dict} mapping a C{SSHChannel} subclass
33         to remote channel ID.
34     @type channelsToRemoteChannel: C{dict}
35     @ivar deferreds: a C{dict} mapping a local channel ID to a C{list} of
36         C{Deferreds} for outstanding channel requests.  Also, the 'global'
37         key stores the C{list} of pending global request C{Deferred}s.
38     """
39     name = 'ssh-connection'
40
41     def __init__(self):
42         self.localChannelID = 0 # this is the current # to use for channel ID
43         self.localToRemoteChannel = {} # local channel ID -> remote channel ID
44         self.channels = {} # local channel ID -> subclass of SSHChannel
45         self.channelsToRemoteChannel = {} # subclass of SSHChannel ->
46                                           # remote channel ID
47         self.deferreds = {"global": []} # local channel -> list of deferreds 
48                             # for pending requests or 'global' -> list of 
49                             # deferreds for global requests
50         self.transport = None # gets set later
51
52
53     def serviceStarted(self):
54         if hasattr(self.transport, 'avatar'):
55             self.transport.avatar.conn = self
56
57
58     def serviceStopped(self):
59         """
60         Called when the connection is stopped.
61         """
62         map(self.channelClosed, self.channels.values())
63         self._cleanupGlobalDeferreds()
64
65
66     def _cleanupGlobalDeferreds(self):
67         """
68         All pending requests that have returned a deferred must be errbacked
69         when this service is stopped, otherwise they might be left uncalled and
70         uncallable.
71         """
72         for d in self.deferreds["global"]:
73             d.errback(error.ConchError("Connection stopped."))
74         del self.deferreds["global"][:]
75
76
77     # packet methods
78     def ssh_GLOBAL_REQUEST(self, packet):
79         """
80         The other side has made a global request.  Payload::
81             string  request type
82             bool    want reply
83             <request specific data>
84
85         This dispatches to self.gotGlobalRequest.
86         """
87         requestType, rest = common.getNS(packet)
88         wantReply, rest = ord(rest[0]), rest[1:]
89         ret = self.gotGlobalRequest(requestType, rest)
90         if wantReply:
91             reply = MSG_REQUEST_FAILURE
92             data = ''
93             if ret:
94                 reply = MSG_REQUEST_SUCCESS
95                 if isinstance(ret, (tuple, list)):
96                     data = ret[1]
97             self.transport.sendPacket(reply, data)
98
99     def ssh_REQUEST_SUCCESS(self, packet):
100         """
101         Our global request succeeded.  Get the appropriate Deferred and call
102         it back with the packet we received.
103         """
104         log.msg('RS')
105         self.deferreds['global'].pop(0).callback(packet)
106
107     def ssh_REQUEST_FAILURE(self, packet):
108         """
109         Our global request failed.  Get the appropriate Deferred and errback
110         it with the packet we received.
111         """
112         log.msg('RF')
113         self.deferreds['global'].pop(0).errback(
114             error.ConchError('global request failed', packet))
115
116     def ssh_CHANNEL_OPEN(self, packet):
117         """
118         The other side wants to get a channel.  Payload::
119             string  channel name
120             uint32  remote channel number
121             uint32  remote window size
122             uint32  remote maximum packet size
123             <channel specific data>
124
125         We get a channel from self.getChannel(), give it a local channel number
126         and notify the other side.  Then notify the channel by calling its
127         channelOpen method.
128         """
129         channelType, rest = common.getNS(packet)
130         senderChannel, windowSize, maxPacket = struct.unpack('>3L', rest[:12])
131         packet = rest[12:]
132         try:
133             channel = self.getChannel(channelType, windowSize, maxPacket,
134                             packet)
135             localChannel = self.localChannelID
136             self.localChannelID += 1
137             channel.id = localChannel
138             self.channels[localChannel] = channel
139             self.channelsToRemoteChannel[channel] = senderChannel
140             self.localToRemoteChannel[localChannel] = senderChannel
141             self.transport.sendPacket(MSG_CHANNEL_OPEN_CONFIRMATION,
142                 struct.pack('>4L', senderChannel, localChannel,
143                     channel.localWindowSize,
144                     channel.localMaxPacket)+channel.specificData)
145             log.callWithLogger(channel, channel.channelOpen, packet)
146         except Exception, e:
147             log.msg('channel open failed')
148             log.err(e)
149             if isinstance(e, error.ConchError):
150                 textualInfo, reason = e.args
151                 if isinstance(textualInfo, (int, long)):
152                     # See #3657 and #3071
153                     textualInfo, reason = reason, textualInfo
154             else:
155                 reason = OPEN_CONNECT_FAILED
156                 textualInfo = "unknown failure"
157             self.transport.sendPacket(
158                 MSG_CHANNEL_OPEN_FAILURE,
159                 struct.pack('>2L', senderChannel, reason) +
160                 common.NS(textualInfo) + common.NS(''))
161
162     def ssh_CHANNEL_OPEN_CONFIRMATION(self, packet):
163         """
164         The other side accepted our MSG_CHANNEL_OPEN request.  Payload::
165             uint32  local channel number
166             uint32  remote channel number
167             uint32  remote window size
168             uint32  remote maximum packet size
169             <channel specific data>
170
171         Find the channel using the local channel number and notify its
172         channelOpen method.
173         """
174         (localChannel, remoteChannel, windowSize,
175                 maxPacket) = struct.unpack('>4L', packet[: 16])
176         specificData = packet[16:]
177         channel = self.channels[localChannel]
178         channel.conn = self
179         self.localToRemoteChannel[localChannel] = remoteChannel
180         self.channelsToRemoteChannel[channel] = remoteChannel
181         channel.remoteWindowLeft = windowSize
182         channel.remoteMaxPacket = maxPacket
183         log.callWithLogger(channel, channel.channelOpen, specificData)
184
185     def ssh_CHANNEL_OPEN_FAILURE(self, packet):
186         """
187         The other side did not accept our MSG_CHANNEL_OPEN request.  Payload::
188             uint32  local channel number
189             uint32  reason code
190             string  reason description
191
192         Find the channel using the local channel number and notify it by
193         calling its openFailed() method.
194         """
195         localChannel, reasonCode = struct.unpack('>2L', packet[:8])
196         reasonDesc = common.getNS(packet[8:])[0]
197         channel = self.channels[localChannel]
198         del self.channels[localChannel]
199         channel.conn = self
200         reason = error.ConchError(reasonDesc, reasonCode)
201         log.callWithLogger(channel, channel.openFailed, reason)
202
203     def ssh_CHANNEL_WINDOW_ADJUST(self, packet):
204         """
205         The other side is adding bytes to its window.  Payload::
206             uint32  local channel number
207             uint32  bytes to add
208
209         Call the channel's addWindowBytes() method to add new bytes to the
210         remote window.
211         """
212         localChannel, bytesToAdd = struct.unpack('>2L', packet[:8])
213         channel = self.channels[localChannel]
214         log.callWithLogger(channel, channel.addWindowBytes, bytesToAdd)
215
216     def ssh_CHANNEL_DATA(self, packet):
217         """
218         The other side is sending us data.  Payload::
219             uint32 local channel number
220             string data
221
222         Check to make sure the other side hasn't sent too much data (more
223         than what's in the window, or more than the maximum packet size).  If
224         they have, close the channel.  Otherwise, decrease the available
225         window and pass the data to the channel's dataReceived().
226         """
227         localChannel, dataLength = struct.unpack('>2L', packet[:8])
228         channel = self.channels[localChannel]
229         # XXX should this move to dataReceived to put client in charge?
230         if (dataLength > channel.localWindowLeft or
231            dataLength > channel.localMaxPacket): # more data than we want
232             log.callWithLogger(channel, log.msg, 'too much data')
233             self.sendClose(channel)
234             return
235             #packet = packet[:channel.localWindowLeft+4]
236         data = common.getNS(packet[4:])[0]
237         channel.localWindowLeft -= dataLength
238         if channel.localWindowLeft < channel.localWindowSize / 2:
239             self.adjustWindow(channel, channel.localWindowSize - \
240                                        channel.localWindowLeft)
241             #log.msg('local window left: %s/%s' % (channel.localWindowLeft,
242             #                                    channel.localWindowSize))
243         log.callWithLogger(channel, channel.dataReceived, data)
244
245     def ssh_CHANNEL_EXTENDED_DATA(self, packet):
246         """
247         The other side is sending us exteneded data.  Payload::
248             uint32  local channel number
249             uint32  type code
250             string  data
251
252         Check to make sure the other side hasn't sent too much data (more
253         than what's in the window, or or than the maximum packet size).  If
254         they have, close the channel.  Otherwise, decrease the available
255         window and pass the data and type code to the channel's
256         extReceived().
257         """
258         localChannel, typeCode, dataLength = struct.unpack('>3L', packet[:12])
259         channel = self.channels[localChannel]
260         if (dataLength > channel.localWindowLeft or
261                 dataLength > channel.localMaxPacket):
262             log.callWithLogger(channel, log.msg, 'too much extdata')
263             self.sendClose(channel)
264             return
265         data = common.getNS(packet[8:])[0]
266         channel.localWindowLeft -= dataLength
267         if channel.localWindowLeft < channel.localWindowSize / 2:
268             self.adjustWindow(channel, channel.localWindowSize -
269                                        channel.localWindowLeft)
270         log.callWithLogger(channel, channel.extReceived, typeCode, data)
271
272     def ssh_CHANNEL_EOF(self, packet):
273         """
274         The other side is not sending any more data.  Payload::
275             uint32  local channel number
276
277         Notify the channel by calling its eofReceived() method.
278         """
279         localChannel = struct.unpack('>L', packet[:4])[0]
280         channel = self.channels[localChannel]
281         log.callWithLogger(channel, channel.eofReceived)
282
283     def ssh_CHANNEL_CLOSE(self, packet):
284         """
285         The other side is closing its end; it does not want to receive any
286         more data.  Payload::
287             uint32  local channel number
288
289         Notify the channnel by calling its closeReceived() method.  If
290         the channel has also sent a close message, call self.channelClosed().
291         """
292         localChannel = struct.unpack('>L', packet[:4])[0]
293         channel = self.channels[localChannel]
294         log.callWithLogger(channel, channel.closeReceived)
295         channel.remoteClosed = True
296         if channel.localClosed and channel.remoteClosed:
297             self.channelClosed(channel)
298
299     def ssh_CHANNEL_REQUEST(self, packet):
300         """
301         The other side is sending a request to a channel.  Payload::
302             uint32  local channel number
303             string  request name
304             bool    want reply
305             <request specific data>
306
307         Pass the message to the channel's requestReceived method.  If the
308         other side wants a reply, add callbacks which will send the
309         reply.
310         """
311         localChannel = struct.unpack('>L', packet[: 4])[0]
312         requestType, rest = common.getNS(packet[4:])
313         wantReply = ord(rest[0])
314         channel = self.channels[localChannel]
315         d = defer.maybeDeferred(log.callWithLogger, channel,
316                 channel.requestReceived, requestType, rest[1:])
317         if wantReply:
318             d.addCallback(self._cbChannelRequest, localChannel)
319             d.addErrback(self._ebChannelRequest, localChannel)
320             return d
321
322     def _cbChannelRequest(self, result, localChannel):
323         """
324         Called back if the other side wanted a reply to a channel request.  If
325         the result is true, send a MSG_CHANNEL_SUCCESS.  Otherwise, raise
326         a C{error.ConchError}
327
328         @param result: the value returned from the channel's requestReceived()
329             method.  If it's False, the request failed.
330         @type result: C{bool}
331         @param localChannel: the local channel ID of the channel to which the
332             request was made.
333         @type localChannel: C{int}
334         @raises ConchError: if the result is False.
335         """
336         if not result:
337             raise error.ConchError('failed request')
338         self.transport.sendPacket(MSG_CHANNEL_SUCCESS, struct.pack('>L',
339                                 self.localToRemoteChannel[localChannel]))
340
341     def _ebChannelRequest(self, result, localChannel):
342         """
343         Called if the other wisde wanted a reply to the channel requeset and
344         the channel request failed.
345
346         @param result: a Failure, but it's not used.
347         @param localChannel: the local channel ID of the channel to which the
348             request was made.
349         @type localChannel: C{int}
350         """
351         self.transport.sendPacket(MSG_CHANNEL_FAILURE, struct.pack('>L',
352                                 self.localToRemoteChannel[localChannel]))
353
354     def ssh_CHANNEL_SUCCESS(self, packet):
355         """
356         Our channel request to the other other side succeeded.  Payload::
357             uint32  local channel number
358
359         Get the C{Deferred} out of self.deferreds and call it back.
360         """
361         localChannel = struct.unpack('>L', packet[:4])[0]
362         if self.deferreds.get(localChannel):
363             d = self.deferreds[localChannel].pop(0)
364             log.callWithLogger(self.channels[localChannel],
365                                d.callback, '')
366
367     def ssh_CHANNEL_FAILURE(self, packet):
368         """
369         Our channel request to the other side failed.  Payload::
370             uint32  local channel number
371
372         Get the C{Deferred} out of self.deferreds and errback it with a
373         C{error.ConchError}.
374         """
375         localChannel = struct.unpack('>L', packet[:4])[0]
376         if self.deferreds.get(localChannel):
377             d = self.deferreds[localChannel].pop(0)
378             log.callWithLogger(self.channels[localChannel],
379                                d.errback,
380                                error.ConchError('channel request failed'))
381
382     # methods for users of the connection to call
383
384     def sendGlobalRequest(self, request, data, wantReply=0):
385         """
386         Send a global request for this connection.  Current this is only used
387         for remote->local TCP forwarding.
388
389         @type request:      C{str}
390         @type data:         C{str}
391         @type wantReply:    C{bool}
392         @rtype              C{Deferred}/C{None}
393         """
394         self.transport.sendPacket(MSG_GLOBAL_REQUEST,
395                                   common.NS(request)
396                                   + (wantReply and '\xff' or '\x00')
397                                   + data)
398         if wantReply:
399             d = defer.Deferred()
400             self.deferreds['global'].append(d)
401             return d
402
403     def openChannel(self, channel, extra=''):
404         """
405         Open a new channel on this connection.
406
407         @type channel:  subclass of C{SSHChannel}
408         @type extra:    C{str}
409         """
410         log.msg('opening channel %s with %s %s'%(self.localChannelID,
411                 channel.localWindowSize, channel.localMaxPacket))
412         self.transport.sendPacket(MSG_CHANNEL_OPEN, common.NS(channel.name)
413                     + struct.pack('>3L', self.localChannelID,
414                     channel.localWindowSize, channel.localMaxPacket)
415                     + extra)
416         channel.id = self.localChannelID
417         self.channels[self.localChannelID] = channel
418         self.localChannelID += 1
419
420     def sendRequest(self, channel, requestType, data, wantReply=0):
421         """
422         Send a request to a channel.
423
424         @type channel:      subclass of C{SSHChannel}
425         @type requestType:  C{str}
426         @type data:         C{str}
427         @type wantReply:    C{bool}
428         @rtype              C{Deferred}/C{None}
429         """
430         if channel.localClosed:
431             return
432         log.msg('sending request %s' % requestType)
433         self.transport.sendPacket(MSG_CHANNEL_REQUEST, struct.pack('>L',
434                                     self.channelsToRemoteChannel[channel])
435                                   + common.NS(requestType)+chr(wantReply)
436                                   + data)
437         if wantReply:
438             d = defer.Deferred()
439             self.deferreds.setdefault(channel.id, []).append(d)
440             return d
441
442     def adjustWindow(self, channel, bytesToAdd):
443         """
444         Tell the other side that we will receive more data.  This should not
445         normally need to be called as it is managed automatically.
446
447         @type channel:      subclass of L{SSHChannel}
448         @type bytesToAdd:   C{int}
449         """
450         if channel.localClosed:
451             return # we're already closed
452         self.transport.sendPacket(MSG_CHANNEL_WINDOW_ADJUST, struct.pack('>2L',
453                                     self.channelsToRemoteChannel[channel],
454                                     bytesToAdd))
455         log.msg('adding %i to %i in channel %i' % (bytesToAdd,
456             channel.localWindowLeft, channel.id))
457         channel.localWindowLeft += bytesToAdd
458
459     def sendData(self, channel, data):
460         """
461         Send data to a channel.  This should not normally be used: instead use
462         channel.write(data) as it manages the window automatically.
463
464         @type channel:  subclass of L{SSHChannel}
465         @type data:     C{str}
466         """
467         if channel.localClosed:
468             return # we're already closed
469         self.transport.sendPacket(MSG_CHANNEL_DATA, struct.pack('>L',
470                                     self.channelsToRemoteChannel[channel]) +
471                                    common.NS(data))
472
473     def sendExtendedData(self, channel, dataType, data):
474         """
475         Send extended data to a channel.  This should not normally be used:
476         instead use channel.writeExtendedData(data, dataType) as it manages
477         the window automatically.
478
479         @type channel:  subclass of L{SSHChannel}
480         @type dataType: C{int}
481         @type data:     C{str}
482         """
483         if channel.localClosed:
484             return # we're already closed
485         self.transport.sendPacket(MSG_CHANNEL_EXTENDED_DATA, struct.pack('>2L',
486                             self.channelsToRemoteChannel[channel],dataType) \
487                             + common.NS(data))
488
489     def sendEOF(self, channel):
490         """
491         Send an EOF (End of File) for a channel.
492
493         @type channel:  subclass of L{SSHChannel}
494         """
495         if channel.localClosed:
496             return # we're already closed
497         log.msg('sending eof')
498         self.transport.sendPacket(MSG_CHANNEL_EOF, struct.pack('>L',
499                                     self.channelsToRemoteChannel[channel]))
500
501     def sendClose(self, channel):
502         """
503         Close a channel.
504
505         @type channel:  subclass of L{SSHChannel}
506         """
507         if channel.localClosed:
508             return # we're already closed
509         log.msg('sending close %i' % channel.id)
510         self.transport.sendPacket(MSG_CHANNEL_CLOSE, struct.pack('>L',
511                 self.channelsToRemoteChannel[channel]))
512         channel.localClosed = True
513         if channel.localClosed and channel.remoteClosed:
514             self.channelClosed(channel)
515
516     # methods to override
517     def getChannel(self, channelType, windowSize, maxPacket, data):
518         """
519         The other side requested a channel of some sort.
520         channelType is the type of channel being requested,
521         windowSize is the initial size of the remote window,
522         maxPacket is the largest packet we should send,
523         data is any other packet data (often nothing).
524
525         We return a subclass of L{SSHChannel}.
526
527         By default, this dispatches to a method 'channel_channelType' with any
528         non-alphanumerics in the channelType replace with _'s.  If it cannot
529         find a suitable method, it returns an OPEN_UNKNOWN_CHANNEL_TYPE error.
530         The method is called with arguments of windowSize, maxPacket, data.
531
532         @type channelType:  C{str}
533         @type windowSize:   C{int}
534         @type maxPacket:    C{int}
535         @type data:         C{str}
536         @rtype:             subclass of L{SSHChannel}/C{tuple}
537         """
538         log.msg('got channel %s request' % channelType)
539         if hasattr(self.transport, "avatar"): # this is a server!
540             chan = self.transport.avatar.lookupChannel(channelType,
541                                                        windowSize,
542                                                        maxPacket,
543                                                        data)
544         else:
545             channelType = channelType.translate(TRANSLATE_TABLE)
546             f = getattr(self, 'channel_%s' % channelType, None)
547             if f is not None:
548                 chan = f(windowSize, maxPacket, data)
549             else:
550                 chan = None
551         if chan is None:
552             raise error.ConchError('unknown channel',
553                     OPEN_UNKNOWN_CHANNEL_TYPE)
554         else:
555             chan.conn = self
556             return chan
557
558     def gotGlobalRequest(self, requestType, data):
559         """
560         We got a global request.  pretty much, this is just used by the client
561         to request that we forward a port from the server to the client.
562         Returns either:
563             - 1: request accepted
564             - 1, <data>: request accepted with request specific data
565             - 0: request denied
566
567         By default, this dispatches to a method 'global_requestType' with
568         -'s in requestType replaced with _'s.  The found method is passed data.
569         If this method cannot be found, this method returns 0.  Otherwise, it
570         returns the return value of that method.
571
572         @type requestType:  C{str}
573         @type data:         C{str}
574         @rtype:             C{int}/C{tuple}
575         """
576         log.msg('got global %s request' % requestType)
577         if hasattr(self.transport, 'avatar'): # this is a server!
578             return self.transport.avatar.gotGlobalRequest(requestType, data)
579
580         requestType = requestType.replace('-','_')
581         f = getattr(self, 'global_%s' % requestType, None)
582         if not f:
583             return 0
584         return f(data)
585
586     def channelClosed(self, channel):
587         """
588         Called when a channel is closed.
589         It clears the local state related to the channel, and calls
590         channel.closed().
591         MAKE SURE YOU CALL THIS METHOD, even if you subclass L{SSHConnection}.
592         If you don't, things will break mysteriously.
593
594         @type channel: L{SSHChannel}
595         """
596         if channel in self.channelsToRemoteChannel: # actually open
597             channel.localClosed = channel.remoteClosed = True
598             del self.localToRemoteChannel[channel.id]
599             del self.channels[channel.id]
600             del self.channelsToRemoteChannel[channel]
601             for d in self.deferreds.setdefault(channel.id, []):
602                 d.errback(error.ConchError("Channel closed."))
603             del self.deferreds[channel.id][:]
604             log.callWithLogger(channel, channel.closed)
605
606 MSG_GLOBAL_REQUEST = 80
607 MSG_REQUEST_SUCCESS = 81
608 MSG_REQUEST_FAILURE = 82
609 MSG_CHANNEL_OPEN = 90
610 MSG_CHANNEL_OPEN_CONFIRMATION = 91
611 MSG_CHANNEL_OPEN_FAILURE = 92
612 MSG_CHANNEL_WINDOW_ADJUST = 93
613 MSG_CHANNEL_DATA = 94
614 MSG_CHANNEL_EXTENDED_DATA = 95
615 MSG_CHANNEL_EOF = 96
616 MSG_CHANNEL_CLOSE = 97
617 MSG_CHANNEL_REQUEST = 98
618 MSG_CHANNEL_SUCCESS = 99
619 MSG_CHANNEL_FAILURE = 100
620
621 OPEN_ADMINISTRATIVELY_PROHIBITED = 1
622 OPEN_CONNECT_FAILED = 2
623 OPEN_UNKNOWN_CHANNEL_TYPE = 3
624 OPEN_RESOURCE_SHORTAGE = 4
625
626 EXTENDED_DATA_STDERR = 1
627
628 messages = {}
629 for name, value in locals().copy().items():
630     if name[:4] == 'MSG_':
631         messages[value] = name # doesn't handle doubles
632
633 import string
634 alphanums = string.letters + string.digits
635 TRANSLATE_TABLE = ''.join([chr(i) in alphanums and chr(i) or '_'
636     for i in range(256)])
637 SSHConnection.protocolMessages = messages