Imported Upstream version 12.1.0
[contrib/python-twisted.git] / twisted / test / test_adbapi.py
1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
3
4
5 """
6 Tests for twisted.enterprise.adbapi.
7 """
8
9 from twisted.trial import unittest
10
11 import os, stat
12 import types
13
14 from twisted.enterprise.adbapi import ConnectionPool, ConnectionLost
15 from twisted.enterprise.adbapi import Connection, Transaction
16 from twisted.internet import reactor, defer, interfaces
17 from twisted.python.failure import Failure
18
19
20 simple_table_schema = """
21 CREATE TABLE simple (
22   x integer
23 )
24 """
25
26
27 class ADBAPITestBase:
28     """Test the asynchronous DB-API code."""
29
30     openfun_called = {}
31
32     if interfaces.IReactorThreads(reactor, None) is None:
33         skip = "ADB-API requires threads, no way to test without them"
34
35     def extraSetUp(self):
36         """
37         Set up the database and create a connection pool pointing at it.
38         """
39         self.startDB()
40         self.dbpool = self.makePool(cp_openfun=self.openfun)
41         self.dbpool.start()
42
43
44     def tearDown(self):
45         d =  self.dbpool.runOperation('DROP TABLE simple')
46         d.addCallback(lambda res: self.dbpool.close())
47         d.addCallback(lambda res: self.stopDB())
48         return d
49
50     def openfun(self, conn):
51         self.openfun_called[conn] = True
52
53     def checkOpenfunCalled(self, conn=None):
54         if not conn:
55             self.failUnless(self.openfun_called)
56         else:
57             self.failUnless(self.openfun_called.has_key(conn))
58
59     def testPool(self):
60         d = self.dbpool.runOperation(simple_table_schema)
61         if self.test_failures:
62             d.addCallback(self._testPool_1_1)
63             d.addCallback(self._testPool_1_2)
64             d.addCallback(self._testPool_1_3)
65             d.addCallback(self._testPool_1_4)
66             d.addCallback(lambda res: self.flushLoggedErrors())
67         d.addCallback(self._testPool_2)
68         d.addCallback(self._testPool_3)
69         d.addCallback(self._testPool_4)
70         d.addCallback(self._testPool_5)
71         d.addCallback(self._testPool_6)
72         d.addCallback(self._testPool_7)
73         d.addCallback(self._testPool_8)
74         d.addCallback(self._testPool_9)
75         return d
76
77     def _testPool_1_1(self, res):
78         d = defer.maybeDeferred(self.dbpool.runQuery, "select * from NOTABLE")
79         d.addCallbacks(lambda res: self.fail('no exception'),
80                        lambda f: None)
81         return d
82
83     def _testPool_1_2(self, res):
84         d = defer.maybeDeferred(self.dbpool.runOperation,
85                                 "deletexxx from NOTABLE")
86         d.addCallbacks(lambda res: self.fail('no exception'),
87                        lambda f: None)
88         return d
89
90     def _testPool_1_3(self, res):
91         d = defer.maybeDeferred(self.dbpool.runInteraction,
92                                 self.bad_interaction)
93         d.addCallbacks(lambda res: self.fail('no exception'),
94                        lambda f: None)
95         return d
96
97     def _testPool_1_4(self, res):
98         d = defer.maybeDeferred(self.dbpool.runWithConnection,
99                                 self.bad_withConnection)
100         d.addCallbacks(lambda res: self.fail('no exception'),
101                        lambda f: None)
102         return d
103
104     def _testPool_2(self, res):
105         # verify simple table is empty
106         sql = "select count(1) from simple"
107         d = self.dbpool.runQuery(sql)
108         def _check(row):
109             self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back")
110             self.checkOpenfunCalled()
111         d.addCallback(_check)
112         return d
113
114     def _testPool_3(self, res):
115         sql = "select count(1) from simple"
116         inserts = []
117         # add some rows to simple table (runOperation)
118         for i in range(self.num_iterations):
119             sql = "insert into simple(x) values(%d)" % i
120             inserts.append(self.dbpool.runOperation(sql))
121         d = defer.gatherResults(inserts)
122
123         def _select(res):
124             # make sure they were added (runQuery)
125             sql = "select x from simple order by x";
126             d = self.dbpool.runQuery(sql)
127             return d
128         d.addCallback(_select)
129
130         def _check(rows):
131             self.failUnless(len(rows) == self.num_iterations,
132                             "Wrong number of rows")
133             for i in range(self.num_iterations):
134                 self.failUnless(len(rows[i]) == 1, "Wrong size row")
135                 self.failUnless(rows[i][0] == i, "Values not returned.")
136         d.addCallback(_check)
137
138         return d
139
140     def _testPool_4(self, res):
141         # runInteraction
142         d = self.dbpool.runInteraction(self.interaction)
143         d.addCallback(lambda res: self.assertEqual(res, "done"))
144         return d
145
146     def _testPool_5(self, res):
147         # withConnection
148         d = self.dbpool.runWithConnection(self.withConnection)
149         d.addCallback(lambda res: self.assertEqual(res, "done"))
150         return d
151
152     def _testPool_6(self, res):
153         # Test a withConnection cannot be closed
154         d = self.dbpool.runWithConnection(self.close_withConnection)
155         return d
156
157     def _testPool_7(self, res):
158         # give the pool a workout
159         ds = []
160         for i in range(self.num_iterations):
161             sql = "select x from simple where x = %d" % i
162             ds.append(self.dbpool.runQuery(sql))
163         dlist = defer.DeferredList(ds, fireOnOneErrback=True)
164         def _check(result):
165             for i in range(self.num_iterations):
166                 self.failUnless(result[i][1][0][0] == i, "Value not returned")
167         dlist.addCallback(_check)
168         return dlist
169
170     def _testPool_8(self, res):
171         # now delete everything
172         ds = []
173         for i in range(self.num_iterations):
174             sql = "delete from simple where x = %d" % i
175             ds.append(self.dbpool.runOperation(sql))
176         dlist = defer.DeferredList(ds, fireOnOneErrback=True)
177         return dlist
178
179     def _testPool_9(self, res):
180         # verify simple table is empty
181         sql = "select count(1) from simple"
182         d = self.dbpool.runQuery(sql)
183         def _check(row):
184             self.failUnless(int(row[0][0]) == 0,
185                             "Didn't successfully delete table contents")
186             self.checkConnect()
187         d.addCallback(_check)
188         return d
189
190     def checkConnect(self):
191         """Check the connect/disconnect synchronous calls."""
192         conn = self.dbpool.connect()
193         self.checkOpenfunCalled(conn)
194         curs = conn.cursor()
195         curs.execute("insert into simple(x) values(1)")
196         curs.execute("select x from simple")
197         res = curs.fetchall()
198         self.assertEqual(len(res), 1)
199         self.assertEqual(len(res[0]), 1)
200         self.assertEqual(res[0][0], 1)
201         curs.execute("delete from simple")
202         curs.execute("select x from simple")
203         self.assertEqual(len(curs.fetchall()), 0)
204         curs.close()
205         self.dbpool.disconnect(conn)
206
207     def interaction(self, transaction):
208         transaction.execute("select x from simple order by x")
209         for i in range(self.num_iterations):
210             row = transaction.fetchone()
211             self.failUnless(len(row) == 1, "Wrong size row")
212             self.failUnless(row[0] == i, "Value not returned.")
213         # should test this, but gadfly throws an exception instead
214         #self.failUnless(transaction.fetchone() is None, "Too many rows")
215         return "done"
216
217     def bad_interaction(self, transaction):
218         if self.can_rollback:
219             transaction.execute("insert into simple(x) values(0)")
220
221         transaction.execute("select * from NOTABLE")
222
223     def withConnection(self, conn):
224         curs = conn.cursor()
225         try:
226             curs.execute("select x from simple order by x")
227             for i in range(self.num_iterations):
228                 row = curs.fetchone()
229                 self.failUnless(len(row) == 1, "Wrong size row")
230                 self.failUnless(row[0] == i, "Value not returned.")
231             # should test this, but gadfly throws an exception instead
232             #self.failUnless(transaction.fetchone() is None, "Too many rows")
233         finally:
234             curs.close()
235         return "done"
236
237     def close_withConnection(self, conn):
238         conn.close()
239
240     def bad_withConnection(self, conn):
241         curs = conn.cursor()
242         try:
243             curs.execute("select * from NOTABLE")
244         finally:
245             curs.close()
246
247
248 class ReconnectTestBase:
249     """Test the asynchronous DB-API code with reconnect."""
250
251     if interfaces.IReactorThreads(reactor, None) is None:
252         skip = "ADB-API requires threads, no way to test without them"
253
254     def extraSetUp(self):
255         """
256         Skip the test if C{good_sql} is unavailable.  Otherwise, set up the
257         database, create a connection pool pointed at it, and set up a simple
258         schema in it.
259         """
260         if self.good_sql is None:
261             raise unittest.SkipTest('no good sql for reconnect test')
262         self.startDB()
263         self.dbpool = self.makePool(cp_max=1, cp_reconnect=True,
264                                     cp_good_sql=self.good_sql)
265         self.dbpool.start()
266         return self.dbpool.runOperation(simple_table_schema)
267
268
269     def tearDown(self):
270         d = self.dbpool.runOperation('DROP TABLE simple')
271         d.addCallback(lambda res: self.dbpool.close())
272         d.addCallback(lambda res: self.stopDB())
273         return d
274
275     def testPool(self):
276         d = defer.succeed(None)
277         d.addCallback(self._testPool_1)
278         d.addCallback(self._testPool_2)
279         if not self.early_reconnect:
280             d.addCallback(self._testPool_3)
281         d.addCallback(self._testPool_4)
282         d.addCallback(self._testPool_5)
283         return d
284
285     def _testPool_1(self, res):
286         sql = "select count(1) from simple"
287         d = self.dbpool.runQuery(sql)
288         def _check(row):
289             self.failUnless(int(row[0][0]) == 0, "Table not empty")
290         d.addCallback(_check)
291         return d
292
293     def _testPool_2(self, res):
294         # reach in and close the connection manually
295         self.dbpool.connections.values()[0].close()
296
297     def _testPool_3(self, res):
298         sql = "select count(1) from simple"
299         d = defer.maybeDeferred(self.dbpool.runQuery, sql)
300         d.addCallbacks(lambda res: self.fail('no exception'),
301                        lambda f: None)
302         return d
303
304     def _testPool_4(self, res):
305         sql = "select count(1) from simple"
306         d = self.dbpool.runQuery(sql)
307         def _check(row):
308             self.failUnless(int(row[0][0]) == 0, "Table not empty")
309         d.addCallback(_check)
310         return d
311
312     def _testPool_5(self, res):
313         self.flushLoggedErrors()
314         sql = "select * from NOTABLE" # bad sql
315         d = defer.maybeDeferred(self.dbpool.runQuery, sql)
316         d.addCallbacks(lambda res: self.fail('no exception'),
317                        lambda f: self.failIf(f.check(ConnectionLost)))
318         return d
319
320
321 class DBTestConnector:
322     """A class which knows how to test for the presence of
323     and establish a connection to a relational database.
324
325     To enable test cases  which use a central, system database,
326     you must create a database named DB_NAME with a user DB_USER
327     and password DB_PASS with full access rights to database DB_NAME.
328     """
329
330     TEST_PREFIX = None # used for creating new test cases
331
332     DB_NAME = "twisted_test"
333     DB_USER = 'twisted_test'
334     DB_PASS = 'twisted_test'
335
336     DB_DIR = None # directory for database storage
337
338     nulls_ok = True # nulls supported
339     trailing_spaces_ok = True # trailing spaces in strings preserved
340     can_rollback = True # rollback supported
341     test_failures = True # test bad sql?
342     escape_slashes = True # escape \ in sql?
343     good_sql = ConnectionPool.good_sql
344     early_reconnect = True # cursor() will fail on closed connection
345     can_clear = True # can try to clear out tables when starting
346
347     num_iterations = 50 # number of iterations for test loops
348                         # (lower this for slow db's)
349
350     def setUp(self):
351         self.DB_DIR = self.mktemp()
352         os.mkdir(self.DB_DIR)
353         if not self.can_connect():
354             raise unittest.SkipTest('%s: Cannot access db' % self.TEST_PREFIX)
355         return self.extraSetUp()
356
357     def can_connect(self):
358         """Return true if this database is present on the system
359         and can be used in a test."""
360         raise NotImplementedError()
361
362     def startDB(self):
363         """Take any steps needed to bring database up."""
364         pass
365
366     def stopDB(self):
367         """Bring database down, if needed."""
368         pass
369
370     def makePool(self, **newkw):
371         """Create a connection pool with additional keyword arguments."""
372         args, kw = self.getPoolArgs()
373         kw = kw.copy()
374         kw.update(newkw)
375         return ConnectionPool(*args, **kw)
376
377     def getPoolArgs(self):
378         """Return a tuple (args, kw) of list and keyword arguments
379         that need to be passed to ConnectionPool to create a connection
380         to this database."""
381         raise NotImplementedError()
382
383 class GadflyConnector(DBTestConnector):
384     TEST_PREFIX = 'Gadfly'
385
386     nulls_ok = False
387     can_rollback = False
388     escape_slashes = False
389     good_sql = 'select * from simple where 1=0'
390
391     num_iterations = 1 # slow
392
393     def can_connect(self):
394         try: import gadfly
395         except: return False
396         if not getattr(gadfly, 'connect', None):
397             gadfly.connect = gadfly.gadfly
398         return True
399
400     def startDB(self):
401         import gadfly
402         conn = gadfly.gadfly()
403         conn.startup(self.DB_NAME, self.DB_DIR)
404
405         # gadfly seems to want us to create something to get the db going
406         cursor = conn.cursor()
407         cursor.execute("create table x (x integer)")
408         conn.commit()
409         conn.close()
410
411     def getPoolArgs(self):
412         args = ('gadfly', self.DB_NAME, self.DB_DIR)
413         kw = {'cp_max': 1}
414         return args, kw
415
416 class SQLiteConnector(DBTestConnector):
417     TEST_PREFIX = 'SQLite'
418
419     escape_slashes = False
420
421     num_iterations = 1 # slow
422
423     def can_connect(self):
424         try: import sqlite
425         except: return False
426         return True
427
428     def startDB(self):
429         self.database = os.path.join(self.DB_DIR, self.DB_NAME)
430         if os.path.exists(self.database):
431             os.unlink(self.database)
432
433     def getPoolArgs(self):
434         args = ('sqlite',)
435         kw = {'database': self.database, 'cp_max': 1}
436         return args, kw
437
438 class PyPgSQLConnector(DBTestConnector):
439     TEST_PREFIX = "PyPgSQL"
440
441     def can_connect(self):
442         try: from pyPgSQL import PgSQL
443         except: return False
444         try:
445             conn = PgSQL.connect(database=self.DB_NAME, user=self.DB_USER,
446                                  password=self.DB_PASS)
447             conn.close()
448             return True
449         except:
450             return False
451
452     def getPoolArgs(self):
453         args = ('pyPgSQL.PgSQL',)
454         kw = {'database': self.DB_NAME, 'user': self.DB_USER,
455               'password': self.DB_PASS, 'cp_min': 0}
456         return args, kw
457
458 class PsycopgConnector(DBTestConnector):
459     TEST_PREFIX = 'Psycopg'
460
461     def can_connect(self):
462         try: import psycopg
463         except: return False
464         try:
465             conn = psycopg.connect(database=self.DB_NAME, user=self.DB_USER,
466                                    password=self.DB_PASS)
467             conn.close()
468             return True
469         except:
470             return False
471
472     def getPoolArgs(self):
473         args = ('psycopg',)
474         kw = {'database': self.DB_NAME, 'user': self.DB_USER,
475               'password': self.DB_PASS, 'cp_min': 0}
476         return args, kw
477
478 class MySQLConnector(DBTestConnector):
479     TEST_PREFIX = 'MySQL'
480
481     trailing_spaces_ok = False
482     can_rollback = False
483     early_reconnect = False
484
485     def can_connect(self):
486         try: import MySQLdb
487         except: return False
488         try:
489             conn = MySQLdb.connect(db=self.DB_NAME, user=self.DB_USER,
490                                    passwd=self.DB_PASS)
491             conn.close()
492             return True
493         except:
494             return False
495
496     def getPoolArgs(self):
497         args = ('MySQLdb',)
498         kw = {'db': self.DB_NAME, 'user': self.DB_USER, 'passwd': self.DB_PASS}
499         return args, kw
500
501 class FirebirdConnector(DBTestConnector):
502     TEST_PREFIX = 'Firebird'
503
504     test_failures = False # failure testing causes problems
505     escape_slashes = False
506     good_sql = None # firebird doesn't handle failed sql well
507     can_clear = False # firebird is not so good
508
509     num_iterations = 5 # slow
510
511     def can_connect(self):
512         try: import kinterbasdb
513         except: return False
514         try:
515             self.startDB()
516             self.stopDB()
517             return True
518         except:
519             return False
520
521
522     def startDB(self):
523         import kinterbasdb
524         self.DB_NAME = os.path.join(self.DB_DIR, DBTestConnector.DB_NAME)
525         os.chmod(self.DB_DIR, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
526         sql = 'create database "%s" user "%s" password "%s"'
527         sql %= (self.DB_NAME, self.DB_USER, self.DB_PASS);
528         conn = kinterbasdb.create_database(sql)
529         conn.close()
530
531
532     def getPoolArgs(self):
533         args = ('kinterbasdb',)
534         kw = {'database': self.DB_NAME, 'host': '127.0.0.1',
535               'user': self.DB_USER, 'password': self.DB_PASS}
536         return args, kw
537
538     def stopDB(self):
539         import kinterbasdb
540         conn = kinterbasdb.connect(database=self.DB_NAME,
541                                    host='127.0.0.1', user=self.DB_USER,
542                                    password=self.DB_PASS)
543         conn.drop_database()
544
545 def makeSQLTests(base, suffix, globals):
546     """
547     Make a test case for every db connector which can connect.
548
549     @param base: Base class for test case. Additional base classes
550                  will be a DBConnector subclass and unittest.TestCase
551     @param suffix: A suffix used to create test case names. Prefixes
552                    are defined in the DBConnector subclasses.
553     """
554     connectors = [GadflyConnector, SQLiteConnector, PyPgSQLConnector,
555                   PsycopgConnector, MySQLConnector, FirebirdConnector]
556     for connclass in connectors:
557         name = connclass.TEST_PREFIX + suffix
558         klass = types.ClassType(name, (connclass, base, unittest.TestCase),
559                                 base.__dict__)
560         globals[name] = klass
561
562 # GadflyADBAPITestCase SQLiteADBAPITestCase PyPgSQLADBAPITestCase
563 # PsycopgADBAPITestCase MySQLADBAPITestCase FirebirdADBAPITestCase
564 makeSQLTests(ADBAPITestBase, 'ADBAPITestCase', globals())
565
566 # GadflyReconnectTestCase SQLiteReconnectTestCase PyPgSQLReconnectTestCase
567 # PsycopgReconnectTestCase MySQLReconnectTestCase FirebirdReconnectTestCase
568 makeSQLTests(ReconnectTestBase, 'ReconnectTestCase', globals())
569
570
571
572 class FakePool(object):
573     """
574     A fake L{ConnectionPool} for tests.
575
576     @ivar connectionFactory: factory for making connections returned by the
577         C{connect} method.
578     @type connectionFactory: any callable
579     """
580     reconnect = True
581     noisy = True
582
583     def __init__(self, connectionFactory):
584         self.connectionFactory = connectionFactory
585
586
587     def connect(self):
588         """
589         Return an instance of C{self.connectionFactory}.
590         """
591         return self.connectionFactory()
592
593
594     def disconnect(self, connection):
595         """
596         Do nothing.
597         """
598
599
600
601 class ConnectionTestCase(unittest.TestCase):
602     """
603     Tests for the L{Connection} class.
604     """
605
606     def test_rollbackErrorLogged(self):
607         """
608         If an error happens during rollback, L{ConnectionLost} is raised but
609         the original error is logged.
610         """
611         class ConnectionRollbackRaise(object):
612             def rollback(self):
613                 raise RuntimeError("problem!")
614
615         pool = FakePool(ConnectionRollbackRaise)
616         connection = Connection(pool)
617         self.assertRaises(ConnectionLost, connection.rollback)
618         errors = self.flushLoggedErrors(RuntimeError)
619         self.assertEqual(len(errors), 1)
620         self.assertEqual(errors[0].value.args[0], "problem!")
621
622
623
624 class TransactionTestCase(unittest.TestCase):
625     """
626     Tests for the L{Transaction} class.
627     """
628
629     def test_reopenLogErrorIfReconnect(self):
630         """
631         If the cursor creation raises an error in L{Transaction.reopen}, it
632         reconnects but log the error occurred.
633         """
634         class ConnectionCursorRaise(object):
635             count = 0
636
637             def reconnect(self):
638                 pass
639
640             def cursor(self):
641                 if self.count == 0:
642                     self.count += 1
643                     raise RuntimeError("problem!")
644
645         pool = FakePool(None)
646         transaction = Transaction(pool, ConnectionCursorRaise())
647         transaction.reopen()
648         errors = self.flushLoggedErrors(RuntimeError)
649         self.assertEqual(len(errors), 1)
650         self.assertEqual(errors[0].value.args[0], "problem!")
651
652
653
654 class NonThreadPool(object):
655     def callInThreadWithCallback(self, onResult, f, *a, **kw):
656         success = True
657         try:
658             result = f(*a, **kw)
659         except Exception, e:
660             success = False
661             result = Failure()
662         onResult(success, result)
663
664
665
666 class DummyConnectionPool(ConnectionPool):
667     """
668     A testable L{ConnectionPool};
669     """
670     threadpool = NonThreadPool()
671
672     def __init__(self):
673         """
674         Don't forward init call.
675         """
676         self.reactor = reactor
677
678
679
680 class EventReactor(object):
681     """
682     Partial L{IReactorCore} implementation with simple event-related
683     methods.
684
685     @ivar _running: A C{bool} indicating whether the reactor is pretending
686         to have been started already or not.
687
688     @ivar triggers: A C{list} of pending system event triggers.
689     """
690     def __init__(self, running):
691         self._running = running
692         self.triggers = []
693
694
695     def callWhenRunning(self, function):
696         if self._running:
697             function()
698         else:
699             return self.addSystemEventTrigger('after', 'startup', function)
700
701
702     def addSystemEventTrigger(self, phase, event, trigger):
703         handle = (phase, event, trigger)
704         self.triggers.append(handle)
705         return handle
706
707
708     def removeSystemEventTrigger(self, handle):
709         self.triggers.remove(handle)
710
711
712
713 class ConnectionPoolTestCase(unittest.TestCase):
714     """
715     Unit tests for L{ConnectionPool}.
716     """
717
718     def test_runWithConnectionRaiseOriginalError(self):
719         """
720         If rollback fails, L{ConnectionPool.runWithConnection} raises the
721         original exception and log the error of the rollback.
722         """
723         class ConnectionRollbackRaise(object):
724             def __init__(self, pool):
725                 pass
726
727             def rollback(self):
728                 raise RuntimeError("problem!")
729
730         def raisingFunction(connection):
731             raise ValueError("foo")
732
733         pool = DummyConnectionPool()
734         pool.connectionFactory = ConnectionRollbackRaise
735         d = pool.runWithConnection(raisingFunction)
736         d = self.assertFailure(d, ValueError)
737         def cbFailed(ignored):
738             errors = self.flushLoggedErrors(RuntimeError)
739             self.assertEqual(len(errors), 1)
740             self.assertEqual(errors[0].value.args[0], "problem!")
741         d.addCallback(cbFailed)
742         return d
743
744
745     def test_closeLogError(self):
746         """
747         L{ConnectionPool._close} logs exceptions.
748         """
749         class ConnectionCloseRaise(object):
750             def close(self):
751                 raise RuntimeError("problem!")
752
753         pool = DummyConnectionPool()
754         pool._close(ConnectionCloseRaise())
755
756         errors = self.flushLoggedErrors(RuntimeError)
757         self.assertEqual(len(errors), 1)
758         self.assertEqual(errors[0].value.args[0], "problem!")
759
760
761     def test_runWithInteractionRaiseOriginalError(self):
762         """
763         If rollback fails, L{ConnectionPool.runInteraction} raises the
764         original exception and log the error of the rollback.
765         """
766         class ConnectionRollbackRaise(object):
767             def __init__(self, pool):
768                 pass
769
770             def rollback(self):
771                 raise RuntimeError("problem!")
772
773         class DummyTransaction(object):
774             def __init__(self, pool, connection):
775                 pass
776
777         def raisingFunction(transaction):
778             raise ValueError("foo")
779
780         pool = DummyConnectionPool()
781         pool.connectionFactory = ConnectionRollbackRaise
782         pool.transactionFactory = DummyTransaction
783
784         d = pool.runInteraction(raisingFunction)
785         d = self.assertFailure(d, ValueError)
786         def cbFailed(ignored):
787             errors = self.flushLoggedErrors(RuntimeError)
788             self.assertEqual(len(errors), 1)
789             self.assertEqual(errors[0].value.args[0], "problem!")
790         d.addCallback(cbFailed)
791         return d
792
793
794     def test_unstartedClose(self):
795         """
796         If L{ConnectionPool.close} is called without L{ConnectionPool.start}
797         having been called, the pool's startup event is cancelled.
798         """
799         reactor = EventReactor(False)
800         pool = ConnectionPool('twisted.test.test_adbapi', cp_reactor=reactor)
801         # There should be a startup trigger waiting.
802         self.assertEqual(reactor.triggers, [('after', 'startup', pool._start)])
803         pool.close()
804         # But not anymore.
805         self.assertFalse(reactor.triggers)
806
807
808     def test_startedClose(self):
809         """
810         If L{ConnectionPool.close} is called after it has been started, but
811         not by its shutdown trigger, the shutdown trigger is cancelled.
812         """
813         reactor = EventReactor(True)
814         pool = ConnectionPool('twisted.test.test_adbapi', cp_reactor=reactor)
815         # There should be a shutdown trigger waiting.
816         self.assertEqual(reactor.triggers, [('during', 'shutdown', pool.finalClose)])
817         pool.close()
818         # But not anymore.
819         self.assertFalse(reactor.triggers)