1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
6 Tests for twisted.enterprise.adbapi.
9 from twisted.trial import unittest
14 from twisted.enterprise.adbapi import ConnectionPool, ConnectionLost
15 from twisted.enterprise.adbapi import Connection, Transaction
16 from twisted.internet import reactor, defer, interfaces
17 from twisted.python.failure import Failure
20 simple_table_schema = """
28 """Test the asynchronous DB-API code."""
32 if interfaces.IReactorThreads(reactor, None) is None:
33 skip = "ADB-API requires threads, no way to test without them"
37 Set up the database and create a connection pool pointing at it.
40 self.dbpool = self.makePool(cp_openfun=self.openfun)
45 d = self.dbpool.runOperation('DROP TABLE simple')
46 d.addCallback(lambda res: self.dbpool.close())
47 d.addCallback(lambda res: self.stopDB())
50 def openfun(self, conn):
51 self.openfun_called[conn] = True
53 def checkOpenfunCalled(self, conn=None):
55 self.failUnless(self.openfun_called)
57 self.failUnless(self.openfun_called.has_key(conn))
60 d = self.dbpool.runOperation(simple_table_schema)
61 if self.test_failures:
62 d.addCallback(self._testPool_1_1)
63 d.addCallback(self._testPool_1_2)
64 d.addCallback(self._testPool_1_3)
65 d.addCallback(self._testPool_1_4)
66 d.addCallback(lambda res: self.flushLoggedErrors())
67 d.addCallback(self._testPool_2)
68 d.addCallback(self._testPool_3)
69 d.addCallback(self._testPool_4)
70 d.addCallback(self._testPool_5)
71 d.addCallback(self._testPool_6)
72 d.addCallback(self._testPool_7)
73 d.addCallback(self._testPool_8)
74 d.addCallback(self._testPool_9)
77 def _testPool_1_1(self, res):
78 d = defer.maybeDeferred(self.dbpool.runQuery, "select * from NOTABLE")
79 d.addCallbacks(lambda res: self.fail('no exception'),
83 def _testPool_1_2(self, res):
84 d = defer.maybeDeferred(self.dbpool.runOperation,
85 "deletexxx from NOTABLE")
86 d.addCallbacks(lambda res: self.fail('no exception'),
90 def _testPool_1_3(self, res):
91 d = defer.maybeDeferred(self.dbpool.runInteraction,
93 d.addCallbacks(lambda res: self.fail('no exception'),
97 def _testPool_1_4(self, res):
98 d = defer.maybeDeferred(self.dbpool.runWithConnection,
99 self.bad_withConnection)
100 d.addCallbacks(lambda res: self.fail('no exception'),
104 def _testPool_2(self, res):
105 # verify simple table is empty
106 sql = "select count(1) from simple"
107 d = self.dbpool.runQuery(sql)
109 self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back")
110 self.checkOpenfunCalled()
111 d.addCallback(_check)
114 def _testPool_3(self, res):
115 sql = "select count(1) from simple"
117 # add some rows to simple table (runOperation)
118 for i in range(self.num_iterations):
119 sql = "insert into simple(x) values(%d)" % i
120 inserts.append(self.dbpool.runOperation(sql))
121 d = defer.gatherResults(inserts)
124 # make sure they were added (runQuery)
125 sql = "select x from simple order by x";
126 d = self.dbpool.runQuery(sql)
128 d.addCallback(_select)
131 self.failUnless(len(rows) == self.num_iterations,
132 "Wrong number of rows")
133 for i in range(self.num_iterations):
134 self.failUnless(len(rows[i]) == 1, "Wrong size row")
135 self.failUnless(rows[i][0] == i, "Values not returned.")
136 d.addCallback(_check)
140 def _testPool_4(self, res):
142 d = self.dbpool.runInteraction(self.interaction)
143 d.addCallback(lambda res: self.assertEqual(res, "done"))
146 def _testPool_5(self, res):
148 d = self.dbpool.runWithConnection(self.withConnection)
149 d.addCallback(lambda res: self.assertEqual(res, "done"))
152 def _testPool_6(self, res):
153 # Test a withConnection cannot be closed
154 d = self.dbpool.runWithConnection(self.close_withConnection)
157 def _testPool_7(self, res):
158 # give the pool a workout
160 for i in range(self.num_iterations):
161 sql = "select x from simple where x = %d" % i
162 ds.append(self.dbpool.runQuery(sql))
163 dlist = defer.DeferredList(ds, fireOnOneErrback=True)
165 for i in range(self.num_iterations):
166 self.failUnless(result[i][1][0][0] == i, "Value not returned")
167 dlist.addCallback(_check)
170 def _testPool_8(self, res):
171 # now delete everything
173 for i in range(self.num_iterations):
174 sql = "delete from simple where x = %d" % i
175 ds.append(self.dbpool.runOperation(sql))
176 dlist = defer.DeferredList(ds, fireOnOneErrback=True)
179 def _testPool_9(self, res):
180 # verify simple table is empty
181 sql = "select count(1) from simple"
182 d = self.dbpool.runQuery(sql)
184 self.failUnless(int(row[0][0]) == 0,
185 "Didn't successfully delete table contents")
187 d.addCallback(_check)
190 def checkConnect(self):
191 """Check the connect/disconnect synchronous calls."""
192 conn = self.dbpool.connect()
193 self.checkOpenfunCalled(conn)
195 curs.execute("insert into simple(x) values(1)")
196 curs.execute("select x from simple")
197 res = curs.fetchall()
198 self.assertEqual(len(res), 1)
199 self.assertEqual(len(res[0]), 1)
200 self.assertEqual(res[0][0], 1)
201 curs.execute("delete from simple")
202 curs.execute("select x from simple")
203 self.assertEqual(len(curs.fetchall()), 0)
205 self.dbpool.disconnect(conn)
207 def interaction(self, transaction):
208 transaction.execute("select x from simple order by x")
209 for i in range(self.num_iterations):
210 row = transaction.fetchone()
211 self.failUnless(len(row) == 1, "Wrong size row")
212 self.failUnless(row[0] == i, "Value not returned.")
213 # should test this, but gadfly throws an exception instead
214 #self.failUnless(transaction.fetchone() is None, "Too many rows")
217 def bad_interaction(self, transaction):
218 if self.can_rollback:
219 transaction.execute("insert into simple(x) values(0)")
221 transaction.execute("select * from NOTABLE")
223 def withConnection(self, conn):
226 curs.execute("select x from simple order by x")
227 for i in range(self.num_iterations):
228 row = curs.fetchone()
229 self.failUnless(len(row) == 1, "Wrong size row")
230 self.failUnless(row[0] == i, "Value not returned.")
231 # should test this, but gadfly throws an exception instead
232 #self.failUnless(transaction.fetchone() is None, "Too many rows")
237 def close_withConnection(self, conn):
240 def bad_withConnection(self, conn):
243 curs.execute("select * from NOTABLE")
248 class ReconnectTestBase:
249 """Test the asynchronous DB-API code with reconnect."""
251 if interfaces.IReactorThreads(reactor, None) is None:
252 skip = "ADB-API requires threads, no way to test without them"
254 def extraSetUp(self):
256 Skip the test if C{good_sql} is unavailable. Otherwise, set up the
257 database, create a connection pool pointed at it, and set up a simple
260 if self.good_sql is None:
261 raise unittest.SkipTest('no good sql for reconnect test')
263 self.dbpool = self.makePool(cp_max=1, cp_reconnect=True,
264 cp_good_sql=self.good_sql)
266 return self.dbpool.runOperation(simple_table_schema)
270 d = self.dbpool.runOperation('DROP TABLE simple')
271 d.addCallback(lambda res: self.dbpool.close())
272 d.addCallback(lambda res: self.stopDB())
276 d = defer.succeed(None)
277 d.addCallback(self._testPool_1)
278 d.addCallback(self._testPool_2)
279 if not self.early_reconnect:
280 d.addCallback(self._testPool_3)
281 d.addCallback(self._testPool_4)
282 d.addCallback(self._testPool_5)
285 def _testPool_1(self, res):
286 sql = "select count(1) from simple"
287 d = self.dbpool.runQuery(sql)
289 self.failUnless(int(row[0][0]) == 0, "Table not empty")
290 d.addCallback(_check)
293 def _testPool_2(self, res):
294 # reach in and close the connection manually
295 self.dbpool.connections.values()[0].close()
297 def _testPool_3(self, res):
298 sql = "select count(1) from simple"
299 d = defer.maybeDeferred(self.dbpool.runQuery, sql)
300 d.addCallbacks(lambda res: self.fail('no exception'),
304 def _testPool_4(self, res):
305 sql = "select count(1) from simple"
306 d = self.dbpool.runQuery(sql)
308 self.failUnless(int(row[0][0]) == 0, "Table not empty")
309 d.addCallback(_check)
312 def _testPool_5(self, res):
313 self.flushLoggedErrors()
314 sql = "select * from NOTABLE" # bad sql
315 d = defer.maybeDeferred(self.dbpool.runQuery, sql)
316 d.addCallbacks(lambda res: self.fail('no exception'),
317 lambda f: self.failIf(f.check(ConnectionLost)))
321 class DBTestConnector:
322 """A class which knows how to test for the presence of
323 and establish a connection to a relational database.
325 To enable test cases which use a central, system database,
326 you must create a database named DB_NAME with a user DB_USER
327 and password DB_PASS with full access rights to database DB_NAME.
330 TEST_PREFIX = None # used for creating new test cases
332 DB_NAME = "twisted_test"
333 DB_USER = 'twisted_test'
334 DB_PASS = 'twisted_test'
336 DB_DIR = None # directory for database storage
338 nulls_ok = True # nulls supported
339 trailing_spaces_ok = True # trailing spaces in strings preserved
340 can_rollback = True # rollback supported
341 test_failures = True # test bad sql?
342 escape_slashes = True # escape \ in sql?
343 good_sql = ConnectionPool.good_sql
344 early_reconnect = True # cursor() will fail on closed connection
345 can_clear = True # can try to clear out tables when starting
347 num_iterations = 50 # number of iterations for test loops
348 # (lower this for slow db's)
351 self.DB_DIR = self.mktemp()
352 os.mkdir(self.DB_DIR)
353 if not self.can_connect():
354 raise unittest.SkipTest('%s: Cannot access db' % self.TEST_PREFIX)
355 return self.extraSetUp()
357 def can_connect(self):
358 """Return true if this database is present on the system
359 and can be used in a test."""
360 raise NotImplementedError()
363 """Take any steps needed to bring database up."""
367 """Bring database down, if needed."""
370 def makePool(self, **newkw):
371 """Create a connection pool with additional keyword arguments."""
372 args, kw = self.getPoolArgs()
375 return ConnectionPool(*args, **kw)
377 def getPoolArgs(self):
378 """Return a tuple (args, kw) of list and keyword arguments
379 that need to be passed to ConnectionPool to create a connection
381 raise NotImplementedError()
383 class GadflyConnector(DBTestConnector):
384 TEST_PREFIX = 'Gadfly'
388 escape_slashes = False
389 good_sql = 'select * from simple where 1=0'
391 num_iterations = 1 # slow
393 def can_connect(self):
396 if not getattr(gadfly, 'connect', None):
397 gadfly.connect = gadfly.gadfly
402 conn = gadfly.gadfly()
403 conn.startup(self.DB_NAME, self.DB_DIR)
405 # gadfly seems to want us to create something to get the db going
406 cursor = conn.cursor()
407 cursor.execute("create table x (x integer)")
411 def getPoolArgs(self):
412 args = ('gadfly', self.DB_NAME, self.DB_DIR)
416 class SQLiteConnector(DBTestConnector):
417 TEST_PREFIX = 'SQLite'
419 escape_slashes = False
421 num_iterations = 1 # slow
423 def can_connect(self):
429 self.database = os.path.join(self.DB_DIR, self.DB_NAME)
430 if os.path.exists(self.database):
431 os.unlink(self.database)
433 def getPoolArgs(self):
435 kw = {'database': self.database, 'cp_max': 1}
438 class PyPgSQLConnector(DBTestConnector):
439 TEST_PREFIX = "PyPgSQL"
441 def can_connect(self):
442 try: from pyPgSQL import PgSQL
445 conn = PgSQL.connect(database=self.DB_NAME, user=self.DB_USER,
446 password=self.DB_PASS)
452 def getPoolArgs(self):
453 args = ('pyPgSQL.PgSQL',)
454 kw = {'database': self.DB_NAME, 'user': self.DB_USER,
455 'password': self.DB_PASS, 'cp_min': 0}
458 class PsycopgConnector(DBTestConnector):
459 TEST_PREFIX = 'Psycopg'
461 def can_connect(self):
465 conn = psycopg.connect(database=self.DB_NAME, user=self.DB_USER,
466 password=self.DB_PASS)
472 def getPoolArgs(self):
474 kw = {'database': self.DB_NAME, 'user': self.DB_USER,
475 'password': self.DB_PASS, 'cp_min': 0}
478 class MySQLConnector(DBTestConnector):
479 TEST_PREFIX = 'MySQL'
481 trailing_spaces_ok = False
483 early_reconnect = False
485 def can_connect(self):
489 conn = MySQLdb.connect(db=self.DB_NAME, user=self.DB_USER,
496 def getPoolArgs(self):
498 kw = {'db': self.DB_NAME, 'user': self.DB_USER, 'passwd': self.DB_PASS}
501 class FirebirdConnector(DBTestConnector):
502 TEST_PREFIX = 'Firebird'
504 test_failures = False # failure testing causes problems
505 escape_slashes = False
506 good_sql = None # firebird doesn't handle failed sql well
507 can_clear = False # firebird is not so good
509 num_iterations = 5 # slow
511 def can_connect(self):
512 try: import kinterbasdb
524 self.DB_NAME = os.path.join(self.DB_DIR, DBTestConnector.DB_NAME)
525 os.chmod(self.DB_DIR, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
526 sql = 'create database "%s" user "%s" password "%s"'
527 sql %= (self.DB_NAME, self.DB_USER, self.DB_PASS);
528 conn = kinterbasdb.create_database(sql)
532 def getPoolArgs(self):
533 args = ('kinterbasdb',)
534 kw = {'database': self.DB_NAME, 'host': '127.0.0.1',
535 'user': self.DB_USER, 'password': self.DB_PASS}
540 conn = kinterbasdb.connect(database=self.DB_NAME,
541 host='127.0.0.1', user=self.DB_USER,
542 password=self.DB_PASS)
545 def makeSQLTests(base, suffix, globals):
547 Make a test case for every db connector which can connect.
549 @param base: Base class for test case. Additional base classes
550 will be a DBConnector subclass and unittest.TestCase
551 @param suffix: A suffix used to create test case names. Prefixes
552 are defined in the DBConnector subclasses.
554 connectors = [GadflyConnector, SQLiteConnector, PyPgSQLConnector,
555 PsycopgConnector, MySQLConnector, FirebirdConnector]
556 for connclass in connectors:
557 name = connclass.TEST_PREFIX + suffix
558 klass = types.ClassType(name, (connclass, base, unittest.TestCase),
560 globals[name] = klass
562 # GadflyADBAPITestCase SQLiteADBAPITestCase PyPgSQLADBAPITestCase
563 # PsycopgADBAPITestCase MySQLADBAPITestCase FirebirdADBAPITestCase
564 makeSQLTests(ADBAPITestBase, 'ADBAPITestCase', globals())
566 # GadflyReconnectTestCase SQLiteReconnectTestCase PyPgSQLReconnectTestCase
567 # PsycopgReconnectTestCase MySQLReconnectTestCase FirebirdReconnectTestCase
568 makeSQLTests(ReconnectTestBase, 'ReconnectTestCase', globals())
572 class FakePool(object):
574 A fake L{ConnectionPool} for tests.
576 @ivar connectionFactory: factory for making connections returned by the
578 @type connectionFactory: any callable
583 def __init__(self, connectionFactory):
584 self.connectionFactory = connectionFactory
589 Return an instance of C{self.connectionFactory}.
591 return self.connectionFactory()
594 def disconnect(self, connection):
601 class ConnectionTestCase(unittest.TestCase):
603 Tests for the L{Connection} class.
606 def test_rollbackErrorLogged(self):
608 If an error happens during rollback, L{ConnectionLost} is raised but
609 the original error is logged.
611 class ConnectionRollbackRaise(object):
613 raise RuntimeError("problem!")
615 pool = FakePool(ConnectionRollbackRaise)
616 connection = Connection(pool)
617 self.assertRaises(ConnectionLost, connection.rollback)
618 errors = self.flushLoggedErrors(RuntimeError)
619 self.assertEqual(len(errors), 1)
620 self.assertEqual(errors[0].value.args[0], "problem!")
624 class TransactionTestCase(unittest.TestCase):
626 Tests for the L{Transaction} class.
629 def test_reopenLogErrorIfReconnect(self):
631 If the cursor creation raises an error in L{Transaction.reopen}, it
632 reconnects but log the error occurred.
634 class ConnectionCursorRaise(object):
643 raise RuntimeError("problem!")
645 pool = FakePool(None)
646 transaction = Transaction(pool, ConnectionCursorRaise())
648 errors = self.flushLoggedErrors(RuntimeError)
649 self.assertEqual(len(errors), 1)
650 self.assertEqual(errors[0].value.args[0], "problem!")
654 class NonThreadPool(object):
655 def callInThreadWithCallback(self, onResult, f, *a, **kw):
662 onResult(success, result)
666 class DummyConnectionPool(ConnectionPool):
668 A testable L{ConnectionPool};
670 threadpool = NonThreadPool()
674 Don't forward init call.
676 self.reactor = reactor
680 class EventReactor(object):
682 Partial L{IReactorCore} implementation with simple event-related
685 @ivar _running: A C{bool} indicating whether the reactor is pretending
686 to have been started already or not.
688 @ivar triggers: A C{list} of pending system event triggers.
690 def __init__(self, running):
691 self._running = running
695 def callWhenRunning(self, function):
699 return self.addSystemEventTrigger('after', 'startup', function)
702 def addSystemEventTrigger(self, phase, event, trigger):
703 handle = (phase, event, trigger)
704 self.triggers.append(handle)
708 def removeSystemEventTrigger(self, handle):
709 self.triggers.remove(handle)
713 class ConnectionPoolTestCase(unittest.TestCase):
715 Unit tests for L{ConnectionPool}.
718 def test_runWithConnectionRaiseOriginalError(self):
720 If rollback fails, L{ConnectionPool.runWithConnection} raises the
721 original exception and log the error of the rollback.
723 class ConnectionRollbackRaise(object):
724 def __init__(self, pool):
728 raise RuntimeError("problem!")
730 def raisingFunction(connection):
731 raise ValueError("foo")
733 pool = DummyConnectionPool()
734 pool.connectionFactory = ConnectionRollbackRaise
735 d = pool.runWithConnection(raisingFunction)
736 d = self.assertFailure(d, ValueError)
737 def cbFailed(ignored):
738 errors = self.flushLoggedErrors(RuntimeError)
739 self.assertEqual(len(errors), 1)
740 self.assertEqual(errors[0].value.args[0], "problem!")
741 d.addCallback(cbFailed)
745 def test_closeLogError(self):
747 L{ConnectionPool._close} logs exceptions.
749 class ConnectionCloseRaise(object):
751 raise RuntimeError("problem!")
753 pool = DummyConnectionPool()
754 pool._close(ConnectionCloseRaise())
756 errors = self.flushLoggedErrors(RuntimeError)
757 self.assertEqual(len(errors), 1)
758 self.assertEqual(errors[0].value.args[0], "problem!")
761 def test_runWithInteractionRaiseOriginalError(self):
763 If rollback fails, L{ConnectionPool.runInteraction} raises the
764 original exception and log the error of the rollback.
766 class ConnectionRollbackRaise(object):
767 def __init__(self, pool):
771 raise RuntimeError("problem!")
773 class DummyTransaction(object):
774 def __init__(self, pool, connection):
777 def raisingFunction(transaction):
778 raise ValueError("foo")
780 pool = DummyConnectionPool()
781 pool.connectionFactory = ConnectionRollbackRaise
782 pool.transactionFactory = DummyTransaction
784 d = pool.runInteraction(raisingFunction)
785 d = self.assertFailure(d, ValueError)
786 def cbFailed(ignored):
787 errors = self.flushLoggedErrors(RuntimeError)
788 self.assertEqual(len(errors), 1)
789 self.assertEqual(errors[0].value.args[0], "problem!")
790 d.addCallback(cbFailed)
794 def test_unstartedClose(self):
796 If L{ConnectionPool.close} is called without L{ConnectionPool.start}
797 having been called, the pool's startup event is cancelled.
799 reactor = EventReactor(False)
800 pool = ConnectionPool('twisted.test.test_adbapi', cp_reactor=reactor)
801 # There should be a startup trigger waiting.
802 self.assertEqual(reactor.triggers, [('after', 'startup', pool._start)])
805 self.assertFalse(reactor.triggers)
808 def test_startedClose(self):
810 If L{ConnectionPool.close} is called after it has been started, but
811 not by its shutdown trigger, the shutdown trigger is cancelled.
813 reactor = EventReactor(True)
814 pool = ConnectionPool('twisted.test.test_adbapi', cp_reactor=reactor)
815 # There should be a shutdown trigger waiting.
816 self.assertEqual(reactor.triggers, [('during', 'shutdown', pool.finalClose)])
819 self.assertFalse(reactor.triggers)