# Twisted, the Framework of Your Internet # Copyright (C) 2001-2002 Matthew W. Lefkowitz # # This library is free software; you can redistribute it and/or # modify it under the terms of version 2.1 of the GNU Lesser General Public # License as published by the Free Software Foundation. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library; if not, write to the Free Software # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA # """Tests for twisted.enterprise.""" from twisted.trial import unittest import os import stat import random import tempfile from twisted.enterprise.row import RowObject from twisted.enterprise.reflector import * from twisted.enterprise.xmlreflector import XMLReflector from twisted.enterprise.sqlreflector import SQLReflector from twisted.enterprise.adbapi import ConnectionPool from twisted.enterprise import util from twisted.internet import defer from twisted.trial.util import deferredResult, deferredError from twisted.python import log try: import gadfly except: gadfly = None try: import sqlite except: sqlite = None try: from pyPgSQL import PgSQL except: PgSQL = None try: import MySQLdb except: MySQLdb = None try: import psycopg except: psycopg = None try: import kinterbasdb except: kinterbasdb = None tableName = "testTable" childTableName = "childTable" class TestRow(RowObject): rowColumns = [("key_string", "varchar"), ("col2", "int"), ("another_column", "varchar"), ("Column4", "varchar"), ("column_5_", "int")] rowKeyColumns = [("key_string", "varchar")] rowTableName = tableName class ChildRow(RowObject): rowColumns = [("childId", "int"), ("foo", "varchar"), ("test_key", "varchar"), ("stuff", "varchar"), ("gogogo", "int"), ("data", "varchar")] rowKeyColumns = [("childId", "int")] rowTableName = childTableName rowForeignKeys = [(tableName, [("test_key","varchar")], [("key_string","varchar")], None, 1)] main_table_schema = """ CREATE TABLE testTable ( key_string varchar(64), col2 integer, another_column varchar(64), Column4 varchar(64), column_5_ integer ) """ child_table_schema = """ CREATE TABLE childTable ( childId integer, foo varchar(64), test_key varchar(64), stuff varchar(64), gogogo integer, data varchar(64) ) """ simple_table_schema = """ CREATE TABLE simple ( x integer ) """ def randomizeRow(row, nullsOK=1, trailingSpacesOK=1): values = {} for name, type in row.rowColumns: if util.getKeyColumn(row, name): values[name] = getattr(row, name) continue elif nullsOK and random.randint(0, 9) == 0: value = None # null elif type == 'int': value = random.randint(-10000, 10000) else: if random.randint(0, 9) == 0: value = '' else: value = ''.join(map(lambda i:chr(random.randrange(32,127)), xrange(random.randint(1, 64)))) if not trailingSpacesOK: value = value.rstrip() setattr(row, name, value) values[name] = value return values def rowMatches(row, values): for name, type in row.rowColumns: if getattr(row, name) != values[name]: print ("Mismatch on column %s: |%s| (row) |%s| (values)" % (name, getattr(row, name), values[name])) return return 1 class ReflectorTestCase: """Base class for testing reflectors. Subclass and implement createReflector for the style and db you want to test. This may involve creating a new database, starting a server, etc. If createReflector returns None, the test is skipped. This allows subclasses to test for the presence of the database libraries and silently skip the test if they are not present. Implement destroyReflector if your database needs to be shutdown afterwards. """ count = 100 # a parameter used for running iterative tests nullsOK = 1 # we can put nulls into the db trailingSpacesOK = 1 # we can put strings with trailing spaces into the db def randomizeRow(self, row): return randomizeRow(row, self.nullsOK, self.trailingSpacesOK) def setUp(self): self.reflector = self.createReflector() def tearDown(self): self.destroyReflector() def destroyReflector(self): pass def testReflector(self): # create one row to work with row = TestRow() row.assignKeyAttr("key_string", "first") values = self.randomizeRow(row) # save it deferredResult(self.reflector.insertRow(row)) # now load it back in whereClause = [("key_string", EQUAL, "first")] d = self.reflector.loadObjectsFrom(tableName, whereClause=whereClause) d.addCallback(self.gotData) deferredResult(d) # make sure it came back as what we saved self.failUnless(len(self.data) == 1, "no row") parent = self.data[0] self.failUnless(rowMatches(parent, values), "no match") # create some child rows child_values = {} for i in range(0, self.count): row = ChildRow() row.assignKeyAttr("childId", i) values = self.randomizeRow(row) values['test_key'] = row.test_key = "first" child_values[i] = values deferredResult(self.reflector.insertRow(row)) row = None d = self.reflector.loadObjectsFrom(childTableName, parentRow=parent) d.addCallback(self.gotData) deferredResult(d) self.failUnless(len(self.data) == self.count, "no rows on query") self.failUnless(len(parent.childRows) == self.count, "did not load child rows: %d" % len(parent.childRows)) for child in parent.childRows: self.failUnless(rowMatches(child, child_values[child.childId]), "child %d does not match" % child.childId) # loading these objects a second time should not re-add them # to the parentRow. d = self.reflector.loadObjectsFrom(childTableName, parentRow=parent) d.addCallback(self.gotData) deferredResult(d) self.failUnless(len(self.data) == self.count, "no rows on query") self.failUnless(len(parent.childRows) == self.count, "child rows added twice!: %d" % len(parent.childRows)) # now change the parent values = self.randomizeRow(parent) deferredResult(self.reflector.updateRow(parent)) parent = None # now load it back in whereClause = [("key_string", EQUAL, "first")] d = self.reflector.loadObjectsFrom(tableName, whereClause=whereClause) d.addCallback(self.gotData) deferredResult(d) # make sure it came back as what we saved self.failUnless(len(self.data) == 1, "no row") parent = self.data[0] self.failUnless(rowMatches(parent, values), "no match") # save parent test_values = {} test_values[parent.key_string] = values parent = None # save some more test rows for i in range(0, self.count): row = TestRow() row.assignKeyAttr("key_string", "bulk%d"%i) test_values[row.key_string] = self.randomizeRow(row) deferredResult(self.reflector.insertRow(row)) row = None # now load them all back in d = self.reflector.loadObjectsFrom("testTable") d.addCallback(self.gotData) deferredResult(d) # make sure they are the same self.failUnless(len(self.data) == self.count + 1, "query did not get rows") for row in self.data: self.failUnless(rowMatches(row, test_values[row.key_string]), "child %s does not match" % row.key_string) # now change them all for row in self.data: test_values[row.key_string] = self.randomizeRow(row) deferredResult(self.reflector.updateRow(row)) self.data = None # load'em back d = self.reflector.loadObjectsFrom("testTable") d.addCallback(self.gotData) deferredResult(d) # make sure they are the same self.failUnless(len(self.data) == self.count + 1, "query did not get rows") for row in self.data: self.failUnless(rowMatches(row, test_values[row.key_string]), "child %s does not match" % row.key_string) # now delete them for row in self.data: deferredResult(self.reflector.deleteRow(row)) self.data = None # load'em back d = self.reflector.loadObjectsFrom("testTable") d.addCallback(self.gotData) deferredResult(d) self.failUnless(len(self.data) == 0, "rows were not deleted") # create one row to work with row = TestRow() row.assignKeyAttr("key_string", "first") values = self.randomizeRow(row) # save it deferredResult(self.reflector.insertRow(row)) # delete it deferredResult(self.reflector.deleteRow(row)) def gotData(self, data): self.data = data class XMLReflectorTestCase(ReflectorTestCase, unittest.TestCase): """Test cases for the XML reflector. """ count = 10 # xmlreflector is slow DB = "./xmlDB" def createReflector(self): return XMLReflector(self.DB, [TestRow, ChildRow]) class SQLReflectorTestCase(ReflectorTestCase): """Test cases for the SQL reflector. To enable this test for databases which use a central, system database, you must create a database named DB_NAME with a user DB_USER and password DB_PASS with full access rights to the database DB_NAME. """ DB_NAME = "twisted_test" DB_USER = 'twisted_test' DB_PASS = 'twisted_test' can_rollback = 1 test_failures = 1 reflectorClass = SQLReflector def createReflector(self): self.startDB() self.dbpool = self.makePool() self.dbpool.start() deferredResult(self.dbpool.runOperation(main_table_schema)) deferredResult(self.dbpool.runOperation(child_table_schema)) deferredResult(self.dbpool.runOperation(simple_table_schema)) return self.reflectorClass(self.dbpool, [TestRow, ChildRow]) def destroyReflector(self): deferredResult(self.dbpool.runOperation('DROP TABLE testTable')) deferredResult(self.dbpool.runOperation('DROP TABLE childTable')) deferredResult(self.dbpool.runOperation('DROP TABLE simple')) self.dbpool.close() self.stopDB() def testPool(self): if self.test_failures: # make sure failures are raised correctly deferredError(self.dbpool.runQuery("select * from NOTABLE")) deferredError(self.dbpool.runOperation("deletexxx from NOTABLE")) deferredError(self.dbpool.runInteraction(self.bad_interaction)) log.flushErrors() # verify simple table is empty sql = "select count(1) from simple" row = deferredResult(self.dbpool.runQuery(sql)) self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back") # add some rows to simple table (runOperation) for i in range(self.count): sql = "insert into simple(x) values(%d)" % i deferredResult(self.dbpool.runOperation(sql)) # make sure they were added (runQuery) sql = "select x from simple order by x"; rows = deferredResult(self.dbpool.runQuery(sql)) self.failUnless(len(rows) == self.count, "Wrong number of rows") for i in range(self.count): self.failUnless(len(rows[i]) == 1, "Wrong size row") self.failUnless(rows[i][0] == i, "Values not returned.") # runInteraction self.assertEquals(deferredResult(self.dbpool.runInteraction(self.interaction)), "done") # give the pool a workout ds = [] for i in range(self.count): sql = "select x from simple where x = %d" % i ds.append(self.dbpool.runQuery(sql)) dlist = defer.DeferredList(ds, fireOnOneErrback=1) result = deferredResult(dlist) for i in range(self.count): self.failUnless(result[i][1][0][0] == i, "Value not returned") # now delete everything ds = [] for i in range(self.count): sql = "delete from simple where x = %d" % i ds.append(self.dbpool.runOperation(sql)) dlist = defer.DeferredList(ds, fireOnOneErrback=1) deferredResult(dlist) # verify simple table is empty sql = "select count(1) from simple" row = deferredResult(self.dbpool.runQuery(sql)) self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back") def interaction(self, transaction): transaction.execute("select x from simple order by x") for i in range(self.count): row = transaction.fetchone() self.failUnless(len(row) == 1, "Wrong size row") self.failUnless(row[0] == i, "Value not returned.") # should test this, but gadfly throws an exception instead #self.failUnless(transaction.fetchone() is None, "Too many rows") return "done" def bad_interaction(self, transaction): if self.can_rollback: transaction.execute("insert into simple(x) values(0)") transaction.execute("select * from NOTABLE") def startDB(self): pass def stopDB(self): pass class NoSlashSQLReflector(SQLReflector): def escape_string(self, text): return text.replace("'", "''") class GadflyTestCase(SQLReflectorTestCase, unittest.TestCase): """Test cases for the SQL reflector using Gadfly. """ count = 10 # gadfly is slow nullsOK = 0 DB_DIR = "./gadflyDB" reflectorClass = NoSlashSQLReflector can_rollback = 0 def startDB(self): if not os.path.exists(self.DB_DIR): os.mkdir(self.DB_DIR) conn = gadfly.gadfly() conn.startup(self.DB_NAME, self.DB_DIR) # gadfly seems to want us to create something to get the db going cursor = conn.cursor() cursor.execute("create table x (x integer)") conn.commit() conn.close() def makePool(self): return ConnectionPool('gadfly', self.DB_NAME, self.DB_DIR, cp_max=1) class SQLiteTestCase(SQLReflectorTestCase, unittest.TestCase): """Test cases for the SQL reflector using SQLite. """ DB_DIR = "./sqliteDB" reflectorClass = NoSlashSQLReflector def startDB(self): if not os.path.exists(self.DB_DIR): os.mkdir(self.DB_DIR) self.database = os.path.join(self.DB_DIR, self.DB_NAME) if os.path.exists(self.database): os.unlink(self.database) def makePool(self): return ConnectionPool('sqlite', database=self.database, cp_max=1) class PostgresTestCase(SQLReflectorTestCase, unittest.TestCase): """Test cases for the SQL reflector using Postgres. """ def makePool(self): return ConnectionPool('pyPgSQL.PgSQL', database=self.DB_NAME, user=self.DB_USER, password=self.DB_PASS, cp_min=0) class PsycopgTestCase(SQLReflectorTestCase, unittest.TestCase): """Test cases for the SQL reflector using psycopg for Postgres. """ def makePool(self): return ConnectionPool('psycopg', database=self.DB_NAME, user=self.DB_USER, password=self.DB_PASS, cp_min=0) class MySQLTestCase(SQLReflectorTestCase, unittest.TestCase): """Test cases for the SQL reflector using MySQL. """ trailingSpacesOK = 0 can_rollback = 0 def makePool(self): return ConnectionPool('MySQLdb', db=self.DB_NAME, user=self.DB_USER, passwd=self.DB_PASS) class FirebirdTestCase(SQLReflectorTestCase, unittest.TestCase): """Test cases for the SQL reflector using Firebird/Interbase.""" count = 2 # CHANGEME test_failures = 0 # failure testing causes problems reflectorClass = NoSlashSQLReflector DB_DIR = tempfile.mktemp() DB_NAME = os.path.join(DB_DIR, SQLReflectorTestCase.DB_NAME) def startDB(self): os.chmod(self.DB_DIR, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO) sql = 'create database "%s" user "%s" password "%s"' sql %= (self.DB_NAME, self.DB_USER, self.DB_PASS); conn = kinterbasdb.create_database(sql) conn.close() os.chmod(self.DB_NAME, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO) def makePool(self): return ConnectionPool('kinterbasdb', database=self.DB_NAME, host='localhost', user=self.DB_USER, password=self.DB_PASS) def stopDB(self): conn = kinterbasdb.connect(database=self.DB_NAME, host='localhost', user=self.DB_USER, password=self.DB_PASS) conn.drop_database() conn.close() class QuotingTestCase(unittest.TestCase): def testQuoting(self): for value, typ, expected in [ (12, "integer", "12"), ("foo'd", "text", "'foo''d'"), ("\x00abc\\s\xFF", "bytea", "'\\\\000abc\\\\\\\\s\\377'"), ]: self.assertEquals(util.quote(value, typ), expected) if gadfly is None: GadflyTestCase.skip = "gadfly module not available" elif not getattr(gadfly, 'connect', None): gadfly.connect = gadfly.gadfly if sqlite is None: SQLiteTestCase.skip = "sqlite module not available" if PgSQL is None: PostgresTestCase.skip = "pyPgSQL module not available" else: try: conn = PgSQL.connect(database=PostgresTestCase.DB_NAME, user=PostgresTestCase.DB_USER, password=PostgresTestCase.DB_PASS) conn.close() except Exception, e: PostgresTestCase.skip = "Connection to PgSQL server failed: " + str(e) if psycopg is None: PsycopgTestCase.skip = "psycopg module not available" else: try: conn = psycopg.connect(database=PsycopgTestCase.DB_NAME, user=PsycopgTestCase.DB_USER, password=PsycopgTestCase.DB_PASS) conn.close() except Exception, e: PsycopgTestCase.skip = "Connection to PostgreSQL using psycopg failed: " + str(e) if MySQLdb is None: MySQLTestCase.skip = "MySQLdb module not available" else: try: conn = MySQLdb.connect(db=MySQLTestCase.DB_NAME, user=MySQLTestCase.DB_USER, passwd=MySQLTestCase.DB_PASS) conn.close() except Exception, e: MySQLTestCase.skip = "Connection to MySQL server failed: " + str(e) if kinterbasdb is None: FirebirdTestCase.skip = "kinterbasdb module not available" else: try: testcase = FirebirdTestCase() testcase.startDB() testcase.stopDB() except Exception, e: FirebirdTestCase.skip = "Connection to Firebase server failed: " + str(e)