# Twisted, the Framework of Your Internet
# Copyright (C) 2001-2002 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
"""Tests for twisted.enterprise."""
from twisted.trial import unittest
import os
import stat
import random
import tempfile
from twisted.enterprise.row import RowObject
from twisted.enterprise.reflector import *
from twisted.enterprise.xmlreflector import XMLReflector
from twisted.enterprise.sqlreflector import SQLReflector
from twisted.enterprise.adbapi import ConnectionPool
from twisted.enterprise import util
from twisted.internet import defer
from twisted.trial.util import deferredResult, deferredError
from twisted.python import log
try: import gadfly
except: gadfly = None
try: import sqlite
except: sqlite = None
try: from pyPgSQL import PgSQL
except: PgSQL = None
try: import MySQLdb
except: MySQLdb = None
try: import psycopg
except: psycopg = None
try: import kinterbasdb
except: kinterbasdb = None
tableName = "testTable"
childTableName = "childTable"
class TestRow(RowObject):
rowColumns = [("key_string", "varchar"),
("col2", "int"),
("another_column", "varchar"),
("Column4", "varchar"),
("column_5_", "int")]
rowKeyColumns = [("key_string", "varchar")]
rowTableName = tableName
class ChildRow(RowObject):
rowColumns = [("childId", "int"),
("foo", "varchar"),
("test_key", "varchar"),
("stuff", "varchar"),
("gogogo", "int"),
("data", "varchar")]
rowKeyColumns = [("childId", "int")]
rowTableName = childTableName
rowForeignKeys = [(tableName,
[("test_key","varchar")],
[("key_string","varchar")],
None, 1)]
main_table_schema = """
CREATE TABLE testTable (
key_string varchar(64),
col2 integer,
another_column varchar(64),
Column4 varchar(64),
column_5_ integer
)
"""
child_table_schema = """
CREATE TABLE childTable (
childId integer,
foo varchar(64),
test_key varchar(64),
stuff varchar(64),
gogogo integer,
data varchar(64)
)
"""
simple_table_schema = """
CREATE TABLE simple (
x integer
)
"""
def randomizeRow(row, nullsOK=1, trailingSpacesOK=1):
values = {}
for name, type in row.rowColumns:
if util.getKeyColumn(row, name):
values[name] = getattr(row, name)
continue
elif nullsOK and random.randint(0, 9) == 0:
value = None # null
elif type == 'int':
value = random.randint(-10000, 10000)
else:
if random.randint(0, 9) == 0:
value = ''
else:
value = ''.join(map(lambda i:chr(random.randrange(32,127)),
xrange(random.randint(1, 64))))
if not trailingSpacesOK:
value = value.rstrip()
setattr(row, name, value)
values[name] = value
return values
def rowMatches(row, values):
for name, type in row.rowColumns:
if getattr(row, name) != values[name]:
print ("Mismatch on column %s: |%s| (row) |%s| (values)" %
(name, getattr(row, name), values[name]))
return
return 1
class ReflectorTestCase:
"""Base class for testing reflectors.
Subclass and implement createReflector for the style and db you
want to test. This may involve creating a new database, starting a
server, etc. If createReflector returns None, the test is skipped.
This allows subclasses to test for the presence of the database
libraries and silently skip the test if they are not present.
Implement destroyReflector if your database needs to be shutdown
afterwards.
"""
count = 100 # a parameter used for running iterative tests
nullsOK = 1 # we can put nulls into the db
trailingSpacesOK = 1 # we can put strings with trailing spaces into the db
def randomizeRow(self, row):
return randomizeRow(row, self.nullsOK, self.trailingSpacesOK)
def setUp(self):
self.reflector = self.createReflector()
def tearDown(self):
self.destroyReflector()
def destroyReflector(self):
pass
def testReflector(self):
# create one row to work with
row = TestRow()
row.assignKeyAttr("key_string", "first")
values = self.randomizeRow(row)
# save it
deferredResult(self.reflector.insertRow(row))
# now load it back in
whereClause = [("key_string", EQUAL, "first")]
d = self.reflector.loadObjectsFrom(tableName, whereClause=whereClause)
d.addCallback(self.gotData)
deferredResult(d)
# make sure it came back as what we saved
self.failUnless(len(self.data) == 1, "no row")
parent = self.data[0]
self.failUnless(rowMatches(parent, values), "no match")
# create some child rows
child_values = {}
for i in range(0, self.count):
row = ChildRow()
row.assignKeyAttr("childId", i)
values = self.randomizeRow(row)
values['test_key'] = row.test_key = "first"
child_values[i] = values
deferredResult(self.reflector.insertRow(row))
row = None
d = self.reflector.loadObjectsFrom(childTableName, parentRow=parent)
d.addCallback(self.gotData)
deferredResult(d)
self.failUnless(len(self.data) == self.count, "no rows on query")
self.failUnless(len(parent.childRows) == self.count,
"did not load child rows: %d" % len(parent.childRows))
for child in parent.childRows:
self.failUnless(rowMatches(child, child_values[child.childId]),
"child %d does not match" % child.childId)
# loading these objects a second time should not re-add them
# to the parentRow.
d = self.reflector.loadObjectsFrom(childTableName, parentRow=parent)
d.addCallback(self.gotData)
deferredResult(d)
self.failUnless(len(self.data) == self.count, "no rows on query")
self.failUnless(len(parent.childRows) == self.count,
"child rows added twice!: %d" % len(parent.childRows))
# now change the parent
values = self.randomizeRow(parent)
deferredResult(self.reflector.updateRow(parent))
parent = None
# now load it back in
whereClause = [("key_string", EQUAL, "first")]
d = self.reflector.loadObjectsFrom(tableName, whereClause=whereClause)
d.addCallback(self.gotData)
deferredResult(d)
# make sure it came back as what we saved
self.failUnless(len(self.data) == 1, "no row")
parent = self.data[0]
self.failUnless(rowMatches(parent, values), "no match")
# save parent
test_values = {}
test_values[parent.key_string] = values
parent = None
# save some more test rows
for i in range(0, self.count):
row = TestRow()
row.assignKeyAttr("key_string", "bulk%d"%i)
test_values[row.key_string] = self.randomizeRow(row)
deferredResult(self.reflector.insertRow(row))
row = None
# now load them all back in
d = self.reflector.loadObjectsFrom("testTable")
d.addCallback(self.gotData)
deferredResult(d)
# make sure they are the same
self.failUnless(len(self.data) == self.count + 1,
"query did not get rows")
for row in self.data:
self.failUnless(rowMatches(row, test_values[row.key_string]),
"child %s does not match" % row.key_string)
# now change them all
for row in self.data:
test_values[row.key_string] = self.randomizeRow(row)
deferredResult(self.reflector.updateRow(row))
self.data = None
# load'em back
d = self.reflector.loadObjectsFrom("testTable")
d.addCallback(self.gotData)
deferredResult(d)
# make sure they are the same
self.failUnless(len(self.data) == self.count + 1,
"query did not get rows")
for row in self.data:
self.failUnless(rowMatches(row, test_values[row.key_string]),
"child %s does not match" % row.key_string)
# now delete them
for row in self.data:
deferredResult(self.reflector.deleteRow(row))
self.data = None
# load'em back
d = self.reflector.loadObjectsFrom("testTable")
d.addCallback(self.gotData)
deferredResult(d)
self.failUnless(len(self.data) == 0, "rows were not deleted")
# create one row to work with
row = TestRow()
row.assignKeyAttr("key_string", "first")
values = self.randomizeRow(row)
# save it
deferredResult(self.reflector.insertRow(row))
# delete it
deferredResult(self.reflector.deleteRow(row))
def gotData(self, data):
self.data = data
class XMLReflectorTestCase(ReflectorTestCase, unittest.TestCase):
"""Test cases for the XML reflector.
"""
count = 10 # xmlreflector is slow
DB = "./xmlDB"
def createReflector(self):
return XMLReflector(self.DB, [TestRow, ChildRow])
class SQLReflectorTestCase(ReflectorTestCase):
"""Test cases for the SQL reflector.
To enable this test for databases which use a central, system database,
you must create a database named DB_NAME with a user DB_USER and password
DB_PASS with full access rights to the database DB_NAME.
"""
DB_NAME = "twisted_test"
DB_USER = 'twisted_test'
DB_PASS = 'twisted_test'
can_rollback = 1
test_failures = 1
reflectorClass = SQLReflector
def createReflector(self):
self.startDB()
self.dbpool = self.makePool()
self.dbpool.start()
deferredResult(self.dbpool.runOperation(main_table_schema))
deferredResult(self.dbpool.runOperation(child_table_schema))
deferredResult(self.dbpool.runOperation(simple_table_schema))
return self.reflectorClass(self.dbpool, [TestRow, ChildRow])
def destroyReflector(self):
deferredResult(self.dbpool.runOperation('DROP TABLE testTable'))
deferredResult(self.dbpool.runOperation('DROP TABLE childTable'))
deferredResult(self.dbpool.runOperation('DROP TABLE simple'))
self.dbpool.close()
self.stopDB()
def testPool(self):
if self.test_failures:
# make sure failures are raised correctly
deferredError(self.dbpool.runQuery("select * from NOTABLE"))
deferredError(self.dbpool.runOperation("deletexxx from NOTABLE"))
deferredError(self.dbpool.runInteraction(self.bad_interaction))
log.flushErrors()
# verify simple table is empty
sql = "select count(1) from simple"
row = deferredResult(self.dbpool.runQuery(sql))
self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back")
# add some rows to simple table (runOperation)
for i in range(self.count):
sql = "insert into simple(x) values(%d)" % i
deferredResult(self.dbpool.runOperation(sql))
# make sure they were added (runQuery)
sql = "select x from simple order by x";
rows = deferredResult(self.dbpool.runQuery(sql))
self.failUnless(len(rows) == self.count, "Wrong number of rows")
for i in range(self.count):
self.failUnless(len(rows[i]) == 1, "Wrong size row")
self.failUnless(rows[i][0] == i, "Values not returned.")
# runInteraction
self.assertEquals(deferredResult(self.dbpool.runInteraction(self.interaction)),
"done")
# give the pool a workout
ds = []
for i in range(self.count):
sql = "select x from simple where x = %d" % i
ds.append(self.dbpool.runQuery(sql))
dlist = defer.DeferredList(ds, fireOnOneErrback=1)
result = deferredResult(dlist)
for i in range(self.count):
self.failUnless(result[i][1][0][0] == i, "Value not returned")
# now delete everything
ds = []
for i in range(self.count):
sql = "delete from simple where x = %d" % i
ds.append(self.dbpool.runOperation(sql))
dlist = defer.DeferredList(ds, fireOnOneErrback=1)
deferredResult(dlist)
# verify simple table is empty
sql = "select count(1) from simple"
row = deferredResult(self.dbpool.runQuery(sql))
self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back")
def interaction(self, transaction):
transaction.execute("select x from simple order by x")
for i in range(self.count):
row = transaction.fetchone()
self.failUnless(len(row) == 1, "Wrong size row")
self.failUnless(row[0] == i, "Value not returned.")
# should test this, but gadfly throws an exception instead
#self.failUnless(transaction.fetchone() is None, "Too many rows")
return "done"
def bad_interaction(self, transaction):
if self.can_rollback:
transaction.execute("insert into simple(x) values(0)")
transaction.execute("select * from NOTABLE")
def startDB(self): pass
def stopDB(self): pass
class NoSlashSQLReflector(SQLReflector):
def escape_string(self, text):
return text.replace("'", "''")
class GadflyTestCase(SQLReflectorTestCase, unittest.TestCase):
"""Test cases for the SQL reflector using Gadfly.
"""
count = 10 # gadfly is slow
nullsOK = 0
DB_DIR = "./gadflyDB"
reflectorClass = NoSlashSQLReflector
can_rollback = 0
def startDB(self):
if not os.path.exists(self.DB_DIR): os.mkdir(self.DB_DIR)
conn = gadfly.gadfly()
conn.startup(self.DB_NAME, self.DB_DIR)
# gadfly seems to want us to create something to get the db going
cursor = conn.cursor()
cursor.execute("create table x (x integer)")
conn.commit()
conn.close()
def makePool(self):
return ConnectionPool('gadfly', self.DB_NAME, self.DB_DIR, cp_max=1)
class SQLiteTestCase(SQLReflectorTestCase, unittest.TestCase):
"""Test cases for the SQL reflector using SQLite.
"""
DB_DIR = "./sqliteDB"
reflectorClass = NoSlashSQLReflector
def startDB(self):
if not os.path.exists(self.DB_DIR): os.mkdir(self.DB_DIR)
self.database = os.path.join(self.DB_DIR, self.DB_NAME)
if os.path.exists(self.database): os.unlink(self.database)
def makePool(self):
return ConnectionPool('sqlite', database=self.database, cp_max=1)
class PostgresTestCase(SQLReflectorTestCase, unittest.TestCase):
"""Test cases for the SQL reflector using Postgres.
"""
def makePool(self):
return ConnectionPool('pyPgSQL.PgSQL', database=self.DB_NAME,
user=self.DB_USER, password=self.DB_PASS,
cp_min=0)
class PsycopgTestCase(SQLReflectorTestCase, unittest.TestCase):
"""Test cases for the SQL reflector using psycopg for Postgres.
"""
def makePool(self):
return ConnectionPool('psycopg', database=self.DB_NAME,
user=self.DB_USER, password=self.DB_PASS,
cp_min=0)
class MySQLTestCase(SQLReflectorTestCase, unittest.TestCase):
"""Test cases for the SQL reflector using MySQL.
"""
trailingSpacesOK = 0
can_rollback = 0
def makePool(self):
return ConnectionPool('MySQLdb', db=self.DB_NAME,
user=self.DB_USER, passwd=self.DB_PASS)
class FirebirdTestCase(SQLReflectorTestCase, unittest.TestCase):
"""Test cases for the SQL reflector using Firebird/Interbase."""
count = 2 # CHANGEME
test_failures = 0 # failure testing causes problems
reflectorClass = NoSlashSQLReflector
DB_DIR = tempfile.mktemp()
DB_NAME = os.path.join(DB_DIR, SQLReflectorTestCase.DB_NAME)
def startDB(self):
os.chmod(self.DB_DIR, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
sql = 'create database "%s" user "%s" password "%s"'
sql %= (self.DB_NAME, self.DB_USER, self.DB_PASS);
conn = kinterbasdb.create_database(sql)
conn.close()
os.chmod(self.DB_NAME, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
def makePool(self):
return ConnectionPool('kinterbasdb', database=self.DB_NAME,
host='localhost', user=self.DB_USER,
password=self.DB_PASS)
def stopDB(self):
conn = kinterbasdb.connect(database=self.DB_NAME,
host='localhost', user=self.DB_USER,
password=self.DB_PASS)
conn.drop_database()
conn.close()
class QuotingTestCase(unittest.TestCase):
def testQuoting(self):
for value, typ, expected in [
(12, "integer", "12"),
("foo'd", "text", "'foo''d'"),
("\x00abc\\s\xFF", "bytea", "'\\\\000abc\\\\\\\\s\\377'"),
]:
self.assertEquals(util.quote(value, typ), expected)
if gadfly is None: GadflyTestCase.skip = "gadfly module not available"
elif not getattr(gadfly, 'connect', None): gadfly.connect = gadfly.gadfly
if sqlite is None: SQLiteTestCase.skip = "sqlite module not available"
if PgSQL is None: PostgresTestCase.skip = "pyPgSQL module not available"
else:
try:
conn = PgSQL.connect(database=PostgresTestCase.DB_NAME,
user=PostgresTestCase.DB_USER,
password=PostgresTestCase.DB_PASS)
conn.close()
except Exception, e:
PostgresTestCase.skip = "Connection to PgSQL server failed: " + str(e)
if psycopg is None: PsycopgTestCase.skip = "psycopg module not available"
else:
try:
conn = psycopg.connect(database=PsycopgTestCase.DB_NAME,
user=PsycopgTestCase.DB_USER,
password=PsycopgTestCase.DB_PASS)
conn.close()
except Exception, e:
PsycopgTestCase.skip = "Connection to PostgreSQL using psycopg failed: " + str(e)
if MySQLdb is None: MySQLTestCase.skip = "MySQLdb module not available"
else:
try:
conn = MySQLdb.connect(db=MySQLTestCase.DB_NAME,
user=MySQLTestCase.DB_USER,
passwd=MySQLTestCase.DB_PASS)
conn.close()
except Exception, e:
MySQLTestCase.skip = "Connection to MySQL server failed: " + str(e)
if kinterbasdb is None:
FirebirdTestCase.skip = "kinterbasdb module not available"
else:
try:
testcase = FirebirdTestCase()
testcase.startDB()
testcase.stopDB()
except Exception, e:
FirebirdTestCase.skip = "Connection to Firebase server failed: " + str(e)
syntax highlighted by Code2HTML, v. 0.9.1