1 # -*- test-case-name: twisted.test.test_policies -*-
2 # Copyright (c) Twisted Matrix Laboratories.
3 # See LICENSE for details.
6 Resource limiting policies.
8 @seealso: See also L{twisted.protocols.htb} for rate limiting.
14 from zope.interface import directlyProvides, providedBy
17 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
18 from twisted.internet import error
19 from twisted.internet.interfaces import ILoggingContext
20 from twisted.python import log
23 def _wrappedLogPrefix(wrapper, wrapped):
25 Compute a log prefix for a wrapper and the object it wraps.
29 if ILoggingContext.providedBy(wrapped):
30 logPrefix = wrapped.logPrefix()
32 logPrefix = wrapped.__class__.__name__
33 return "%s (%s)" % (logPrefix, wrapper.__class__.__name__)
37 class ProtocolWrapper(Protocol):
39 Wraps protocol instances and acts as their transport as well.
41 @ivar wrappedProtocol: An L{IProtocol<twisted.internet.interfaces.IProtocol>}
42 provider to which L{IProtocol<twisted.internet.interfaces.IProtocol>}
43 method calls onto this L{ProtocolWrapper} will be proxied.
45 @ivar factory: The L{WrappingFactory} which created this
51 def __init__(self, factory, wrappedProtocol):
52 self.wrappedProtocol = wrappedProtocol
53 self.factory = factory
58 Use a customized log prefix mentioning both the wrapped protocol and
61 return _wrappedLogPrefix(self, self.wrappedProtocol)
64 def makeConnection(self, transport):
66 When a connection is made, register this wrapper with its factory,
67 save the real transport, and connect the wrapped protocol to this
68 L{ProtocolWrapper} to intercept any transport calls it makes.
70 directlyProvides(self, providedBy(transport))
71 Protocol.makeConnection(self, transport)
72 self.factory.registerProtocol(self)
73 self.wrappedProtocol.makeConnection(self)
78 def write(self, data):
79 self.transport.write(data)
82 def writeSequence(self, data):
83 self.transport.writeSequence(data)
86 def loseConnection(self):
87 self.disconnecting = 1
88 self.transport.loseConnection()
92 return self.transport.getPeer()
96 return self.transport.getHost()
99 def registerProducer(self, producer, streaming):
100 self.transport.registerProducer(producer, streaming)
103 def unregisterProducer(self):
104 self.transport.unregisterProducer()
107 def stopConsuming(self):
108 self.transport.stopConsuming()
111 def __getattr__(self, name):
112 return getattr(self.transport, name)
117 def dataReceived(self, data):
118 self.wrappedProtocol.dataReceived(data)
121 def connectionLost(self, reason):
122 self.factory.unregisterProtocol(self)
123 self.wrappedProtocol.connectionLost(reason)
127 class WrappingFactory(ClientFactory):
129 Wraps a factory and its protocols, and keeps track of them.
132 protocol = ProtocolWrapper
134 def __init__(self, wrappedFactory):
135 self.wrappedFactory = wrappedFactory
141 Generate a log prefix mentioning both the wrapped factory and this one.
143 return _wrappedLogPrefix(self, self.wrappedFactory)
147 self.wrappedFactory.doStart()
148 ClientFactory.doStart(self)
152 self.wrappedFactory.doStop()
153 ClientFactory.doStop(self)
156 def startedConnecting(self, connector):
157 self.wrappedFactory.startedConnecting(connector)
160 def clientConnectionFailed(self, connector, reason):
161 self.wrappedFactory.clientConnectionFailed(connector, reason)
164 def clientConnectionLost(self, connector, reason):
165 self.wrappedFactory.clientConnectionLost(connector, reason)
168 def buildProtocol(self, addr):
169 return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
172 def registerProtocol(self, p):
174 Called by protocol to register itself.
176 self.protocols[p] = 1
179 def unregisterProtocol(self, p):
181 Called by protocols when they go away.
183 del self.protocols[p]
187 class ThrottlingProtocol(ProtocolWrapper):
188 """Protocol for ThrottlingFactory."""
190 # wrap API for tracking bandwidth
192 def write(self, data):
193 self.factory.registerWritten(len(data))
194 ProtocolWrapper.write(self, data)
196 def writeSequence(self, seq):
197 self.factory.registerWritten(reduce(operator.add, map(len, seq)))
198 ProtocolWrapper.writeSequence(self, seq)
200 def dataReceived(self, data):
201 self.factory.registerRead(len(data))
202 ProtocolWrapper.dataReceived(self, data)
204 def registerProducer(self, producer, streaming):
205 self.producer = producer
206 ProtocolWrapper.registerProducer(self, producer, streaming)
208 def unregisterProducer(self):
210 ProtocolWrapper.unregisterProducer(self)
213 def throttleReads(self):
214 self.transport.pauseProducing()
216 def unthrottleReads(self):
217 self.transport.resumeProducing()
219 def throttleWrites(self):
220 if hasattr(self, "producer"):
221 self.producer.pauseProducing()
223 def unthrottleWrites(self):
224 if hasattr(self, "producer"):
225 self.producer.resumeProducing()
228 class ThrottlingFactory(WrappingFactory):
230 Throttles bandwidth and number of connections.
232 Write bandwidth will only be throttled if there is a producer
236 protocol = ThrottlingProtocol
238 def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint,
239 readLimit=None, writeLimit=None):
240 WrappingFactory.__init__(self, wrappedFactory)
241 self.connectionCount = 0
242 self.maxConnectionCount = maxConnectionCount
243 self.readLimit = readLimit # max bytes we should read per second
244 self.writeLimit = writeLimit # max bytes we should write per second
245 self.readThisSecond = 0
246 self.writtenThisSecond = 0
247 self.unthrottleReadsID = None
248 self.checkReadBandwidthID = None
249 self.unthrottleWritesID = None
250 self.checkWriteBandwidthID = None
253 def callLater(self, period, func):
255 Wrapper around L{reactor.callLater} for test purpose.
257 from twisted.internet import reactor
258 return reactor.callLater(period, func)
261 def registerWritten(self, length):
263 Called by protocol to tell us more bytes were written.
265 self.writtenThisSecond += length
268 def registerRead(self, length):
270 Called by protocol to tell us more bytes were read.
272 self.readThisSecond += length
275 def checkReadBandwidth(self):
277 Checks if we've passed bandwidth limits.
279 if self.readThisSecond > self.readLimit:
281 throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
282 self.unthrottleReadsID = self.callLater(throttleTime,
283 self.unthrottleReads)
284 self.readThisSecond = 0
285 self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
288 def checkWriteBandwidth(self):
289 if self.writtenThisSecond > self.writeLimit:
290 self.throttleWrites()
291 throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
292 self.unthrottleWritesID = self.callLater(throttleTime,
293 self.unthrottleWrites)
294 # reset for next round
295 self.writtenThisSecond = 0
296 self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
299 def throttleReads(self):
301 Throttle reads on all protocols.
303 log.msg("Throttling reads on %s" % self)
304 for p in self.protocols.keys():
308 def unthrottleReads(self):
310 Stop throttling reads on all protocols.
312 self.unthrottleReadsID = None
313 log.msg("Stopped throttling reads on %s" % self)
314 for p in self.protocols.keys():
318 def throttleWrites(self):
320 Throttle writes on all protocols.
322 log.msg("Throttling writes on %s" % self)
323 for p in self.protocols.keys():
327 def unthrottleWrites(self):
329 Stop throttling writes on all protocols.
331 self.unthrottleWritesID = None
332 log.msg("Stopped throttling writes on %s" % self)
333 for p in self.protocols.keys():
337 def buildProtocol(self, addr):
338 if self.connectionCount == 0:
339 if self.readLimit is not None:
340 self.checkReadBandwidth()
341 if self.writeLimit is not None:
342 self.checkWriteBandwidth()
344 if self.connectionCount < self.maxConnectionCount:
345 self.connectionCount += 1
346 return WrappingFactory.buildProtocol(self, addr)
348 log.msg("Max connection count reached!")
352 def unregisterProtocol(self, p):
353 WrappingFactory.unregisterProtocol(self, p)
354 self.connectionCount -= 1
355 if self.connectionCount == 0:
356 if self.unthrottleReadsID is not None:
357 self.unthrottleReadsID.cancel()
358 if self.checkReadBandwidthID is not None:
359 self.checkReadBandwidthID.cancel()
360 if self.unthrottleWritesID is not None:
361 self.unthrottleWritesID.cancel()
362 if self.checkWriteBandwidthID is not None:
363 self.checkWriteBandwidthID.cancel()
367 class SpewingProtocol(ProtocolWrapper):
368 def dataReceived(self, data):
369 log.msg("Received: %r" % data)
370 ProtocolWrapper.dataReceived(self,data)
372 def write(self, data):
373 log.msg("Sending: %r" % data)
374 ProtocolWrapper.write(self,data)
378 class SpewingFactory(WrappingFactory):
379 protocol = SpewingProtocol
383 class LimitConnectionsByPeer(WrappingFactory):
385 maxConnectionsPerPeer = 5
387 def startFactory(self):
388 self.peerConnections = {}
390 def buildProtocol(self, addr):
392 connectionCount = self.peerConnections.get(peerHost, 0)
393 if connectionCount >= self.maxConnectionsPerPeer:
395 self.peerConnections[peerHost] = connectionCount + 1
396 return WrappingFactory.buildProtocol(self, addr)
398 def unregisterProtocol(self, p):
399 peerHost = p.getPeer()[1]
400 self.peerConnections[peerHost] -= 1
401 if self.peerConnections[peerHost] == 0:
402 del self.peerConnections[peerHost]
405 class LimitTotalConnectionsFactory(ServerFactory):
407 Factory that limits the number of simultaneous connections.
409 @type connectionCount: C{int}
410 @ivar connectionCount: number of current connections.
411 @type connectionLimit: C{int} or C{None}
412 @cvar connectionLimit: maximum number of connections.
413 @type overflowProtocol: L{Protocol} or C{None}
414 @cvar overflowProtocol: Protocol to use for new connections when
415 connectionLimit is exceeded. If C{None} (the default value), excess
416 connections will be closed immediately.
419 connectionLimit = None
420 overflowProtocol = None
422 def buildProtocol(self, addr):
423 if (self.connectionLimit is None or
424 self.connectionCount < self.connectionLimit):
425 # Build the normal protocol
426 wrappedProtocol = self.protocol()
427 elif self.overflowProtocol is None:
428 # Just drop the connection
431 # Too many connections, so build the overflow protocol
432 wrappedProtocol = self.overflowProtocol()
434 wrappedProtocol.factory = self
435 protocol = ProtocolWrapper(self, wrappedProtocol)
436 self.connectionCount += 1
439 def registerProtocol(self, p):
442 def unregisterProtocol(self, p):
443 self.connectionCount -= 1
447 class TimeoutProtocol(ProtocolWrapper):
449 Protocol that automatically disconnects when the connection is idle.
452 def __init__(self, factory, wrappedProtocol, timeoutPeriod):
456 @param factory: An L{IFactory}.
457 @param wrappedProtocol: A L{Protocol} to wrapp.
458 @param timeoutPeriod: Number of seconds to wait for activity before
461 ProtocolWrapper.__init__(self, factory, wrappedProtocol)
462 self.timeoutCall = None
463 self.setTimeout(timeoutPeriod)
466 def setTimeout(self, timeoutPeriod=None):
470 This will cancel any existing timeouts.
472 @param timeoutPeriod: If not C{None}, change the timeout period.
473 Otherwise, use the existing value.
476 if timeoutPeriod is not None:
477 self.timeoutPeriod = timeoutPeriod
478 self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc)
481 def cancelTimeout(self):
485 If the timeout was already cancelled, this does nothing.
489 self.timeoutCall.cancel()
490 except error.AlreadyCalled:
492 self.timeoutCall = None
495 def resetTimeout(self):
497 Reset the timeout, usually because some activity just happened.
500 self.timeoutCall.reset(self.timeoutPeriod)
503 def write(self, data):
505 ProtocolWrapper.write(self, data)
508 def writeSequence(self, seq):
510 ProtocolWrapper.writeSequence(self, seq)
513 def dataReceived(self, data):
515 ProtocolWrapper.dataReceived(self, data)
518 def connectionLost(self, reason):
520 ProtocolWrapper.connectionLost(self, reason)
523 def timeoutFunc(self):
525 This method is called when the timeout is triggered.
527 By default it calls L{loseConnection}. Override this if you want
528 something else to happen.
530 self.loseConnection()
534 class TimeoutFactory(WrappingFactory):
536 Factory for TimeoutWrapper.
538 protocol = TimeoutProtocol
541 def __init__(self, wrappedFactory, timeoutPeriod=30*60):
542 self.timeoutPeriod = timeoutPeriod
543 WrappingFactory.__init__(self, wrappedFactory)
546 def buildProtocol(self, addr):
547 return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
548 timeoutPeriod=self.timeoutPeriod)
551 def callLater(self, period, func):
553 Wrapper around L{reactor.callLater} for test purpose.
555 from twisted.internet import reactor
556 return reactor.callLater(period, func)
560 class TrafficLoggingProtocol(ProtocolWrapper):
562 def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None,
565 @param factory: factory which created this protocol.
566 @type factory: C{protocol.Factory}.
567 @param wrappedProtocol: the underlying protocol.
568 @type wrappedProtocol: C{protocol.Protocol}.
569 @param logfile: file opened for writing used to write log messages.
570 @type logfile: C{file}
571 @param lengthLimit: maximum size of the datareceived logged.
572 @type lengthLimit: C{int}
573 @param number: identifier of the connection.
574 @type number: C{int}.
576 ProtocolWrapper.__init__(self, factory, wrappedProtocol)
577 self.logfile = logfile
578 self.lengthLimit = lengthLimit
579 self._number = number
582 def _log(self, line):
583 self.logfile.write(line + '\n')
587 def _mungeData(self, data):
588 if self.lengthLimit and len(data) > self.lengthLimit:
589 data = data[:self.lengthLimit - 12] + '<... elided>'
594 def connectionMade(self):
596 return ProtocolWrapper.connectionMade(self)
599 def dataReceived(self, data):
600 self._log('C %d: %r' % (self._number, self._mungeData(data)))
601 return ProtocolWrapper.dataReceived(self, data)
604 def connectionLost(self, reason):
605 self._log('C %d: %r' % (self._number, reason))
606 return ProtocolWrapper.connectionLost(self, reason)
610 def write(self, data):
611 self._log('S %d: %r' % (self._number, self._mungeData(data)))
612 return ProtocolWrapper.write(self, data)
615 def writeSequence(self, iovec):
616 self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec]))
617 return ProtocolWrapper.writeSequence(self, iovec)
620 def loseConnection(self):
621 self._log('S %d: *' % (self._number,))
622 return ProtocolWrapper.loseConnection(self)
626 class TrafficLoggingFactory(WrappingFactory):
627 protocol = TrafficLoggingProtocol
631 def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
632 self.logfilePrefix = logfilePrefix
633 self.lengthLimit = lengthLimit
634 WrappingFactory.__init__(self, wrappedFactory)
637 def open(self, name):
638 return file(name, 'w')
641 def buildProtocol(self, addr):
643 logfile = self.open(self.logfilePrefix + '-' + str(self._counter))
644 return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
645 logfile, self.lengthLimit, self._counter)
648 def resetCounter(self):
650 Reset the value of the counter used to identify connections.
658 Mixin for protocols which wish to timeout connections.
660 Protocols that mix this in have a single timeout, set using L{setTimeout}.
661 When the timeout is hit, L{timeoutConnection} is called, which, by
662 default, closes the connection.
664 @cvar timeOut: The number of seconds after which to timeout the connection.
670 def callLater(self, period, func):
672 Wrapper around L{reactor.callLater} for test purpose.
674 from twisted.internet import reactor
675 return reactor.callLater(period, func)
678 def resetTimeout(self):
680 Reset the timeout count down.
682 If the connection has already timed out, then do nothing. If the
683 timeout has been cancelled (probably using C{setTimeout(None)}), also
686 It's often a good idea to call this when the protocol has received
687 some meaningful input from the other end of the connection. "I've got
688 some data, they're still there, reset the timeout".
690 if self.__timeoutCall is not None and self.timeOut is not None:
691 self.__timeoutCall.reset(self.timeOut)
693 def setTimeout(self, period):
695 Change the timeout period
697 @type period: C{int} or C{NoneType}
698 @param period: The period, in seconds, to change the timeout to, or
699 C{None} to disable the timeout.
702 self.timeOut = period
704 if self.__timeoutCall is not None:
706 self.__timeoutCall.cancel()
707 self.__timeoutCall = None
709 self.__timeoutCall.reset(period)
710 elif period is not None:
711 self.__timeoutCall = self.callLater(period, self.__timedOut)
715 def __timedOut(self):
716 self.__timeoutCall = None
717 self.timeoutConnection()
719 def timeoutConnection(self):
721 Called when the connection times out.
723 Override to define behavior other than dropping the connection.
725 self.transport.loseConnection()