# -*- test-case-name: twisted.test.test_policies -*-
# Twisted, the Framework of Your Internet
# Copyright (C) 2001-2002 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
"""Resource limiting policies."""
# system imports
import sys, operator, time
# twisted imports
from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
from twisted.internet.interfaces import ITransport
from twisted.internet import reactor, error
from twisted.python import log
class ProtocolWrapper(Protocol):
"""Wraps protocol instances and acts as their transport as well."""
__implements__ = ITransport,
disconnecting = 0
def __init__(self, factory, wrappedProtocol):
self.wrappedProtocol = wrappedProtocol
self.factory = factory
def makeConnection(self, transport):
class _MyClass(self.__class__):
__implements__ = transport.__implements__
self.__class__ = _MyClass
Protocol.makeConnection(self, transport)
# Transport relaying
def write(self, data):
self.transport.write(data)
def writeSequence(self, data):
self.transport.writeSequence(data)
def loseConnection(self):
self.disconnecting = 1
self.transport.loseConnection()
def getPeer(self):
return self.transport.getPeer()
def getHost(self):
return self.transport.getHost()
def registerProducer(self, producer, streaming):
self.transport.registerProducer(producer, streaming)
def unregisterProducer(self):
self.transport.unregisterProducer()
def stopConsuming(self):
self.transport.stopConsuming()
def __getattr__(self, name):
return getattr(self.transport, name)
# Protocol relaying
def connectionMade(self):
self.factory.registerProtocol(self)
self.wrappedProtocol.makeConnection(self)
def dataReceived(self, data):
self.wrappedProtocol.dataReceived(data)
def connectionLost(self, reason):
self.factory.unregisterProtocol(self)
self.wrappedProtocol.connectionLost(reason)
class WrappingFactory(ClientFactory):
"""Wraps a factory and its protocols, and keeps track of them."""
protocol = ProtocolWrapper
def __init__(self, wrappedFactory):
self.wrappedFactory = wrappedFactory
self.protocols = {}
def doStart(self):
self.wrappedFactory.doStart()
ClientFactory.doStart(self)
def doStop(self):
self.wrappedFactory.doStop()
ClientFactory.doStop(self)
def startedConnecting(self, connector):
self.wrappedFactory.startedConnecting(connector)
def clientConnectionFailed(self, connector, reason):
self.wrappedFactory.clientConnectionFailed(connector, reason)
def clientConnectionLost(self, connector, reason):
self.wrappedFactory.clientConnectionLost(connector, reason)
def buildProtocol(self, addr):
return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
def registerProtocol(self, p):
"""Called by protocol to register itself."""
self.protocols[p] = 1
def unregisterProtocol(self, p):
"""Called by protocols when they go away."""
del self.protocols[p]
class ThrottlingProtocol(ProtocolWrapper):
"""Protocol for ThrottlingFactory."""
# wrap API for tracking bandwidth
def write(self, data):
self.factory.registerWritten(len(data))
ProtocolWrapper.write(self, data)
def writeSequence(self, seq):
self.factory.registerWritten(reduce(operator.add, map(len, seq)))
ProtocolWrapper.writeSequence(self, seq)
def dataReceived(self, data):
self.factory.registerRead(len(data))
ProtocolWrapper.dataReceived(self, data)
def registerProducer(self, producer, streaming):
self.producer = producer
ProtocolWrapper.registerProducer(self, producer, streaming)
def unregisterProducer(self):
del self.producer
ProtocolWrapper.unregisterProducer(self)
def throttleReads(self):
self.transport.stopReading()
def unthrottleReads(self):
self.transport.startReading()
def throttleWrites(self):
if hasattr(self, "producer"):
self.producer.pauseProducing()
def unthrottleWrites(self):
if hasattr(self, "producer"):
self.producer.resumeProducing()
class ThrottlingFactory(WrappingFactory):
"""Throttles bandwidth and number of connections.
Write bandwidth will only be throttled if there is a producer
registered.
"""
protocol = ThrottlingProtocol
def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint, readLimit=None, writeLimit=None):
WrappingFactory.__init__(self, wrappedFactory)
self.connectionCount = 0
self.maxConnectionCount = maxConnectionCount
self.readLimit = readLimit # max bytes we should read per second
self.writeLimit = writeLimit # max bytes we should write per second
self.readThisSecond = 0
self.writtenThisSecond = 0
self.unthrottleReadsID = None
self.checkReadBandwidthID = None
self.unthrottleWritesID = None
self.checkWriteBandwidthID = None
def registerWritten(self, length):
"""Called by protocol to tell us more bytes were written."""
self.writtenThisSecond += length
def registerRead(self, length):
"""Called by protocol to tell us more bytes were read."""
self.readThisSecond += length
def checkReadBandwidth(self):
"""Checks if we've passed bandwidth limits."""
if self.readThisSecond > self.readLimit:
self.throttleReads()
throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
self.unthrottleReadsID = reactor.callLater(throttleTime,
self.unthrottleReads)
self.readThisSecond = 0
self.checkReadBandwidthID = reactor.callLater(1, self.checkReadBandwidth)
def checkWriteBandwidth(self):
if self.writtenThisSecond > self.writeLimit:
self.throttleWrites()
throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
self.unthrottleWritesID = reactor.callLater(throttleTime,
self.unthrottleWrites)
# reset for next round
self.writtenThisSecond = 0
self.checkWriteBandwidthID = reactor.callLater(1, self.checkWriteBandwidth)
def throttleReads(self):
"""Throttle reads on all protocols."""
log.msg("Throttling reads on %s" % self)
for p in self.protocols.keys():
p.throttleReads()
def unthrottleReads(self):
"""Stop throttling reads on all protocols."""
self.unthrottleReadsID = None
log.msg("Stopped throttling reads on %s" % self)
for p in self.protocols.keys():
p.unthrottleReads()
def throttleWrites(self):
"""Throttle writes on all protocols."""
log.msg("Throttling writes on %s" % self)
for p in self.protocols.keys():
p.throttleWrites()
def unthrottleWrites(self):
"""Stop throttling writes on all protocols."""
self.unthrottleWritesID = None
log.msg("Stopped throttling writes on %s" % self)
for p in self.protocols.keys():
p.unthrottleWrites()
def buildProtocol(self, addr):
if self.connectionCount == 0:
if self.readLimit is not None:
self.checkReadBandwidth()
if self.writeLimit is not None:
self.checkWriteBandwidth()
if self.connectionCount < self.maxConnectionCount:
self.connectionCount += 1
return WrappingFactory.buildProtocol(self, addr)
else:
log.msg("Max connection count reached!")
return None
def unregisterProtocol(self, p):
WrappingFactory.unregisterProtocol(self, p)
self.connectionCount -= 1
if self.connectionCount == 0:
if self.unthrottleReadsID is not None:
self.unthrottleReadsID.cancel()
if self.checkReadBandwidthID is not None:
self.checkReadBandwidthID.cancel()
if self.unthrottleWritesID is not None:
self.unthrottleWritesID.cancel()
if self.checkWriteBandwidthID is not None:
self.checkWriteBandwidthID.cancel()
class SpewingProtocol(ProtocolWrapper):
def dataReceived(self, data):
log.msg("Received: %r" % data)
ProtocolWrapper.dataReceived(self,data)
def write(self, data):
log.msg("Sending: %r" % data)
ProtocolWrapper.write(self,data)
class SpewingFactory(WrappingFactory):
protocol = SpewingProtocol
class LimitConnectionsByPeer(WrappingFactory):
"""Stability: Unstable"""
maxConnectionsPerPeer = 5
def startFactory(self):
self.peerConnections = {}
def buildProtocol(self, addr):
peerHost = addr[0]
connectionCount = self.peerConnections.get(peerHost, 0)
if connectionCount >= self.maxConnectionsPerPeer:
return None
self.peerConnections[peerHost] = connectionCount + 1
return WrappingFactory.buildProtocol(self, addr)
def unregisterProtocol(self, p):
peerHost = p.getPeer()[1]
self.peerConnections[peerHost] -= 1
if self.peerConnections[peerHost] == 0:
del self.peerConnections[peerHost]
class TimeoutProtocol(ProtocolWrapper):
"""Protocol that automatically disconnects when the connection is idle.
Stability: Unstable
"""
def __init__(self, factory, wrappedProtocol, timeoutPeriod):
"""Constructor.
@param factory: An L{IFactory}.
@param wrappedProtocol: A L{Protocol} to wrapp.
@param timeoutPeriod: Number of seconds to wait for activity before
timing out.
"""
ProtocolWrapper.__init__(self, factory, wrappedProtocol)
self.timeoutCall = None
self.setTimeout(timeoutPeriod)
def setTimeout(self, timeoutPeriod=None):
"""Set a timeout.
This will cancel any existing timeouts.
@param timeoutPeriod: If not C{None}, change the timeout period.
Otherwise, use the existing value.
"""
self.cancelTimeout()
if timeoutPeriod is not None:
self.timeoutPeriod = timeoutPeriod
self.timeoutCall = reactor.callLater(self.timeoutPeriod, self.timeoutFunc)
def cancelTimeout(self):
"""Cancel the timeout.
If the timeout was already cancelled, this does nothing.
"""
if self.timeoutCall:
try:
self.timeoutCall.cancel()
except error.AlreadyCalled:
pass
self.timeoutCall = None
def resetTimeout(self):
"""Reset the timeout, usually because some activity just happened."""
if self.timeoutCall:
self.timeoutCall.reset(self.timeoutPeriod)
def write(self, data):
self.resetTimeout()
ProtocolWrapper.write(self, data)
def writeSequence(self, seq):
self.resetTimeout()
ProtocolWrapper.writeSequence(self, seq)
def dataReceived(self, data):
self.resetTimeout()
ProtocolWrapper.dataReceived(self, data)
def connectionLost(self, reason):
self.cancelTimeout()
ProtocolWrapper.connectionLost(self, reason)
def timeoutFunc(self):
"""This method is called when the timeout is triggered.
By default it calls L{loseConnection}. Override this if you want
something else to happen.
"""
self.loseConnection()
class TimeoutFactory(WrappingFactory):
"""Factory for TimeoutWrapper.
Stability: Unstable
"""
protocol = TimeoutProtocol
def __init__(self, wrappedFactory, timeoutPeriod=30*60):
self.timeoutPeriod = timeoutPeriod
WrappingFactory.__init__(self, wrappedFactory)
def buildProtocol(self, addr):
return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
timeoutPeriod=self.timeoutPeriod)
class TimeoutMixin:
"""Mixin for protocols which wish to timeout connections
@cvar timeOut: The number of seconds after which to timeout the connection.
"""
timeOut = None
__timeoutCall = None
__lastReceived = None
def resetTimeout(self):
"""Reset the timeout count down"""
self.__lastReceived = time.time()
def setTimeout(self, period):
"""Change the timeout period
@type period: C{int} or C{NoneType}
@param period: The period, in seconds, to change the timeout to, or
C{None} to disable the timeout.
"""
prev = self.timeOut
self.timeOut = period
self.__lastReceived = time.time()
if self.__timeoutCall:
self.__timeoutCall.cancel()
self.__timeoutCall = None
if period is not None:
self.__timeoutCall = reactor.callLater(period, self.__timedOut)
return prev
def __timedOut(self):
self.__timeoutCall = None
now = time.time()
if now - (self.__lastReceived or now) > self.timeOut:
self.timeoutConnection()
else:
when = self.__lastReceived - now + self.timeOut
self.__timeoutCall = reactor.callLater(when, self.__timedOut)
def timeoutConnection(self):
"""Called when the connection times out.
Override to define behavior other than dropping the connection.
"""
self.transport.loseConnection()
syntax highlighted by Code2HTML, v. 0.9.1