import time
import struct
import crypt
from twisted.python import context
from twisted.python import components
# Record Layer Content Types
RL_CT_CHANGE_CIPHER_SPEC = 20
RL_CT_ALERT = 21
RL_CT_HANDSHAKE = 22
RL_CT_APPLICATION_DATA = 23
def twos(s):
return zip(*[iter(s)] * 2)
class IEncodable(components.Interface):
def encode(self):
"""Encode this record fully for transmission.
"""
class Random:
__implements__ = (IEncodable,)
def __init__(self, tstamp=None, rbytes=None):
self.time = tstamp or int(time.time())
self.bytes = rbytes or crypt.getRandomBytes(28)
def encode(self):
return struct.pack('>I', self.time) + self.bytes
def NullCipherMethod(record):
return record
def NullCompressionMethod(record):
return record
class SecurityParameters(object):
macAlgorithm = staticmethod(crypt.HMAC_NULL)
bulkEncryptionAlgorithm = staticmethod(NullCipherMethod)
compressionAlgorithm = staticmethod(NullCompressionMethod)
masterSecret = None
clientRandom = None
serverRandom = None
def getCipherSuites(self):
return [0]
def getCompressionMethods(self):
return [0]
# Handshake Content Types
HS_CT_HELLO_REQUEST = 0
HS_CT_CLIENT_HELLO = 1
HS_CT_SERVER_HELLO = 2
HS_CT_CERTIFICATE = 11
HS_CT_SERVER_KEY_EXCHANGE = 12
HS_CT_CERTIFICATE_REQUEST = 13
HS_CT_SERVER_HELLO_DONE = 14
HS_CT_CERTIFICATE_VERIFY = 15
HS_CT_CLIENT_KEY_EXCHANGE = 16
HS_CT_FINISHED = 20
def logbytes(name, bytes):
print name, len(bytes), ':', ' '.join(map(''.join, zip(*[iter(bytes.encode('hex'))] * 2)))
class RecordLayer(object):
MAX_FRAGMENT_SIZE = 2 ** 14
def __init__(self, record):
self.record = record
def encode(self):
S = self.MAX_FRAGMENT_SIZE
packets = []
bytes = self.record.encode()
hdr = struct.pack('>BBB', self.record.type, *self.record.version)
while bytes:
toenc, bytes = bytes[:S], bytes[S:]
packets.append(hdr + struct.pack('>H', len(toenc)) + toenc)
logbytes('Record Layer', ''.join(packets))
return ''.join(packets)
class Plaintext(object):
def __init__(self, record):
self.type = record.type
self.version = record.version
self.record = record
def encode(self):
r = self.record.encode()
logbytes('Plaintext', r)
return r
class Handshake(object):
type = RL_CT_HANDSHAKE
version = (3, 1)
handshakeType = None
def encode(self):
body = self.handshake_encode()
assert len(body) < 2 ** 24
high = len(body) >> 8
low = len(body) & 0xff
r = struct.pack('>BHB', self.handshakeType, high, low) + body
logbytes('Handshake', r)
return r
class ClientHello(Handshake):
version = (3, 1)
handshakeType = HS_CT_CLIENT_HELLO
def __init__(self, version, random, session_id, ciphers, compressors):
self.version = version
self.random = random
self.sessionID = session_id
self.ciphers = ciphers
self.compressors = compressors
def handshake_encode(self):
ver = ''.join(map(chr, self.version))
rand = self.random.encode()
sess = chr(len(self.sessionID)) + self.sessionID
ciph = ''.join([struct.pack('>H', c) for c in self.ciphers])
ciph = struct.pack('>H', len(ciph)) + ciph
comp = chr(len(self.compressors)) + ''.join(map(chr, self.compressors))
r = ver + rand + sess + ciph + comp
logbytes('ClientHello', r)
return r
class ServerHello(ClientHello):
handshakeType = HS_CT_SERVER_HELLO
class Certificate(Handshake):
handshakeType = HS_CT_CERTIFICATE
def handshake_encode(self):
pass
import sys
sys.path.append('../../pahan/statefulprotocol')
from stateful import StatefulProtocol
class RecordProtocol(StatefulProtocol):
currentReadSecurity = None
currentWriteSecurity = None
pendingReadSecurity = None
pendingWriteSecurity = None
def connectionMade(self):
for cp in ('pending', 'current'):
for rw in ('Read', 'Write'):
setattr(self, cp + rw + 'Security', SecurityParameters())
def _write(self, record):
comp = self.currentWriteSecurity.compressionAlgorithm
ciph = self.currentWriteSecurity.bulkEncryptionAlgorithm
bytes = RecordLayer(ciph(comp(record))).encode()
logbytes('Sending', bytes)
self.transport.write(bytes)
def send(self, record):
ctx = {'SecurityParameters': self.currentWriteSecurity}
rec = Plaintext(record)
context.call(ctx, self._write, rec)
class TLSClient(RecordProtocol):
buffer = ''
CONTENT_TYPE_MAP = {chr(20): 'ChangeCipherSpec',
chr(21): 'Alert',
chr(22): 'Handshake',
chr(23): 'ApplicationData'}
def getInitialState(self):
m = self.state_RecordType
return m, m.byteCount
def dataReceived(self, data):
print 'Received', repr(data)
StatefulProtocol.dataReceived(self, data)
def connectionMade(self):
RecordProtocol.connectionMade(self)
sp = getattr(self, cp + rw + 'Security')
cipherSuites = sp.getCipherSuites()
compMethods = sp.getCompressionMethods()
self.send(ClientHello((3, 1), Random(), '', cipherSuites, compMethods))
def state_RecordType(self, data):
method = self.CONTENT_TYPE_MAP[data]
method = getattr(self, 'rt_' + method)
return method, method.byteCount
state_RecordType.byteCount = 1
def rt_ChangeCipherSpec(self, data):
self.changeCipherSpec()
return self.state_RecordType, self.state_RecordType.byteCount
rt_ChangeCipherSpec.byteCount = 1
def rt_Alert(self, data):
print 'Alert'
def rt_Handshake(self, data):
print 'Handshake'
def rt_ApplicationData(self, data):
print 'ApplicationData'
class TLSServerProtocol(RecordProtocol):
CONTENT_TYPE_MAP = {chr(20): 'ChangeCipherSpec',
chr(21): 'Alert',
chr(22): 'Handshake',
chr(23): 'ApplicationData'}
HANDSHAKE_TYPE_MAP = {chr(0): 'HelloRequest',
chr(1): 'ClientHello',
chr(2): 'ServerHello',
chr(11): 'Certificate',
chr(12): 'ServerKeyExchange',
chr(13): 'CertificateRequest',
chr(14): 'ServerHelloDone',
chr(15): 'CertificateVerify',
chr(16): 'ClientKeyExchange',
chr(20): 'HandshakeFinished'}
def dataReceived(self, bytes):
logbytes("Server Received", bytes)
StatefulProtocol.dataReceived(self, bytes)
def getInitialState(self):
m = self.state_RecordTypeAndVersionAndLength
return m, m.byteCount
def state_RecordTypeAndVersionAndLength(self, data):
logbytes("RTAV", data)
recordType = data[0]
self.recordVersion = map(ord, data[1:3])
recordLength = struct.unpack('>H', data[3:])[0]
print 'Version is', self.recordVersion
m = getattr(self, "rt_" + self.CONTENT_TYPE_MAP[recordType])
print 'Next state is', m, recordLength
return m, m.byteCount
state_RecordTypeAndVersionAndLength.byteCount = 5
def rt_Handshake(self, data):
logbytes("Handshake", data)
# Determine the type of handshake record this is
m = getattr(self, 'hs_' + self.HANDSHAKE_TYPE_MAP[data[0]])
bytes = struct.unpack('>I', '\0' + data[1:4])[0]
print 'Next state is', m, bytes
return m, bytes
rt_Handshake.byteCount = 4
def rt_Alert(self, data):
print 'Alert!', data
def hs_HelloRequest(self, data):
logbytes("HelloRequest", data)
print 'whaaaat'
pass
def hs_ClientHello(self, data):
logbytes("ClientHello", data)
fmt = '>BBI28spH'
L = struct.calcsize(fmt)
front, data = data[:L], data[L:]
cv1, cv2, time, random, sessionID, nCiphs = struct.unpack(fmt, front)
logbytes("Ciphers", data)
ciphs, data = data[:nCiphs], data[nCiphs:]
ciphers = [ord(a) << 8 | ord(b) for (a, b) in twos(ciphs)]
nComps = ord(data[0])
compressors = map(ord, data[1:nComps+1])
ch = ClientHello((3, 1), Random(time, random), sessionID, ciphers, compressors)
self.handshakeMessage(ch)
m = self.state_RecordTypeAndVersionAndLength
return m, m.byteCount
class TLSServer(TLSServerProtocol):
sessions = {}
def generateSessionID(self):
return 'sessionID'
def handshakeMessage(self, msg):
f = self.HANDSHAKE_TYPE_MAP[chr(msg.handshakeType)]
return getattr(self, 'handshake_' + f)(msg)
def handshake_ClientHello(self, msg):
if msg.sessionID in self.sessions:
return self.resumeSession(msg)
# XXX Check timestamp
self.pendingWriteSecurity.clientRandom = msg.random.bytes
self.pendingReadSecurity.clientRandom = msg.random.bytes
random = crypt.getRandomBytes(28)
self.pendingWriteSecurity.serverRandom = random
self.pendingReadSecurity.serverRandom = random
sp = self.currentWriteSecurity
sessionID = self.generateSessionID()
cs = sp.getCipherSuites()
cm = sp.getCompressionMethods()
h = ServerHello((3, 1), Random(rbytes=random), sessionID, cs, cm)
self.send(h)
c = Certificate()
self.send(c)
if __name__ == '__main__':
from twisted.internet import ssl
from twisted.internet import reactor
from twisted.internet import protocol
from twisted.python import log
import sys
log.startLogging(sys.stdout)
class ClientContextFactory(ssl.ClientContextFactory):
method = ssl.SSL.TLSv1_METHOD
class HexBytePrintingProtocol(protocol.Protocol):
def dataReceived(self, bytes):
logbytes("Received", bytes)
OpenSSLServerFactory = protocol.ServerFactory()
OpenSSLServerFactory.protocol = protocol.Protocol
OpenSSLClientFactory = protocol.ClientFactory()
OpenSSLClientFactory.protocol = protocol.Protocol
PlainServerFactory = protocol.ServerFactory()
PlainServerFactory.protocol = HexBytePrintingProtocol
PlainClientFactory = protocol.ClientFactory()
PlainClientFactory.protocol = HexBytePrintingProtocol
PythonTLSServerFactory = protocol.ServerFactory()
PythonTLSServerFactory.protocol = TLSServer
PythonTLSClientFactory = protocol.ClientFactory()
PythonTLSClientFactory.protocol = TLSClient
pem = '/home/exarkun/projects/python/Twisted/twisted/test/server.pem'
OpenSSLPort = reactor.listenSSL(0, OpenSSLServerFactory, ssl.DefaultOpenSSLContextFactory(pem, pem), interface='127.0.0.1')
PlainPort = reactor.listenTCP(0, PlainServerFactory)
PythonTLSPort = reactor.listenTCP(0, PythonTLSServerFactory)
OpenSSLConn = reactor.connectSSL('127.0.0.1', PythonTLSPort.getHost()[2], OpenSSLClientFactory, ClientContextFactory())
# PythonTLSConn = reactor.connectTCP('127.0.0.1', OpenSSLPort.getHost()[2], PythonTLSClientFactory)
reactor.callLater(1, reactor.stop)
reactor.run()
syntax highlighted by Code2HTML, v. 0.9.1