Imported Upstream version 12.1.0
[contrib/python-twisted.git] / twisted / protocols / policies.py
1 # -*- test-case-name: twisted.test.test_policies -*-
2 # Copyright (c) Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5 """
6 Resource limiting policies.
7
8 @seealso: See also L{twisted.protocols.htb} for rate limiting.
9 """
10
11 # system imports
12 import sys, operator
13
14 from zope.interface import directlyProvides, providedBy
15
16 # twisted imports
17 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
18 from twisted.internet import error
19 from twisted.internet.interfaces import ILoggingContext
20 from twisted.python import log
21
22
23 def _wrappedLogPrefix(wrapper, wrapped):
24     """
25     Compute a log prefix for a wrapper and the object it wraps.
26
27     @rtype: C{str}
28     """
29     if ILoggingContext.providedBy(wrapped):
30         logPrefix = wrapped.logPrefix()
31     else:
32         logPrefix = wrapped.__class__.__name__
33     return "%s (%s)" % (logPrefix, wrapper.__class__.__name__)
34
35
36
37 class ProtocolWrapper(Protocol):
38     """
39     Wraps protocol instances and acts as their transport as well.
40
41     @ivar wrappedProtocol: An L{IProtocol<twisted.internet.interfaces.IProtocol>}
42         provider to which L{IProtocol<twisted.internet.interfaces.IProtocol>}
43         method calls onto this L{ProtocolWrapper} will be proxied.
44
45     @ivar factory: The L{WrappingFactory} which created this
46         L{ProtocolWrapper}.
47     """
48
49     disconnecting = 0
50
51     def __init__(self, factory, wrappedProtocol):
52         self.wrappedProtocol = wrappedProtocol
53         self.factory = factory
54
55
56     def logPrefix(self):
57         """
58         Use a customized log prefix mentioning both the wrapped protocol and
59         the current one.
60         """
61         return _wrappedLogPrefix(self, self.wrappedProtocol)
62
63
64     def makeConnection(self, transport):
65         """
66         When a connection is made, register this wrapper with its factory,
67         save the real transport, and connect the wrapped protocol to this
68         L{ProtocolWrapper} to intercept any transport calls it makes.
69         """
70         directlyProvides(self, providedBy(transport))
71         Protocol.makeConnection(self, transport)
72         self.factory.registerProtocol(self)
73         self.wrappedProtocol.makeConnection(self)
74
75
76     # Transport relaying
77
78     def write(self, data):
79         self.transport.write(data)
80
81
82     def writeSequence(self, data):
83         self.transport.writeSequence(data)
84
85
86     def loseConnection(self):
87         self.disconnecting = 1
88         self.transport.loseConnection()
89
90
91     def getPeer(self):
92         return self.transport.getPeer()
93
94
95     def getHost(self):
96         return self.transport.getHost()
97
98
99     def registerProducer(self, producer, streaming):
100         self.transport.registerProducer(producer, streaming)
101
102
103     def unregisterProducer(self):
104         self.transport.unregisterProducer()
105
106
107     def stopConsuming(self):
108         self.transport.stopConsuming()
109
110
111     def __getattr__(self, name):
112         return getattr(self.transport, name)
113
114
115     # Protocol relaying
116
117     def dataReceived(self, data):
118         self.wrappedProtocol.dataReceived(data)
119
120
121     def connectionLost(self, reason):
122         self.factory.unregisterProtocol(self)
123         self.wrappedProtocol.connectionLost(reason)
124
125
126
127 class WrappingFactory(ClientFactory):
128     """
129     Wraps a factory and its protocols, and keeps track of them.
130     """
131
132     protocol = ProtocolWrapper
133
134     def __init__(self, wrappedFactory):
135         self.wrappedFactory = wrappedFactory
136         self.protocols = {}
137
138
139     def logPrefix(self):
140         """
141         Generate a log prefix mentioning both the wrapped factory and this one.
142         """
143         return _wrappedLogPrefix(self, self.wrappedFactory)
144
145
146     def doStart(self):
147         self.wrappedFactory.doStart()
148         ClientFactory.doStart(self)
149
150
151     def doStop(self):
152         self.wrappedFactory.doStop()
153         ClientFactory.doStop(self)
154
155
156     def startedConnecting(self, connector):
157         self.wrappedFactory.startedConnecting(connector)
158
159
160     def clientConnectionFailed(self, connector, reason):
161         self.wrappedFactory.clientConnectionFailed(connector, reason)
162
163
164     def clientConnectionLost(self, connector, reason):
165         self.wrappedFactory.clientConnectionLost(connector, reason)
166
167
168     def buildProtocol(self, addr):
169         return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
170
171
172     def registerProtocol(self, p):
173         """
174         Called by protocol to register itself.
175         """
176         self.protocols[p] = 1
177
178
179     def unregisterProtocol(self, p):
180         """
181         Called by protocols when they go away.
182         """
183         del self.protocols[p]
184
185
186
187 class ThrottlingProtocol(ProtocolWrapper):
188     """Protocol for ThrottlingFactory."""
189
190     # wrap API for tracking bandwidth
191
192     def write(self, data):
193         self.factory.registerWritten(len(data))
194         ProtocolWrapper.write(self, data)
195
196     def writeSequence(self, seq):
197         self.factory.registerWritten(reduce(operator.add, map(len, seq)))
198         ProtocolWrapper.writeSequence(self, seq)
199
200     def dataReceived(self, data):
201         self.factory.registerRead(len(data))
202         ProtocolWrapper.dataReceived(self, data)
203
204     def registerProducer(self, producer, streaming):
205         self.producer = producer
206         ProtocolWrapper.registerProducer(self, producer, streaming)
207
208     def unregisterProducer(self):
209         del self.producer
210         ProtocolWrapper.unregisterProducer(self)
211
212
213     def throttleReads(self):
214         self.transport.pauseProducing()
215
216     def unthrottleReads(self):
217         self.transport.resumeProducing()
218
219     def throttleWrites(self):
220         if hasattr(self, "producer"):
221             self.producer.pauseProducing()
222
223     def unthrottleWrites(self):
224         if hasattr(self, "producer"):
225             self.producer.resumeProducing()
226
227
228 class ThrottlingFactory(WrappingFactory):
229     """
230     Throttles bandwidth and number of connections.
231
232     Write bandwidth will only be throttled if there is a producer
233     registered.
234     """
235
236     protocol = ThrottlingProtocol
237
238     def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint,
239                  readLimit=None, writeLimit=None):
240         WrappingFactory.__init__(self, wrappedFactory)
241         self.connectionCount = 0
242         self.maxConnectionCount = maxConnectionCount
243         self.readLimit = readLimit # max bytes we should read per second
244         self.writeLimit = writeLimit # max bytes we should write per second
245         self.readThisSecond = 0
246         self.writtenThisSecond = 0
247         self.unthrottleReadsID = None
248         self.checkReadBandwidthID = None
249         self.unthrottleWritesID = None
250         self.checkWriteBandwidthID = None
251
252
253     def callLater(self, period, func):
254         """
255         Wrapper around L{reactor.callLater} for test purpose.
256         """
257         from twisted.internet import reactor
258         return reactor.callLater(period, func)
259
260
261     def registerWritten(self, length):
262         """
263         Called by protocol to tell us more bytes were written.
264         """
265         self.writtenThisSecond += length
266
267
268     def registerRead(self, length):
269         """
270         Called by protocol to tell us more bytes were read.
271         """
272         self.readThisSecond += length
273
274
275     def checkReadBandwidth(self):
276         """
277         Checks if we've passed bandwidth limits.
278         """
279         if self.readThisSecond > self.readLimit:
280             self.throttleReads()
281             throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
282             self.unthrottleReadsID = self.callLater(throttleTime,
283                                                     self.unthrottleReads)
284         self.readThisSecond = 0
285         self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
286
287
288     def checkWriteBandwidth(self):
289         if self.writtenThisSecond > self.writeLimit:
290             self.throttleWrites()
291             throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
292             self.unthrottleWritesID = self.callLater(throttleTime,
293                                                         self.unthrottleWrites)
294         # reset for next round
295         self.writtenThisSecond = 0
296         self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
297
298
299     def throttleReads(self):
300         """
301         Throttle reads on all protocols.
302         """
303         log.msg("Throttling reads on %s" % self)
304         for p in self.protocols.keys():
305             p.throttleReads()
306
307
308     def unthrottleReads(self):
309         """
310         Stop throttling reads on all protocols.
311         """
312         self.unthrottleReadsID = None
313         log.msg("Stopped throttling reads on %s" % self)
314         for p in self.protocols.keys():
315             p.unthrottleReads()
316
317
318     def throttleWrites(self):
319         """
320         Throttle writes on all protocols.
321         """
322         log.msg("Throttling writes on %s" % self)
323         for p in self.protocols.keys():
324             p.throttleWrites()
325
326
327     def unthrottleWrites(self):
328         """
329         Stop throttling writes on all protocols.
330         """
331         self.unthrottleWritesID = None
332         log.msg("Stopped throttling writes on %s" % self)
333         for p in self.protocols.keys():
334             p.unthrottleWrites()
335
336
337     def buildProtocol(self, addr):
338         if self.connectionCount == 0:
339             if self.readLimit is not None:
340                 self.checkReadBandwidth()
341             if self.writeLimit is not None:
342                 self.checkWriteBandwidth()
343
344         if self.connectionCount < self.maxConnectionCount:
345             self.connectionCount += 1
346             return WrappingFactory.buildProtocol(self, addr)
347         else:
348             log.msg("Max connection count reached!")
349             return None
350
351
352     def unregisterProtocol(self, p):
353         WrappingFactory.unregisterProtocol(self, p)
354         self.connectionCount -= 1
355         if self.connectionCount == 0:
356             if self.unthrottleReadsID is not None:
357                 self.unthrottleReadsID.cancel()
358             if self.checkReadBandwidthID is not None:
359                 self.checkReadBandwidthID.cancel()
360             if self.unthrottleWritesID is not None:
361                 self.unthrottleWritesID.cancel()
362             if self.checkWriteBandwidthID is not None:
363                 self.checkWriteBandwidthID.cancel()
364
365
366
367 class SpewingProtocol(ProtocolWrapper):
368     def dataReceived(self, data):
369         log.msg("Received: %r" % data)
370         ProtocolWrapper.dataReceived(self,data)
371
372     def write(self, data):
373         log.msg("Sending: %r" % data)
374         ProtocolWrapper.write(self,data)
375
376
377
378 class SpewingFactory(WrappingFactory):
379     protocol = SpewingProtocol
380
381
382
383 class LimitConnectionsByPeer(WrappingFactory):
384
385     maxConnectionsPerPeer = 5
386
387     def startFactory(self):
388         self.peerConnections = {}
389
390     def buildProtocol(self, addr):
391         peerHost = addr[0]
392         connectionCount = self.peerConnections.get(peerHost, 0)
393         if connectionCount >= self.maxConnectionsPerPeer:
394             return None
395         self.peerConnections[peerHost] = connectionCount + 1
396         return WrappingFactory.buildProtocol(self, addr)
397
398     def unregisterProtocol(self, p):
399         peerHost = p.getPeer()[1]
400         self.peerConnections[peerHost] -= 1
401         if self.peerConnections[peerHost] == 0:
402             del self.peerConnections[peerHost]
403
404
405 class LimitTotalConnectionsFactory(ServerFactory):
406     """
407     Factory that limits the number of simultaneous connections.
408
409     @type connectionCount: C{int}
410     @ivar connectionCount: number of current connections.
411     @type connectionLimit: C{int} or C{None}
412     @cvar connectionLimit: maximum number of connections.
413     @type overflowProtocol: L{Protocol} or C{None}
414     @cvar overflowProtocol: Protocol to use for new connections when
415         connectionLimit is exceeded.  If C{None} (the default value), excess
416         connections will be closed immediately.
417     """
418     connectionCount = 0
419     connectionLimit = None
420     overflowProtocol = None
421
422     def buildProtocol(self, addr):
423         if (self.connectionLimit is None or
424             self.connectionCount < self.connectionLimit):
425                 # Build the normal protocol
426                 wrappedProtocol = self.protocol()
427         elif self.overflowProtocol is None:
428             # Just drop the connection
429             return None
430         else:
431             # Too many connections, so build the overflow protocol
432             wrappedProtocol = self.overflowProtocol()
433
434         wrappedProtocol.factory = self
435         protocol = ProtocolWrapper(self, wrappedProtocol)
436         self.connectionCount += 1
437         return protocol
438
439     def registerProtocol(self, p):
440         pass
441
442     def unregisterProtocol(self, p):
443         self.connectionCount -= 1
444
445
446
447 class TimeoutProtocol(ProtocolWrapper):
448     """
449     Protocol that automatically disconnects when the connection is idle.
450     """
451
452     def __init__(self, factory, wrappedProtocol, timeoutPeriod):
453         """
454         Constructor.
455
456         @param factory: An L{IFactory}.
457         @param wrappedProtocol: A L{Protocol} to wrapp.
458         @param timeoutPeriod: Number of seconds to wait for activity before
459             timing out.
460         """
461         ProtocolWrapper.__init__(self, factory, wrappedProtocol)
462         self.timeoutCall = None
463         self.setTimeout(timeoutPeriod)
464
465
466     def setTimeout(self, timeoutPeriod=None):
467         """
468         Set a timeout.
469
470         This will cancel any existing timeouts.
471
472         @param timeoutPeriod: If not C{None}, change the timeout period.
473             Otherwise, use the existing value.
474         """
475         self.cancelTimeout()
476         if timeoutPeriod is not None:
477             self.timeoutPeriod = timeoutPeriod
478         self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc)
479
480
481     def cancelTimeout(self):
482         """
483         Cancel the timeout.
484
485         If the timeout was already cancelled, this does nothing.
486         """
487         if self.timeoutCall:
488             try:
489                 self.timeoutCall.cancel()
490             except error.AlreadyCalled:
491                 pass
492             self.timeoutCall = None
493
494
495     def resetTimeout(self):
496         """
497         Reset the timeout, usually because some activity just happened.
498         """
499         if self.timeoutCall:
500             self.timeoutCall.reset(self.timeoutPeriod)
501
502
503     def write(self, data):
504         self.resetTimeout()
505         ProtocolWrapper.write(self, data)
506
507
508     def writeSequence(self, seq):
509         self.resetTimeout()
510         ProtocolWrapper.writeSequence(self, seq)
511
512
513     def dataReceived(self, data):
514         self.resetTimeout()
515         ProtocolWrapper.dataReceived(self, data)
516
517
518     def connectionLost(self, reason):
519         self.cancelTimeout()
520         ProtocolWrapper.connectionLost(self, reason)
521
522
523     def timeoutFunc(self):
524         """
525         This method is called when the timeout is triggered.
526
527         By default it calls L{loseConnection}.  Override this if you want
528         something else to happen.
529         """
530         self.loseConnection()
531
532
533
534 class TimeoutFactory(WrappingFactory):
535     """
536     Factory for TimeoutWrapper.
537     """
538     protocol = TimeoutProtocol
539
540
541     def __init__(self, wrappedFactory, timeoutPeriod=30*60):
542         self.timeoutPeriod = timeoutPeriod
543         WrappingFactory.__init__(self, wrappedFactory)
544
545
546     def buildProtocol(self, addr):
547         return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
548                              timeoutPeriod=self.timeoutPeriod)
549
550
551     def callLater(self, period, func):
552         """
553         Wrapper around L{reactor.callLater} for test purpose.
554         """
555         from twisted.internet import reactor
556         return reactor.callLater(period, func)
557
558
559
560 class TrafficLoggingProtocol(ProtocolWrapper):
561
562     def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None,
563                  number=0):
564         """
565         @param factory: factory which created this protocol.
566         @type factory: C{protocol.Factory}.
567         @param wrappedProtocol: the underlying protocol.
568         @type wrappedProtocol: C{protocol.Protocol}.
569         @param logfile: file opened for writing used to write log messages.
570         @type logfile: C{file}
571         @param lengthLimit: maximum size of the datareceived logged.
572         @type lengthLimit: C{int}
573         @param number: identifier of the connection.
574         @type number: C{int}.
575         """
576         ProtocolWrapper.__init__(self, factory, wrappedProtocol)
577         self.logfile = logfile
578         self.lengthLimit = lengthLimit
579         self._number = number
580
581
582     def _log(self, line):
583         self.logfile.write(line + '\n')
584         self.logfile.flush()
585
586
587     def _mungeData(self, data):
588         if self.lengthLimit and len(data) > self.lengthLimit:
589             data = data[:self.lengthLimit - 12] + '<... elided>'
590         return data
591
592
593     # IProtocol
594     def connectionMade(self):
595         self._log('*')
596         return ProtocolWrapper.connectionMade(self)
597
598
599     def dataReceived(self, data):
600         self._log('C %d: %r' % (self._number, self._mungeData(data)))
601         return ProtocolWrapper.dataReceived(self, data)
602
603
604     def connectionLost(self, reason):
605         self._log('C %d: %r' % (self._number, reason))
606         return ProtocolWrapper.connectionLost(self, reason)
607
608
609     # ITransport
610     def write(self, data):
611         self._log('S %d: %r' % (self._number, self._mungeData(data)))
612         return ProtocolWrapper.write(self, data)
613
614
615     def writeSequence(self, iovec):
616         self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec]))
617         return ProtocolWrapper.writeSequence(self, iovec)
618
619
620     def loseConnection(self):
621         self._log('S %d: *' % (self._number,))
622         return ProtocolWrapper.loseConnection(self)
623
624
625
626 class TrafficLoggingFactory(WrappingFactory):
627     protocol = TrafficLoggingProtocol
628
629     _counter = 0
630
631     def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
632         self.logfilePrefix = logfilePrefix
633         self.lengthLimit = lengthLimit
634         WrappingFactory.__init__(self, wrappedFactory)
635
636
637     def open(self, name):
638         return file(name, 'w')
639
640
641     def buildProtocol(self, addr):
642         self._counter += 1
643         logfile = self.open(self.logfilePrefix + '-' + str(self._counter))
644         return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
645                              logfile, self.lengthLimit, self._counter)
646
647
648     def resetCounter(self):
649         """
650         Reset the value of the counter used to identify connections.
651         """
652         self._counter = 0
653
654
655
656 class TimeoutMixin:
657     """
658     Mixin for protocols which wish to timeout connections.
659
660     Protocols that mix this in have a single timeout, set using L{setTimeout}.
661     When the timeout is hit, L{timeoutConnection} is called, which, by
662     default, closes the connection.
663
664     @cvar timeOut: The number of seconds after which to timeout the connection.
665     """
666     timeOut = None
667
668     __timeoutCall = None
669
670     def callLater(self, period, func):
671         """
672         Wrapper around L{reactor.callLater} for test purpose.
673         """
674         from twisted.internet import reactor
675         return reactor.callLater(period, func)
676
677
678     def resetTimeout(self):
679         """
680         Reset the timeout count down.
681
682         If the connection has already timed out, then do nothing.  If the
683         timeout has been cancelled (probably using C{setTimeout(None)}), also
684         do nothing.
685
686         It's often a good idea to call this when the protocol has received
687         some meaningful input from the other end of the connection.  "I've got
688         some data, they're still there, reset the timeout".
689         """
690         if self.__timeoutCall is not None and self.timeOut is not None:
691             self.__timeoutCall.reset(self.timeOut)
692
693     def setTimeout(self, period):
694         """
695         Change the timeout period
696
697         @type period: C{int} or C{NoneType}
698         @param period: The period, in seconds, to change the timeout to, or
699         C{None} to disable the timeout.
700         """
701         prev = self.timeOut
702         self.timeOut = period
703
704         if self.__timeoutCall is not None:
705             if period is None:
706                 self.__timeoutCall.cancel()
707                 self.__timeoutCall = None
708             else:
709                 self.__timeoutCall.reset(period)
710         elif period is not None:
711             self.__timeoutCall = self.callLater(period, self.__timedOut)
712
713         return prev
714
715     def __timedOut(self):
716         self.__timeoutCall = None
717         self.timeoutConnection()
718
719     def timeoutConnection(self):
720         """
721         Called when the connection times out.
722
723         Override to define behavior other than dropping the connection.
724         """
725         self.transport.loseConnection()