# Twisted, the Framework of Your Internet
# Copyright (C) 2001-2002 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
"""Test code for policies."""
from __future__ import nested_scopes
from StringIO import StringIO
from twisted.trial import unittest
import time
from twisted.internet import protocol, reactor
from twisted.protocols import policies
class StringIOWithoutClosing(StringIO):
def close(self): pass
class SimpleProtocol(protocol.Protocol):
connected = disconnected = 0
buffer = ""
def connectionMade(self):
self.connected = 1
def connectionLost(self, reason):
self.disconnected = 1
def dataReceived(self, data):
self.buffer += data
class SillyFactory(protocol.ClientFactory):
def __init__(self, p):
self.p = p
def buildProtocol(self, addr):
return self.p
class EchoProtocol(protocol.Protocol):
def pauseProducing(self):
self.paused = time.time()
def resumeProducing(self):
self.resume = time.time()
def stopProducing(self):
pass
def dataReceived(self, data):
self.transport.write(data)
class Server(protocol.ServerFactory):
protocol = EchoProtocol
class SimpleSenderProtocol(SimpleProtocol):
finished = 0
data = ''
def __init__(self, testcase):
self.testcase = testcase
def connectionMade(self):
SimpleProtocol.connectionMade(self)
self.writeSomething()
def writeSomething(self):
if self.disconnected:
if not self.finished:
self.fail()
else:
reactor.crash()
if not self.disconnected:
self.transport.write('foo')
reactor.callLater(1, self.writeSomething)
def finish(self):
self.finished = 1
self.transport.loseConnection()
def fail(self):
self.testcase.failed = 1
def dataReceived(self, data):
self.data += data
class ThrottlingTestCase(unittest.TestCase):
def doIterations(self, count=5):
for i in range(count):
reactor.iterate()
def testLimit(self):
server = Server()
c1, c2, c3, c4 = [SimpleProtocol() for i in range(4)]
tServer = policies.ThrottlingFactory(server, 2)
p = reactor.listenTCP(0, tServer, interface="127.0.0.1")
n = p.getHost()[2]
self.doIterations()
for c in c1, c2, c3:
reactor.connectTCP("127.0.0.1", n, SillyFactory(c))
self.doIterations()
self.assertEquals([c.connected for c in c1, c2, c3], [1, 1, 1])
self.assertEquals([c.disconnected for c in c1, c2, c3], [0, 0, 1])
self.assertEquals(len(tServer.protocols.keys()), 2)
# disconnect one protocol and now another should be able to connect
c1.transport.loseConnection()
self.doIterations()
reactor.connectTCP("127.0.0.1", n, SillyFactory(c4))
self.doIterations()
self.assertEquals(c4.connected, 1)
self.assertEquals(c4.disconnected, 0)
for c in c2, c4: c.transport.loseConnection()
p.stopListening()
self.doIterations()
def testWriteLimit(self):
server = Server()
c1, c2 = SimpleProtocol(), SimpleProtocol()
# The throttling factory starts checking bandwidth immediately
now = time.time()
tServer = policies.ThrottlingFactory(server, writeLimit=10)
port = reactor.listenTCP(0, tServer, interface="127.0.0.1")
n = port.getHost()[2]
reactor.iterate(); reactor.iterate()
for c in c1, c2:
reactor.connectTCP("127.0.0.1", n, SillyFactory(c))
self.doIterations()
for p in tServer.protocols.keys():
p = p.wrappedProtocol
self.assert_(isinstance(p, EchoProtocol))
p.transport.registerProducer(p, 1)
c1.transport.write("0123456789")
c2.transport.write("abcdefghij")
self.doIterations()
self.assertEquals(c1.buffer, "0123456789")
self.assertEquals(c2.buffer, "abcdefghij")
self.assertEquals(tServer.writtenThisSecond, 20)
# at this point server should've written 20 bytes, 10 bytes
# above the limit so writing should be paused around 1 second
# from 'now', and resumed a second after that
for p in tServer.protocols.keys():
self.assert_(not hasattr(p.wrappedProtocol, "paused"))
self.assert_(not hasattr(p.wrappedProtocol, "resume"))
while not hasattr(p.wrappedProtocol, "paused"):
reactor.iterate()
self.assertEquals(tServer.writtenThisSecond, 0)
for p in tServer.protocols.keys():
self.assert_(hasattr(p.wrappedProtocol, "paused"))
self.assert_(not hasattr(p.wrappedProtocol, "resume"))
self.assert_(abs(p.wrappedProtocol.paused - now - 1.0) < 0.1)
while not hasattr(p.wrappedProtocol, "resume"):
reactor.iterate()
for p in tServer.protocols.keys():
self.assert_(hasattr(p.wrappedProtocol, "resume"))
self.assert_(abs(p.wrappedProtocol.resume -
p.wrappedProtocol.paused - 1.0) < 0.1)
c1.transport.loseConnection()
c2.transport.loseConnection()
port.stopListening()
for p in tServer.protocols.keys():
p.loseConnection()
self.doIterations()
def testReadLimit(self):
server = Server()
c1, c2 = SimpleProtocol(), SimpleProtocol()
now = time.time()
tServer = policies.ThrottlingFactory(server, readLimit=10)
port = reactor.listenTCP(0, tServer, interface="127.0.0.1")
n = port.getHost()[2]
self.doIterations()
for c in c1, c2:
reactor.connectTCP("127.0.0.1", n, SillyFactory(c))
self.doIterations()
c1.transport.write("0123456789")
c2.transport.write("abcdefghij")
self.doIterations()
self.assertEquals(c1.buffer, "0123456789")
self.assertEquals(c2.buffer, "abcdefghij")
self.assertEquals(tServer.readThisSecond, 20)
# we wrote 20 bytes, so after one second it should stop reading
# and then a second later start reading again
while time.time() - now < 1.05:
reactor.iterate()
self.assertEquals(tServer.readThisSecond, 0)
# write some more - data should *not* get written for another second
c1.transport.write("0123456789")
c2.transport.write("abcdefghij")
self.doIterations()
self.assertEquals(c1.buffer, "0123456789")
self.assertEquals(c2.buffer, "abcdefghij")
self.assertEquals(tServer.readThisSecond, 0)
while time.time() - now < 2.05:
reactor.iterate()
self.assertEquals(c1.buffer, "01234567890123456789")
self.assertEquals(c2.buffer, "abcdefghijabcdefghij")
c1.transport.loseConnection()
c2.transport.loseConnection()
port.stopListening()
for p in tServer.protocols.keys():
p.loseConnection()
self.doIterations()
# These fail intermittently.
testReadLimit.skip = "Inaccurate tests are worse than no tests."
testWriteLimit.skip = "Inaccurate tests are worse than no tests."
class TimeoutTestCase(unittest.TestCase):
def setUp(self):
self.failed = 0
def testTimeout(self):
# Create a server which times out inactive connections
server = policies.TimeoutFactory(Server(), 3)
port = reactor.listenTCP(0, server, interface="127.0.0.1")
# Create a client tha sends and receive nothing
client = SimpleProtocol()
f = SillyFactory(client)
reactor.connectTCP("127.0.0.1", port.getHost()[2], f)
for i in range(10):
reactor.iterate()
self.assert_(client.connected)
time.sleep(3.5)
for i in range(3):
reactor.iterate()
self.assert_(client.disconnected)
# Clean up
port.loseConnection()
for i in range(10):
reactor.iterate()
def testThatSendingDataAvoidsTimeout(self):
# Create a server which times out inactive connections
server = policies.TimeoutFactory(Server(), 2)
port = reactor.listenTCP(0, server, interface="127.0.0.1")
# Create a client that sends and receive nothing
client = SimpleSenderProtocol(self)
f = SillyFactory(client)
f.protocol = client
reactor.connectTCP("127.0.0.1", port.getHost()[2], f)
reactor.callLater(3.5, client.finish)
reactor.run()
self.failUnlessEqual(self.failed, 0)
self.failUnlessEqual(client.data, 'foo'*4)
def testThatReadingDataAvoidsTimeout(self):
# Create a server that sends occasionally
server = SillyFactory(SimpleSenderProtocol(self))
port = reactor.listenTCP(0, server, interface='127.0.0.1')
clientFactory = policies.WrappingFactory(SillyFactory(SimpleProtocol()))
port = reactor.connectTCP('127.0.0.1', port.getHost()[2], clientFactory)
reactor.iterate()
reactor.iterate()
reactor.callLater(5, server.p.finish)
reactor.run()
self.failUnlessEqual(self.failed, 0)
class TimeoutTester(protocol.Protocol, policies.TimeoutMixin):
timeOut = 3
timedOut = 0
def connectionMade(self):
self.setTimeout(self.timeOut)
def dataReceived(self, data):
self.resetTimeout()
protocol.Protocol.dataReceived(self, data)
def connectionLost(self, reason=None):
self.setTimeout(None)
def timeoutConnection(self):
self.timedOut = 1
class TestTimeout(unittest.TestCase):
def testTimeout(self):
p = TimeoutTester()
s = StringIOWithoutClosing()
p.makeConnection(protocol.FileWrapper(s))
for i in range(10):
reactor.iterate()
self.failIf(p.timedOut)
time.sleep(3.5)
reactor.iterate()
self.failUnless(p.timedOut)
def testNoTimeout(self):
p = TimeoutTester()
s = StringIOWithoutClosing()
p.makeConnection(protocol.FileWrapper(s))
for i in range(10):
reactor.iterate()
self.failIf(p.timedOut)
time.sleep(2)
p.dataReceived('hello there')
time.sleep(1.5)
for i in range(10):
reactor.iterate()
self.failIf(p.timedOut)
time.sleep(2)
for i in range(10):
reactor.iterate()
self.failUnless(p.timedOut)
def testResetTimeout(self):
p = TimeoutTester()
p.timeOut = None
s = StringIOWithoutClosing()
p.makeConnection(protocol.FileWrapper(s))
p.setTimeout(1)
self.assertEquals(p.timeOut, 1)
for i in range(10):
reactor.iterate()
self.failIf(p.timedOut)
time.sleep(1.1)
reactor.iterate()
self.failUnless(p.timedOut)
p.connectionLost()
def testCancelTimeout(self):
p = TimeoutTester()
p.timeOut = 5
s = StringIOWithoutClosing()
p.makeConnection(protocol.FileWrapper(s))
p.setTimeout(None)
self.assertEquals(p.timeOut, None)
for i in range(10):
reactor.iterate()
self.failIf(p.timedOut)
p.connectionLost()
def testReturn(self):
p = TimeoutTester()
p.timeOut = 5
self.assertEquals(p.setTimeout(10), 5)
self.assertEquals(p.setTimeout(None), 10)
self.assertEquals(p.setTimeout(1), None)
self.assertEquals(p.timeOut, 1)
p.connectionLost()
syntax highlighted by Code2HTML, v. 0.9.1