# -*- test-case-name: twisted.test.test_protocols -*-
# Twisted, the Framework of Your Internet
# Copyright (C) 2001 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
"""Basic protocols, such as line-oriented, netstring, and 32-bit-int prefixed strings.
API Stability: semi-stable.
Maintainer: U{Itamar Shtull-Trauring<mailto:twisted@itamarst.org>}
"""
# System imports
import re
import struct
# Twisted imports
from twisted.internet import protocol, defer, interfaces, error
from twisted.python import log
LENGTH, DATA, COMMA = range(3)
NUMBER = re.compile('(\d*)(:?)')
DEBUG = 0
class NetstringParseError(ValueError):
"""The incoming data is not in valid Netstring format."""
pass
class NetstringReceiver(protocol.Protocol):
"""This uses djb's Netstrings protocol to break up the input into strings.
Each string makes a callback to stringReceived, with a single
argument of that string.
Security features:
1. Messages are limited in size, useful if you don't want someone
sending you a 500MB netstring (change MAX_LENGTH to the maximum
length you wish to accept).
2. The connection is lost if an illegal message is received.
"""
MAX_LENGTH = 99999
brokenPeer = 0
_readerState = LENGTH
_readerLength = 0
def stringReceived(self, line):
"""
Override this.
"""
raise NotImplementedError
def doData(self):
buffer,self.__data = self.__data[:int(self._readerLength)],self.__data[int(self._readerLength):]
self._readerLength = self._readerLength - len(buffer)
self.__buffer = self.__buffer + buffer
if self._readerLength != 0:
return
self.stringReceived(self.__buffer)
self._readerState = COMMA
def doComma(self):
self._readerState = LENGTH
if self.__data[0] != ',':
if DEBUG:
raise NetstringParseError(repr(self.__data))
else:
raise NetstringParseError
self.__data = self.__data[1:]
def doLength(self):
m = NUMBER.match(self.__data)
if not m.end():
if DEBUG:
raise NetstringParseError(repr(self.__data))
else:
raise NetstringParseError
self.__data = self.__data[m.end():]
if m.group(1):
try:
self._readerLength = self._readerLength * (10**len(m.group(1))) + long(m.group(1))
except OverflowError:
raise NetstringParseError, "netstring too long"
if self._readerLength > self.MAX_LENGTH:
raise NetstringParseError, "netstring too long"
if m.group(2):
self.__buffer = ''
self._readerState = DATA
def dataReceived(self, data):
self.__data = data
try:
while self.__data:
if self._readerState == DATA:
self.doData()
elif self._readerState == COMMA:
self.doComma()
elif self._readerState == LENGTH:
self.doLength()
else:
raise RuntimeError, "mode is not DATA, COMMA or LENGTH"
except NetstringParseError:
self.transport.loseConnection()
self.brokenPeer = 1
def sendString(self, data):
self.transport.write('%d:%s,' % (len(data), data))
class SafeNetstringReceiver(NetstringReceiver):
"""This class is deprecated, use NetstringReceiver instead.
"""
class LineOnlyReceiver(protocol.Protocol):
"""A protocol that receives only lines.
This is purely a speed optimisation over LineReceiver, for the cases that raw mode is known to be unnecessary.
"""
_buffer = ''
delimiter = '\r\n'
MAX_LENGTH = 16384
def dataReceived(self, data):
"""Protocol.dataReceived.
Translates bytes into lines, and calls lineReceived.
"""
lines = (self._buffer+data).split(self.delimiter)
self._buffer = lines[-1]
for line in lines[:-1]:
if self.transport.disconnecting:
# this is necessary because the transport may be told to lose
# the connection by a line within a larger packet, and it is
# important to disregard all the lines in that packet following
# the one that told it to close.
return
if len(line) > self.MAX_LENGTH:
return self.lineLengthExceeded(line)
else:
self.lineReceived(line)
if len(self._buffer) > self.MAX_LENGTH:
return self.lineLengthExceeded(self._buffer)
def lineReceived(self, line):
"""Override this for when each line is received.
"""
raise NotImplementedError
def sendLine(self, line):
"""Sends a line to the other end of the connection.
"""
return self.transport.writeSequence((line,self.delimiter))
def lineLengthExceeded(self, line):
"""Called when the maximum line length has been reached.
Override if it needs to be dealt with in some special way.
"""
return error.ConnectionLost('Line length exceeded')
class LineReceiver(protocol.Protocol):
"""A protocol that receives lines and/or raw data, depending on mode.
In line mode, each line that's received becomes a callback to
L{lineReceived}. In raw data mode, each chunk of raw data becomes a
callback to L{rawDataReceived}. The L{setLineMode} and L{setRawMode}
methods switch between the two modes.
This is useful for line-oriented protocols such as IRC, HTTP, POP, etc.
@cvar delimiter: The line-ending delimiter to use. By default this is
'\\r\\n'.
@cvar MAX_LENGTH: The maximum length of a line to allow (If a
sent line is longer than this, the connection is dropped).
Default is 16384.
"""
line_mode = 1
__buffer = ''
delimiter = '\r\n'
MAX_LENGTH = 16384
def clearLineBuffer(self):
"""Clear buffered data."""
self.__buffer = ""
def dataReceived(self, data):
"""Protocol.dataReceived.
Translates bytes into lines, and calls lineReceived (or
rawDataReceived, depending on mode.)
"""
self.__buffer = self.__buffer+data
while self.line_mode:
try:
line, self.__buffer = self.__buffer.split(self.delimiter, 1)
except ValueError:
if len(self.__buffer) > self.MAX_LENGTH:
line, self.__buffer = self.__buffer, ''
return self.lineLengthExceeded(line)
break
else:
linelength = len(line)
if linelength > self.MAX_LENGTH:
line, self.__buffer = self.__buffer, ''
return self.lineLengthExceeded(line)
why = self.lineReceived(line)
if why or self.transport and self.transport.disconnecting:
return why
else:
data, self.__buffer = self.__buffer, ''
if data:
return self.rawDataReceived(data)
def setLineMode(self, extra=''):
"""Sets the line-mode of this receiver.
If you are calling this from a rawDataReceived callback,
you can pass in extra unhandled data, and that data will
be parsed for lines. Further data received will be sent
to lineReceived rather than rawDataReceived.
"""
self.line_mode = 1
return self.dataReceived(extra)
def setRawMode(self):
"""Sets the raw mode of this receiver.
Further data received will be sent to rawDataReceived rather
than lineReceived.
"""
self.line_mode = 0
def rawDataReceived(self, data):
"""Override this for when raw data is received.
"""
raise NotImplementedError
def lineReceived(self, line):
"""Override this for when each line is received.
"""
raise NotImplementedError
def sendLine(self, line):
"""Sends a line to the other end of the connection.
"""
return self.transport.write(line + self.delimiter)
def lineLengthExceeded(self, line):
"""Called when the maximum line length has been reached.
Override if it needs to be dealt with in some special way.
"""
return self.transport.loseConnection()
class Int32StringReceiver(protocol.Protocol):
"""A receiver for int32-prefixed strings.
An int32 string is a string prefixed by 4 bytes, the 32-bit length of
the string encoded in network byte order.
This class publishes the same interface as NetstringReceiver.
"""
MAX_LENGTH = 99999
recvd = ""
def stringReceived(self, msg):
"""Override this.
"""
raise NotImplementedError
def dataReceived(self, recd):
"""Convert int32 prefixed strings into calls to stringReceived.
"""
self.recvd = self.recvd + recd
while len(self.recvd) > 3:
length ,= struct.unpack("!i",self.recvd[:4])
if length > self.MAX_LENGTH:
self.transport.loseConnection()
return
if len(self.recvd) < length+4:
break
packet = self.recvd[4:length+4]
self.recvd = self.recvd[length+4:]
self.stringReceived(packet)
def sendString(self, data):
"""Send an int32-prefixed string to the other end of the connection.
"""
self.transport.write(struct.pack("!i",len(data))+data)
class Int16StringReceiver(protocol.Protocol):
"""A receiver for int16-prefixed strings.
An int16 string is a string prefixed by 2 bytes, the 16-bit length of
the string encoded in network byte order.
This class publishes the same interface as NetstringReceiver.
"""
recvd = ""
def stringReceived(self, msg):
"""Override this.
"""
raise NotImplementedError
def dataReceived(self, recd):
"""Convert int16 prefixed strings into calls to stringReceived.
"""
self.recvd = self.recvd + recd
while len(self.recvd) > 1:
length = (ord(self.recvd[0]) * 256) + ord(self.recvd[1])
if len(self.recvd) < length+2:
break
packet = self.recvd[2:length+2]
self.recvd = self.recvd[length+2:]
self.stringReceived(packet)
def sendString(self, data):
"""Send an int16-prefixed string to the other end of the connection.
"""
assert len(data) < 65536, "message too long"
self.transport.write(struct.pack("!h",len(data)) + data)
class StatefulStringProtocol:
"""A stateful string protocol.
This is a mixin for string protocols (Int32StringReceiver,
NetstringReceiver) which translates stringReceived into a callback
(prefixed with 'proto_') depending on state."""
state = 'init'
def stringReceived(self,string):
"""Choose a protocol phase function and call it.
Call back to the appropriate protocol phase; this begins with
the function proto_init and moves on to proto_* depending on
what each proto_* function returns. (For example, if
self.proto_init returns 'foo', then self.proto_foo will be the
next function called when a protocol message is received.
"""
try:
pto = 'proto_'+self.state
statehandler = getattr(self,pto)
except AttributeError:
log.msg('callback',self.state,'not found')
else:
self.state = statehandler(string)
if self.state == 'done':
self.transport.loseConnection()
class FileSender:
"""A producer that sends the contents of a file to a consumer.
This is a helper for protocols that, at some point, will take a
file-like object, read its contents, and write them out to the network,
optionally performing some transformation on the bytes in between.
This API is unstable.
"""
__implements__ = (interfaces.IProducer,)
CHUNK_SIZE = 2 ** 14
lastSent = ''
deferred = None
def beginFileTransfer(self, file, consumer, transform = None):
"""Begin transferring a file
@type file: Any file-like object
@param file: The file object to read data from
@type consumer: Any implementor of IConsumer
@param consumer: The object to write data to
@param transform: A callable taking one string argument and returning
the same. All bytes read from the file are passed through this before
being written to the consumer.
@rtype: C{Deferred}
@return: A deferred whose callback will be invoked when the file has been
completely written to the consumer. The last byte written to the consumer
is passed to the callback.
"""
self.file = file
self.consumer = consumer
self.transform = transform
self.consumer.registerProducer(self, False)
self.deferred = defer.Deferred()
return self.deferred
def resumeProducing(self):
chunk = ''
if self.file:
chunk = self.file.read(self.CHUNK_SIZE)
if not chunk:
self.file = None
self.consumer.unregisterProducer()
if self.deferred:
self.deferred.callback(self.lastSent)
self.deferred = None
return
if self.transform:
chunk = self.transform(chunk)
self.consumer.write(chunk)
self.lastSent = chunk[-1]
def pauseProducing(self):
pass
def stopProducing(self):
if self.deferred:
self.deferred.errback(Exception("Consumer asked us to stop producing"))
self.deferred = None
syntax highlighted by Code2HTML, v. 0.9.1