# Twisted, the Framework of Your Internet
# Copyright (C) 2001 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
"""
Test cases for twisted.protocols package.
"""
from twisted.trial import unittest
from twisted.protocols import basic, wire
from twisted.internet import reactor, protocol
import string, struct
import StringIO
class StringIOWithoutClosing(StringIO.StringIO):
def close(self):
pass
class LineTester(basic.LineReceiver):
delimiter = '\n'
MAX_LENGTH = 64
def connectionMade(self):
self.received = []
def lineReceived(self, line):
self.received.append(line)
if line == '':
self.setRawMode()
if line[:4] == 'len ':
self.length = int(line[4:])
def rawDataReceived(self, data):
data, rest = data[:self.length], data[self.length:]
self.length = self.length - len(data)
self.received[-1] = self.received[-1] + data
if self.length == 0:
self.setLineMode(rest)
def lineLengthExceeded(self, line):
if len(line) > self.MAX_LENGTH+1:
self.setLineMode(line[self.MAX_LENGTH+1:])
class LineOnlyTester(basic.LineOnlyReceiver):
delimiter = '\n'
MAX_LENGTH = 64
def connectionMade(self):
self.received = []
def lineReceived(self, line):
self.received.append(line)
class WireTestCase(unittest.TestCase):
def testEcho(self):
t = StringIOWithoutClosing()
a = wire.Echo()
a.makeConnection(protocol.FileWrapper(t))
a.dataReceived("hello")
a.dataReceived("world")
a.dataReceived("how")
a.dataReceived("are")
a.dataReceived("you")
self.failUnlessEqual(t.getvalue(), "helloworldhowareyou")
def testWho(self):
t = StringIOWithoutClosing()
a = wire.Who()
a.makeConnection(protocol.FileWrapper(t))
self.failUnlessEqual(t.getvalue(), "root\r\n")
def testQOTD(self):
t = StringIOWithoutClosing()
a = wire.QOTD()
a.makeConnection(protocol.FileWrapper(t))
self.failUnlessEqual(t.getvalue(),
"An apple a day keeps the doctor away.\r\n")
def testDiscard(self):
t = StringIOWithoutClosing()
a = wire.Discard()
a.makeConnection(protocol.FileWrapper(t))
a.dataReceived("hello")
a.dataReceived("world")
a.dataReceived("how")
a.dataReceived("are")
a.dataReceived("you")
self.failUnlessEqual(t.getvalue(), "")
class LineReceiverTestCase(unittest.TestCase):
buffer = '''\
len 10
0123456789len 5
1234
len 20
foo 123
0123456789
012345678len 0
foo 5
1234567890123456789012345678901234567890123456789012345678901234567890
len 1
a'''
output = ['len 10', '0123456789', 'len 5', '1234\n',
'len 20', 'foo 123', '0123456789\n012345678',
'len 0', 'foo 5', '', '67890', 'len 1', 'a']
def testBuffer(self):
for packet_size in range(1, 10):
t = StringIOWithoutClosing()
a = LineTester()
a.makeConnection(protocol.FileWrapper(t))
for i in range(len(self.buffer)/packet_size + 1):
s = self.buffer[i*packet_size:(i+1)*packet_size]
a.dataReceived(s)
self.failUnlessEqual(self.output, a.received)
class LineOnlyReceiverTestCase(unittest.TestCase):
buffer = """foo
bleakness
desolation
plastic forks
"""
def testBuffer(self):
t = StringIOWithoutClosing()
a = LineOnlyTester()
a.makeConnection(protocol.FileWrapper(t))
for c in self.buffer:
a.dataReceived(c)
self.failUnlessEqual(a.received, self.buffer.split('\n')[:-1])
def testLineTooLong(self):
t = StringIOWithoutClosing()
a = LineOnlyTester()
a.makeConnection(protocol.FileWrapper(t))
res = a.dataReceived('x'*200)
self.failIfEqual(res, None)
class TestMixin:
def connectionMade(self):
self.received = []
def stringReceived(self, s):
self.received.append(s)
MAX_LENGTH = 50
closed = 0
def connectionLost(self, reason):
self.closed = 1
class TestNetstring(TestMixin, basic.NetstringReceiver):
pass
class LPTestCaseMixin:
illegal_strings = []
protocol = None
def getProtocol(self):
t = StringIOWithoutClosing()
a = self.protocol()
a.makeConnection(protocol.FileWrapper(t))
return a
def testIllegal(self):
for s in self.illegal_strings:
r = self.getProtocol()
for c in s:
r.dataReceived(c)
self.assertEquals(r.transport.closed, 1)
class NetstringReceiverTestCase(unittest.TestCase, LPTestCaseMixin):
strings = ['hello', 'world', 'how', 'are', 'you123', ':today', "a"*515]
illegal_strings = ['9999999999999999999999', 'abc', '4:abcde',
'51:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab,',]
protocol = TestNetstring
def testBuffer(self):
for packet_size in range(1, 10):
t = StringIOWithoutClosing()
a = TestNetstring()
a.MAX_LENGTH = 699
a.makeConnection(protocol.FileWrapper(t))
for s in self.strings:
a.sendString(s)
out = t.getvalue()
for i in range(len(out)/packet_size + 1):
s = out[i*packet_size:(i+1)*packet_size]
if s:
a.dataReceived(s)
self.assertEquals(a.received, self.strings)
class TestInt32(TestMixin, basic.Int32StringReceiver):
MAX_LENGTH = 50
class Int32TestCase(unittest.TestCase, LPTestCaseMixin):
protocol = TestInt32
strings = ["a", "b" * 16]
illegal_strings = ["\x10\x00\x00\x00aaaaaa"]
partial_strings = ["\x00\x00\x00", "hello there", ""]
def testPartial(self):
for s in self.partial_strings:
r = self.getProtocol()
r.MAX_LENGTH = 99999999
for c in s:
r.dataReceived(c)
self.assertEquals(r.received, [])
def testReceive(self):
r = self.getProtocol()
for s in self.strings:
for c in struct.pack("!i",len(s))+s:
r.dataReceived(c)
self.assertEquals(r.received, self.strings)
syntax highlighted by Code2HTML, v. 0.9.1