# Twisted, the Framework of Your Internet # Copyright (C) 2001 Matthew W. Lefkowitz # # This library is free software; you can redistribute it and/or # modify it under the terms of version 2.1 of the GNU Lesser General Public # License as published by the Free Software Foundation. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library; if not, write to the Free Software # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA """ Test cases for twisted.protocols.smtp module. """ from twisted.trial import unittest import twisted.internet.protocol import twisted.protocols.smtp from twisted import protocols from twisted import internet from twisted.protocols import loopback from twisted.protocols import smtp from twisted.internet import defer, protocol, reactor, interfaces from twisted.test.test_protocols import StringIOWithoutClosing from twisted.python import components from twisted import cred import twisted.cred.error import twisted.cred.portal import twisted.cred.checkers import twisted.cred.credentials try: from ssl_helpers import ClientTLSContext, ServerTLSContext except ImportError: ClientTLSContext = ServerTLSContext = None import re try: from cStringIO import StringIO except ImportError: import StringIO def spameater(*spam, **eggs): return None class DummyMessage: def __init__(self, domain, user): self.domain = domain self.user = user self.buffer = [] def lineReceived(self, line): # Throw away the generated Received: header if not re.match('Received: From yyy.com \(\[.*\]\) by localhost;', line): self.buffer.append(line) def eomReceived(self): message = '\n'.join(self.buffer) + '\n' self.domain.messages[self.user.dest.local].append(message) deferred = defer.Deferred() deferred.callback("saved") return deferred class DummyDomain: def __init__(self, names): self.messages = {} for name in names: self.messages[name] = [] def exists(self, user): if self.messages.has_key(user.dest.local): return defer.succeed(lambda: self.startMessage(user)) return defer.fail(smtp.SMTPBadRcpt(user)) def startMessage(self, user): return DummyMessage(self, user) class SMTPTestCase(unittest.TestCase): messages = [('foo@bar.com', ['foo@baz.com', 'qux@baz.com'], '''\ Subject: urgent\015 \015 Someone set up us the bomb!\015 ''')] mbox = {'foo': ['Subject: urgent\n\nSomeone set up us the bomb!\n']} def setUp(self): self.factory = smtp.SMTPFactory() self.factory.domains = {} self.factory.domains['baz.com'] = DummyDomain(['foo']) self.output = StringIOWithoutClosing() self.transport = internet.protocol.FileWrapper(self.output) def testMessages(self): from twisted.mail import protocols protocol = protocols.DomainSMTP() protocol.service = self.factory protocol.factory = self.factory protocol.receivedHeader = spameater protocol.makeConnection(self.transport) protocol.lineReceived('HELO yyy.com') for message in self.messages: protocol.lineReceived('MAIL FROM:<%s>' % message[0]) for target in message[1]: protocol.lineReceived('RCPT TO:<%s>' % target) protocol.lineReceived('DATA') protocol.dataReceived(message[2]) protocol.lineReceived('.') protocol.lineReceived('QUIT') if self.mbox != self.factory.domains['baz.com'].messages: raise AssertionError(self.factory.domains['baz.com'].messages) protocol.setTimeout(None) mail = '''\ Subject: hello Goodbye ''' class MyClient: def __init__(self): self.mail = 'moshez@foo.bar', ['moshez@foo.bar'], mail def getMailFrom(self): return self.mail[0] def getMailTo(self): return self.mail[1] def getMailData(self): return StringIO(self.mail[2]) def sentMail(self, code, resp, numOk, addresses, log): self.mail = None, None, None class MySMTPClient(MyClient, smtp.SMTPClient): def __init__(self): smtp.SMTPClient.__init__(self, 'foo.baz') MyClient.__init__(self) class MyESMTPClient(MyClient, smtp.ESMTPClient): def __init__(self, secret = '', contextFactory = None): smtp.ESMTPClient.__init__(self, secret, contextFactory, 'foo.baz') MyClient.__init__(self) class LoopbackMixin: def loopback(self, server, client): loopback.loopbackTCP(server, client) class LoopbackTestCase(LoopbackMixin): def testMessages(self): factory = smtp.SMTPFactory() factory.domains = {} factory.domains['foo.bar'] = DummyDomain(['moshez']) from twisted.mail.protocols import DomainSMTP protocol = DomainSMTP() protocol.service = factory protocol.factory = factory clientProtocol = self.clientClass() self.loopback(protocol, clientProtocol) class LoopbackSMTPTestCase(LoopbackTestCase, unittest.TestCase): clientClass = MySMTPClient class LoopbackESMTPTestCase(LoopbackTestCase, unittest.TestCase): clientClass = MyESMTPClient class FakeSMTPServer(protocols.basic.LineReceiver): clientData = [ '220 hello', '250 nice to meet you', '250 great', '250 great', '354 go on, lad' ] def connectionMade(self): self.buffer = [] self.clientData = self.clientData[:] self.clientData.reverse() self.sendLine(self.clientData.pop()) def lineReceived(self, line): self.buffer.append(line) if line == "QUIT": self.transport.write("221 see ya around\r\n") self.transport.loseConnection() elif line == ".": self.transport.write("250 gotcha\r\n") elif line == "RSET": self.transport.loseConnection() if self.clientData: self.sendLine(self.clientData.pop()) class SMTPClientTestCase(unittest.TestCase, LoopbackMixin): expected_output = [ 'HELO foo.baz', 'MAIL FROM:', 'RCPT TO:', 'DATA', 'Subject: hello', '', 'Goodbye', '.', 'RSET' ] def testMessages(self): # this test is disabled temporarily client = MySMTPClient() server = FakeSMTPServer() self.loopback(server, client) self.assertEquals(server.buffer, self.expected_output) class DummySMTPMessage: def __init__(self, protocol, users): self.protocol = protocol self.users = users self.buffer = [] def lineReceived(self, line): self.buffer.append(line) def eomReceived(self): message = '\n'.join(self.buffer) + '\n' helo, origin = self.users[0].helo[0], str(self.users[0].orig) recipients = [] for user in self.users: recipients.append(str(user)) self.protocol.message[tuple(recipients)] = (helo, origin, recipients, message) return defer.succeed("saved") deferred.callback("saved") return deferred class DummyProto: def connectionMade(self): self.dummyMixinBase.connectionMade(self) self.message = {} def startMessage(self, users): return DummySMTPMessage(self, users) def receivedHeader(*spam): return None def validateTo(self, user): self.delivery = DummyDelivery() return lambda: self.startMessage([user]) def validateFrom(self, helo, origin): return origin class DummySMTP(DummyProto, smtp.SMTP): dummyMixinBase = smtp.SMTP class DummyESMTP(DummyProto, smtp.ESMTP): dummyMixinBase = smtp.ESMTP class AnotherTestCase: serverClass = None clientClass = None messages = [ ('foo.com', 'moshez@foo.com', ['moshez@bar.com'], 'moshez@foo.com', ['moshez@bar.com'], '''\ From: Moshe To: Moshe Hi, how are you? '''), ('foo.com', 'tttt@rrr.com', ['uuu@ooo', 'yyy@eee'], 'tttt@rrr.com', ['uuu@ooo', 'yyy@eee'], '''\ Subject: pass ..rrrr.. '''), ('foo.com', '@this,@is,@ignored:foo@bar.com', ['@ignore,@this,@too:bar@foo.com'], 'foo@bar.com', ['bar@foo.com'], '''\ Subject: apa To: foo 123 . 456 '''), ] data = [ ('', '220.*\r\n$', None, None), ('HELO foo.com\r\n', '250.*\r\n$', None, None), ('RSET\r\n', '250.*\r\n$', None, None), ] for helo_, from_, to_, realfrom, realto, msg in messages: data.append(('MAIL FROM:<%s>\r\n' % from_, '250.*\r\n', None, None)) for rcpt in to_: data.append(('RCPT TO:<%s>\r\n' % rcpt, '250.*\r\n', None, None)) data.append(('DATA\r\n','354.*\r\n', msg, ('250.*\r\n', (helo_, realfrom, realto, msg)))) def testBuffer(self): output = StringIOWithoutClosing() a = self.serverClass() class fooFactory: domain = 'foo.com' a.factory = fooFactory() a.makeConnection(protocol.FileWrapper(output)) for (send, expect, msg, msgexpect) in self.data: if send: a.dataReceived(send) data = output.getvalue() output.truncate(0) if not re.match(expect, data): raise AssertionError, (send, expect, data) if data[:3] == '354': for line in msg.splitlines(): if line and line[0] == '.': line = '.' + line a.dataReceived(line + '\r\n') a.dataReceived('.\r\n') # Special case for DATA. Now we want a 250, and then # we compare the messages data = output.getvalue() output.truncate() resp, msgdata = msgexpect if not re.match(resp, data): raise AssertionError, (resp, data) for recip in msgdata[2]: expected = list(msgdata[:]) expected[2] = [recip] self.assertEquals( a.message[(recip,)], tuple(expected) ) a.setTimeout(None) class AnotherESMTPTestCase(AnotherTestCase, unittest.TestCase): serverClass = DummyESMTP clientClass = MyESMTPClient class AnotherSMTPTestCase(AnotherTestCase, unittest.TestCase): serverClass = DummySMTP clientClass = MySMTPClient # XXX - These need to be moved from twisted.protocols import imap4 class DummyChecker: __implements__ = (cred.checkers.ICredentialsChecker,) users = { 'testuser': 'testpassword' } credentialInterfaces = (cred.credentials.IUsernameHashedPassword,) def requestAvatarId(self, credentials): return defer.maybeDeferred( credentials.checkPassword, self.users[credentials.username] ).addCallback(self._cbCheck, credentials.username) def _cbCheck(self, result, username): if result: return username raise cred.error.UnauthorizedLogin() class DummyDelivery: __implements__ = (smtp.IMessageDelivery,) def validateTo(self, user): return user def validateFrom(self, helo, origin): return origin def receivedHeader(*args): return None class DummyRealm: def requestAvatar(self, avatarId, mind, *interfaces): return smtp.IMessageDelivery, DummyDelivery(), lambda: None class AuthTestCase(unittest.TestCase, LoopbackMixin): def testAuth(self): realm = DummyRealm() p = cred.portal.Portal(realm) p.registerChecker(DummyChecker()) server = DummyESMTP({'CRAM-MD5': cred.credentials.CramMD5Credentials}) server.portal = p client = MyESMTPClient('testpassword') cAuth = imap4.CramMD5ClientAuthenticator('testuser') client.registerAuthenticator(cAuth) self.loopback(server, client) self.assertEquals(server.authenticated, 1) class SMTPHelperTestCase(unittest.TestCase): def testMessageID(self): d = {} for i in range(1000): m = smtp.messageid('testcase') self.failIf(m in d) d[m] = None def testQuoteAddr(self): cases = [ ['user@host.name', ''], ['"User Name" ', ''], [smtp.Address('someguy@someplace'), ''], ] for (c, e) in cases: self.assertEquals(smtp.quoteaddr(c), e) def testUser(self): u = smtp.User('user@host', 'helo.host.name', None, None) self.assertEquals(str(u), 'user@host') def testXtextEncoding(self): cases = [ ('Hello world', 'Hello+20world'), ('Hello+world', 'Hello+2Bworld'), ('\0\1\2\3\4\5', '+00+01+02+03+04+05'), ('e=mc2@example.com', 'e+3Dmc2@example.com') ] for (case, expected) in cases: self.assertEquals(case.encode('xtext'), expected) self.assertEquals(expected.decode('xtext'), case) class NoticeTLSClient(MyESMTPClient): tls = False def esmtpState_starttls(self, code, resp): MyESMTPClient.esmtpState_starttls(self, code, resp) self.tls = True class TLSTestCase(unittest.TestCase, LoopbackMixin): def testTLS(self): clientCTX = ClientTLSContext() serverCTX = ServerTLSContext() client = NoticeTLSClient(contextFactory=clientCTX) server = DummyESMTP(contextFactory=serverCTX) self.loopback(server, client) self.assertEquals(client.tls, True) self.assertEquals(server.startedTLS, True) if ClientTLSContext is None: for case in (TLSTestCase,): case.skip = "OpenSSL not present" if not components.implements(reactor, interfaces.IReactorSSL): for case in (TLSTestCase,): case.skip = "Reactor doesn't support SSL" class EmptyLineTestCase(unittest.TestCase): def testEmptyLineSyntaxError(self): proto = smtp.SMTP() output = StringIOWithoutClosing() transport = internet.protocol.FileWrapper(output) proto.makeConnection(transport) proto.lineReceived('') proto.setTimeout(None) out = output.getvalue().splitlines() self.assertEquals(len(out), 2) self.failUnless(out[0].startswith('220')) self.assertEquals(out[1], "500 Error: bad syntax")