"""
XXX: I AM INCORRECT, BROKEN CODE
A IOCP-based event loop.
This requires win32all to be installed.
TODO:
1. Pass tests.
2. Switch everyone to a decent OS so we don't have to deal with insane APIs.
3. Process support, SSL, UDP.
"""
# Win32 imports
from win32file import WSAEventSelect, FD_READ, FD_WRITE, FD_CLOSE, \
FD_ACCEPT, FD_CONNECT
from win32event import CreateEvent, WaitForMultipleObjects, \
WAIT_OBJECT_0, WAIT_TIMEOUT, INFINITE
import win32api
import win32con
import win32event
import win32file
import win32pipe
import win32process
import win32security
import pywintypes
import msvcrt
# Twisted imports
from twisted.internet import default, abstract
from twisted.internet.interfaces import IReactorCore, IReactorTime, IReactorTCP
from twisted.python import log, threadable
from twisted.protocols import protocol
from twisted.persisted import styles
# System imports
import os
import threading
import Queue
import string
import time
import socket
import sys
import struct
# globals
files = {} # files handled by IOCP
class Win32Reactor(default.PosixReactorBase):
"""Reactor that uses Win32 event APIs.
Actually, this uses Proactor pattern.
"""
__implements__ = IReactorCore, IReactorTime, IReactorTCP
def __init__(self, handleSignals=1):
default.PosixReactorBase.__init__(self, handleSignals)
self.iocp = win32file.CreateIoCompletionPort(win32file.INVALID_HANDLE_VALUE, None, 0, 1)
def installWaker(self):
self.wakeupOverlapped = pywintypes.OVERLAPPED()
def wakeUp(self):
"""Wake up the event loop."""
if not threadable.isInIOThread():
win32file.PostQueuedCompletionStatus(self.iocp, 0, 0, self.wakeupOverlapped)
def removeAll(self):
return []
def registerFile(self, file, wrapper):
"""Register an object that will be handled by the I/O completion port."""
print "registering %d for %s" % (int(file), wrapper)
files[int(file)] = wrapper
self.iocp = win32file.CreateIoCompletionPort(file, self.iocp, int(file), 1)
def doIteration(self, timeout):
if timeout is None:
timeout = 10000
else:
timeout = int(1000 * timeout)
rc, numBytes, key, overlapped = win32file.GetQueuedCompletionStatus(self.iocp, timeout)
print "GQCS", rc, numBytes, key, overlapped, repr(overlapped.object)
if key == 0:
return
object = files[key]
#print "about to run method %r on object %r" % (overlapped.object, object)
action = getattr(object, overlapped.object)
try:
action()
except:
log.deferr()
try:
object.connectionLost()
except:
log.deferr()
# IReactorTCP
def listenTCP(self, port, factory, backlog=5, interface=''):
"""See twisted.internet.interfaces.IReactorTCP.listenTCP
"""
p = Port(port, factory, backlog, interface)
p.startListening()
return p
def clientTCP(self, host, port, protocol, timeout=30):
return Client(host, port, protocol, timeout)
def install():
threadable.init(1)
r = Win32Reactor()
import main
main.installReactor(r)
import threadtask
threadtask.theDispatcher.start()
class Connection(protocol.Transport, styles.Ephemeral):
"""A TCP connection for the Proactor pattern."""
connected = 0
producerPaused = 0
streamingProducer = 0
unsent = ""
producer = None
disconnected = 0
disconnecting = 0
bufferSize = 2**2**2**2
writing = 0
def __init__(self, skt, protocol):
self.socket = skt
self.socket.setblocking(0)
self.protocol = protocol
# setup win32 objects
self.winSocket = skt.fileno()
self.outOverlapped = pywintypes.OVERLAPPED()
self.outOverlapped.hEvent = win32event.CreateEvent(None, 0, 0, None)
self.inOverlapped = pywintypes.OVERLAPPED()
self.inOverlapped.hEvent = win32event.CreateEvent(None, 0, 0, None)
self.readData = win32file.AllocateReadBuffer(self.bufferSize)
from twisted.internet import reactor
reactor.registerFile(self.winSocket, self)
self.outOverlapped.object = "finishedWriting"
self.inOverlapped.object = "doRead"
def registerProducer(self, producer, streaming):
"""Register to receive data from a producer.
This sets this selectable to be a consumer for a producer. When this
selectable runs out of data on a write() call, it will ask the producer
to resumeProducing(). A producer should implement the IProducer
interface.
FileDescriptor provides some infrastructure for producer methods.
"""
self.producer = producer
self.streamingProducer = streaming
if not streaming:
producer.resumeProducing()
def unregisterProducer(self):
"""Stop consuming data from a producer, without disconnecting.
"""
self.producer = None
def stopConsuming(self):
"""Stop consuming data.
This is called when a producer has lost its connection, to tell the
consumer to go lose its connection (and break potential circular
references).
"""
self.unregisterProducer()
self.loseConnection()
def connectionLost(self):
"""The connection was lost.
This is called when the connection on a selectable object has been
lost. It will be called whether the connection was closed explicitly,
an exception occurred in an event handler, or the other end of the
connection closed it first.
Clean up state here, but make sure to call back up to FileDescriptor.
"""
#print "closing connection", self
self.disconnected = 1
self.connected = 0
if self.producer is not None:
self.producer.stopProducing()
self.producer = None
try:
self.socket.shutdown(2)
except socket.error:
pass
protocol = self.protocol
del self.protocol
del self.socket
protocol.connectionLost()
def write(self, data):
#print self, "is writing", repr(data)
self.unsent = self.unsent + data
if not self.writing:
self.startWriting()
if self.producer is not None:
if len(self.unsent) > self.bufferSize:
self.producerPaused = 1
self.producer.pauseProducing()
def loseConnection(self):
if self.connected:
if self.writing:
self.disconnecting = 1
else:
self.connectionLost()
def startWriting(self):
#print self, "startWriting"
self.writing = 1
size = min(len(self.unsent), self.bufferSize)
data, self.unsent = self.unsent[:size], self.unsent[size:]
try:
win32file.WriteFile(self.winSocket, data, self.outOverlapped)
except win32api.error:
self.connectionLost()
return
def finishedWriting(self):
#print self, "finishedWriting"
if self.disconnected:
return
if self.unsent:
self.startWriting()
return
else:
if self.producer is not None and ((not self.streamingProducer)
or self.producerPaused):
# tell them to supply some more.
self.writing = 0
self.producer.resumeProducing()
self.producerPaused = 0
return
elif self.disconnecting:
# But if I was previously asked to let the connection die, do
# so.
self.connectionLost()
return
self.writing = 0
def startReading(self):
#print self, "startReading"
try:
result, readData = win32file.ReadFile(self.winSocket, self.readData, self.inOverlapped)
assert self.readData is readData
except win32api.error, e:
#print "win32api error", e
try:
length = win32file.GetOverlappedResult(self.winSocket, self.inOverlapped, 0)
self.protocol.dataReceived(self.readData[:length])
except win32api.error:
pass
return
def doRead(self):
#print self, "doRead"
if self.disconnected:
#print "disconnected, byebye:", self.disconnected
return
#print "not disconnected"
length = win32file.GetOverlappedResult(self.winSocket, self.inOverlapped, 0)
#print self, "received data of length", length
if length:
#print self, "received data", repr(self.readData[:length])
self.protocol.dataReceived(self.readData[:length])
self.startReading()
else:
self.connectionLost()
# this code is identical to tcp.Connection:
class Server(Connection):
"""Serverside socket-stream connection class.
I am a serverside network connection transport; a socket which came from an
accept() on a server. Programmers for the twisted.net framework should not
have to use me directly, since I am automatically instantiated in
TCPServer's doRead method. For documentation on what I do, refer to the
documentation for twisted.protocols.protocol.Transport.
"""
def __init__(self, sock, protocol, client, server, sessionno):
"""Server(sock, protocol, client, server, sessionno)
Initialize me with a socket, a protocol, a descriptor for my peer (a
tuple of host, port describing the other end of the connection), an
instance of Port, and a session number.
"""
self.repstr = "<%s #%s on %s>" % (protocol.__class__.__name__, sessionno, server.port)
Connection.__init__(self, sock, protocol)
self.server = server
self.client = client
self.sessionno = sessionno
self.hostname = client[0]
self.logstr = "%s,%s,%s" % (self.protocol.__class__.__name__, sessionno, self.hostname)
self.startReading()
self.connected = 1
def __repr__(self):
"""A string representation of this connection.
"""
return self.repstr
def getHost(self):
"""Returns a tuple of ('INET', hostname, port).
This indicates the servers address.
"""
return ('INET',)+self.socket.getsockname()
def getPeer(self):
"""
Returns a tuple of ('INET', hostname, port), indicating the connected
client's address.
"""
return ('INET',)+self.client
class Port:
"""I am a TCP server port, listening for connections.
When a connection is accepted, I will call my factory's buildProtocol with
the incoming connection as an argument, according to the specification
described in twisted.protocols.protocol.Factory.
If you wish to change the sort of transport that will be used, my
`transport' attribute will be called with the signature expected for
Server.__init__, so it can be replaced.
"""
transport = Server
sessionno = 0
interface = ''
backlog = 5
def __init__(self, port, factory, backlog=5, interface='', reactor=None):
"""Initialize with a numeric port to listen on.
"""
self.port = port
self.factory = factory
self.backlog = backlog
self.interface = interface
if reactor is None:
from twisted.internet import reactor
self.reactor = reactor
self.overlapped = pywintypes.OVERLAPPED()
self.overlapped.hEvent = win32event.CreateEvent(None, 0, 0, None)
self.buffer = win32file.AllocateReadBuffer(64)
def __repr__(self):
return "<%s on %s>" % (self.factory.__class__, self.port)
def createInternetSocket(self):
"""(internal) create an AF_INET socket.
"""
s = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
return s
def __getstate__(self):
"""(internal) get my state for persistence
"""
dct = copy.copy(self.__dict__)
try: del dct['socket']
except: pass
try: del dct['fileno']
except: pass
return dct
def startListening(self):
"""Create and bind my socket, and begin listening on it.
This is called on unserialization, and must be called after creating a
server to begin listening on the specified port.
"""
log.msg("%s starting on %s"%(self.factory.__class__, self.port))
skt = self.createInternetSocket()
skt.setblocking(0)
skt.bind((self.interface, self.port))
skt.listen(self.backlog)
self.connected = 1
self.socket = skt
winSocket = skt.fileno()
self.overlapped.object = "doRead"
self.reactor.registerFile(winSocket, self)
self.startReading()
def startReading(self):
self.newSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.newSocket.setblocking(0)
win32file.AcceptEx(self.socket, self.newSocket, self.buffer, self.overlapped)
def doRead(self):
"""Called when my socket is ready for reading.
This accepts a connection and callse self.protocol() to handle the
wire-level protocol.
"""
if not self.connected:
return
try:
skt = self.newSocket
del self.newSocket
# make new socket inherit properties from the port's socket
skt.setsockopt(socket.SOL_SOCKET, win32file.SO_UPDATE_ACCEPT_CONTEXT, struct.pack("I", self.socket.fileno()))
# get new socket's address
family, localaddr, addr = win32file.GetAcceptExSockaddrs(self.socket, self.buffer)
# build the new protocol
protocol = self.factory.buildProtocol(addr)
if protocol is None:
skt.close()
else:
s = self.sessionno
self.sessionno = s+1
transport = self.transport(skt, protocol, addr, self, s)
protocol.makeConnection(transport, self)
except:
log.deferr()
self.startReading()
def loseConnection(self):
""" Stop accepting connections on this port.
This will shut down my socket and call self.connectionLost().
"""
self.disconnecting = 1
if self.connected:
self.reactor.callLater(0, self.connectionLost)
def connectionLost(self):
"""Cleans up my socket.
"""
log.msg('(Port %s Closed)' % self.port)
self.disconnected = 1
self.connected = 0
self.socket.close()
del self.socket
self.factory.stopFactory()
def logPrefix(self):
"""Returns the name of my class, to prefix log entries with.
"""
return str(self.factory.__class__)
def getHost(self):
"""Returns a tuple of ('INET', hostname, port).
This indicates the servers address.
"""
return ('INET',)+self.socket.getsockname()
class Client(Connection):
"""A client for TCP (and similiar) sockets.
"""
def __init__(self, host, port, protocol, timeout=None, connector=None, reactor=None):
"""Initialize the client, setting up its socket, and request to connect.
"""
if reactor is None:
from twisted.internet import reactor
self.reactor = reactor
self.socket = self.createInternetSocket()
self.addr = (host, port)
self.protocol = protocol
self.host = host
self.port = port
self.connector = connector
self.logstr = self.protocol.__class__.__name__+",client"
if timeout is not None:
self.reactor.callLater(timeout, self.failIfNotConnected)
self.reactor.callLater(0, self.resolveAddress)
def failIfNotConnected(self, *ignored):
# print 'failing if not connected'
if (not self.connected) and (not self.disconnected):
if self.connector:
self.connector.connectionFailed()
self.protocol.connectionFailed()
def createInternetSocket(self):
"""(internal) Create an AF_INET socket.
"""
# factored out so as to minimise the code necessary for SecureClient
return socket.socket(socket.AF_INET,socket.SOCK_STREAM)
def resolveAddress(self):
if abstract.isIPAddress(self.addr[0]):
self._setRealAddress(self.addr[0])
else:
self.reactor.resolve(self.addr[0]
).addCallbacks(
self._setRealAddress, self.failIfNotConnected
).arm()
def _setRealAddress(self, address):
# print 'real address:',repr(address),repr(self.addr)
self.realAddress = (address, self.addr[1])
import threadtask
threadtask.theDispatcher.dispatch(log.logOwner.owner(), self._connect)
def _connect(self):
"""Runs in thread"""
try:
self.socket.connect(self.realAddress)
except socket.error:
self.reactor.callFromThread(self.failIfNotConnected)
else:
self.reactor.callFromThread(self._connected)
def _connected(self):
"""Called when connection succeeded."""
self.connected = 1
Connection.__init__(self, self.socket, self.protocol)
self.protocol.makeConnection(self)
self.startReading()
def connectionLost(self):
if not self.connected:
self.failIfNotConnected()
else:
Connection.connectionLost(self)
if self.connector:
self.connector.connectionLost()
def getHost(self):
"""Returns a tuple of ('INET', hostname, port).
This indicates the address from which I am connecting.
"""
return ('INET',)+self.socket.getsockname()
def getPeer(self):
"""Returns a tuple of ('INET', hostname, port).
This indicates the address that I am connected to. I implement
twisted.protocols.protocol.Transport.
"""
return ('INET',)+self.addr
def __repr__(self):
s = '<%s to %s at %x>' % (self.__class__, self.addr, id(self))
return s
__all__ = ["Win32Reactor", "install"]
syntax highlighted by Code2HTML, v. 0.9.1