# -*- test-case-name: twisted.test.test_tcp -*- # Twisted, the Framework of Your Internet # Copyright (C) 2001 Matthew W. Lefkowitz # # This library is free software; you can redistribute it and/or # modify it under the terms of version 2.1 of the GNU Lesser General Public # License as published by the Free Software Foundation. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library; if not, write to the Free Software # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA """Various asynchronous TCP/IP classes. End users shouldn't use this module directly - use the reactor APIs instead. Maintainer: U{Itamar Shtull-Trauring} """ # System Imports import os import stat import types import exceptions import socket import sys import select import operator import warnings try: import fcntl except ImportError: fcntl = None try: from OpenSSL import SSL except ImportError: SSL = None if os.name == 'nt': # we hardcode these since windows actually wants e.g. # WSAEALREADY rather than EALREADY. Possibly we should # just be doing "from errno import WSAEALREADY as EALREADY". EPERM = 10001 EINVAL = 10022 EWOULDBLOCK = 10035 EINPROGRESS = 10036 EALREADY = 10037 ECONNRESET = 10054 EISCONN = 10056 ENOTCONN = 10057 EINTR = 10004 elif os.name != 'java': from errno import EPERM from errno import EINVAL from errno import EWOULDBLOCK from errno import EINPROGRESS from errno import EALREADY from errno import ECONNRESET from errno import EISCONN from errno import ENOTCONN from errno import EINTR from errno import EAGAIN # Twisted Imports from twisted.internet import protocol, defer, base from twisted.persisted import styles from twisted.python import log, failure, reflect from twisted.python.runtime import platform, platformType from twisted.internet.error import CannotListenError # Sibling Imports import abstract import main import interfaces import error class _TLSMixin: writeBlockedOnRead = 0 readBlockedOnWrite = 0 sslShutdown = 0 def getPeerCertificate(self): return self.socket.get_peer_certificate() def doRead(self): if self.writeBlockedOnRead: self.writeBlockedOnRead = 0 self.startWriting() try: return Connection.doRead(self) except SSL.ZeroReturnError: # close SSL layer, since other side has done so, if we haven't if not self.sslShutdown: try: self.socket.shutdown() self.sslShutdown = 1 except SSL.Error: pass print 'losting conn1' return main.CONNECTION_DONE except SSL.WantReadError: return except SSL.WantWriteError: self.readBlockedOnWrite = 1 self.startWriting() return except SSL.Error: print 'losting conn2' log.err() return main.CONNECTION_LOST def loseConnection(self): Connection.loseConnection(self) if self.connected: self.startReading() def doWrite(self): if self.writeBlockedOnRead: self.stopWriting() return if self.readBlockedOnWrite: self.readBlockedOnWrite = 0 # XXX - This is touching internal guts bad bad bad if not self.dataBuffer: self.stopWriting() return self.doRead() return Connection.doWrite(self) def writeSomeData(self, data): try: return _BufferFlushBase.writeSomeData(self, data) except SSL.WantWriteError: return 0 except SSL.WantReadError: self.writeBlockedOnRead = 1 return 0 except SSL.SysCallError, e: if e[0] == -1 and data == "": # errors when writing empty strings are expected # and can be ignored return 0 else: print 'losting conn3' return main.CONNECTION_LOST except SSL.Error: print 'losting conn4' log.err() return main.CONNECTION_LOST def write(self, data): """Reliably write some data. If there is no buffered data this tries to write this data immediately, otherwise this adds data to be written the next time this file descriptor is ready for writing. """ assert isinstance(data, str), "Data must be a string." if not self.connected or not data: return if (not self.dataBuffer) and (self.producer is None): l = self.writeSomeData(data) if l == len(data): # all data was sent, our work here is done return elif not isinstance(l, Exception) and l > 0: # some data was sent self.dataBuffer = data self.offset = l else: # either no data was sent, or we were disconnected. # if we were disconnected we still continue, so that # the event loop can figure it out later on. self.dataBuffer = data else: self.dataBuffer = self.dataBuffer + data if self.producer is not None: if len(self.dataBuffer) > self.bufferSize: self.producerPaused = 1 self.producer.pauseProducing() self.startWriting() def writeSequence(self, iovec): self.write("".join(iovec)) def _closeSocket(self): try: self.socket.sock_shutdown(2) except: pass try: self.socket.close() except: pass def _postLoseConnection(self): """Gets called after loseConnection(), after buffered data is sent. We close the SSL transport layer, and if the other side hasn't closed it yet we start reading, waiting for a ZeroReturnError which will indicate the SSL shutdown has completed. """ try: done = self.socket.shutdown() self.sslShutdown = 1 except SSL.Error: log.err() return main.CONNECTION_LOST if done: return main.CONNECTION_DONE else: # we wait for other side to close SSL connection - # this will be signaled by SSL.ZeroReturnError when reading # from the socket self.stopWriting() self.startReading() # don't close socket just yet return None class _BufferFlushBase(abstract.FileDescriptor): def writeSomeData(self, data): """Connection.writeSomeData(data) -> #of bytes written | CONNECTION_LOST This writes as much data as possible to the socket and returns either the number of bytes read (which is positive) or a connection error code (which is negative) """ try: return self.socket.send(data) except socket.error, se: if se.args[0] == EINTR: return self.writeSomeData(data) elif se.args[0] == EWOULDBLOCK: return 0 else: return main.CONNECTION_LOST def _flattenForSSL(self): pass class _IOVecFlushBase(abstract.FileDescriptor): def _flattenForSSL(self): self.dataBuffer = ''.join(self.vector) self.offset = 0 del self.vector def writeVector(self, vector): written, errno = iovec.writev(self.fileno(), vector) if written == -1: if errno == EINTR: return self.writeVector(vector) elif errno == EWOULDBLOCK: return 0 else: log.msg("writev() failed with errno = %d" % (errno,)) return main.CONNECTION_LOST w = written i = 0 L = len(vector) while i < L and w >= len(vector[i]): w -= len(vector[i]) i += 1 del vector[:i] if w: vector[0] = vector[0][w:] return written try: from twisted.python import iovec except ImportError: _FlushBase = _BufferFlushBase else: _FlushBase = _IOVecFlushBase class Connection(_FlushBase): """I am the superclass of all socket-based FileDescriptors. This is an abstract superclass of all objects which represent a TCP/IP connection based socket. """ __implements__ = abstract.FileDescriptor.__implements__, interfaces.ITCPTransport TLS = 0 def __init__(self, skt, protocol, reactor=None): abstract.FileDescriptor.__init__(self, reactor=reactor) self.socket = skt self.socket.setblocking(0) self.fileno = skt.fileno self.protocol = protocol if SSL: __implements__ = __implements__ + (interfaces.ITLSTransport,) def startTLS(self, ctx): assert not self.TLS self.stopReading() self.stopWriting() self._startTLS() self.socket = SSL.Connection(ctx.getContext(), self.socket) self.fileno = self.socket.fileno self.startReading() def _startTLS(self): self.TLS = 1 class TLSConnection(_TLSMixin, _BufferFlushBase, self.__class__): pass self._flattenForSSL() self.__class__ = TLSConnection def doRead(self): """Calls self.protocol.dataReceived with all available data. This reads up to self.bufferSize bytes of data from its socket, then calls self.dataReceived(data) to process it. If the connection is not lost through an error in the physical recv(), this function will return the result of the dataReceived call. """ try: data = self.socket.recv(self.bufferSize) except socket.error, se: if se.args[0] == EWOULDBLOCK: return else: return main.CONNECTION_LOST except SSL.SysCallError, (retval, desc): # Yes, SSL might be None, but self.socket.recv() can *only* # raise socket.error, if anything else is raised, it must be an # SSL socket, and so SSL can't be None. (That's my story, I'm # stickin' to it) if retval == -1 and desc == 'Unexpected EOF': return main.CONNECTION_DONE raise if not data: return main.CONNECTION_DONE return self.protocol.dataReceived(data) def _closeSocket(self): """Called to close our socket.""" # This used to close() the socket, but that doesn't *really* close if # there's another reference to it in the TCP/IP stack, e.g. if it was # was inherited by a subprocess. And we really do want to close the # connection. So we use shutdown() instead. try: self.socket.shutdown(2) except socket.error: pass def connectionLost(self, reason): """See abstract.FileDescriptor.connectionLost(). """ abstract.FileDescriptor.connectionLost(self, reason) self._closeSocket() protocol = self.protocol del self.protocol del self.socket del self.fileno try: protocol.connectionLost(reason) except TypeError, e: # while this may break, it will only break on deprecated code # as opposed to other approaches that might've broken on # code that uses the new API (e.g. inspect). if e.args and e.args[0] == "connectionLost() takes exactly 1 argument (2 given)": warnings.warn("Protocol %s's connectionLost should accept a reason argument" % protocol, category=DeprecationWarning, stacklevel=2) protocol.connectionLost() else: raise logstr = "Uninitialized" def logPrefix(self): """Return the prefix to log with when I own the logging thread. """ return self.logstr def getTcpNoDelay(self): return operator.truth(self.socket.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)) def setTcpNoDelay(self, enabled): self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, enabled) def getTcpKeepAlive(self): return operator.truth(self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)) def setTcpKeepAlive(self, enabled): self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, enabled) class BaseClient(Connection): """A base class for client TCP (and similiar) sockets. """ addressFamily = socket.AF_INET socketType = socket.SOCK_STREAM def _finishInit(self, whenDone, skt, error, reactor): """Called by base classes to continue to next stage of initialization.""" if whenDone: Connection.__init__(self, skt, None, reactor) self.doWrite = self.doConnect self.doRead = self.doConnect reactor.callLater(0, whenDone) else: reactor.callLater(0, self.failIfNotConnected, error) def startTLS(self, ctx, client=1): holder = Connection.startTLS(self, ctx) if client: self.socket.set_connect_state() else: self.socket.set_accept_state() return holder def stopConnecting(self): """Stop attempt to connect.""" self.failIfNotConnected(error.UserError()) def failIfNotConnected(self, err): if (self.connected or self.disconnected or not (hasattr(self, "connector"))): return self.connector.connectionFailed(failure.Failure(err)) if hasattr(self, "reactor"): # this doesn't happens if we failed in __init__ self.stopReading() self.stopWriting() del self.connector def createInternetSocket(self): """(internal) Create a non-blocking socket using self.addressFamily, self.socketType. """ s = socket.socket(self.addressFamily, self.socketType) s.setblocking(0) if fcntl and hasattr(fcntl, 'FD_CLOEXEC'): old = fcntl.fcntl(s.fileno(), fcntl.F_GETFD) fcntl.fcntl(s.fileno(), fcntl.F_SETFD, old | fcntl.FD_CLOEXEC) return s def resolveAddress(self): if abstract.isIPAddress(self.addr[0]): self._setRealAddress(self.addr[0]) else: d = self.reactor.resolve(self.addr[0]) d.addCallbacks(self._setRealAddress, self.failIfNotConnected) def _setRealAddress(self, address): self.realAddress = (address, self.addr[1]) self.doConnect() def doConnect(self): """I connect the socket. Then, call the protocol's makeConnection, and start waiting for data. """ if not hasattr(self, "connector"): # this happens when connection failed but doConnect # was scheduled via a callLater in self._finishInit return # on windows failed connects are reported on exception # list, not write or read list. if platformType == "win32": r, w, e = select.select([], [], [self.fileno()], 0.0) if e: err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) self.failIfNotConnected(error.getConnectError((err, os.strerror(err)))) return try: connectResult = self.socket.connect_ex(self.realAddress) except socket.error, se: connectResult = se.args[0] if connectResult: if connectResult == EISCONN: pass # on Windows EINVAL means sometimes that we should keep trying: # http://msdn.microsoft.com/library/default.asp?url=/library/en-us/winsock/winsock/connect_2.asp elif ((connectResult in (EWOULDBLOCK, EINPROGRESS, EALREADY)) or (connectResult == EINVAL and platformType == "win32")): self.startReading() self.startWriting() return else: self.failIfNotConnected(error.getConnectError((connectResult, os.strerror(connectResult)))) return # If I have reached this point without raising or returning, that means # that the socket is connected. del self.doWrite del self.doRead # we first stop and then start, to reset any references to the old doRead self.stopReading() self.stopWriting() self._connectDone() def _connectDone(self): self.protocol = self.connector.buildProtocol(self.getPeer()) self.connected = 1 self.protocol.makeConnection(self) self.logstr = self.protocol.__class__.__name__+",client" self.startReading() def connectionLost(self, reason): if not self.connected: self.failIfNotConnected(error.ConnectError(string=reason)) else: Connection.connectionLost(self, reason) self.connector.connectionLost(reason) class Client(BaseClient): """A TCP client.""" def __init__(self, host, port, bindAddress, connector, reactor=None): # BaseClient.__init__ is invoked later self.connector = connector self.addr = (host, port) whenDone = self.resolveAddress err = None skt = None try: skt = self.createInternetSocket() except socket.error, se: err = error.ConnectBindError(se[0], se[1]) whenDone = None if whenDone and bindAddress is not None: try: skt.bind(bindAddress) except socket.error, se: err = error.ConnectBindError(se[0], se[1]) whenDone = None self._finishInit(whenDone, skt, err, reactor) def getHost(self): """Returns a tuple of ('INET', hostname, port). This indicates the address from which I am connecting. """ return ('INET',)+self.socket.getsockname() def getPeer(self): """Returns a tuple of ('INET', hostname, port). This indicates the address that I am connected to. """ return ('INET',)+self.addr def __repr__(self): s = '<%s to %s at %x>' % (self.__class__, self.addr, id(self)) return s class Server(Connection): """Serverside socket-stream connection class. I am a serverside network connection transport; a socket which came from an accept() on a server. """ def __init__(self, sock, protocol, client, server, sessionno): """Server(sock, protocol, client, server, sessionno) Initialize me with a socket, a protocol, a descriptor for my peer (a tuple of host, port describing the other end of the connection), an instance of Port, and a session number. """ Connection.__init__(self, sock, protocol) self.server = server self.client = client self.sessionno = sessionno self.hostname = client[0] self.logstr = "%s,%s,%s" % (self.protocol.__class__.__name__, sessionno, self.hostname) self.repstr = "<%s #%s on %s>" % (self.protocol.__class__.__name__, self.sessionno, self.server.port) self.startReading() self.connected = 1 def __repr__(self): """A string representation of this connection. """ return self.repstr def startTLS(self, ctx, server=1): holder = Connection.startTLS(self, ctx) if server: self.socket.set_accept_state() else: self.socket.set_connect_state() return holder def getHost(self): """Returns a tuple of ('INET', hostname, port). This indicates the servers address. """ return ('INET',)+self.socket.getsockname() def getPeer(self): """ Returns a tuple of ('INET', hostname, port), indicating the connected client's address. """ return ('INET',)+self.client class Port(base.BasePort): """I am a TCP server port, listening for connections. When a connection is accepted, I will call my factory's buildProtocol with the incoming connection as an argument, according to the specification described in twisted.internet.interfaces.IProtocolFactory. If you wish to change the sort of transport that will be used, my `transport' attribute will be called with the signature expected for Server.__init__, so it can be replaced. """ addressFamily = socket.AF_INET socketType = socket.SOCK_STREAM transport = Server sessionno = 0 interface = '' backlog = 5 def __init__(self, port, factory, backlog=5, interface='', reactor=None): """Initialize with a numeric port to listen on. """ base.BasePort.__init__(self, reactor=reactor) self.port = port self.factory = factory self.backlog = backlog self.interface = interface def __repr__(self): return "<%s on %s>" % (self.factory.__class__, self.port) def createInternetSocket(self): s = base.BasePort.createInternetSocket(self) if platformType == "posix": s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s def startListening(self): """Create and bind my socket, and begin listening on it. This is called on unserialization, and must be called after creating a server to begin listening on the specified port. """ log.msg("%s starting on %s"%(self.factory.__class__, self.port)) try: skt = self.createInternetSocket() skt.bind((self.interface, self.port)) except socket.error, le: raise CannotListenError, (self.interface, self.port, le) self.factory.doStart() skt.listen(self.backlog) self.connected = 1 self.socket = skt self.fileno = self.socket.fileno self.numberAccepts = 100 self.startReading() def doRead(self): """Called when my socket is ready for reading. This accepts a connection and calls self.protocol() to handle the wire-level protocol. """ try: if platformType == "posix": numAccepts = self.numberAccepts else: # win32 event loop breaks if we do more than one accept() # in an iteration of the event loop. numAccepts = 1 for i in range(numAccepts): # we need this so we can deal with a factory's buildProtocol # calling our loseConnection if self.disconnecting: return try: skt, addr = self.socket.accept() except socket.error, e: if e.args[0] in (EWOULDBLOCK, EAGAIN): self.numberAccepts = i break elif e.args[0] == EPERM: continue raise protocol = self.factory.buildProtocol(addr) if protocol is None: skt.close() continue s = self.sessionno self.sessionno = s+1 transport = self.transport(skt, protocol, addr, self, s) transport = self._preMakeConnection(transport) protocol.makeConnection(transport) else: self.numberAccepts = self.numberAccepts+20 except: # Note that in TLS mode, this will possibly catch SSL.Errors # raised by self.socket.accept() # # There is no "except SSL.Error:" above because SSL may be # None if there is no SSL support. In any case, all the # "except SSL.Error:" suite would probably do is log.deferr() # and return, so handling it here works just as well. log.deferr() def _preMakeConnection(self, transport): return transport def loseConnection(self, connDone=failure.Failure(main.CONNECTION_DONE)): """Stop accepting connections on this port. This will shut down my socket and call self.connectionLost(). It returns a deferred which will fire successfully when the port is actually closed. """ self.disconnecting = 1 self.stopReading() if self.connected: self.deferred = defer.Deferred() self.reactor.callLater(0, self.connectionLost, connDone) return self.deferred stopListening = loseConnection def connectionLost(self, reason): """Cleans up my socket. """ log.msg('(Port %r Closed)' % self.port) base.BasePort.connectionLost(self, reason) self.connected = 0 self.socket.close() del self.socket del self.fileno self.factory.doStop() if hasattr(self, "deferred"): self.deferred.callback(None) del self.deferred def logPrefix(self): """Returns the name of my class, to prefix log entries with. """ return reflect.qual(self.factory.__class__) def getHost(self): """Returns a tuple of ('INET', hostname, port). This indicates the server's address. """ return ('INET',)+self.socket.getsockname() class Connector(base.BaseConnector): def __init__(self, host, port, factory, timeout, bindAddress, reactor=None): self.host = host if isinstance(port, types.StringTypes): try: port = socket.getservbyname(port, 'tcp') except socket.error, e: raise error.ServiceNameUnknownError(string=str(e)) self.port = port self.bindAddress = bindAddress base.BaseConnector.__init__(self, factory, timeout, reactor) def _makeTransport(self): return Client(self.host, self.port, self.bindAddress, self, self.reactor) def getDestination(self): return ('INET', self.host, self.port)