# Twisted, the Framework of Your Internet
# Copyright (C) 2001 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
from __future__ import nested_scopes
"""Generic TCP tests."""
import socket, time
from twisted.trial import unittest
from twisted.internet import protocol, reactor, defer
from twisted.internet import error
from twisted.internet.address import IPv4Address
class ClosingProtocol(protocol.Protocol):
def connectionMade(self):
self.transport.loseConnection()
def connectionLost(self, reason):
reason.trap(error.ConnectionDone)
class ClosingFactory(protocol.ServerFactory):
"""Factory that closes port immediatley."""
def buildProtocol(self, conn):
self.port.loseConnection()
return ClosingProtocol()
class MyProtocol(protocol.Protocol):
made = 0
closed = 0
failed = 0
def connectionMade(self):
self.made = 1
def connectionLost(self, reason):
self.closed = 1
class MyServerFactory(protocol.ServerFactory):
called = 0
def buildProtocol(self, addr):
self.called += 1
p = MyProtocol()
self.protocol = p
return p
class MyClientFactory(protocol.ClientFactory):
failed = 0
stopped = 0
def buildProtocol(self, addr):
p = MyProtocol()
self.protocol = p
return p
def clientConnectionFailed(self, connector, reason):
self.failed = 1
self.reason = reason
def clientConnectionLost(self, connector, reason):
self.lostReason = reason
def stopFactory(self):
self.stopped = 1
class PortCleanerUpper(unittest.TestCase):
def __init__(self):
self.ports = []
def tearDown(self):
for p in self.ports:
try:
if self.connected:
p.stopListening()
except:
pass
reactor.iterate()
class ListeningTestCase(PortCleanerUpper):
def testListen(self):
f = MyServerFactory()
p1 = reactor.listenTCP(0, f, interface="127.0.0.1")
p1.stopListening()
def testStopListening(self):
f = MyServerFactory()
port = reactor.listenTCP(0, f, interface="127.0.0.1")
n = port.getHost().port
self.ports.append(port)
l = []
defer.maybeDeferred(port.stopListening).addCallback(l.append)
while not l:
reactor.iterate(0.1)
port = reactor.listenTCP(n, f, interface="127.0.0.1")
self.ports.append(port)
def testNumberedInterface(self):
f = MyServerFactory()
# listen only on the loopback interface
p1 = reactor.listenTCP(0, f, interface='127.0.0.1')
p1.stopListening()
def testNamedInterface(self):
f = MyServerFactory()
# use named interface instead of 127.0.0.1
p1 = reactor.listenTCP(0, f, interface='localhost')
# might raise exception if reactor can't handle named interfaces
p1.stopListening()
class LoopbackTestCase(PortCleanerUpper):
"""Test loopback connections."""
n = 10081
def testClosePortInProtocolFactory(self):
f = ClosingFactory()
port = reactor.listenTCP(0, f, interface="127.0.0.1")
self.n = port.getHost().port
self.ports.append(port)
f.port = port
clientF = MyClientFactory()
reactor.connectTCP("localhost", self.n, clientF)
while not clientF.protocol or not clientF.protocol.closed:
reactor.iterate()
reactor.iterate()
reactor.iterate()
self.assert_(clientF.protocol.made)
self.assert_(port.disconnected)
clientF.lostReason.trap(error.ConnectionDone)
def testTcpNoDelay(self):
f = MyServerFactory()
port = reactor.listenTCP(0, f, interface="127.0.0.1")
self.n = port.getHost().port
self.ports.append(port)
clientF = MyClientFactory()
reactor.connectTCP("localhost", self.n, clientF)
reactor.iterate()
reactor.iterate()
for p in clientF.protocol, f.protocol:
transport = p.transport
self.assertEquals(transport.getTcpNoDelay(), 0)
transport.setTcpNoDelay(1)
self.assertEquals(transport.getTcpNoDelay(), 1)
transport.setTcpNoDelay(0)
reactor.iterate()
self.assertEquals(transport.getTcpNoDelay(), 0)
clientF.protocol.transport.loseConnection()
port.stopListening()
reactor.iterate()
reactor.iterate()
clientF.lostReason.trap(error.ConnectionDone)
def testTcpKeepAlive(self):
f = MyServerFactory()
port = reactor.listenTCP(0, f, interface="127.0.0.1")
self.n = port.getHost().port
self.ports.append(port)
clientF = MyClientFactory()
reactor.connectTCP("localhost", self.n, clientF)
reactor.iterate()
reactor.iterate()
for p in clientF.protocol, f.protocol:
transport = p.transport
self.assertEquals(transport.getTcpKeepAlive(), 0)
transport.setTcpKeepAlive(1)
self.assertEquals(transport.getTcpKeepAlive(), 1)
transport.setTcpKeepAlive(0)
reactor.iterate()
self.assertEquals(transport.getTcpKeepAlive(), 0)
clientF.protocol.transport.loseConnection()
port.stopListening()
reactor.iterate()
reactor.iterate()
clientF.lostReason.trap(error.ConnectionDone)
def testFailing(self):
clientF = MyClientFactory()
# XXX we assume no one is listening on TCP port 69
reactor.connectTCP("localhost", 69, clientF, timeout=5)
start = time.time()
while not clientF.failed:
reactor.iterate()
clientF.reason.trap(error.ConnectionRefusedError)
#self.assert_(time.time() - start < 0.1)
def testConnectByService(self):
serv = socket.getservbyname
s = MyServerFactory()
port = reactor.listenTCP(0, s, interface="127.0.0.1")
self.n = port.getHost().port
socket.getservbyname = lambda s, p,n=self.n: s == 'http' and p == 'tcp' and n or 10
self.ports.append(port)
try:
c = reactor.connectTCP('localhost', 'http', MyClientFactory())
except:
socket.getservbyname = serv
raise
reactor.iterate()
reactor.iterate()
reactor.iterate()
port.stopListening()
c.disconnect()
socket.getservbyname = serv
assert s.called
class StartStopFactory(protocol.Factory):
started = 0
stopped = 0
def startFactory(self):
if self.started or self.stopped:
raise RuntimeError
self.started = 1
def stopFactory(self):
if not self.started or self.stopped:
raise RuntimeError
self.stopped = 1
class ClientStartStopFactory(MyClientFactory):
started = 0
stopped = 0
def startFactory(self):
if self.started or self.stopped:
raise RuntimeError
self.started = 1
def stopFactory(self):
if not self.started or self.stopped:
raise RuntimeError
self.stopped = 1
class FactoryTestCase(PortCleanerUpper):
"""Tests for factories."""
def testServerStartStop(self):
f = StartStopFactory()
# listen on port
p1 = reactor.listenTCP(0, f, interface='127.0.0.1')
self.n1 = p1.getHost().port
self.ports.append(p1)
reactor.iterate()
reactor.iterate()
self.assertEquals((f.started, f.stopped), (1, 0))
# listen on two more ports
p2 = reactor.listenTCP(0, f, interface='127.0.0.1')
self.n2 = p2.getHost().port
self.ports.append(p2)
p3 = reactor.listenTCP(0, f, interface='127.0.0.1')
self.n3 = p3.getHost().port
self.ports.append(p3)
reactor.iterate()
reactor.iterate()
self.assertEquals((f.started, f.stopped), (1, 0))
# close two ports
p1.stopListening()
p2.stopListening()
reactor.iterate()
reactor.iterate()
self.assertEquals((f.started, f.stopped), (1, 0))
# close last port
p3.stopListening()
reactor.iterate()
reactor.iterate()
self.assertEquals((f.started, f.stopped), (1, 1))
def testClientStartStop(self):
f = ClosingFactory()
p = reactor.listenTCP(0, f, interface="127.0.0.1")
self.n = p.getHost().port
self.ports.append(p)
f.port = p
reactor.iterate()
reactor.iterate()
factory = ClientStartStopFactory()
reactor.connectTCP("127.0.0.1", self.n, factory)
self.assert_(factory.started)
reactor.iterate()
reactor.iterate()
while not factory.stopped:
reactor.iterate()
class ConnectorTestCase(PortCleanerUpper):
def testConnectorIdentity(self):
f = ClosingFactory()
p = reactor.listenTCP(0, f, interface="127.0.0.1")
n = p.getHost().port
self.ports.append(p)
f.port = p
reactor.iterate()
reactor.iterate()
l = []; m = []
factory = ClientStartStopFactory()
factory.clientConnectionLost = lambda c, r: (l.append(c), m.append(r))
factory.startedConnecting = lambda c: l.append(c)
connector = reactor.connectTCP("127.0.0.1", n, factory)
self.assertEquals(connector.getDestination(), ('INET', "127.0.0.1", n))
i = 0
while i < 50 and not factory.stopped:
reactor.iterate(0.1)
i += 1
m[0].trap(error.ConnectionDone)
self.assertEquals(l, [connector, connector])
def testUserFail(self):
f = MyServerFactory()
p = reactor.listenTCP(0, f, interface="127.0.0.1")
n = p.getHost().port
self.ports.append(p)
def startedConnecting(connector):
connector.stopConnecting()
factory = ClientStartStopFactory()
factory.startedConnecting = startedConnecting
reactor.connectTCP("127.0.0.1", n, factory)
while not factory.stopped:
reactor.iterate()
self.assertEquals(factory.failed, 1)
factory.reason.trap(error.UserError)
p.stopListening()
reactor.iterate()
def testReconnect(self):
f = ClosingFactory()
p = reactor.listenTCP(0, f, interface="127.0.0.1")
n = p.getHost().port
self.ports.append(p)
f.port = p
reactor.iterate()
reactor.iterate()
factory = MyClientFactory()
def clientConnectionLost(c, reason):
c.connect()
factory.clientConnectionLost = clientConnectionLost
reactor.connectTCP("127.0.0.1", n, factory)
i = 0
while i < 50 and not factory.failed:
reactor.iterate(0.1)
i += 1
p = factory.protocol
self.assertEquals((p.made, p.closed), (1, 1))
factory.reason.trap(error.ConnectionRefusedError)
self.assertEquals(factory.stopped, 1)
class CannotBindTestCase(PortCleanerUpper):
"""Tests for correct behavior when a reactor cannot bind to the required TCP port."""
def testCannotBind(self):
f = MyServerFactory()
p1 = reactor.listenTCP(0, f, interface='127.0.0.1')
n = p1.getHost().port
self.ports.append(p1)
self.assertEquals(p1.getHost(), ("INET", "127.0.0.1", n,))
# make sure new listen raises error
self.assertRaises(error.CannotListenError, reactor.listenTCP, n, f, interface='127.0.0.1')
p1.stopListening()
def testClientBind(self):
f = MyServerFactory()
p = reactor.listenTCP(0, f, interface="127.0.0.1")
self.ports.append(p)
factory = MyClientFactory()
reactor.connectTCP("127.0.0.1", p.getHost().port, factory, bindAddress=("127.0.0.1", 0))
while not factory.protocol:
reactor.iterate()
self.assertEquals(factory.protocol.made, 1)
port = factory.protocol.transport.getHost().port
f2 = MyClientFactory()
reactor.connectTCP("127.0.0.1", p.getHost().port, f2, bindAddress=("127.0.0.1", port))
reactor.iterate()
reactor.iterate()
self.assertEquals(f2.failed, 1)
f2.reason.trap(error.ConnectBindError)
self.assertEquals(f2.stopped, 1)
p.stopListening()
factory.protocol.transport.loseConnection()
reactor.iterate()
reactor.iterate()
reactor.iterate()
self.assertEquals(factory.stopped, 1)
class MyOtherClientFactory(protocol.ClientFactory):
def buildProtocol(self, address):
self.address = address
self.protocol = MyProtocol()
return self.protocol
class LocalRemoteAddressTestCase(PortCleanerUpper):
"""Tests for correct getHost/getPeer values and that the correct address
is passed to buildProtocol.
"""
def testHostAddress(self):
f1 = MyServerFactory()
p1 = reactor.listenTCP(0, f1, interface='127.0.0.1')
n = p1.getHost().port
self.ports.append(p1)
f2 = MyOtherClientFactory()
p2 = reactor.connectTCP('127.0.0.1', n, f2)
for i in range(5):
reactor.iterate(0.01)
self.assertEquals(p1.getHost(), f2.address)
self.assertEquals(p1.getHost(), f2.protocol.transport.getPeer())
p1.stopListening()
p2.disconnect()
class WriterProtocol(protocol.Protocol):
def connectionMade(self):
# use everything ITransport claims to provide. If something here
# fails, the exception will be written to the log, but it will not
# directly flunk the test. The test will fail when maximum number of
# iterations have passed and the writer's factory.done has not yet
# been set.
self.transport.write("Hello Cleveland!\n")
seq = ["Goodbye", " cruel", " world", "\n"]
self.transport.writeSequence(seq)
peer = self.transport.getPeer()
if peer[0] != "INET":
print "getPeer returned non-INET socket:", peer
self.factory.problem = 1
us = self.transport.getHost()
if us[0] != "INET":
print "getHost returned non-INET socket:", us
self.factory.problem = 1
self.factory.done = 1
self.transport.loseConnection()
class ReaderProtocol(protocol.Protocol):
def dataReceived(self, data):
self.factory.data += data
def connectionLost(self, reason):
self.factory.done = 1
class WriterClientFactory(protocol.ClientFactory):
def __init__(self):
self.done = 0
self.data = ""
def buildProtocol(self, addr):
p = ReaderProtocol()
p.factory = self
self.protocol = p
return p
class WriteDataTestCase(PortCleanerUpper):
"""Test that connected TCP sockets can actually write data. Try to
exercise the entire ITransport interface.
"""
def testWriter(self):
f = protocol.Factory()
f.protocol = WriterProtocol
f.done = 0
f.problem = 0
p = reactor.listenTCP(0, f, interface="127.0.0.1")
n = p.getHost().port
self.ports.append(p)
clientF = WriterClientFactory()
reactor.connectTCP("localhost", n, clientF)
count = 0
while not ((count > 20) or (f.done and clientF.done)):
reactor.iterate()
count += 1
self.failUnless(f.done, "writer didn't finish, it probably died")
self.failUnless(f.problem == 0, "writer indicated an error")
self.failUnless(clientF.done, "client didn't see connection dropped")
expected = "".join(["Hello Cleveland!\n",
"Goodbye", " cruel", " world", "\n"])
self.failUnless(clientF.data == expected,
"client didn't receive all the data it expected")
p.stopListening()
class ConnectionLosingProtocol(protocol.Protocol):
def connectionMade(self):
self.transport.write("1")
self.transport.loseConnection()
self.master._connectionMade()
class ProperlyCloseFilesTestCase(unittest.TestCase):
numberRounds = 2048
timeLimit = 200
def setUp(self):
f = protocol.ServerFactory()
f.protocol = protocol.Protocol
self.listener = reactor.listenTCP(0, f, interface="127.0.0.1")
f = protocol.ClientFactory()
f.protocol = ConnectionLosingProtocol
f.protocol.master = self
def connector():
p = self.listener.getHost().port
return reactor.connectTCP('127.0.0.1', p, f)
self.connector = connector
self.totalConnections = 0
def testProperlyCloseFiles(self):
self.connector()
timeLimit = time.time() + self.timeLimit
while (self.totalConnections < self.numberRounds and
time.time() < timeLimit):
reactor.iterate(0.01)
reactor.iterate(0.01)
self.failUnlessEqual(self.totalConnections, self.numberRounds)
def _connectionMade(self):
self.totalConnections += 1
if self.totalConnections<self.numberRounds:
self.connector()
def tearDown(self):
self.listener.stopListening()
class AProtocol(protocol.Protocol):
def connectionMade(self):
reactor.callLater(0.1, self.transport.loseConnection)
self.factory.testcase.assertEquals(self.transport.getHost(),
IPv4Address("TCP", self.transport.getHost().host, self.transport.getHost().port))
self.factory.testcase.assertEquals(self.transport.getPeer(),
IPv4Address("TCP", self.transport.getPeer().host, self.transport.getPeer().port))
self.factory.testcase.assertEquals(self.transport.getPeer(), self.factory.ipv4addr)
self.factory.testcase.ran = 1
class AClientFactory(protocol.ClientFactory):
def __init__(self, testcase, ipv4addr):
self.testcase = testcase
self.ipv4addr = ipv4addr
def buildProtocol(self, addr):
self.testcase.assertEquals(addr, self.ipv4addr)
self.testcase.assertEquals(addr, ('INET', self.ipv4addr.host, self.ipv4addr.port))
p = AProtocol()
p.factory = self
return p
class AServerFactory(protocol.ServerFactory):
def __init__(self, testcase, ipv4addr):
self.testcase = testcase
self.ipv4addr = ipv4addr
def buildProtocol(self, addr):
self.testcase.assertEquals(addr, self.ipv4addr)
self.testcase.assertEquals(addr, (self.ipv4addr.host, self.ipv4addr.port))
p = AProtocol()
p.factory = self
return p
class AddressTestCase(unittest.TestCase):
def getFreePort(self):
"""Get an empty port."""
p = reactor.listenTCP(0, protocol.ServerFactory())
reactor.iterate(); reactor.iterate()
port = p.getHost().port
p.stopListening()
reactor.iterate(); reactor.iterate()
return port
def testBuildProtocol(self):
portno = self.getFreePort()
p = reactor.listenTCP(0, AServerFactory(self, IPv4Address('TCP', '127.0.0.1', portno)))
reactor.iterate()
reactor.connectTCP("127.0.0.1", p.getHost().port,
AClientFactory(self, IPv4Address("TCP", "127.0.0.1", p.getHost().port)),
bindAddress=("127.0.0.1", portno))
self.runReactor(0.4, True)
p.stopListening()
self.assert_(hasattr(self, "ran"))
del self.ran
try:
import resource
except ImportError:
pass
else:
ProperlyCloseFilesTestCase.numberRounds = resource.getrlimit(resource.RLIMIT_NOFILE)[0] + 10
syntax highlighted by Code2HTML, v. 0.9.1