# Twisted, the Framework of Your Internet
# Copyright (C) 2001 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
"""
Test cases for twisted.protocols.smtp module.
"""
from twisted.trial import unittest
import twisted.internet.protocol
import twisted.protocols.smtp
from twisted import protocols
from twisted import internet
from twisted.protocols import loopback
from twisted.protocols import smtp
from twisted.internet import defer, protocol, reactor, interfaces
from twisted.test.test_protocols import StringIOWithoutClosing
from twisted.python import components
from twisted import cred
import twisted.cred.error
import twisted.cred.portal
import twisted.cred.checkers
import twisted.cred.credentials
try:
from ssl_helpers import ClientTLSContext, ServerTLSContext
except ImportError:
ClientTLSContext = ServerTLSContext = None
import re
try:
from cStringIO import StringIO
except ImportError:
import StringIO
def spameater(*spam, **eggs):
return None
class DummyMessage:
def __init__(self, domain, user):
self.domain = domain
self.user = user
self.buffer = []
def lineReceived(self, line):
# Throw away the generated Received: header
if not re.match('Received: From yyy.com \(\[.*\]\) by localhost;', line):
self.buffer.append(line)
def eomReceived(self):
message = '\n'.join(self.buffer) + '\n'
self.domain.messages[self.user.dest.local].append(message)
deferred = defer.Deferred()
deferred.callback("saved")
return deferred
class DummyDomain:
def __init__(self, names):
self.messages = {}
for name in names:
self.messages[name] = []
def exists(self, user):
if self.messages.has_key(user.dest.local):
return defer.succeed(lambda: self.startMessage(user))
return defer.fail(smtp.SMTPBadRcpt(user))
def startMessage(self, user):
return DummyMessage(self, user)
class SMTPTestCase(unittest.TestCase):
messages = [('foo@bar.com', ['foo@baz.com', 'qux@baz.com'], '''\
Subject: urgent\015
\015
Someone set up us the bomb!\015
''')]
mbox = {'foo': ['Subject: urgent\n\nSomeone set up us the bomb!\n']}
def setUp(self):
self.factory = smtp.SMTPFactory()
self.factory.domains = {}
self.factory.domains['baz.com'] = DummyDomain(['foo'])
self.output = StringIOWithoutClosing()
self.transport = internet.protocol.FileWrapper(self.output)
def testMessages(self):
from twisted.mail import protocols
protocol = protocols.DomainSMTP()
protocol.service = self.factory
protocol.factory = self.factory
protocol.receivedHeader = spameater
protocol.makeConnection(self.transport)
protocol.lineReceived('HELO yyy.com')
for message in self.messages:
protocol.lineReceived('MAIL FROM:<%s>' % message[0])
for target in message[1]:
protocol.lineReceived('RCPT TO:<%s>' % target)
protocol.lineReceived('DATA')
protocol.dataReceived(message[2])
protocol.lineReceived('.')
protocol.lineReceived('QUIT')
if self.mbox != self.factory.domains['baz.com'].messages:
raise AssertionError(self.factory.domains['baz.com'].messages)
protocol.setTimeout(None)
mail = '''\
Subject: hello
Goodbye
'''
class MyClient:
def __init__(self):
self.mail = 'moshez@foo.bar', ['moshez@foo.bar'], mail
def getMailFrom(self):
return self.mail[0]
def getMailTo(self):
return self.mail[1]
def getMailData(self):
return StringIO(self.mail[2])
def sentMail(self, code, resp, numOk, addresses, log):
self.mail = None, None, None
class MySMTPClient(MyClient, smtp.SMTPClient):
def __init__(self):
smtp.SMTPClient.__init__(self, 'foo.baz')
MyClient.__init__(self)
class MyESMTPClient(MyClient, smtp.ESMTPClient):
def __init__(self, secret = '', contextFactory = None):
smtp.ESMTPClient.__init__(self, secret, contextFactory, 'foo.baz')
MyClient.__init__(self)
class LoopbackMixin:
def loopback(self, server, client):
loopback.loopbackTCP(server, client)
class LoopbackTestCase(LoopbackMixin):
def testMessages(self):
factory = smtp.SMTPFactory()
factory.domains = {}
factory.domains['foo.bar'] = DummyDomain(['moshez'])
from twisted.mail.protocols import DomainSMTP
protocol = DomainSMTP()
protocol.service = factory
protocol.factory = factory
clientProtocol = self.clientClass()
self.loopback(protocol, clientProtocol)
class LoopbackSMTPTestCase(LoopbackTestCase, unittest.TestCase):
clientClass = MySMTPClient
class LoopbackESMTPTestCase(LoopbackTestCase, unittest.TestCase):
clientClass = MyESMTPClient
class FakeSMTPServer(protocols.basic.LineReceiver):
clientData = [
'220 hello', '250 nice to meet you',
'250 great', '250 great', '354 go on, lad'
]
def connectionMade(self):
self.buffer = []
self.clientData = self.clientData[:]
self.clientData.reverse()
self.sendLine(self.clientData.pop())
def lineReceived(self, line):
self.buffer.append(line)
if line == "QUIT":
self.transport.write("221 see ya around\r\n")
self.transport.loseConnection()
elif line == ".":
self.transport.write("250 gotcha\r\n")
elif line == "RSET":
self.transport.loseConnection()
if self.clientData:
self.sendLine(self.clientData.pop())
class SMTPClientTestCase(unittest.TestCase, LoopbackMixin):
expected_output = [
'HELO foo.baz', 'MAIL FROM:<moshez@foo.bar>',
'RCPT TO:<moshez@foo.bar>', 'DATA',
'Subject: hello', '', 'Goodbye', '.', 'RSET'
]
def testMessages(self):
# this test is disabled temporarily
client = MySMTPClient()
server = FakeSMTPServer()
self.loopback(server, client)
self.assertEquals(server.buffer, self.expected_output)
class DummySMTPMessage:
def __init__(self, protocol, users):
self.protocol = protocol
self.users = users
self.buffer = []
def lineReceived(self, line):
self.buffer.append(line)
def eomReceived(self):
message = '\n'.join(self.buffer) + '\n'
helo, origin = self.users[0].helo[0], str(self.users[0].orig)
recipients = []
for user in self.users:
recipients.append(str(user))
self.protocol.message[tuple(recipients)] = (helo, origin, recipients, message)
return defer.succeed("saved")
deferred.callback("saved")
return deferred
class DummyProto:
def connectionMade(self):
self.dummyMixinBase.connectionMade(self)
self.message = {}
def startMessage(self, users):
return DummySMTPMessage(self, users)
def receivedHeader(*spam):
return None
def validateTo(self, user):
self.delivery = DummyDelivery()
return lambda: self.startMessage([user])
def validateFrom(self, helo, origin):
return origin
class DummySMTP(DummyProto, smtp.SMTP):
dummyMixinBase = smtp.SMTP
class DummyESMTP(DummyProto, smtp.ESMTP):
dummyMixinBase = smtp.ESMTP
class AnotherTestCase:
serverClass = None
clientClass = None
messages = [ ('foo.com', 'moshez@foo.com', ['moshez@bar.com'],
'moshez@foo.com', ['moshez@bar.com'], '''\
From: Moshe
To: Moshe
Hi,
how are you?
'''),
('foo.com', 'tttt@rrr.com', ['uuu@ooo', 'yyy@eee'],
'tttt@rrr.com', ['uuu@ooo', 'yyy@eee'], '''\
Subject: pass
..rrrr..
'''),
('foo.com', '@this,@is,@ignored:foo@bar.com',
['@ignore,@this,@too:bar@foo.com'],
'foo@bar.com', ['bar@foo.com'], '''\
Subject: apa
To: foo
123
.
456
'''),
]
data = [
('', '220.*\r\n$', None, None),
('HELO foo.com\r\n', '250.*\r\n$', None, None),
('RSET\r\n', '250.*\r\n$', None, None),
]
for helo_, from_, to_, realfrom, realto, msg in messages:
data.append(('MAIL FROM:<%s>\r\n' % from_, '250.*\r\n',
None, None))
for rcpt in to_:
data.append(('RCPT TO:<%s>\r\n' % rcpt, '250.*\r\n',
None, None))
data.append(('DATA\r\n','354.*\r\n',
msg, ('250.*\r\n',
(helo_, realfrom, realto, msg))))
def testBuffer(self):
output = StringIOWithoutClosing()
a = self.serverClass()
class fooFactory:
domain = 'foo.com'
a.factory = fooFactory()
a.makeConnection(protocol.FileWrapper(output))
for (send, expect, msg, msgexpect) in self.data:
if send:
a.dataReceived(send)
data = output.getvalue()
output.truncate(0)
if not re.match(expect, data):
raise AssertionError, (send, expect, data)
if data[:3] == '354':
for line in msg.splitlines():
if line and line[0] == '.':
line = '.' + line
a.dataReceived(line + '\r\n')
a.dataReceived('.\r\n')
# Special case for DATA. Now we want a 250, and then
# we compare the messages
data = output.getvalue()
output.truncate()
resp, msgdata = msgexpect
if not re.match(resp, data):
raise AssertionError, (resp, data)
for recip in msgdata[2]:
expected = list(msgdata[:])
expected[2] = [recip]
self.assertEquals(
a.message[(recip,)],
tuple(expected)
)
a.setTimeout(None)
class AnotherESMTPTestCase(AnotherTestCase, unittest.TestCase):
serverClass = DummyESMTP
clientClass = MyESMTPClient
class AnotherSMTPTestCase(AnotherTestCase, unittest.TestCase):
serverClass = DummySMTP
clientClass = MySMTPClient
# XXX - These need to be moved
from twisted.protocols import imap4
class DummyChecker:
__implements__ = (cred.checkers.ICredentialsChecker,)
users = {
'testuser': 'testpassword'
}
credentialInterfaces = (cred.credentials.IUsernameHashedPassword,)
def requestAvatarId(self, credentials):
return defer.maybeDeferred(
credentials.checkPassword, self.users[credentials.username]
).addCallback(self._cbCheck, credentials.username)
def _cbCheck(self, result, username):
if result:
return username
raise cred.error.UnauthorizedLogin()
class DummyDelivery:
__implements__ = (smtp.IMessageDelivery,)
def validateTo(self, user):
return user
def validateFrom(self, helo, origin):
return origin
def receivedHeader(*args):
return None
class DummyRealm:
def requestAvatar(self, avatarId, mind, *interfaces):
return smtp.IMessageDelivery, DummyDelivery(), lambda: None
class AuthTestCase(unittest.TestCase, LoopbackMixin):
def testAuth(self):
realm = DummyRealm()
p = cred.portal.Portal(realm)
p.registerChecker(DummyChecker())
server = DummyESMTP({'CRAM-MD5': cred.credentials.CramMD5Credentials})
server.portal = p
client = MyESMTPClient('testpassword')
cAuth = imap4.CramMD5ClientAuthenticator('testuser')
client.registerAuthenticator(cAuth)
self.loopback(server, client)
self.assertEquals(server.authenticated, 1)
class SMTPHelperTestCase(unittest.TestCase):
def testMessageID(self):
d = {}
for i in range(1000):
m = smtp.messageid('testcase')
self.failIf(m in d)
d[m] = None
def testQuoteAddr(self):
cases = [
['user@host.name', '<user@host.name>'],
['"User Name" <user@host.name>', '<user@host.name>'],
[smtp.Address('someguy@someplace'), '<someguy@someplace>'],
]
for (c, e) in cases:
self.assertEquals(smtp.quoteaddr(c), e)
def testUser(self):
u = smtp.User('user@host', 'helo.host.name', None, None)
self.assertEquals(str(u), 'user@host')
def testXtextEncoding(self):
cases = [
('Hello world', 'Hello+20world'),
('Hello+world', 'Hello+2Bworld'),
('\0\1\2\3\4\5', '+00+01+02+03+04+05'),
('e=mc2@example.com', 'e+3Dmc2@example.com')
]
for (case, expected) in cases:
self.assertEquals(case.encode('xtext'), expected)
self.assertEquals(expected.decode('xtext'), case)
class NoticeTLSClient(MyESMTPClient):
tls = False
def esmtpState_starttls(self, code, resp):
MyESMTPClient.esmtpState_starttls(self, code, resp)
self.tls = True
class TLSTestCase(unittest.TestCase, LoopbackMixin):
def testTLS(self):
clientCTX = ClientTLSContext()
serverCTX = ServerTLSContext()
client = NoticeTLSClient(contextFactory=clientCTX)
server = DummyESMTP(contextFactory=serverCTX)
self.loopback(server, client)
self.assertEquals(client.tls, True)
self.assertEquals(server.startedTLS, True)
if ClientTLSContext is None:
for case in (TLSTestCase,):
case.skip = "OpenSSL not present"
if not components.implements(reactor, interfaces.IReactorSSL):
for case in (TLSTestCase,):
case.skip = "Reactor doesn't support SSL"
class EmptyLineTestCase(unittest.TestCase):
def testEmptyLineSyntaxError(self):
proto = smtp.SMTP()
output = StringIOWithoutClosing()
transport = internet.protocol.FileWrapper(output)
proto.makeConnection(transport)
proto.lineReceived('')
proto.setTimeout(None)
out = output.getvalue().splitlines()
self.assertEquals(len(out), 2)
self.failUnless(out[0].startswith('220'))
self.assertEquals(out[1], "500 Error: bad syntax")
syntax highlighted by Code2HTML, v. 0.9.1