Functional unit tests after renaming bsdddb3 -> rpmdb, _db -> _rpmdb.
authorjbj <devnull@localhost>
Mon, 3 Jun 2002 20:44:08 +0000 (20:44 +0000)
committerjbj <devnull@localhost>
Mon, 3 Jun 2002 20:44:08 +0000 (20:44 +0000)
CVS patchset: 5459
CVS date: 2002/06/03 20:44:08

24 files changed:
python/_rpmdb.c
python/rpmdb/Makefile.am
python/rpmdb/__init__.py
python/rpmdb/db.py [moved from python/rpmdb/rpmdb.py with 97% similarity]
python/rpmdb/dbobj.py [moved from python/rpmdb/rpmdbobj.py with 100% similarity]
python/rpmdb/dbrecio.py [moved from python/rpmdb/rpmdbrecio.py with 98% similarity]
python/rpmdb/dbshelve.py [moved from python/rpmdb/rpmdbshelve.py with 97% similarity]
python/rpmdb/dbtables.py [moved from python/rpmdb/rpmdbtables.py with 99% similarity]
python/rpmdb/dbutils.py [moved from python/rpmdb/rpmdbutils.py with 98% similarity]
python/test/test_all.py [new file with mode: 0644]
python/test/test_associate.py [new file with mode: 0644]
python/test/test_basics.py [new file with mode: 0644]
python/test/test_compat.py [new file with mode: 0644]
python/test/test_dbobj.py [new file with mode: 0644]
python/test/test_dbshelve.py [new file with mode: 0644]
python/test/test_dbtables.py [new file with mode: 0644]
python/test/test_get_none.py [new file with mode: 0644]
python/test/test_join.py [new file with mode: 0644]
python/test/test_lock.py [new file with mode: 0644]
python/test/test_misc.py [new file with mode: 0644]
python/test/test_queue.py [new file with mode: 0644]
python/test/test_recno.py [new file with mode: 0644]
python/test/test_thread.py [new file with mode: 0644]
python/test/unittest.py [new file with mode: 0644]

index 8df27b9..b91b7fc 100644 (file)
@@ -88,7 +88,7 @@
 
 #define PY_BSDDB_VERSION "3.3.1"
 
-static char *rcs_id = "$Id: _rpmdb.c,v 1.1 2002/06/02 20:50:49 jbj Exp $";
+static char *rcs_id = "$Id: _rpmdb.c,v 1.2 2002/06/03 20:44:08 jbj Exp $";
 
 
 #ifdef WITH_THREAD
@@ -1996,7 +1996,7 @@ DB_set_get_returns_none(DBObject* self, PyObject* args)
 /*-------------------------------------------------------------- */
 /* Mapping and Dictionary-like access routines */
 
-int DB_length(DBObject* self)
+static int DB_length(DBObject* self)
 {
     int err;
     long size = 0;
@@ -2034,7 +2034,7 @@ int DB_length(DBObject* self)
 }
 
 
-PyObject* DB_subscript(DBObject* self, PyObject* keyobj)
+static PyObject* DB_subscript(DBObject* self, PyObject* keyobj)
 {
     int err;
     PyObject* retval;
@@ -3927,6 +3927,8 @@ static PyMethodDef bsddb_methods[] = {
 
 
 
+void init_rpmdb(void);         /* XXX remove compiler warning */
+
 DL_EXPORT(void) init_rpmdb(void)
 {
     PyObject* m;
@@ -3948,7 +3950,7 @@ DL_EXPORT(void) init_rpmdb(void)
 #endif
 
     /* Create the module and add the functions */
-    m = Py_InitModule("_db", bsddb_methods);
+    m = Py_InitModule("_rpmdb", bsddb_methods);
 
     /* Add some symbolic constants to the module */
     d = PyModule_GetDict(m);
index 543a08b..9812312 100644 (file)
@@ -6,5 +6,4 @@ PYVER= @WITH_PYTHON_VERSION@
 
 rpmdbdir = $(prefix)/lib/python${PYVER}/site-packages/rpmdb
 rpmdb_SCRIPTS = __init__.py \
-       rpmdbobj.py rpmdb.py rpmdbrecio.py rpmdbshelve.py rpmdbtables.py \
-       rpmdbutils.py
+       dbobj.py db.py dbrecio.py dbshelve.py dbtables.py dbutils.py
index 563e18f..6dd102d 100644 (file)
 
 """
 This package initialization module provides a compatibility interface
-that should enable bsddb3 to be a near drop-in replacement for the original
+that should enable rpmdb to be a near drop-in replacement for the original
 old bsddb module.  The functions and classes provided here are all
-wrappers around the new functionality provided in the bsddb3.db module.
+wrappers around the new functionality provided in the rpmdb.db module.
 
 People interested in the more advanced capabilites of Berkeley DB 3.x
-should use the bsddb3.db module directly.
+should use the rpmdb.db module directly.
 """
 
-import _db
+import _rpmdb as _db
 __version__ = _db.__version__
 
-error = _db.DBError  # So bsddb3.error will mean something...
+error = _db.DBError  # So rpmdb.error will mean something...
 
 #----------------------------------------------------------------------
 
similarity index 97%
rename from python/rpmdb/rpmdb.py
rename to python/rpmdb/db.py
index b4365d0..3bf5f8e 100644 (file)
@@ -37,8 +37,8 @@
 # case we ever want to augment the stuff in _db in any way.  For now
 # it just simply imports everything from _db.
 
-from _db import *
-from _db import __version__
+from _rpmdb import *
+from _rpmdb import __version__
 
 if version() < (3, 1, 0):
     raise ImportError, "BerkeleyDB 3.x symbols not found.  Perhaps python was statically linked with an older version?"
similarity index 98%
rename from python/rpmdb/rpmdbrecio.py
rename to python/rpmdb/dbrecio.py
index 995dad7..4ef6f6b 100644 (file)
@@ -1,6 +1,6 @@
 
 """
-File-like objects that read from or write to a bsddb3 record.
+File-like objects that read from or write to a rpmdb record.
 
 This implements (nearly) all stdio methods.
 
similarity index 97%
rename from python/rpmdb/rpmdbshelve.py
rename to python/rpmdb/dbshelve.py
index dab8caa..e4ca933 100644 (file)
 #------------------------------------------------------------------------
 
 """
-Manage shelves of pickled objects using bsddb3 database files for the
+Manage shelves of pickled objects using rpmdb database files for the
 storage.
 """
 
 #------------------------------------------------------------------------
 
 import cPickle
-from bsddb3 import db
+from rpmdb import db
 
 #------------------------------------------------------------------------
 
@@ -43,7 +43,7 @@ def open(filename, flags=db.DB_CREATE, mode=0660, filetype=db.DB_HASH,
     shleve.py module.  It can be used like this, where key is a string
     and data is a pickleable object:
 
-        from bsddb3 import dbshelve
+        from rpmdb import dbshelve
         db = dbshelve.open(filename)
 
         db[key] = data
@@ -63,7 +63,7 @@ def open(filename, flags=db.DB_CREATE, mode=0660, filetype=db.DB_HASH,
         elif sflag == 'n':
             flags = db.DB_TRUNCATE | db.DB_CREATE
         else:
-            raise error, "flags should be one of 'r', 'w', 'c' or 'n' or use the bsddb3.db.DB_* flags"
+            raise error, "flags should be one of 'r', 'w', 'c' or 'n' or use the rpmdb.db.DB_* flags"
 
     d = DBShelf(dbenv)
     d.open(filename, dbname, filetype, flags, mode)
@@ -73,7 +73,7 @@ def open(filename, flags=db.DB_CREATE, mode=0660, filetype=db.DB_HASH,
 
 class DBShelf:
     """
-    A shelf to hold pickled objects, built upon a bsddb3 DB object.  It
+    A shelf to hold pickled objects, built upon a rpmdb DB object.  It
     automatically pickles/unpickles data objects going to/from the DB.
     """
     def __init__(self, dbenv=None):
similarity index 99%
rename from python/rpmdb/rpmdbtables.py
rename to python/rpmdb/dbtables.py
index 8ffed91..05bf2ef 100644 (file)
@@ -28,7 +28,7 @@ import xdrlib
 import re
 import copy
 
-from bsddb3.db import *
+from rpmdb.db import *
 
 
 class TableDBError(StandardError): pass
similarity index 98%
rename from python/rpmdb/rpmdbutils.py
rename to python/rpmdb/dbutils.py
index fe08407..d1dbd5c 100644 (file)
 
 #
 # import the time.sleep function in a namespace safe way to allow
-# "from bsddb3.db import *"
+# "from rpmdb.db import *"
 #
 from time import sleep
 _sleep = sleep
 del sleep
 
-import _db
+import _rpmdb as _db
 
 _deadlock_MinSleepTime = 1.0/64  # always sleep at least N seconds between retrys
 _deadlock_MaxSleepTime = 1.0     # never sleep more than N seconds between retrys
diff --git a/python/test/test_all.py b/python/test/test_all.py
new file mode 100644 (file)
index 0000000..6fc011d
--- /dev/null
@@ -0,0 +1,58 @@
+"""
+Run all test cases.
+"""
+
+import sys
+import unittest
+
+verbose = 0
+if 'verbose' in sys.argv:
+    verbose = 1
+    sys.argv.remove('verbose')
+
+if 'silent' in sys.argv:  # take care of old flag, just in case
+    verbose = 0
+    sys.argv.remove('silent')
+
+
+# This little hack is for when this module is run as main and all the
+# other modules import it so they will still be able to get the right
+# verbose setting.  It's confusing but it works.
+import test_all
+test_all.verbose = verbose
+
+
+def suite():
+    test_modules = [ 'test_compat',
+                     'test_basics',
+                     'test_misc',
+                     'test_dbobj',
+                     'test_recno',
+                     'test_queue',
+                     'test_get_none',
+                     'test_dbshelve',
+                     'test_dbtables',
+                     'test_thread',
+                     'test_lock',
+                     'test_associate',
+                   ]
+
+    alltests = unittest.TestSuite()
+    for name in test_modules:
+        module = __import__(name)
+        alltests.addTest(module.suite())
+    return alltests
+
+
+if __name__ == '__main__':
+    from rpmdb import db
+    print '-=' * 38
+    print db.DB_VERSION_STRING
+    print 'rpmdb.db.version():   %s' % (db.version(), )
+    print 'rpmdb.db.__version__: %s' % db.__version__
+    print 'rpmdb.db.cvsid:       %s' % db.cvsid
+    print 'python version:        %s' % sys.version
+    print '-=' * 38
+
+    unittest.main( defaultTest='suite' )
+
diff --git a/python/test/test_associate.py b/python/test/test_associate.py
new file mode 100644 (file)
index 0000000..0db04ec
--- /dev/null
@@ -0,0 +1,323 @@
+"""
+TestCases for multi-threaded access to a DB.
+"""
+
+import sys, os, string
+import tempfile
+import time
+from pprint import pprint
+
+try:
+    from threading import Thread, currentThread
+    have_threads = 1
+except ImportError:
+    have_threads = 0
+
+import unittest
+from test_all import verbose
+
+from rpmdb import db, dbshelve
+
+
+#----------------------------------------------------------------------
+
+
+musicdata = {
+1 : ("Bad English", "The Price Of Love", "Rock"),
+2 : ("DNA featuring Suzanne Vega", "Tom's Diner", "Rock"),
+3 : ("George Michael", "Praying For Time", "Rock"),
+4 : ("Gloria Estefan", "Here We Are", "Rock"),
+5 : ("Linda Ronstadt", "Don't Know Much", "Rock"),
+6 : ("Michael Bolton", "How Am I Supposed To Live Without You", "Blues"),
+7 : ("Paul Young", "Oh Girl", "Rock"),
+8 : ("Paula Abdul", "Opposites Attract", "Rock"),
+9 : ("Richard Marx", "Should've Known Better", "Rock"),
+10: ("Rod Stewart", "Forever Young", "Rock"),
+11: ("Roxette", "Dangerous", "Rock"),
+12: ("Sheena Easton", "The Lover In Me", "Rock"),
+13: ("Sinead O'Connor", "Nothing Compares 2 U", "Rock"),
+14: ("Stevie B.", "Because I Love You", "Rock"),
+15: ("Taylor Dayne", "Love Will Lead You Back", "Rock"),
+16: ("The Bangles", "Eternal Flame", "Rock"),
+17: ("Wilson Phillips", "Release Me", "Rock"),
+18: ("Billy Joel", "Blonde Over Blue", "Rock"),
+19: ("Billy Joel", "Famous Last Words", "Rock"),
+20: ("Billy Joel", "Lullabye (Goodnight, My Angel)", "Rock"),
+21: ("Billy Joel", "The River Of Dreams", "Rock"),
+22: ("Billy Joel", "Two Thousand Years", "Rock"),
+23: ("Janet Jackson", "Alright", "Rock"),
+24: ("Janet Jackson", "Black Cat", "Rock"),
+25: ("Janet Jackson", "Come Back To Me", "Rock"),
+26: ("Janet Jackson", "Escapade", "Rock"),
+27: ("Janet Jackson", "Love Will Never Do (Without You)", "Rock"),
+28: ("Janet Jackson", "Miss You Much", "Rock"),
+29: ("Janet Jackson", "Rhythm Nation", "Rock"),
+30: ("Janet Jackson", "State Of The World", "Rock"),
+31: ("Janet Jackson", "The Knowledge", "Rock"),
+32: ("Spyro Gyra", "End of Romanticism", "Jazz"),
+33: ("Spyro Gyra", "Heliopolis", "Jazz"),
+34: ("Spyro Gyra", "Jubilee", "Jazz"),
+35: ("Spyro Gyra", "Little Linda", "Jazz"),
+36: ("Spyro Gyra", "Morning Dance", "Jazz"),
+37: ("Spyro Gyra", "Song for Lorraine", "Jazz"),
+38: ("Yes", "Owner Of A Lonely Heart", "Rock"),
+39: ("Yes", "Rhythm Of Love", "Rock"),
+40: ("Cusco", "Dream Catcher", "New Age"),
+41: ("Cusco", "Geronimos Laughter", "New Age"),
+42: ("Cusco", "Ghost Dance", "New Age"),
+43: ("Blue Man Group", "Drumbone", "New Age"),
+44: ("Blue Man Group", "Endless Column", "New Age"),
+45: ("Blue Man Group", "Klein Mandelbrot", "New Age"),
+46: ("Kenny G", "Silhouette", "Jazz"),
+47: ("Sade", "Smooth Operator", "Jazz"),
+48: ("David Arkenstone", "Papillon (On The Wings Of The Butterfly)", "New Age"),
+49: ("David Arkenstone", "Stepping Stars", "New Age"),
+50: ("David Arkenstone", "Carnation Lily Lily Rose", "New Age"),
+51: ("David Lanz", "Behind The Waterfall", "New Age"),
+52: ("David Lanz", "Cristofori's Dream", "New Age"),
+53: ("David Lanz", "Heartsounds", "New Age"),
+54: ("David Lanz", "Leaves on the Seine", "New Age"),
+}
+
+#----------------------------------------------------------------------
+
+
+class AssociateTestCase(unittest.TestCase):
+    keytype = ''
+
+    def setUp(self):
+        self.filename = self.__class__.__name__ + '.db'
+        homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home')
+        self.homeDir = homeDir
+        try: os.mkdir(homeDir)
+        except os.error: pass
+        self.env = db.DBEnv()
+        self.env.open(homeDir, db.DB_CREATE | db.DB_INIT_MPOOL |
+                               db.DB_INIT_LOCK | db.DB_THREAD)
+
+    def tearDown(self):
+        self.closeDB()
+        self.env.close()
+        import glob
+        files = glob.glob(os.path.join(self.homeDir, '*'))
+        for file in files:
+            os.remove(file)
+
+    def addDataToDB(self, d):
+        for key, value in musicdata.items():
+            if type(self.keytype) == type(''):
+                key = "%02d" % key
+            d.put(key, string.join(value, '|'))
+
+
+
+    def createDB(self):
+        self.primary = db.DB(self.env)
+        self.primary.open(self.filename, "primary", self.dbtype,
+                          db.DB_CREATE | db.DB_THREAD)
+
+    def closeDB(self):
+        self.primary.close()
+
+    def getDB(self):
+        return self.primary
+
+
+
+    def test01_associateWithDB(self):
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test01_associateWithDB..." % self.__class__.__name__
+
+        self.createDB()
+
+        secDB = db.DB(self.env)
+        secDB.set_flags(db.DB_DUP)
+        secDB.open(self.filename, "secondary", db.DB_BTREE, db.DB_CREATE | db.DB_THREAD)
+        self.getDB().associate(secDB, self.getGenre)
+
+        self.addDataToDB(self.getDB())
+
+        self.finish_test(secDB)
+
+
+    def test02_associateAfterDB(self):
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test02_associateAfterDB..." % self.__class__.__name__
+
+        self.createDB()
+        self.addDataToDB(self.getDB())
+
+        secDB = db.DB(self.env)
+        secDB.set_flags(db.DB_DUP)
+        secDB.open(self.filename, "secondary", db.DB_BTREE, db.DB_CREATE | db.DB_THREAD)
+
+        # adding the DB_CREATE flag will cause it to index existing records
+        self.getDB().associate(secDB, self.getGenre, db.DB_CREATE)
+
+        self.finish_test(secDB)
+
+
+
+
+    def finish_test(self, secDB):
+        if verbose:
+            print "Primary key traversal:"
+        c = self.getDB().cursor()
+        count = 0
+        rec = c.first()
+        while rec is not None:
+            if type(self.keytype) == type(''):
+                assert string.atoi(rec[0])  # for primary db, key is a number
+            else:
+                assert rec[0] and type(rec[0]) == type(0)
+            count = count + 1
+            if verbose:
+                print rec
+            rec = c.next()
+        assert count == len(musicdata) # all items accounted for
+
+
+        if verbose:
+            print "Secondary key traversal:"
+        c = secDB.cursor()
+        count = 0
+        rec = c.first()
+        assert rec[0] == "Jazz"
+        while rec is not None:
+            count = count + 1
+            if verbose:
+                print rec
+            rec = c.next()
+        assert count == len(musicdata)-1   # all items accounted for EXCEPT for 1 with "Blues" genre
+
+
+
+    def getGenre(self, priKey, priData):
+        assert type(priData) == type("")
+        if verbose:
+            print 'getGenre key:', `priKey`, 'data:', `priData`
+        genre = string.split(priData, '|')[2]
+        if genre == 'Blues':
+            return db.DB_DONOTINDEX
+        else:
+            return genre
+
+
+#----------------------------------------------------------------------
+
+
+class AssociateHashTestCase(AssociateTestCase):
+    dbtype = db.DB_HASH
+
+class AssociateBTreeTestCase(AssociateTestCase):
+    dbtype = db.DB_BTREE
+
+class AssociateRecnoTestCase(AssociateTestCase):
+    dbtype = db.DB_RECNO
+    keytype = 0
+
+
+#----------------------------------------------------------------------
+
+class ShelveAssociateTestCase(AssociateTestCase):
+
+    def createDB(self):
+        self.primary = dbshelve.open(self.filename,
+                                     dbname="primary",
+                                     dbenv=self.env,
+                                     filetype=self.dbtype)
+
+    def addDataToDB(self, d):
+        for key, value in musicdata.items():
+            if type(self.keytype) == type(''):
+                key = "%02d" % key
+            d.put(key, value)    # save the value as is this time
+
+
+    def getGenre(self, priKey, priData):
+        assert type(priData) == type(())
+        if verbose:
+            print 'getGenre key:', `priKey`, 'data:', `priData`
+        genre = priData[2]
+        if genre == 'Blues':
+            return db.DB_DONOTINDEX
+        else:
+            return genre
+
+
+class ShelveAssociateHashTestCase(ShelveAssociateTestCase):
+      dbtype = db.DB_HASH
+
+class ShelveAssociateBTreeTestCase(ShelveAssociateTestCase):
+      dbtype = db.DB_BTREE
+
+class ShelveAssociateRecnoTestCase(ShelveAssociateTestCase):
+    dbtype = db.DB_RECNO
+    keytype = 0
+
+
+#----------------------------------------------------------------------
+
+class ThreadedAssociateTestCase(AssociateTestCase):
+
+    def addDataToDB(self, d):
+        t1 = Thread(target = self.writer1,
+                    args = (d, ))
+        t2 = Thread(target = self.writer2,
+                    args = (d, ))
+
+        t1.start()
+        t2.start()
+        t1.join()
+        t2.join()
+
+    def writer1(self, d):
+        for key, value in musicdata.items():
+            if type(self.keytype) == type(''):
+                key = "%02d" % key
+            d.put(key, string.join(value, '|'))
+
+    def writer2(self, d):
+        for x in range(100, 600):
+            key = 'z%2d' % x
+            value = [key] * 4
+            d.put(key, string.join(value, '|'))
+
+
+class ThreadedAssociateHashTestCase(ShelveAssociateTestCase):
+      dbtype = db.DB_HASH
+
+class ThreadedAssociateBTreeTestCase(ShelveAssociateTestCase):
+      dbtype = db.DB_BTREE
+
+class ThreadedAssociateRecnoTestCase(ShelveAssociateTestCase):
+    dbtype = db.DB_RECNO
+    keytype = 0
+
+
+#----------------------------------------------------------------------
+
+def suite():
+    theSuite = unittest.TestSuite()
+
+    if db.version() >= (3, 3, 11):
+        theSuite.addTest(unittest.makeSuite(AssociateHashTestCase))
+        theSuite.addTest(unittest.makeSuite(AssociateBTreeTestCase))
+        theSuite.addTest(unittest.makeSuite(AssociateRecnoTestCase))
+
+        theSuite.addTest(unittest.makeSuite(ShelveAssociateHashTestCase))
+        theSuite.addTest(unittest.makeSuite(ShelveAssociateBTreeTestCase))
+        theSuite.addTest(unittest.makeSuite(ShelveAssociateRecnoTestCase))
+
+        if have_threads:
+            theSuite.addTest(unittest.makeSuite(ThreadedAssociateHashTestCase))
+            theSuite.addTest(unittest.makeSuite(ThreadedAssociateBTreeTestCase))
+            theSuite.addTest(unittest.makeSuite(ThreadedAssociateRecnoTestCase))
+
+    return theSuite
+
+
+if __name__ == '__main__':
+    unittest.main( defaultTest='suite' )
diff --git a/python/test/test_basics.py b/python/test/test_basics.py
new file mode 100644 (file)
index 0000000..1b62bd8
--- /dev/null
@@ -0,0 +1,776 @@
+"""
+Basic TestCases for BTree and hash DBs, with and without a DBEnv, with
+various DB flags, etc.
+"""
+
+import sys, os, string
+import tempfile
+from pprint import pprint
+import unittest
+
+from rpmdb import db
+
+from test_all import verbose
+
+
+#----------------------------------------------------------------------
+
+class VersionTestCase(unittest.TestCase):
+    def test00_version(self):
+        info = db.version()
+        if verbose:
+            print '\n', '-=' * 20
+            print 'rpmdb.db.version(): %s' % (info, )
+            print db.DB_VERSION_STRING
+            print '-=' * 20
+        assert info == (db.DB_VERSION_MAJOR, db.DB_VERSION_MINOR, db.DB_VERSION_PATCH)
+
+#----------------------------------------------------------------------
+
+class BasicTestCase(unittest.TestCase):
+    dbtype       = db.DB_UNKNOWN  # must be set in derived class
+    dbopenflags  = 0
+    dbsetflags   = 0
+    dbmode       = 0660
+    dbname       = None
+    useEnv       = 0
+    envflags     = 0
+
+    def setUp(self):
+        if self.useEnv:
+            homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home')
+            try: os.mkdir(homeDir)
+            except os.error: pass
+            self.env = db.DBEnv()
+            self.env.set_lg_max(1024*1024)
+            self.env.open(homeDir, self.envflags | db.DB_CREATE)
+            tempfile.tempdir = homeDir
+            self.filename = os.path.split(tempfile.mktemp())[1]
+            tempfile.tempdir = None
+            self.homeDir = homeDir
+        else:
+            self.env = None
+            self.filename = tempfile.mktemp()
+
+        # create and open the DB
+        self.d = db.DB(self.env)
+        self.d.set_flags(self.dbsetflags)
+        if self.dbname:
+            self.d.open(self.filename, self.dbname, self.dbtype,
+                        self.dbopenflags|db.DB_CREATE, self.dbmode)
+        else:
+            self.d.open(self.filename,   # try out keyword args
+                        mode = self.dbmode,
+                        dbtype = self.dbtype, flags = self.dbopenflags|db.DB_CREATE)
+
+        self.populateDB()
+
+
+    def tearDown(self):
+        self.d.close()
+        if self.env is not None:
+            self.env.close()
+
+            import glob
+            files = glob.glob(os.path.join(self.homeDir, '*'))
+            for file in files:
+                os.remove(file)
+
+            ## Make a new DBEnv to remove the env files from the home dir.
+            ## (It can't be done while the env is open, nor after it has been
+            ## closed, so we make a new one to do it.)
+            #e = db.DBEnv()
+            #e.remove(self.homeDir)
+            #os.remove(os.path.join(self.homeDir, self.filename))
+
+        else:
+            os.remove(self.filename)
+
+
+
+    def populateDB(self):
+        d = self.d
+        for x in range(500):
+            key = '%04d' % (1000 - x)  # insert keys in reverse order
+            data = self.makeData(key)
+            d.put(key, data)
+
+        for x in range(500):
+            key = '%04d' % x  # and now some in forward order
+            data = self.makeData(key)
+            d.put(key, data)
+
+        num = len(d)
+        if verbose:
+            print "created %d records" % num
+
+
+    def makeData(self, key):
+        return string.join([key] * 5, '-')
+
+
+
+    #----------------------------------------
+
+    def test01_GetsAndPuts(self):
+        d = self.d
+
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test01_GetsAndPuts..." % self.__class__.__name__
+
+        for key in ['0001', '0100', '0400', '0700', '0999']:
+            data = d.get(key)
+            if verbose:
+                print data
+
+        assert d.get('0321') == '0321-0321-0321-0321-0321'
+
+        # By default non-existant keys return None...
+        assert d.get('abcd') == None
+
+        # ...but they raise exceptions in other situations.  Call
+        # set_get_returns_none() to change it.
+        try:
+            d.delete('abcd')
+        except db.DBNotFoundError, val:
+            assert val[0] == db.DB_NOTFOUND
+            if verbose: print val
+        else:
+            self.fail("expected exception")
+
+
+        d.put('abcd', 'a new record')
+        assert d.get('abcd') == 'a new record'
+
+        d.put('abcd', 'same key')
+        if self.dbsetflags & db.DB_DUP:
+            assert d.get('abcd') == 'a new record'
+        else:
+            assert d.get('abcd') == 'same key'
+
+
+        try:
+            d.put('abcd', 'this should fail', flags=db.DB_NOOVERWRITE)
+        except db.DBKeyExistError, val:
+            assert val[0] == db.DB_KEYEXIST
+            if verbose: print val
+        else:
+            self.fail("expected exception")
+
+        if self.dbsetflags & db.DB_DUP:
+            assert d.get('abcd') == 'a new record'
+        else:
+            assert d.get('abcd') == 'same key'
+
+
+        d.sync()
+        d.close()
+        del d
+
+        self.d = db.DB(self.env)
+        if self.dbname:
+            self.d.open(self.filename, self.dbname)
+        else:
+            self.d.open(self.filename)
+        d = self.d
+
+        assert d.get('0321') == '0321-0321-0321-0321-0321'
+        if self.dbsetflags & db.DB_DUP:
+            assert d.get('abcd') == 'a new record'
+        else:
+            assert d.get('abcd') == 'same key'
+
+        rec = d.get_both('0555', '0555-0555-0555-0555-0555')
+        if verbose:
+            print rec
+
+        assert d.get_both('0555', 'bad data') == None
+
+        # test default value
+        data = d.get('bad key', 'bad data')
+        assert data == 'bad data'
+
+        # any object can pass through
+        data = d.get('bad key', self)
+        assert data == self
+
+        s = d.stat()
+        assert type(s) == type({})
+        if verbose:
+            print 'd.stat() returned this dictionary:'
+            pprint(s)
+
+
+    #----------------------------------------
+
+    def test02_DictionaryMethods(self):
+        d = self.d
+
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test02_DictionaryMethods..." % self.__class__.__name__
+
+        for key in ['0002', '0101', '0401', '0701', '0998']:
+            data = d[key]
+            assert data == self.makeData(key)
+            if verbose:
+                print data
+
+        assert len(d) == 1000
+        keys = d.keys()
+        assert len(keys) == 1000
+        assert type(keys) == type([])
+
+        d['new record'] = 'a new record'
+        assert len(d) == 1001
+        keys = d.keys()
+        assert len(keys) == 1001
+
+        d['new record'] = 'a replacement record'
+        assert len(d) == 1001
+        keys = d.keys()
+        assert len(keys) == 1001
+
+        if verbose:
+            print "the first 10 keys are:"
+            pprint(keys[:10])
+
+        assert d['new record'] == 'a replacement record'
+
+        assert d.has_key('0001') == 1
+        assert d.has_key('spam') == 0
+
+        items = d.items()
+        assert len(items) == 1001
+        assert type(items) == type([])
+        assert type(items[0]) == type(())
+        assert len(items[0]) == 2
+
+        if verbose:
+            print "the first 10 items are:"
+            pprint(items[:10])
+
+        values = d.values()
+        assert len(values) == 1001
+        assert type(values) == type([])
+
+        if verbose:
+            print "the first 10 values are:"
+            pprint(values[:10])
+
+
+
+    #----------------------------------------
+
+    def test03_SimpleCursorStuff(self):
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test03_SimpleCursorStuff..." % self.__class__.__name__
+
+        c = self.d.cursor()
+
+
+        rec = c.first()
+        count = 0
+        while rec is not None:
+            count = count + 1
+            if verbose and count % 100 == 0:
+                print rec
+            rec = c.next()
+
+        assert count == 1000
+
+
+        rec = c.last()
+        count = 0
+        while rec is not None:
+            count = count + 1
+            if verbose and count % 100 == 0:
+                print rec
+            rec = c.prev()
+
+        assert count == 1000
+
+        rec = c.set('0505')
+        rec2 = c.current()
+        assert rec == rec2
+        assert rec[0] == '0505'
+        assert rec[1] == self.makeData('0505')
+
+        try:
+            c.set('bad key')
+        except db.DBNotFoundError, val:
+            assert val[0] == db.DB_NOTFOUND
+            if verbose: print val
+        else:
+            self.fail("expected exception")
+
+        rec = c.get_both('0404', self.makeData('0404'))
+        assert rec == ('0404', self.makeData('0404'))
+
+        try:
+            c.get_both('0404', 'bad data')
+        except db.DBNotFoundError, val:
+            assert val[0] == db.DB_NOTFOUND
+            if verbose: print val
+        else:
+            self.fail("expected exception")
+
+        if self.d.get_type() == db.DB_BTREE:
+            rec = c.set_range('011')
+            if verbose:
+                print "searched for '011', found: ", rec
+
+        c.set('0499')
+        c.delete()
+        try:
+            rec = c.current()
+        except db.DBKeyEmptyError, val:
+            assert val[0] == db.DB_KEYEMPTY
+            if verbose: print val
+        else:
+            self.fail('exception expected')
+
+        c.next()
+        c2 = c.dup(db.DB_POSITION)
+        assert c.current() == c2.current()
+
+        c2.put('', 'a new value', db.DB_CURRENT)
+        assert c.current() == c2.current()
+        assert c.current()[1] == 'a new value'
+
+        c.close()
+        c2.close()
+
+        # time to abuse the closed cursors and hope we don't crash
+        methods_to_test = {
+            'current': (),
+            'delete': (),
+            'dup': (db.DB_POSITION,),
+            'first': (),
+            'get': (0,),
+            'next': (),
+            'prev': (),
+            'last': (),
+            'put':('', 'spam', db.DB_CURRENT),
+            'set': ("0505",),
+        }
+        for method, args in methods_to_test.items():
+            try:
+                if verbose:
+                    print "attempting to use a closed cursor's %s method" % method
+                # a bug may cause a NULL pointer dereference...
+                apply(getattr(c, method), args)
+            except db.DBError, val:
+                assert val[0] == 0
+                if verbose: print val
+            else:
+                self.fail("no exception raised when using a buggy cursor's %s method" % method)
+
+    #----------------------------------------
+
+    def test04_PartialGetAndPut(self):
+        d = self.d
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test04_PartialGetAndPut..." % self.__class__.__name__
+
+        key = "partialTest"
+        data = "1" * 1000 + "2" * 1000
+        d.put(key, data)
+        assert d.get(key) == data
+        assert d.get(key, dlen=20, doff=990) == ("1" * 10) + ("2" * 10)
+
+        d.put("partialtest2", ("1" * 30000) + "robin" )
+        assert d.get("partialtest2", dlen=5, doff=30000) == "robin"
+
+        # There seems to be a bug in DB here...  Commented out the test for now.
+        ##assert d.get("partialtest2", dlen=5, doff=30010) == ""
+
+        if self.dbsetflags != db.DB_DUP:
+            # Partial put with duplicate records requires a cursor
+            d.put(key, "0000", dlen=2000, doff=0)
+            assert d.get(key) == "0000"
+
+            d.put(key, "1111", dlen=1, doff=2)
+            assert d.get(key) == "0011110"
+
+    #----------------------------------------
+
+    def test05_GetSize(self):
+        d = self.d
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test05_GetSize..." % self.__class__.__name__
+
+        for i in range(1, 50000, 500):
+            key = "size%s" % i
+            #print "before ", i,
+            d.put(key, "1" * i)
+            #print "after",
+            assert d.get_size(key) == i
+            #print "done"
+
+#----------------------------------------------------------------------
+
+
+class BasicBTreeTestCase(BasicTestCase):
+    dbtype = db.DB_BTREE
+
+
+class BasicHashTestCase(BasicTestCase):
+    dbtype = db.DB_HASH
+
+
+class BasicBTreeWithThreadFlagTestCase(BasicTestCase):
+    dbtype = db.DB_BTREE
+    dbopenflags = db.DB_THREAD
+
+
+class BasicHashWithThreadFlagTestCase(BasicTestCase):
+    dbtype = db.DB_HASH
+    dbopenflags = db.DB_THREAD
+
+
+class BasicBTreeWithEnvTestCase(BasicTestCase):
+    dbtype = db.DB_BTREE
+    dbopenflags = db.DB_THREAD
+    useEnv = 1
+    envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK
+
+
+class BasicHashWithEnvTestCase(BasicTestCase):
+    dbtype = db.DB_HASH
+    dbopenflags = db.DB_THREAD
+    useEnv = 1
+    envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK
+
+
+#----------------------------------------------------------------------
+
+class BasicTransactionTestCase(BasicTestCase):
+    dbopenflags = db.DB_THREAD
+    useEnv = 1
+    envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK | db.DB_INIT_TXN
+
+
+    def tearDown(self):
+        self.txn.commit()
+        BasicTestCase.tearDown(self)
+
+
+    def populateDB(self):
+        d = self.d
+        txn = self.env.txn_begin()
+        for x in range(500):
+            key = '%04d' % (1000 - x)  # insert keys in reverse order
+            data = self.makeData(key)
+            d.put(key, data, txn)
+
+        for x in range(500):
+            key = '%04d' % x  # and now some in forward order
+            data = self.makeData(key)
+            d.put(key, data, txn)
+
+        txn.commit()
+
+        num = len(d)
+        if verbose:
+            print "created %d records" % num
+
+        self.txn = self.env.txn_begin()
+
+
+
+    def test06_Transactions(self):
+        d = self.d
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test06_Transactions..." % self.__class__.__name__
+
+        assert d.get('new rec', txn=self.txn) == None
+        d.put('new rec', 'this is a new record', self.txn)
+        assert d.get('new rec', txn=self.txn) == 'this is a new record'
+        self.txn.abort()
+        assert d.get('new rec') == None
+
+        self.txn = self.env.txn_begin()
+
+        assert d.get('new rec', txn=self.txn) == None
+        d.put('new rec', 'this is a new record', self.txn)
+        assert d.get('new rec', txn=self.txn) == 'this is a new record'
+        self.txn.commit()
+        assert d.get('new rec') == 'this is a new record'
+
+        self.txn = self.env.txn_begin()
+        c = d.cursor(self.txn)
+        rec = c.first()
+        count = 0
+        while rec is not None:
+            count = count + 1
+            if verbose and count % 100 == 0:
+                print rec
+            rec = c.next()
+        assert count == 1001
+
+        c.close()                # Cursors *MUST* be closed before commit!
+        self.txn.commit()
+
+        # flush pending updates
+        try:
+            self.env.txn_checkpoint (0, 0, 0)
+        except db.DBIncompleteError:
+            pass
+
+        # must have at least one log file present:
+        logs = self.env.log_archive(db.DB_ARCH_ABS | db.DB_ARCH_LOG)
+        assert logs != None
+        for log in logs:
+            if verbose:
+                print 'log file: ' + log
+
+        self.txn = self.env.txn_begin()
+
+
+
+class BTreeTransactionTestCase(BasicTransactionTestCase):
+    dbtype = db.DB_BTREE
+
+class HashTransactionTestCase(BasicTransactionTestCase):
+    dbtype = db.DB_HASH
+
+
+
+#----------------------------------------------------------------------
+
+class BTreeRecnoTestCase(BasicTestCase):
+    dbtype     = db.DB_BTREE
+    dbsetflags = db.DB_RECNUM
+
+    def test07_RecnoInBTree(self):
+        d = self.d
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test07_RecnoInBTree..." % self.__class__.__name__
+
+        rec = d.get(200)
+        assert type(rec) == type(())
+        assert len(rec) == 2
+        if verbose:
+            print "Record #200 is ", rec
+
+        c = d.cursor()
+        c.set('0200')
+        num = c.get_recno()
+        assert type(num) == type(1)
+        if verbose:
+            print "recno of d['0200'] is ", num
+
+        rec = c.current()
+        assert c.set_recno(num) == rec
+
+        c.close()
+
+
+
+class BTreeRecnoWithThreadFlagTestCase(BTreeRecnoTestCase):
+    dbopenflags = db.DB_THREAD
+
+#----------------------------------------------------------------------
+
+class BasicDUPTestCase(BasicTestCase):
+    dbsetflags = db.DB_DUP
+
+    def test08_DuplicateKeys(self):
+        d = self.d
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test08_DuplicateKeys..." % self.__class__.__name__
+
+        d.put("dup0", "before")
+        for x in string.split("The quick brown fox jumped over the lazy dog."):
+            d.put("dup1", x)
+        d.put("dup2", "after")
+
+        data = d.get("dup1")
+        assert data == "The"
+        if verbose:
+            print data
+
+        c = d.cursor()
+        rec = c.set("dup1")
+        assert rec == ('dup1', 'The')
+
+        next = c.next()
+        assert next == ('dup1', 'quick')
+
+        rec = c.set("dup1")
+        count = c.count()
+        assert count == 9
+
+        next_dup = c.next_dup()
+        assert next_dup == ('dup1', 'quick')
+
+        rec = c.set('dup1')
+        while rec is not None:
+            if verbose:
+                print rec
+            rec = c.next_dup()
+
+        c.set('dup1')
+        rec = c.next_nodup()
+        assert rec[0] != 'dup1'
+        if verbose:
+            print rec
+
+        c.close()
+
+
+
+class BTreeDUPTestCase(BasicDUPTestCase):
+    dbtype = db.DB_BTREE
+
+class HashDUPTestCase(BasicDUPTestCase):
+    dbtype = db.DB_HASH
+
+class BTreeDUPWithThreadTestCase(BasicDUPTestCase):
+    dbtype = db.DB_BTREE
+    dbopenflags = db.DB_THREAD
+
+class HashDUPWithThreadTestCase(BasicDUPTestCase):
+    dbtype = db.DB_HASH
+    dbopenflags = db.DB_THREAD
+
+
+#----------------------------------------------------------------------
+
+class BasicMultiDBTestCase(BasicTestCase):
+    dbname = 'first'
+
+    def otherType(self):
+        if self.dbtype == db.DB_BTREE:
+            return db.DB_HASH
+        else:
+            return db.DB_BTREE
+
+    def test09_MultiDB(self):
+        d1 = self.d
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test09_MultiDB..." % self.__class__.__name__
+
+        d2 = db.DB(self.env)
+        d2.open(self.filename, "second", self.dbtype, self.dbopenflags|db.DB_CREATE)
+        d3 = db.DB(self.env)
+        d3.open(self.filename, "third", self.otherType(), self.dbopenflags|db.DB_CREATE)
+
+        for x in string.split("The quick brown fox jumped over the lazy dog"):
+            d2.put(x, self.makeData(x))
+
+        for x in string.letters:
+            d3.put(x, x*70)
+
+        d1.sync()
+        d2.sync()
+        d3.sync()
+        d1.close()
+        d2.close()
+        d3.close()
+
+        self.d = d1 = d2 = d3 = None
+
+        self.d = d1 = db.DB(self.env)
+        d1.open(self.filename, self.dbname, flags = self.dbopenflags)
+        d2 = db.DB(self.env)
+        d2.open(self.filename, "second",  flags = self.dbopenflags)
+        d3 = db.DB(self.env)
+        d3.open(self.filename, "third", flags = self.dbopenflags)
+
+        c1 = d1.cursor()
+        c2 = d2.cursor()
+        c3 = d3.cursor()
+
+        count = 0
+        rec = c1.first()
+        while rec is not None:
+            count = count + 1
+            if verbose and (count % 50) == 0:
+                print rec
+            rec = c1.next()
+        assert count == 1000
+
+        count = 0
+        rec = c2.first()
+        while rec is not None:
+            count = count + 1
+            if verbose:
+                print rec
+            rec = c2.next()
+        assert count == 9
+
+        count = 0
+        rec = c3.first()
+        while rec is not None:
+            count = count + 1
+            if verbose:
+                print rec
+            rec = c3.next()
+        assert count == 52
+
+
+        c1.close()
+        c2.close()
+        c3.close()
+
+        d2.close()
+        d3.close()
+
+
+
+# Strange things happen if you try to use Multiple DBs per file without a
+# DBEnv with MPOOL and LOCKing...
+
+class BTreeMultiDBTestCase(BasicMultiDBTestCase):
+    dbtype = db.DB_BTREE
+    dbopenflags = db.DB_THREAD
+    useEnv = 1
+    envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK
+
+class HashMultiDBTestCase(BasicMultiDBTestCase):
+    dbtype = db.DB_HASH
+    dbopenflags = db.DB_THREAD
+    useEnv = 1
+    envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK
+
+
+#----------------------------------------------------------------------
+#----------------------------------------------------------------------
+
+def suite():
+    theSuite = unittest.TestSuite()
+
+    theSuite.addTest(unittest.makeSuite(VersionTestCase))
+    theSuite.addTest(unittest.makeSuite(BasicBTreeTestCase))
+    theSuite.addTest(unittest.makeSuite(BasicHashTestCase))
+    theSuite.addTest(unittest.makeSuite(BasicBTreeWithThreadFlagTestCase))
+    theSuite.addTest(unittest.makeSuite(BasicHashWithThreadFlagTestCase))
+    theSuite.addTest(unittest.makeSuite(BasicBTreeWithEnvTestCase))
+    theSuite.addTest(unittest.makeSuite(BasicHashWithEnvTestCase))
+    theSuite.addTest(unittest.makeSuite(BTreeTransactionTestCase))
+    theSuite.addTest(unittest.makeSuite(HashTransactionTestCase))
+    theSuite.addTest(unittest.makeSuite(BTreeRecnoTestCase))
+    theSuite.addTest(unittest.makeSuite(BTreeRecnoWithThreadFlagTestCase))
+    theSuite.addTest(unittest.makeSuite(BTreeDUPTestCase))
+    theSuite.addTest(unittest.makeSuite(HashDUPTestCase))
+    theSuite.addTest(unittest.makeSuite(BTreeDUPWithThreadTestCase))
+    theSuite.addTest(unittest.makeSuite(HashDUPWithThreadTestCase))
+    theSuite.addTest(unittest.makeSuite(BTreeMultiDBTestCase))
+    theSuite.addTest(unittest.makeSuite(HashMultiDBTestCase))
+
+    return theSuite
+
+
+if __name__ == '__main__':
+    unittest.main( defaultTest='suite' )
+
diff --git a/python/test/test_compat.py b/python/test/test_compat.py
new file mode 100644 (file)
index 0000000..35f281b
--- /dev/null
@@ -0,0 +1,169 @@
+"""
+Test cases adapted from the test_bsddb.py module in Python's
+regression test suite.
+"""
+
+import sys, os, string
+from rpmdb import hashopen, btopen, rnopen
+import rpmdb
+import unittest
+import tempfile
+
+from test_all import verbose
+
+
+
+class CompatibilityTestCase(unittest.TestCase):
+    def setUp(self):
+        self.filename = tempfile.mktemp()
+
+    def tearDown(self):
+        try:
+            os.remove(self.filename)
+        except os.error:
+            pass
+
+
+    def test01_btopen(self):
+        self.do_bthash_test(btopen, 'btopen')
+
+    def test02_hashopen(self):
+        self.do_bthash_test(hashopen, 'hashopen')
+
+    def test03_rnopen(self):
+        data = string.split("The quick brown fox jumped over the lazy dog.")
+        if verbose:
+            print "\nTesting: rnopen"
+
+        f = rnopen(self.filename, 'c')
+        for x in range(len(data)):
+            f[x+1] = data[x]
+
+        getTest = (f[1], f[2], f[3])
+        if verbose:
+            print '%s %s %s' % getTest
+
+        assert getTest[1] == 'quick', 'data mismatch!'
+
+        f[25] = 'twenty-five'
+        f.close()
+        del f
+
+        f = rnopen(self.filename, 'w')
+        f[20] = 'twenty'
+
+        def noRec(f):
+            rec = f[15]
+        self.assertRaises(KeyError, noRec, f)
+
+        def badKey(f):
+            rec = f['a string']
+        self.assertRaises(TypeError, badKey, f)
+
+        del f[3]
+
+        rec = f.first()
+        while rec:
+            if verbose:
+                print rec
+            try:
+                rec = f.next()
+            except KeyError:
+                break
+
+        f.close()
+
+
+    def test04_n_flag(self):
+        f = hashopen(self.filename, 'n')
+        f.close()
+
+
+
+    def do_bthash_test(self, factory, what):
+        if verbose:
+            print '\nTesting: ', what
+
+        f = factory(self.filename, 'c')
+        if verbose:
+            print 'creation...'
+
+        # truth test
+        if f:
+            if verbose: print "truth test: true"
+        else:
+            if verbose: print "truth test: false"
+
+        f['0'] = ''
+        f['a'] = 'Guido'
+        f['b'] = 'van'
+        f['c'] = 'Rossum'
+        f['d'] = 'invented'
+        f['f'] = 'Python'
+        if verbose:
+            print '%s %s %s' % (f['a'], f['b'], f['c'])
+
+        if verbose:
+            print 'key ordering...'
+        f.set_location(f.first()[0])
+        while 1:
+            try:
+                rec = f.next()
+            except KeyError:
+                assert rec == f.last(), 'Error, last <> last!'
+                f.previous()
+                break
+            if verbose:
+                print rec
+
+        assert f.has_key('f'), 'Error, missing key!'
+
+        f.sync()
+        f.close()
+        # truth test
+        try:
+            if f:
+                if verbose: print "truth test: true"
+            else:
+                if verbose: print "truth test: false"
+        except rpmdb.error:
+            pass
+        else:
+            self.fail("Exception expected")
+
+        del f
+
+        if verbose:
+            print 'modification...'
+        f = factory(self.filename, 'w')
+        f['d'] = 'discovered'
+
+        if verbose:
+            print 'access...'
+        for key in f.keys():
+            word = f[key]
+            if verbose:
+                print word
+
+        def noRec(f):
+            rec = f['no such key']
+        self.assertRaises(KeyError, noRec, f)
+
+        def badKey(f):
+            rec = f[15]
+        self.assertRaises(TypeError, badKey, f)
+
+        f.close()
+
+
+#----------------------------------------------------------------------
+
+
+def suite():
+    return unittest.makeSuite(CompatibilityTestCase)
+
+
+if __name__ == '__main__':
+    unittest.main( defaultTest='suite' )
+
+
diff --git a/python/test/test_dbobj.py b/python/test/test_dbobj.py
new file mode 100644 (file)
index 0000000..2bd7784
--- /dev/null
@@ -0,0 +1,72 @@
+
+import sys, os, string
+import unittest
+import glob
+
+from rpmdb import db, dbobj
+
+
+#----------------------------------------------------------------------
+
+class dbobjTestCase(unittest.TestCase):
+    """Verify that dbobj.DB and dbobj.DBEnv work properly"""
+    db_home = 'db_home'
+    db_name = 'test-dbobj.db'
+
+    def setUp(self):
+        homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home')
+        self.homeDir = homeDir
+        try: os.mkdir(homeDir)
+        except os.error: pass
+
+    def tearDown(self):
+        if hasattr(self, 'db'):
+            del self.db
+        if hasattr(self, 'env'):
+            del self.env
+        files = glob.glob(os.path.join(self.homeDir, '*'))
+        for file in files:
+            os.remove(file)
+
+    def test01_both(self):
+        class TestDBEnv(dbobj.DBEnv): pass
+        class TestDB(dbobj.DB):
+            def put(self, key, *args, **kwargs):
+                key = string.upper(key)
+                # call our parent classes put method with an upper case key
+                return apply(dbobj.DB.put, (self, key) + args, kwargs)
+        self.env = TestDBEnv()
+        self.env.open(self.db_home, db.DB_CREATE | db.DB_INIT_MPOOL)
+        self.db = TestDB(self.env)
+        self.db.open(self.db_name, db.DB_HASH, db.DB_CREATE)
+        self.db.put('spam', 'eggs')
+        assert self.db.get('spam') == None, "overridden dbobj.DB.put() method failed [1]"
+        assert self.db.get('SPAM') == 'eggs', "overridden dbobj.DB.put() method failed [2]"
+        self.db.close()
+        self.env.close()
+
+    def test02_dbobj_dict_interface(self):
+        self.env = dbobj.DBEnv()
+        self.env.open(self.db_home, db.DB_CREATE | db.DB_INIT_MPOOL)
+        self.db = dbobj.DB(self.env)
+        self.db.open(self.db_name+'02', db.DB_HASH, db.DB_CREATE)
+        # __setitem__
+        self.db['spam'] = 'eggs'
+        # __len__
+        assert len(self.db) == 1
+        # __getitem__
+        assert self.db['spam'] == 'eggs'
+        # __del__
+        del self.db['spam']
+        assert self.db.get('spam') == None, "dbobj __del__ failed"
+        self.db.close()
+        self.env.close()
+
+#----------------------------------------------------------------------
+
+def suite():
+    return unittest.makeSuite(dbobjTestCase)
+
+if __name__ == '__main__':
+    unittest.main( defaultTest='suite' )
+
diff --git a/python/test/test_dbshelve.py b/python/test/test_dbshelve.py
new file mode 100644 (file)
index 0000000..d5b974a
--- /dev/null
@@ -0,0 +1,305 @@
+"""
+TestCases for checking dbShelve objects.
+"""
+
+import sys, os, string
+import tempfile, random
+from pprint import pprint
+from types import *
+import unittest
+
+from rpmdb import dbshelve, db
+
+from test_all import verbose
+
+
+#----------------------------------------------------------------------
+
+# We want the objects to be comparable so we can test dbshelve.values
+# later on.
+class DataClass:
+    def __init__(self):
+        self.value = random.random()
+
+    def __cmp__(self, other):
+        return cmp(self.value, other)
+
+class DBShelveTestCase(unittest.TestCase):
+    def setUp(self):
+        self.filename = tempfile.mktemp()
+        self.do_open()
+
+    def tearDown(self):
+        self.do_close()
+        try:
+            os.remove(self.filename)
+        except os.error:
+            pass
+
+    def populateDB(self, d):
+        for x in string.letters:
+            d['S' + x] = 10 * x           # add a string
+            d['I' + x] = ord(x)           # add an integer
+            d['L' + x] = [x] * 10         # add a list
+
+            inst = DataClass()            # add an instance
+            inst.S = 10 * x
+            inst.I = ord(x)
+            inst.L = [x] * 10
+            d['O' + x] = inst
+
+
+    # overridable in derived classes to affect how the shelf is created/opened
+    def do_open(self):
+        self.d = dbshelve.open(self.filename)
+
+    # and closed...
+    def do_close(self):
+        self.d.close()
+
+
+
+    def test01_basics(self):
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test01_basics..." % self.__class__.__name__
+
+        self.populateDB(self.d)
+        self.d.sync()
+        self.do_close()
+        self.do_open()
+        d = self.d
+
+        l = len(d)
+        k = d.keys()
+        s = d.stat()
+        f = d.fd()
+
+        if verbose:
+            print "length:", l
+            print "keys:", k
+            print "stats:", s
+
+        assert 0 == d.has_key('bad key')
+        assert 1 == d.has_key('IA')
+        assert 1 == d.has_key('OA')
+
+        d.delete('IA')
+        del d['OA']
+        assert 0 == d.has_key('IA')
+        assert 0 == d.has_key('OA')
+        assert len(d) == l-2
+
+        values = []
+        for key in d.keys():
+            value = d[key]
+            values.append(value)
+            if verbose:
+                print "%s: %s" % (key, value)
+            self.checkrec(key, value)
+
+        dbvalues = d.values()
+        assert len(dbvalues) == len(d.keys())
+        values.sort()
+        dbvalues.sort()
+        assert values == dbvalues
+
+        items = d.items()
+        assert len(items) == len(values)
+
+        for key, value in items:
+            self.checkrec(key, value)
+
+        assert d.get('bad key') == None
+        assert d.get('bad key', None) == None
+        assert d.get('bad key', 'a string') == 'a string'
+        assert d.get('bad key', [1, 2, 3]) == [1, 2, 3]
+
+        d.set_get_returns_none(0)
+        self.assertRaises(db.DBNotFoundError, d.get, 'bad key')
+        d.set_get_returns_none(1)
+
+        d.put('new key', 'new data')
+        assert d.get('new key') == 'new data'
+        assert d['new key'] == 'new data'
+
+
+
+    def test02_cursors(self):
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test02_cursors..." % self.__class__.__name__
+
+        self.populateDB(self.d)
+        d = self.d
+
+        count = 0
+        c = d.cursor()
+        rec = c.first()
+        while rec is not None:
+            count = count + 1
+            if verbose:
+                print rec
+            key, value = rec
+            self.checkrec(key, value)
+            rec = c.next()
+
+        assert count == len(d)
+
+        count = 0
+        c = d.cursor()
+        rec = c.last()
+        while rec is not None:
+            count = count + 1
+            if verbose:
+                print rec
+            key, value = rec
+            self.checkrec(key, value)
+            rec = c.prev()
+
+        assert count == len(d)
+
+        c.set('SS')
+        key, value = c.current()
+        self.checkrec(key, value)
+
+        c.close()
+
+
+
+
+    def checkrec(self, key, value):
+        x = key[1]
+        if key[0] == 'S':
+            assert type(value) == StringType
+            assert value == 10 * x
+
+        elif key[0] == 'I':
+            assert type(value) == IntType
+            assert value == ord(x)
+
+        elif key[0] == 'L':
+            assert type(value) == ListType
+            assert value == [x] * 10
+
+        elif key[0] == 'O':
+            assert type(value) == InstanceType
+            assert value.S == 10 * x
+            assert value.I == ord(x)
+            assert value.L == [x] * 10
+
+        else:
+            raise AssertionError, 'Unknown key type, fix the test'
+
+#----------------------------------------------------------------------
+
+class BasicShelveTestCase(DBShelveTestCase):
+    def do_open(self):
+        self.d = dbshelve.DBShelf()
+        self.d.open(self.filename, self.dbtype, self.dbflags)
+
+    def do_close(self):
+        self.d.close()
+
+
+
+
+class BTreeShelveTestCase(BasicShelveTestCase):
+    dbtype = db.DB_BTREE
+    dbflags = db.DB_CREATE
+
+
+class HashShelveTestCase(BasicShelveTestCase):
+    dbtype = db.DB_BTREE
+    dbflags = db.DB_CREATE
+
+
+class ThreadBTreeShelveTestCase(BasicShelveTestCase):
+    dbtype = db.DB_BTREE
+    dbflags = db.DB_CREATE | db.DB_THREAD
+
+
+class ThreadHashShelveTestCase(BasicShelveTestCase):
+    dbtype = db.DB_BTREE
+    dbflags = db.DB_CREATE | db.DB_THREAD
+
+
+#----------------------------------------------------------------------
+
+class BasicEnvShelveTestCase(DBShelveTestCase):
+    def do_open(self):
+        self.homeDir = homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home')
+        try: os.mkdir(homeDir)
+        except os.error: pass
+        self.env = db.DBEnv()
+        self.env.open(homeDir, self.envflags | db.DB_INIT_MPOOL | db.DB_CREATE)
+
+        self.filename = os.path.split(self.filename)[1]
+        self.d = dbshelve.DBShelf(self.env)
+        self.d.open(self.filename, self.dbtype, self.dbflags)
+
+
+    def do_close(self):
+        self.d.close()
+        self.env.close()
+
+
+    def tearDown(self):
+        self.do_close()
+        import glob
+        files = glob.glob(os.path.join(self.homeDir, '*'))
+        for file in files:
+            os.remove(file)
+
+
+
+class EnvBTreeShelveTestCase(BasicEnvShelveTestCase):
+    envflags = 0
+    dbtype = db.DB_BTREE
+    dbflags = db.DB_CREATE
+
+
+class EnvHashShelveTestCase(BasicEnvShelveTestCase):
+    envflags = 0
+    dbtype = db.DB_BTREE
+    dbflags = db.DB_CREATE
+
+
+class EnvThreadBTreeShelveTestCase(BasicEnvShelveTestCase):
+    envflags = db.DB_THREAD
+    dbtype = db.DB_BTREE
+    dbflags = db.DB_CREATE | db.DB_THREAD
+
+
+class EnvThreadHashShelveTestCase(BasicEnvShelveTestCase):
+    envflags = db.DB_THREAD
+    dbtype = db.DB_BTREE
+    dbflags = db.DB_CREATE | db.DB_THREAD
+
+
+#----------------------------------------------------------------------
+# TODO:  Add test cases for a DBShelf in a RECNO DB.
+
+
+#----------------------------------------------------------------------
+
+def suite():
+    theSuite = unittest.TestSuite()
+
+    theSuite.addTest(unittest.makeSuite(DBShelveTestCase))
+    theSuite.addTest(unittest.makeSuite(BTreeShelveTestCase))
+    theSuite.addTest(unittest.makeSuite(HashShelveTestCase))
+    theSuite.addTest(unittest.makeSuite(ThreadBTreeShelveTestCase))
+    theSuite.addTest(unittest.makeSuite(ThreadHashShelveTestCase))
+    theSuite.addTest(unittest.makeSuite(EnvBTreeShelveTestCase))
+    theSuite.addTest(unittest.makeSuite(EnvHashShelveTestCase))
+    theSuite.addTest(unittest.makeSuite(EnvThreadBTreeShelveTestCase))
+    theSuite.addTest(unittest.makeSuite(EnvThreadHashShelveTestCase))
+
+    return theSuite
+
+
+if __name__ == '__main__':
+    unittest.main( defaultTest='suite' )
+
+
diff --git a/python/test/test_dbtables.py b/python/test/test_dbtables.py
new file mode 100644 (file)
index 0000000..75fc2ad
--- /dev/null
@@ -0,0 +1,265 @@
+#!/usr/bin/env python
+#
+#-----------------------------------------------------------------------
+# A test suite for the table interface built on rpmdb.db
+#-----------------------------------------------------------------------
+#
+# Copyright (C) 2000, 2001 by Autonomous Zone Industries
+#
+# March 20, 2000
+#
+# License:      This is free software.  You may use this software for any
+#               purpose including modification/redistribution, so long as
+#               this header remains intact and that you do not claim any
+#               rights of ownership or authorship of this software.  This
+#               software has been tested, but no warranty is expressed or
+#               implied.
+#
+#   --  Gregory P. Smith <greg@electricrain.com>
+#
+# Id: test_dbtables.py,v 1.6 2001/05/14 20:48:18 greg Exp 
+
+import sys, os, re
+try:
+    import cPickle
+    pickle = cPickle
+except ImportError:
+    import pickle
+
+import unittest
+from test_all import verbose
+
+from rpmdb import db, dbtables
+
+
+
+#----------------------------------------------------------------------
+
+class TableDBTestCase(unittest.TestCase):
+    db_home = 'db_home'
+    db_name = 'test-table.db'
+
+    def setUp(self):
+        homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home')
+        self.homeDir = homeDir
+        try: os.mkdir(homeDir)
+        except os.error: pass
+        self.tdb = dbtables.bsdTableDB(filename='tabletest.db', dbhome='db_home', create=1)
+
+    def tearDown(self):
+        self.tdb.close()
+        import glob
+        files = glob.glob(os.path.join(self.homeDir, '*'))
+        for file in files:
+            os.remove(file)
+
+    def test01(self):
+        tabname = "test01"
+        colname = 'cool numbers'
+        try:
+            self.tdb.Drop(tabname)
+        except dbtables.TableDBError:
+            pass
+        self.tdb.CreateTable(tabname, [colname])
+        self.tdb.Insert(tabname, {colname: pickle.dumps(3.14159, 1)})
+
+        if verbose:
+            self.tdb._db_print()
+
+        values = self.tdb.Select(tabname, [colname], conditions={colname: None})
+
+        colval = pickle.loads(values[0][colname])
+        assert(colval > 3.141 and colval < 3.142)
+
+
+    def test02(self):
+        tabname = "test02"
+        col0 = 'coolness factor'
+        col1 = 'but can it fly?'
+        col2 = 'Species'
+        testinfo = [
+            {col0: pickle.dumps(8, 1), col1: 'no', col2: 'Penguin'},
+            {col0: pickle.dumps(-1, 1), col1: 'no', col2: 'Turkey'},
+            {col0: pickle.dumps(9, 1), col1: 'yes', col2: 'SR-71A Blackbird'}
+        ]
+
+        try:
+            self.tdb.Drop(tabname)
+        except dbtables.TableDBError:
+            pass
+        self.tdb.CreateTable(tabname, [col0, col1, col2])
+        for row in testinfo :
+            self.tdb.Insert(tabname, row)
+
+        values = self.tdb.Select(tabname, [col2],
+            conditions={col0: lambda x: pickle.loads(x) >= 8})
+
+        assert len(values) == 2
+        if values[0]['Species'] == 'Penguin' :
+            assert values[1]['Species'] == 'SR-71A Blackbird'
+        elif values[0]['Species'] == 'SR-71A Blackbird' :
+            assert values[1]['Species'] == 'Penguin'
+        else :
+            if verbose:
+                print "values=", `values`
+            raise "Wrong values returned!"
+
+    def test03(self):
+        tabname = "test03"
+        try:
+            self.tdb.Drop(tabname)
+        except dbtables.TableDBError:
+            pass
+        if verbose:
+            print '...before CreateTable...'
+            self.tdb._db_print()
+        self.tdb.CreateTable(tabname, ['a', 'b', 'c', 'd', 'e'])
+        if verbose:
+            print '...after CreateTable...'
+            self.tdb._db_print()
+        self.tdb.Drop(tabname)
+        if verbose:
+            print '...after Drop...'
+            self.tdb._db_print()
+        self.tdb.CreateTable(tabname, ['a', 'b', 'c', 'd', 'e'])
+
+        try:
+            self.tdb.Insert(tabname, {'a': "", 'e': pickle.dumps([{4:5, 6:7}, 'foo'], 1), 'f': "Zero"})
+            assert 0
+        except dbtables.TableDBError:
+            pass
+
+        try:
+            self.tdb.Select(tabname, [], conditions={'foo': '123'})
+            assert 0
+        except dbtables.TableDBError:
+            pass
+
+        self.tdb.Insert(tabname, {'a': '42', 'b': "bad", 'c': "meep", 'e': 'Fuzzy wuzzy was a bear'})
+        self.tdb.Insert(tabname, {'a': '581750', 'b': "good", 'd': "bla", 'c': "black", 'e': 'fuzzy was here'})
+        self.tdb.Insert(tabname, {'a': '800000', 'b': "good", 'd': "bla", 'c': "black", 'e': 'Fuzzy wuzzy is a bear'})
+
+        if verbose:
+            self.tdb._db_print()
+
+        # this should return two rows
+        values = self.tdb.Select(tabname, ['b', 'a', 'd'],
+            conditions={'e': re.compile('wuzzy').search, 'a': re.compile('^[0-9]+$').match})
+        assert len(values) == 2
+
+        # now lets delete one of them and try again
+        self.tdb.Delete(tabname, conditions={'b': dbtables.ExactCond('good')})
+        values = self.tdb.Select(tabname, ['a', 'd', 'b'], conditions={'e': dbtables.PrefixCond('Fuzzy')})
+        assert len(values) == 1
+        assert values[0]['d'] == None
+
+        values = self.tdb.Select(tabname, ['b'],
+            conditions={'c': lambda c: c == 'meep'})
+        assert len(values) == 1
+        assert values[0]['b'] == "bad"
+
+
+    def test_CreateOrExtend(self):
+        tabname = "test_CreateOrExtend"
+
+        self.tdb.CreateOrExtendTable(tabname, ['name', 'taste', 'filling', 'alcohol content', 'price'])
+        try:
+            self.tdb.Insert(tabname, {'taste': 'crap', 'filling': 'no', 'is it Guinness?': 'no'})
+            assert 0, "Insert should've failed due to bad column name"
+        except:
+            pass
+        self.tdb.CreateOrExtendTable(tabname, ['name', 'taste', 'is it Guinness?'])
+
+        # these should both succeed as the table should contain the union of both sets of columns.
+        self.tdb.Insert(tabname, {'taste': 'crap', 'filling': 'no', 'is it Guinness?': 'no'})
+        self.tdb.Insert(tabname, {'taste': 'great', 'filling': 'yes', 'is it Guinness?': 'yes', 'name': 'Guinness'})
+
+
+    def test_CondObjs(self):
+        tabname = "test_CondObjs"
+
+        self.tdb.CreateTable(tabname, ['a', 'b', 'c', 'd', 'e', 'p'])
+
+        self.tdb.Insert(tabname, {'a': "the letter A", 'b': "the letter B", 'c': "is for cookie"})
+        self.tdb.Insert(tabname, {'a': "is for aardvark", 'e': "the letter E", 'c': "is for cookie", 'd': "is for dog"})
+        self.tdb.Insert(tabname, {'a': "the letter A", 'e': "the letter E", 'c': "is for cookie", 'p': "is for Python"})
+
+        values = self.tdb.Select(tabname, ['p', 'e'], conditions={'e': dbtables.PrefixCond('the l')})
+        assert len(values) == 2, values
+        assert values[0]['e'] == values[1]['e'], values
+        assert values[0]['p'] != values[1]['p'], values
+
+        values = self.tdb.Select(tabname, ['d', 'a'], conditions={'a': dbtables.LikeCond('%aardvark%')})
+        assert len(values) == 1, values
+        assert values[0]['d'] == "is for dog", values
+        assert values[0]['a'] == "is for aardvark", values
+
+        values = self.tdb.Select(tabname, None, {'b': dbtables.Cond(), 'e':dbtables.LikeCond('%letter%'), 'a':dbtables.PrefixCond('is'), 'd':dbtables.ExactCond('is for dog'), 'c':dbtables.PrefixCond('is for'), 'p':lambda s: not s})
+        assert len(values) == 1, values
+        assert values[0]['d'] == "is for dog", values
+        assert values[0]['a'] == "is for aardvark", values
+    
+    def test_Delete(self):
+        tabname = "test_Delete"
+        self.tdb.CreateTable(tabname, ['x', 'y', 'z'])
+
+        # prior to 2001-05-09 there was a bug where Delete() would
+        # fail if it encountered any rows that did not have values in
+        # every column.
+        # Hunted and Squashed by <Donwulff> (Jukka Santala - donwulff@nic.fi)
+        self.tdb.Insert(tabname, {'x': 'X1', 'y':'Y1'})
+        self.tdb.Insert(tabname, {'x': 'X2', 'y':'Y2', 'z': 'Z2'})
+
+        self.tdb.Delete(tabname, conditions={'x': dbtables.PrefixCond('X')})
+        values = self.tdb.Select(tabname, ['y'], conditions={'x': dbtables.PrefixCond('X')})
+        assert len(values) == 0
+
+    def test_Modify(self):
+        tabname = "test_Modify"
+        self.tdb.CreateTable(tabname, ['Name', 'Type', 'Access'])
+
+        self.tdb.Insert(tabname, {'Name': 'Index to MP3 files.doc', 'Type': 'Word', 'Access': '8'})
+        self.tdb.Insert(tabname, {'Name': 'Nifty.MP3', 'Access': '1'})
+        self.tdb.Insert(tabname, {'Type': 'Unknown', 'Access': '0'})
+
+        def set_type(type):
+            if type == None:
+                return 'MP3'
+            return type
+
+        def increment_access(count):
+            return str(int(count)+1)
+
+        def remove_value(value):
+            return None
+
+        self.tdb.Modify(tabname, conditions={'Access': dbtables.ExactCond('0')}, mappings={'Access': remove_value})
+        self.tdb.Modify(tabname, conditions={'Name': dbtables.LikeCond('%MP3%')}, mappings={'Type': set_type})
+        self.tdb.Modify(tabname, conditions={'Name': dbtables.LikeCond('%')}, mappings={'Access': increment_access})
+
+        # Delete key in select conditions
+        values = self.tdb.Select(tabname, None, conditions={'Type': dbtables.ExactCond('Unknown')})
+        assert len(values) == 1, values
+        assert values[0]['Name'] == None, values
+        assert values[0]['Access'] == None, values
+
+        # Modify value by select conditions
+        values = self.tdb.Select(tabname, None, conditions={'Name': dbtables.ExactCond('Nifty.MP3')})
+        assert len(values) == 1, values
+        assert values[0]['Type'] == "MP3", values
+        assert values[0]['Access'] == "2", values
+
+        # Make sure change applied only to select conditions
+        values = self.tdb.Select(tabname, None, conditions={'Name': dbtables.LikeCond('%doc%')})
+        assert len(values) == 1, values
+        assert values[0]['Type'] == "Word", values
+        assert values[0]['Access'] == "9", values
+
+def suite():
+    theSuite = unittest.TestSuite()
+    theSuite.addTest(unittest.makeSuite(TableDBTestCase))
+    return theSuite
+
+
+if __name__ == '__main__':
+    unittest.main( defaultTest='suite' )
diff --git a/python/test/test_get_none.py b/python/test/test_get_none.py
new file mode 100644 (file)
index 0000000..c39cc51
--- /dev/null
@@ -0,0 +1,98 @@
+"""
+TestCases for checking set_get_returns_none.
+"""
+
+import sys, os, string
+import tempfile
+from pprint import pprint
+import unittest
+
+from rpmdb import db
+
+from test_all import verbose
+
+
+#----------------------------------------------------------------------
+
+class GetReturnsNoneTestCase(unittest.TestCase):
+    def setUp(self):
+        self.filename = tempfile.mktemp()
+
+    def tearDown(self):
+        try:
+            os.remove(self.filename)
+        except os.error:
+            pass
+
+
+    def test01_get_returns_none(self):
+        d = db.DB()
+        d.open(self.filename, db.DB_BTREE, db.DB_CREATE)
+        d.set_get_returns_none(1)
+
+        for x in string.letters:
+            d.put(x, x * 40)
+
+        data = d.get('bad key')
+        assert data == None
+
+        data = d.get('a')
+        assert data == 'a'*40
+
+        count = 0
+        c = d.cursor()
+        rec = c.first()
+        while rec:
+            count = count + 1
+            rec = c.next()
+
+        assert rec == None
+        assert count == 52
+
+        c.close()
+        d.close()
+
+
+    def test02_get_raises_exception(self):
+        d = db.DB()
+        d.open(self.filename, db.DB_BTREE, db.DB_CREATE)
+        d.set_get_returns_none(0)
+
+        for x in string.letters:
+            d.put(x, x * 40)
+
+        self.assertRaises(db.DBNotFoundError, d.get, 'bad key')
+        self.assertRaises(KeyError, d.get, 'bad key')
+
+        data = d.get('a')
+        assert data == 'a'*40
+
+        count = 0
+        exceptionHappened = 0
+        c = d.cursor()
+        rec = c.first()
+        while rec:
+            count = count + 1
+            try:
+                rec = c.next()
+            except db.DBNotFoundError:  # end of the records
+                exceptionHappened = 1
+                break
+
+        assert rec != None
+        assert exceptionHappened
+        assert count == 52
+
+        c.close()
+        d.close()
+
+#----------------------------------------------------------------------
+
+def suite():
+    return unittest.makeSuite(GetReturnsNoneTestCase)
+
+
+if __name__ == '__main__':
+    unittest.main( defaultTest='suite' )
+
+
diff --git a/python/test/test_join.py b/python/test/test_join.py
new file mode 100644 (file)
index 0000000..a2f172e
--- /dev/null
@@ -0,0 +1,14 @@
+"""
+TestCases for using the DB.join and DBCursor.join_item methods.
+"""
+
+import sys, os, string
+import tempfile
+from pprint import pprint
+import unittest
+
+from rpmdb import db
+
+from test_all import verbose
+
+
diff --git a/python/test/test_lock.py b/python/test/test_lock.py
new file mode 100644 (file)
index 0000000..f78ad41
--- /dev/null
@@ -0,0 +1,124 @@
+"""
+TestCases for testing the locking sub-system.
+"""
+
+import sys, os, string
+import tempfile
+import time
+from pprint import pprint
+from whrandom import random
+
+try:
+    from threading import Thread, currentThread
+    have_threads = 1
+except ImportError:
+    have_threads = 0
+
+
+import unittest
+from test_all import verbose
+
+from rpmdb import db
+
+
+#----------------------------------------------------------------------
+
+class LockingTestCase(unittest.TestCase):
+
+    def setUp(self):
+        homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home')
+        self.homeDir = homeDir
+        try: os.mkdir(homeDir)
+        except os.error: pass
+        self.env = db.DBEnv()
+        self.env.open(homeDir, db.DB_THREAD | db.DB_INIT_MPOOL |
+                      db.DB_INIT_LOCK | db.DB_CREATE)
+
+
+    def tearDown(self):
+        self.env.close()
+        import glob
+        files = glob.glob(os.path.join(self.homeDir, '*'))
+        for file in files:
+            os.remove(file)
+
+
+    def test01_simple(self):
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test01_simple..." % self.__class__.__name__
+
+        anID = self.env.lock_id()
+        if verbose:
+            print "locker ID: %s" % anID
+        lock = self.env.lock_get(anID, "some locked thing", db.DB_LOCK_WRITE)
+        if verbose:
+            print "Aquired lock: %s" % lock
+        time.sleep(1)
+        self.env.lock_put(lock)
+        if verbose:
+            print "Released lock: %s" % lock
+
+
+
+
+    def test02_threaded(self):
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test02_threaded..." % self.__class__.__name__
+
+        threads = []
+        threads.append(Thread(target = self.theThread, args=(5, db.DB_LOCK_WRITE)))
+        threads.append(Thread(target = self.theThread, args=(1, db.DB_LOCK_READ)))
+        threads.append(Thread(target = self.theThread, args=(1, db.DB_LOCK_READ)))
+        threads.append(Thread(target = self.theThread, args=(1, db.DB_LOCK_WRITE)))
+        threads.append(Thread(target = self.theThread, args=(1, db.DB_LOCK_READ)))
+        threads.append(Thread(target = self.theThread, args=(1, db.DB_LOCK_READ)))
+        threads.append(Thread(target = self.theThread, args=(1, db.DB_LOCK_WRITE)))
+        threads.append(Thread(target = self.theThread, args=(1, db.DB_LOCK_WRITE)))
+        threads.append(Thread(target = self.theThread, args=(1, db.DB_LOCK_WRITE)))
+
+        for t in threads:
+            t.start()
+        for t in threads:
+            t.join()
+
+
+
+    def theThread(self, sleepTime, lockType):
+        name = currentThread().getName()
+        if lockType ==  db.DB_LOCK_WRITE:
+            lt = "write"
+        else:
+            lt = "read"
+
+        anID = self.env.lock_id()
+        if verbose:
+            print "%s: locker ID: %s" % (name, anID)
+
+        lock = self.env.lock_get(anID, "some locked thing", lockType)
+        if verbose:
+            print "%s: Aquired %s lock: %s" % (name, lt, lock)
+
+        time.sleep(sleepTime)
+
+        self.env.lock_put(lock)
+        if verbose:
+            print "%s: Released %s lock: %s" % (name, lt, lock)
+
+
+#----------------------------------------------------------------------
+
+def suite():
+    theSuite = unittest.TestSuite()
+
+    if have_threads:
+        theSuite.addTest(unittest.makeSuite(LockingTestCase))
+    else:
+        theSuite.addTest(unittest.makeSuite(LockingTestCase, 'test01'))
+
+    return theSuite
+
+
+if __name__ == '__main__':
+    unittest.main( defaultTest='suite' )
diff --git a/python/test/test_misc.py b/python/test/test_misc.py
new file mode 100644 (file)
index 0000000..07e2e84
--- /dev/null
@@ -0,0 +1,56 @@
+"""
+Misc TestCases
+"""
+
+import sys, os, string
+import tempfile
+from pprint import pprint
+import unittest
+
+from rpmdb import db
+from rpmdb import dbshelve
+
+from test_all import verbose
+
+#----------------------------------------------------------------------
+
+class MiscTestCase(unittest.TestCase):
+    def setUp(self):
+        self.filename = self.__class__.__name__ + '.db'
+        homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home')
+        self.homeDir = homeDir
+        try: os.mkdir(homeDir)
+        except os.error: pass
+
+    def tearDown(self):
+        try:   os.remove(self.filename)
+        except os.error: pass
+        import glob
+        files = glob.glob(os.path.join(self.homeDir, '*'))
+        for file in files:
+            os.remove(file)
+
+
+
+    def test01_badpointer(self):
+        dbs = dbshelve.open(self.filename)
+        dbs.close()
+        self.assertRaises(db.DBError, dbs.get, "foo")
+
+
+    def test02_db_home(self):
+        env = db.DBEnv()
+        # check for crash fixed when db_home is used before open()
+        assert env.db_home is None
+        env.open(self.homeDir, db.DB_CREATE)
+        assert self.homeDir == env.db_home
+
+#----------------------------------------------------------------------
+
+
+def suite():
+    return unittest.makeSuite(MiscTestCase)
+
+
+if __name__ == '__main__':
+    unittest.main( defaultTest='suite' )
diff --git a/python/test/test_queue.py b/python/test/test_queue.py
new file mode 100644 (file)
index 0000000..6a55845
--- /dev/null
@@ -0,0 +1,168 @@
+"""
+TestCases for exercising a Queue DB.
+"""
+
+import sys, os, string
+import tempfile
+from pprint import pprint
+import unittest
+
+from rpmdb import db
+
+from test_all import verbose
+
+
+#----------------------------------------------------------------------
+
+class SimpleQueueTestCase(unittest.TestCase):
+    def setUp(self):
+        self.filename = tempfile.mktemp()
+
+    def tearDown(self):
+        try:
+            os.remove(self.filename)
+        except os.error:
+            pass
+
+
+    def test01_basic(self):
+        # Basic Queue tests using the deprecated DBCursor.consume method.
+
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test01_basic..." % self.__class__.__name__
+
+        d = db.DB()
+        d.set_re_len(40)  # Queues must be fixed length
+        d.open(self.filename, db.DB_QUEUE, db.DB_CREATE)
+
+        if verbose:
+            print "before appends" + '-' * 30
+            pprint(d.stat())
+
+        for x in string.letters:
+            d.append(x * 40)
+
+        assert len(d) == 52
+
+        d.put(100, "some more data")
+        d.put(101, "and some more ")
+        d.put(75,  "out of order")
+        d.put(1,   "replacement data")
+
+        assert len(d) == 55
+
+        if verbose:
+            print "before close" + '-' * 30
+            pprint(d.stat())
+
+        d.close()
+        del d
+        d = db.DB()
+        d.open(self.filename)
+
+        if verbose:
+            print "after open" + '-' * 30
+            pprint(d.stat())
+
+        d.append("one more")
+        c = d.cursor()
+
+        if verbose:
+            print "after append" + '-' * 30
+            pprint(d.stat())
+
+        rec = c.consume()
+        while rec:
+            if verbose:
+                print rec
+            rec = c.consume()
+        c.close()
+
+        if verbose:
+            print "after consume loop" + '-' * 30
+            pprint(d.stat())
+
+        assert len(d) == 0, \
+               "if you see this message then you need to rebuild BerkeleyDB 3.1.17 "\
+               "with the patch in patches/qam_stat.diff"
+
+        d.close()
+
+
+
+    def test02_basicPost32(self):
+        # Basic Queue tests using the new DB.consume method in DB 3.2+
+        # (No cursor needed)
+
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test02_basicPost32..." % self.__class__.__name__
+
+        if db.version() < (3, 2, 0):
+            if verbose:
+                print "Test not run, DB not new enough..."
+            return
+
+        d = db.DB()
+        d.set_re_len(40)  # Queues must be fixed length
+        d.open(self.filename, db.DB_QUEUE, db.DB_CREATE)
+
+        if verbose:
+            print "before appends" + '-' * 30
+            pprint(d.stat())
+
+        for x in string.letters:
+            d.append(x * 40)
+
+        assert len(d) == 52
+
+        d.put(100, "some more data")
+        d.put(101, "and some more ")
+        d.put(75,  "out of order")
+        d.put(1,   "replacement data")
+
+        assert len(d) == 55
+
+        if verbose:
+            print "before close" + '-' * 30
+            pprint(d.stat())
+
+        d.close()
+        del d
+        d = db.DB()
+        d.open(self.filename)
+        #d.set_get_returns_none(true)
+
+        if verbose:
+            print "after open" + '-' * 30
+            pprint(d.stat())
+
+        d.append("one more")
+
+        if verbose:
+            print "after append" + '-' * 30
+            pprint(d.stat())
+
+        rec = d.consume()
+        while rec:
+            if verbose:
+                print rec
+            rec = d.consume()
+
+        if verbose:
+            print "after consume loop" + '-' * 30
+            pprint(d.stat())
+
+        d.close()
+
+
+
+#----------------------------------------------------------------------
+
+def suite():
+    return unittest.makeSuite(SimpleQueueTestCase)
+
+
+if __name__ == '__main__':
+    unittest.main( defaultTest='suite' )
diff --git a/python/test/test_recno.py b/python/test/test_recno.py
new file mode 100644 (file)
index 0000000..ccd2d75
--- /dev/null
@@ -0,0 +1,258 @@
+"""
+TestCases for exercising a Recno DB.
+"""
+
+import sys, os, string
+import tempfile
+from pprint import pprint
+import unittest
+
+from rpmdb import db
+
+from test_all import verbose
+
+#----------------------------------------------------------------------
+
+class SimpleRecnoTestCase(unittest.TestCase):
+    def setUp(self):
+        self.filename = tempfile.mktemp()
+
+    def tearDown(self):
+        try:
+            os.remove(self.filename)
+        except os.error:
+            pass
+
+
+
+    def test01_basic(self):
+        d = db.DB()
+        d.open(self.filename, db.DB_RECNO, db.DB_CREATE)
+
+        for x in string.letters:
+            recno = d.append(x * 60)
+            assert type(recno) == type(0)
+            assert recno >= 1
+            if verbose:
+                print recno,
+
+        if verbose: print
+
+        stat = d.stat()
+        if verbose:
+            pprint(stat)
+
+        for recno in range(1, len(d)+1):
+            data = d[recno]
+            if verbose:
+                print data
+
+            assert type(data) == type("")
+            assert data == d.get(recno)
+
+        try:
+            data = d[0]  # This should raise a KeyError!?!?!
+        except db.DBInvalidArgError, val:
+            assert val[0] == db.EINVAL
+            if verbose: print val
+        else:
+            self.fail("expected exception")
+
+        try:
+            data = d[100]
+        except KeyError:
+            pass
+        else:
+            self.fail("expected exception")
+
+        data = d.get(100)
+        assert data == None
+
+        keys = d.keys()
+        if verbose:
+            print keys
+        assert type(keys) == type([])
+        assert type(keys[0]) == type(123)
+        assert len(keys) == len(d)
+
+
+        items = d.items()
+        if verbose:
+            pprint(items)
+        assert type(items) == type([])
+        assert type(items[0]) == type(())
+        assert len(items[0]) == 2
+        assert type(items[0][0]) == type(123)
+        assert type(items[0][1]) == type("")
+        assert len(items) == len(d)
+
+        assert d.has_key(25)
+
+        del d[25]
+        assert not d.has_key(25)
+
+        d.delete(13)
+        assert not d.has_key(13)
+
+        data = d.get_both(26, "z" * 60)
+        assert data == "z" * 60
+        if verbose:
+            print data
+
+        fd = d.fd()
+        if verbose:
+            print fd
+
+        c = d.cursor()
+        rec = c.first()
+        while rec:
+            if verbose:
+                print rec
+            rec = c.next()
+
+        c.set(50)
+        rec = c.current()
+        if verbose:
+            print rec
+
+        c.put(-1, "a replacement record", db.DB_CURRENT)
+
+        c.set(50)
+        rec = c.current()
+        assert rec == (50, "a replacement record")
+        if verbose:
+            print rec
+
+        rec = c.set_range(30)
+        if verbose:
+            print rec
+
+        c.close()
+        d.close()
+
+        d = db.DB()
+        d.open(self.filename)
+        c = d.cursor()
+
+        # put a record beyond the consecutive end of the recno's
+        d[100] = "way out there"
+        assert d[100] == "way out there"
+
+        try:
+            data = d[99]
+        except KeyError:
+            pass
+        else:
+            self.fail("expected exception")
+
+        try:
+            d.get(99)
+        except db.DBKeyEmptyError, val:
+            assert val[0] == db.DB_KEYEMPTY
+            if verbose: print val
+        else:
+            self.fail("expected exception")
+
+        rec = c.set(40)
+        while rec:
+            if verbose:
+                print rec
+            rec = c.next()
+
+        c.close()
+        d.close()
+
+
+    def test02_WithSource(self):
+        """
+        A Recno file that is given a "backing source file" is essentially a simple ASCII
+        file.  Normally each record is delimited by \n and so is just a line in the file,
+        but you can set a different record delimiter if needed.
+        """
+        source = os.path.join(os.path.dirname(sys.argv[0]), 'db_home/test_recno.txt')
+        f = open(source, 'w') # create the file
+        f.close()
+
+        d = db.DB()
+        d.set_re_delim(0x0A)  # This is the default value, just checking if both int
+        d.set_re_delim('\n')  # and char can be used...
+        d.set_re_source(source)
+        d.open(self.filename, db.DB_RECNO, db.DB_CREATE)
+
+        data = string.split("The quick brown fox jumped over the lazy dog")
+        for datum in data:
+            d.append(datum)
+        d.sync()
+        d.close()
+
+        # get the text from the backing source
+        text = open(source, 'r').read()
+        text = string.strip(text)
+        if verbose:
+            print text
+            print data
+            print string.split(text, '\n')
+
+        assert string.split(text, '\n') == data
+
+        # open as a DB again
+        d = db.DB()
+        d.set_re_source(source)
+        d.open(self.filename, db.DB_RECNO)
+
+        d[3] = 'reddish-brown'
+        d[8] = 'comatose'
+
+        d.sync()
+        d.close()
+
+        text = open(source, 'r').read()
+        text = string.strip(text)
+        if verbose:
+            print text
+            print string.split(text, '\n')
+
+        assert string.split(text, '\n') == string.split("The quick reddish-brown fox jumped over the comatose dog")
+
+
+    def test03_FixedLength(self):
+        d = db.DB()
+        d.set_re_len(40)  # fixed length records, 40 bytes long
+        d.set_re_pad('-') # sets the pad character...
+        d.set_re_pad(45)  # ...test both int and char
+        d.open(self.filename, db.DB_RECNO, db.DB_CREATE)
+
+        for x in string.letters:
+            d.append(x * 35)    # These will be padded
+
+        d.append('.' * 40)      # this one will be exact
+
+        try:                    # this one will fail
+            d.append('bad' * 20)
+        except db.DBInvalidArgError, val:
+            assert val[0] == db.EINVAL
+            if verbose: print val
+        else:
+            self.fail("expected exception")
+
+        c = d.cursor()
+        rec = c.first()
+        while rec:
+            if verbose:
+                print rec
+            rec = c.next()
+
+        c.close()
+        d.close()
+
+#----------------------------------------------------------------------
+
+
+def suite():
+    return unittest.makeSuite(SimpleRecnoTestCase)
+
+
+if __name__ == '__main__':
+    unittest.main( defaultTest='suite' )
+
+
diff --git a/python/test/test_thread.py b/python/test/test_thread.py
new file mode 100644 (file)
index 0000000..f8722e2
--- /dev/null
@@ -0,0 +1,487 @@
+"""
+TestCases for multi-threaded access to a DB.
+"""
+
+import sys, os, string
+import tempfile
+import time
+from pprint import pprint
+from whrandom import random
+
+try:
+    from threading import Thread, currentThread
+    have_threads = 1
+except ImportError:
+    have_threads = 0
+
+
+import unittest
+from test_all import verbose
+
+from rpmdb import db
+
+
+#----------------------------------------------------------------------
+
+class BaseThreadedTestCase(unittest.TestCase):
+    dbtype       = db.DB_UNKNOWN  # must be set in derived class
+    dbopenflags  = 0
+    dbsetflags   = 0
+    envflags     = 0
+
+
+    def setUp(self):
+        homeDir = os.path.join(os.path.dirname(sys.argv[0]), 'db_home')
+        self.homeDir = homeDir
+        try: os.mkdir(homeDir)
+        except os.error: pass
+        self.env = db.DBEnv()
+        self.setEnvOpts()
+        self.env.open(homeDir, self.envflags | db.DB_CREATE)
+
+        self.filename = self.__class__.__name__ + '.db'
+        self.d = db.DB(self.env)
+        if self.dbsetflags:
+            self.d.set_flags(self.dbsetflags)
+        self.d.open(self.filename, self.dbtype, self.dbopenflags|db.DB_CREATE)
+
+
+    def tearDown(self):
+        self.d.close()
+        self.env.close()
+        import glob
+        files = glob.glob(os.path.join(self.homeDir, '*'))
+        for file in files:
+            os.remove(file)
+
+
+    def setEnvOpts(self):
+        pass
+
+
+    def makeData(self, key):
+        return string.join([key] * 5, '-')
+
+
+#----------------------------------------------------------------------
+
+
+class ConcurrentDataStoreBase(BaseThreadedTestCase):
+    dbopenflags = db.DB_THREAD
+    envflags    = db.DB_THREAD | db.DB_INIT_CDB | db.DB_INIT_MPOOL
+    readers     = 0 # derived class should set
+    writers     = 0
+    records     = 1000
+
+
+    def test01_1WriterMultiReaders(self):
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test01_1WriterMultiReaders..." % self.__class__.__name__
+
+        threads = []
+        for x in range(self.writers):
+            wt = Thread(target = self.writerThread,
+                        args = (self.d, self.records, x),
+                        name = 'writer %d' % x,
+                        )#verbose = verbose)
+            threads.append(wt)
+
+        for x in range(self.readers):
+            rt = Thread(target = self.readerThread,
+                        args = (self.d, x),
+                        name = 'reader %d' % x,
+                        )#verbose = verbose)
+            threads.append(rt)
+
+        for t in threads:
+            t.start()
+        for t in threads:
+            t.join()
+
+
+    def writerThread(self, d, howMany, writerNum):
+        #time.sleep(0.01 * writerNum + 0.01)
+        name = currentThread().getName()
+        start, stop = howMany * writerNum, howMany * (writerNum + 1) - 1
+        if verbose:
+            print "%s: creating records %d - %d" % (name, start, stop)
+
+        for x in range(start, stop):
+            key = '%04d' % x
+            d.put(key, self.makeData(key))
+            if verbose and x % 100 == 0:
+                print "%s: records %d - %d finished" % (name, start, x)
+
+        if verbose: print "%s: finished creating records" % name
+
+##         # Each write-cursor will be exclusive, the only one that can update the DB...
+##         if verbose: print "%s: deleting a few records" % name
+##         c = d.cursor(flags = db.DB_WRITECURSOR)
+##         for x in range(10):
+##             key = int(random() * howMany) + start
+##             key = '%04d' % key
+##             if d.has_key(key):
+##                 c.set(key)
+##                 c.delete()
+
+##         c.close()
+        if verbose: print "%s: thread finished" % name
+
+
+    def readerThread(self, d, readerNum):
+        time.sleep(0.01 * readerNum)
+        name = currentThread().getName()
+
+        for loop in range(5):
+            c = d.cursor()
+            count = 0
+            rec = c.first()
+            while rec:
+                count = count + 1
+                key, data = rec
+                assert self.makeData(key) == data
+                rec = c.next()
+            if verbose: print "%s: found %d records" % (name, count)
+            c.close()
+            time.sleep(0.05)
+
+        if verbose: print "%s: thread finished" % name
+
+
+
+class BTreeConcurrentDataStore(ConcurrentDataStoreBase):
+    dbtype  = db.DB_BTREE
+    writers = 2
+    readers = 10
+    records = 1000
+
+
+class HashConcurrentDataStore(ConcurrentDataStoreBase):
+    dbtype  = db.DB_HASH
+    writers = 2
+    readers = 10
+    records = 1000
+
+#----------------------------------------------------------------------
+
+class SimpleThreadedBase(BaseThreadedTestCase):
+    dbopenflags = db.DB_THREAD
+    envflags    = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK
+    readers = 5
+    writers = 3
+    records = 1000
+
+
+    def setEnvOpts(self):
+        self.env.set_lk_detect(db.DB_LOCK_DEFAULT)
+
+
+    def test02_SimpleLocks(self):
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test02_SimpleLocks..." % self.__class__.__name__
+
+        threads = []
+        for x in range(self.writers):
+            wt = Thread(target = self.writerThread,
+                        args = (self.d, self.records, x),
+                        name = 'writer %d' % x,
+                        )#verbose = verbose)
+            threads.append(wt)
+        for x in range(self.readers):
+            rt = Thread(target = self.readerThread,
+                        args = (self.d, x),
+                        name = 'reader %d' % x,
+                        )#verbose = verbose)
+            threads.append(rt)
+
+        for t in threads:
+            t.start()
+        for t in threads:
+            t.join()
+
+
+
+    def writerThread(self, d, howMany, writerNum):
+        name = currentThread().getName()
+        start, stop = howMany * writerNum, howMany * (writerNum + 1) - 1
+        if verbose:
+            print "%s: creating records %d - %d" % (name, start, stop)
+
+        # create a bunch of records
+        for x in xrange(start, stop):
+            key = '%04d' % x
+            d.put(key, self.makeData(key))
+
+            if verbose and x % 100 == 0:
+                print "%s: records %d - %d finished" % (name, start, x)
+
+            # do a bit or reading too
+            if random() <= 0.05:
+                for y in xrange(start, x):
+                    key = '%04d' % x
+                    data = d.get(key)
+                    assert data == self.makeData(key)
+
+        # flush them
+        try:
+            d.sync()
+        except db.DBIncompleteError, val:
+            if verbose:
+                print "could not complete sync()..."
+
+        # read them back, deleting a few
+        for x in xrange(start, stop):
+            key = '%04d' % x
+            data = d.get(key)
+            if verbose and x % 100 == 0:
+                print "%s: fetched record (%s, %s)" % (name, key, data)
+            assert data == self.makeData(key)
+            if random() <= 0.10:
+                d.delete(key)
+                if verbose:
+                    print "%s: deleted record %s" % (name, key)
+
+        if verbose: print "%s: thread finished" % name
+
+
+    def readerThread(self, d, readerNum):
+        time.sleep(0.01 * readerNum)
+        name = currentThread().getName()
+
+        for loop in range(5):
+            c = d.cursor()
+            count = 0
+            rec = c.first()
+            while rec:
+                count = count + 1
+                key, data = rec
+                assert self.makeData(key) == data
+                rec = c.next()
+            if verbose: print "%s: found %d records" % (name, count)
+            c.close()
+            time.sleep(0.05)
+
+        if verbose: print "%s: thread finished" % name
+
+
+
+
+class BTreeSimpleThreaded(SimpleThreadedBase):
+    dbtype = db.DB_BTREE
+
+
+class HashSimpleThreaded(SimpleThreadedBase):
+    dbtype = db.DB_BTREE
+
+
+#----------------------------------------------------------------------
+
+
+
+class ThreadedTransactionsBase(BaseThreadedTestCase):
+    dbopenflags = db.DB_THREAD
+    envflags    = (db.DB_THREAD |
+                   db.DB_INIT_MPOOL |
+                   db.DB_INIT_LOCK |
+                   db.DB_INIT_LOG |
+                   db.DB_INIT_TXN
+                   )
+    readers = 0
+    writers = 0
+    records = 2000
+
+    txnFlag = 0
+
+
+    def setEnvOpts(self):
+        #self.env.set_lk_detect(db.DB_LOCK_DEFAULT)
+        pass
+
+
+    def test03_ThreadedTransactions(self):
+        if verbose:
+            print '\n', '-=' * 30
+            print "Running %s.test03_ThreadedTransactions..." % self.__class__.__name__
+
+        threads = []
+        for x in range(self.writers):
+            wt = Thread(target = self.writerThread,
+                        args = (self.d, self.records, x),
+                        name = 'writer %d' % x,
+                        )#verbose = verbose)
+            threads.append(wt)
+
+        for x in range(self.readers):
+            rt = Thread(target = self.readerThread,
+                        args = (self.d, x),
+                        name = 'reader %d' % x,
+                        )#verbose = verbose)
+            threads.append(rt)
+
+        dt = Thread(target = self.deadlockThread)
+        dt.start()
+
+        for t in threads:
+            t.start()
+        for t in threads:
+            t.join()
+
+        self.doLockDetect = 0
+        dt.join()
+
+
+    def doWrite(self, d, name, start, stop):
+        finished = 0
+        while not finished:
+            try:
+                txn = self.env.txn_begin(None, self.txnFlag)
+                for x in range(start, stop):
+                    key = '%04d' % x
+                    d.put(key, self.makeData(key), txn)
+                    if verbose and x % 100 == 0:
+                        print "%s: records %d - %d finished" % (name, start, x)
+                txn.commit()
+                finished = 1
+            except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val:
+                if verbose:
+                    print "%s: Aborting transaction (%s)" % (name, val[1])
+                txn.abort()
+                time.sleep(0.05)
+
+
+
+    def writerThread(self, d, howMany, writerNum):
+        name = currentThread().getName()
+        start, stop = howMany * writerNum, howMany * (writerNum + 1) - 1
+        if verbose:
+            print "%s: creating records %d - %d" % (name, start, stop)
+
+        step = 100
+        for x in range(start, stop, step):
+            self.doWrite(d, name, x, min(stop, x+step))
+
+        if verbose: print "%s: finished creating records" % name
+        if verbose: print "%s: deleting a few records" % name
+
+        finished = 0
+        while not finished:
+            try:
+                recs = []
+                txn = self.env.txn_begin(None, self.txnFlag)
+                for x in range(10):
+                    key = int(random() * howMany) + start
+                    key = '%04d' % key
+                    data = d.get(key, None, txn, db.DB_RMW)
+                    if data is not None:
+                        d.delete(key, txn)
+                        recs.append(key)
+                txn.commit()
+                finished = 1
+                if verbose: print "%s: deleted records %s" % (name, recs)
+            except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val:
+                if verbose:
+                    print "%s: Aborting transaction (%s)" % (name, val[1])
+                txn.abort()
+                time.sleep(0.05)
+
+        if verbose: print "%s: thread finished" % name
+
+
+    def readerThread(self, d, readerNum):
+        time.sleep(0.01 * readerNum + 0.05)
+        name = currentThread().getName()
+
+        for loop in range(5):
+            finished = 0
+            while not finished:
+                try:
+                    txn = self.env.txn_begin(None, self.txnFlag)
+                    c = d.cursor(txn)
+                    count = 0
+                    rec = c.first()
+                    while rec:
+                        count = count + 1
+                        key, data = rec
+                        assert self.makeData(key) == data
+                        rec = c.next()
+                    if verbose: print "%s: found %d records" % (name, count)
+                    c.close()
+                    txn.commit()
+                    finished = 1
+                except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val:
+                    if verbose:
+                        print "%s: Aborting transaction (%s)" % (name, val[1])
+                    c.close()
+                    txn.abort()
+                    time.sleep(0.05)
+
+            time.sleep(0.05)
+
+        if verbose: print "%s: thread finished" % name
+
+
+    def deadlockThread(self):
+        self.doLockDetect = 1
+        while self.doLockDetect:
+            time.sleep(0.5)
+            try:
+                aborted = self.env.lock_detect(db.DB_LOCK_RANDOM, db.DB_LOCK_CONFLICT)
+                if verbose and aborted:
+                    print "deadlock: Aborted %d deadlocked transaction(s)" % aborted
+            except db.DBError:
+                pass
+
+
+
+class BTreeThreadedTransactions(ThreadedTransactionsBase):
+    dbtype = db.DB_BTREE
+    writers = 3
+    readers = 5
+    records = 2000
+
+class HashThreadedTransactions(ThreadedTransactionsBase):
+    dbtype = db.DB_HASH
+    writers = 1
+    readers = 5
+    records = 2000
+
+class BTreeThreadedNoWaitTransactions(ThreadedTransactionsBase):
+    dbtype = db.DB_BTREE
+    writers = 3
+    readers = 5
+    records = 2000
+    txnFlag = db.DB_TXN_NOWAIT
+
+class HashThreadedNoWaitTransactions(ThreadedTransactionsBase):
+    dbtype = db.DB_HASH
+    writers = 1
+    readers = 5
+    records = 2000
+    txnFlag = db.DB_TXN_NOWAIT
+
+
+#----------------------------------------------------------------------
+
+def suite():
+    theSuite = unittest.TestSuite()
+
+    if have_threads:
+        theSuite.addTest(unittest.makeSuite(BTreeConcurrentDataStore))
+        theSuite.addTest(unittest.makeSuite(HashConcurrentDataStore))
+        theSuite.addTest(unittest.makeSuite(BTreeSimpleThreaded))
+        theSuite.addTest(unittest.makeSuite(HashSimpleThreaded))
+        theSuite.addTest(unittest.makeSuite(BTreeThreadedTransactions))
+        theSuite.addTest(unittest.makeSuite(HashThreadedTransactions))
+        theSuite.addTest(unittest.makeSuite(BTreeThreadedNoWaitTransactions))
+        theSuite.addTest(unittest.makeSuite(HashThreadedNoWaitTransactions))
+
+    else:
+        print "Threads not available, skipping thread tests."
+
+    return theSuite
+
+
+if __name__ == '__main__':
+    unittest.main( defaultTest='suite' )
diff --git a/python/test/unittest.py b/python/test/unittest.py
new file mode 100644 (file)
index 0000000..44b05c9
--- /dev/null
@@ -0,0 +1,693 @@
+#!/usr/bin/env python
+"""
+Python unit testing framework, based on Erich Gamma's JUnit and Kent Beck's
+Smalltalk testing framework.
+
+Further information is available in the bundled documentation, and from
+
+  http://pyunit.sourceforge.net/
+
+This module contains the core framework classes that form the basis of
+specific test cases and suites (TestCase, TestSuite etc.), and also a
+text-based utility class for running the tests and reporting the results
+(TextTestRunner).
+
+Copyright (c) 1999, 2000, 2001 Steve Purcell
+This module is free software, and you may redistribute it and/or modify
+it under the same terms as Python itself, so long as this copyright message
+and disclaimer are retained in their original form.
+
+IN NO EVENT SHALL THE AUTHOR BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
+SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF
+THIS CODE, EVEN IF THE AUTHOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
+DAMAGE.
+
+THE AUTHOR SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
+PARTICULAR PURPOSE.  THE CODE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS,
+AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE,
+SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
+"""
+
+__author__ = "Steve Purcell (stephen_purcell@yahoo.com)"
+__version__ = "Revision: 1.3 "[11:-2]
+
+import time
+import sys
+import traceback
+import string
+import os
+
+##############################################################################
+# A platform-specific concession to help the code work for JPython users
+##############################################################################
+
+plat = string.lower(sys.platform)
+_isJPython = string.find(plat, 'java') >= 0 or string.find(plat, 'jdk') >= 0
+del plat
+
+
+##############################################################################
+# Test framework core
+##############################################################################
+
+class TestResult:
+    """Holder for test result information.
+
+    Test results are automatically managed by the TestCase and TestSuite
+    classes, and do not need to be explicitly manipulated by writers of tests.
+
+    Each instance holds the total number of tests run, and collections of
+    failures and errors that occurred among those test runs. The collections
+    contain tuples of (testcase, exceptioninfo), where exceptioninfo is a
+    tuple of values as returned by sys.exc_info().
+    """
+    def __init__(self):
+        self.failures = []
+        self.errors = []
+        self.testsRun = 0
+        self.shouldStop = 0
+
+    def startTest(self, test):
+        "Called when the given test is about to be run"
+        self.testsRun = self.testsRun + 1
+
+    def stopTest(self, test):
+        "Called when the given test has been run"
+        pass
+
+    def addError(self, test, err):
+        "Called when an error has occurred"
+        self.errors.append((test, err))
+
+    def addFailure(self, test, err):
+        "Called when a failure has occurred"
+        self.failures.append((test, err))
+
+    def wasSuccessful(self):
+        "Tells whether or not this result was a success"
+        return len(self.failures) == len(self.errors) == 0
+
+    def stop(self):
+        "Indicates that the tests should be aborted"
+        self.shouldStop = 1
+
+    def __repr__(self):
+        return "<%s run=%i errors=%i failures=%i>" % \
+               (self.__class__, self.testsRun, len(self.errors),
+                len(self.failures))
+
+
+class TestCase:
+    """A class whose instances are single test cases.
+
+    Test authors should subclass TestCase for their own tests. Construction
+    and deconstruction of the test's environment ('fixture') can be
+    implemented by overriding the 'setUp' and 'tearDown' methods respectively.
+
+    By default, the test code itself should be placed in a method named
+    'runTest'.
+
+    If the fixture may be used for many test cases, create as
+    many test methods as are needed. When instantiating such a TestCase
+    subclass, specify in the constructor arguments the name of the test method
+    that the instance is to execute.
+
+    If it is necessary to override the __init__ method, the base class
+    __init__ method must always be called.
+    """
+    def __init__(self, methodName='runTest'):
+        """Create an instance of the class that will use the named test
+           method when executed. Raises a ValueError if the instance does
+           not have a method with the specified name.
+        """
+        try:
+            self.__testMethod = getattr(self,methodName)
+        except AttributeError:
+            raise ValueError, "no such test method in %s: %s" % \
+                  (self.__class__, methodName)
+
+    def setUp(self):
+        "Hook method for setting up the test fixture before exercising it."
+        pass
+
+    def tearDown(self):
+        "Hook method for deconstructing the test fixture after testing it."
+        pass
+
+    def countTestCases(self):
+        return 1
+
+    def defaultTestResult(self):
+        return TestResult()
+
+    def shortDescription(self):
+        """Returns a one-line description of the test, or None if no
+        description has been provided.
+
+        The default implementation of this method returns the first line of
+        the specified test method's docstring.
+        """
+        doc = self.__testMethod.__doc__
+        return doc and string.strip(string.split(doc, "\n")[0]) or None
+
+    def id(self):
+        return "%s.%s" % (self.__class__, self.__testMethod.__name__)
+
+    def __str__(self):
+        return "%s (%s)" % (self.__testMethod.__name__, self.__class__)
+
+    def __repr__(self):
+        return "<%s testMethod=%s>" % \
+               (self.__class__, self.__testMethod.__name__)
+
+    def run(self, result=None):
+        return self(result)
+
+    def __call__(self, result=None):
+        if result is None: result = self.defaultTestResult()
+        result.startTest(self)
+        try:
+            try:
+                self.setUp()
+            except:
+                result.addError(self,self.__exc_info())
+                return
+
+            try:
+                self.__testMethod()
+            except AssertionError, e:
+                result.addFailure(self,self.__exc_info())
+            except:
+                result.addError(self,self.__exc_info())
+
+            try:
+                self.tearDown()
+            except:
+                result.addError(self,self.__exc_info())
+        finally:
+            result.stopTest(self)
+
+    def debug(self):
+        """Run the test without collecting errors in a TestResult"""
+        self.setUp()
+        self.__testMethod()
+        self.tearDown()
+
+    def assert_(self, expr, msg=None):
+        """Equivalent of built-in 'assert', but is not optimised out when
+           __debug__ is false.
+        """
+        if not expr:
+            raise AssertionError, msg
+
+    failUnless = assert_
+
+    def failIf(self, expr, msg=None):
+        "Fail the test if the expression is true."
+        apply(self.assert_,(not expr,msg))
+
+    def assertRaises(self, excClass, callableObj, *args, **kwargs):
+        """Assert that an exception of class excClass is thrown
+           by callableObj when invoked with arguments args and keyword
+           arguments kwargs. If a different type of exception is
+           thrown, it will not be caught, and the test case will be
+           deemed to have suffered an error, exactly as for an
+           unexpected exception.
+        """
+        try:
+            apply(callableObj, args, kwargs)
+        except excClass:
+            return
+        else:
+            if hasattr(excClass,'__name__'): excName = excClass.__name__
+            else: excName = str(excClass)
+            raise AssertionError, excName
+
+    def fail(self, msg=None):
+        """Fail immediately, with the given message."""
+        raise AssertionError, msg
+
+    def __exc_info(self):
+        """Return a version of sys.exc_info() with the traceback frame
+           minimised; usually the top level of the traceback frame is not
+           needed.
+        """
+        exctype, excvalue, tb = sys.exc_info()
+        newtb = tb.tb_next
+        if newtb is None:
+            return (exctype, excvalue, tb)
+        return (exctype, excvalue, newtb)
+
+
+class TestSuite:
+    """A test suite is a composite test consisting of a number of TestCases.
+
+    For use, create an instance of TestSuite, then add test case instances.
+    When all tests have been added, the suite can be passed to a test
+    runner, such as TextTestRunner. It will run the individual test cases
+    in the order in which they were added, aggregating the results. When
+    subclassing, do not forget to call the base class constructor.
+    """
+    def __init__(self, tests=()):
+        self._tests = []
+        self.addTests(tests)
+
+    def __repr__(self):
+        return "<%s tests=%s>" % (self.__class__, self._tests)
+
+    __str__ = __repr__
+
+    def countTestCases(self):
+        cases = 0
+        for test in self._tests:
+            cases = cases + test.countTestCases()
+        return cases
+
+    def addTest(self, test):
+        self._tests.append(test)
+
+    def addTests(self, tests):
+        for test in tests:
+            self.addTest(test)
+
+    def run(self, result):
+        return self(result)
+
+    def __call__(self, result):
+        for test in self._tests:
+            if result.shouldStop:
+                break
+            test(result)
+        return result
+
+    def debug(self):
+        """Run the tests without collecting errors in a TestResult"""
+        for test in self._tests: test.debug()
+
+
+class FunctionTestCase(TestCase):
+    """A test case that wraps a test function.
+
+    This is useful for slipping pre-existing test functions into the
+    PyUnit framework. Optionally, set-up and tidy-up functions can be
+    supplied. As with TestCase, the tidy-up ('tearDown') function will
+    always be called if the set-up ('setUp') function ran successfully.
+    """
+
+    def __init__(self, testFunc, setUp=None, tearDown=None,
+                 description=None):
+        TestCase.__init__(self)
+        self.__setUpFunc = setUp
+        self.__tearDownFunc = tearDown
+        self.__testFunc = testFunc
+        self.__description = description
+
+    def setUp(self):
+        if self.__setUpFunc is not None:
+            self.__setUpFunc()
+
+    def tearDown(self):
+        if self.__tearDownFunc is not None:
+            self.__tearDownFunc()
+
+    def runTest(self):
+        self.__testFunc()
+
+    def id(self):
+        return self.__testFunc.__name__
+
+    def __str__(self):
+        return "%s (%s)" % (self.__class__, self.__testFunc.__name__)
+
+    def __repr__(self):
+        return "<%s testFunc=%s>" % (self.__class__, self.__testFunc)
+
+    def shortDescription(self):
+        if self.__description is not None: return self.__description
+        doc = self.__testFunc.__doc__
+        return doc and string.strip(string.split(doc, "\n")[0]) or None
+
+
+
+##############################################################################
+# Convenience functions
+##############################################################################
+
+def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
+    """Extracts all the names of functions in the given test case class
+       and its base classes that start with the given prefix. This is used
+       by makeSuite().
+    """
+    testFnNames = filter(lambda n,p=prefix: n[:len(p)] == p,
+                         dir(testCaseClass))
+    for baseclass in testCaseClass.__bases__:
+        testFnNames = testFnNames + \
+                      getTestCaseNames(baseclass, prefix, sortUsing=None)
+    if sortUsing:
+        testFnNames.sort(sortUsing)
+    return testFnNames
+
+
+def makeSuite(testCaseClass, prefix='test', sortUsing=cmp):
+    """Returns a TestSuite instance built from all of the test functions
+       in the given test case class whose names begin with the given
+       prefix. The cases are sorted by their function names
+       using the supplied comparison function, which defaults to 'cmp'.
+    """
+    cases = map(testCaseClass,
+                getTestCaseNames(testCaseClass, prefix, sortUsing))
+    return TestSuite(cases)
+
+
+def createTestInstance(name, module=None):
+    """Finds tests by their name, optionally only within the given module.
+
+    Return the newly-constructed test, ready to run. If the name contains a ':'
+    then the portion of the name after the colon is used to find a specific
+    test case within the test case class named before the colon.
+
+    Examples:
+     findTest('examples.listtests.suite')
+        -- returns result of calling 'suite'
+     findTest('examples.listtests.ListTestCase:checkAppend')
+        -- returns result of calling ListTestCase('checkAppend')
+     findTest('examples.listtests.ListTestCase:check-')
+        -- returns result of calling makeSuite(ListTestCase, prefix="check")
+    """
+
+    spec = string.split(name, ':')
+    if len(spec) > 2: raise ValueError, "illegal test name: %s" % name
+    if len(spec) == 1:
+        testName = spec[0]
+        caseName = None
+    else:
+        testName, caseName = spec
+    parts = string.split(testName, '.')
+    if module is None:
+        if len(parts) < 2:
+            raise ValueError, "incomplete test name: %s" % name
+        constructor = __import__(string.join(parts[:-1],'.'))
+        parts = parts[1:]
+    else:
+        constructor = module
+    for part in parts:
+        constructor = getattr(constructor, part)
+    if not callable(constructor):
+        raise ValueError, "%s is not a callable object" % constructor
+    if caseName:
+        if caseName[-1] == '-':
+            prefix = caseName[:-1]
+            if not prefix:
+                raise ValueError, "prefix too short: %s" % name
+            test = makeSuite(constructor, prefix=prefix)
+        else:
+            test = constructor(caseName)
+    else:
+        test = constructor()
+    if not hasattr(test,"countTestCases"):
+        raise TypeError, \
+              "object %s found with spec %s is not a test" % (test, name)
+    return test
+
+
+##############################################################################
+# Text UI
+##############################################################################
+
+class _WritelnDecorator:
+    """Used to decorate file-like objects with a handy 'writeln' method"""
+    def __init__(self,stream):
+        self.stream = stream
+        if _isJPython:
+            import java.lang.System
+            self.linesep = java.lang.System.getProperty("line.separator")
+        else:
+            self.linesep = os.linesep
+
+    def __getattr__(self, attr):
+        return getattr(self.stream,attr)
+
+    def writeln(self, *args):
+        if args: apply(self.write, args)
+        self.write(self.linesep)
+
+
+class _JUnitTextTestResult(TestResult):
+    """A test result class that can print formatted text results to a stream.
+
+    Used by JUnitTextTestRunner.
+    """
+    def __init__(self, stream):
+        self.stream = stream
+        TestResult.__init__(self)
+
+    def addError(self, test, error):
+        TestResult.addError(self,test,error)
+        self.stream.write('E')
+        self.stream.flush()
+        if error[0] is KeyboardInterrupt:
+            self.shouldStop = 1
+
+    def addFailure(self, test, error):
+        TestResult.addFailure(self,test,error)
+        self.stream.write('F')
+        self.stream.flush()
+
+    def startTest(self, test):
+        TestResult.startTest(self,test)
+        self.stream.write('.')
+        self.stream.flush()
+
+    def printNumberedErrors(self,errFlavour,errors):
+        if not errors: return
+        if len(errors) == 1:
+            self.stream.writeln("There was 1 %s:" % errFlavour)
+        else:
+            self.stream.writeln("There were %i %ss:" %
+                                (len(errors), errFlavour))
+        i = 1
+        for test,error in errors:
+            errString = string.join(apply(traceback.format_exception,error),"")
+            self.stream.writeln("%i) %s" % (i, test))
+            self.stream.writeln(errString)
+            i = i + 1
+
+    def printErrors(self):
+        self.printNumberedErrors("error",self.errors)
+
+    def printFailures(self):
+        self.printNumberedErrors("failure",self.failures)
+
+    def printHeader(self):
+        self.stream.writeln()
+        if self.wasSuccessful():
+            self.stream.writeln("OK (%i tests)" % self.testsRun)
+        else:
+            self.stream.writeln("!!!FAILURES!!!")
+            self.stream.writeln("Test Results")
+            self.stream.writeln()
+            self.stream.writeln("Run: %i ; Failures: %i ; Errors: %i" %
+                                (self.testsRun, len(self.failures),
+                                 len(self.errors)))
+
+    def printResult(self):
+        self.printHeader()
+        self.printErrors()
+        self.printFailures()
+
+
+class JUnitTextTestRunner:
+    """A test runner class that displays results in textual form.
+
+    The display format approximates that of JUnit's 'textui' test runner.
+    This test runner may be removed in a future version of PyUnit.
+    """
+    def __init__(self, stream=sys.stderr):
+        self.stream = _WritelnDecorator(stream)
+
+    def run(self, test):
+        "Run the given test case or test suite."
+        result = _JUnitTextTestResult(self.stream)
+        startTime = time.time()
+        test(result)
+        stopTime = time.time()
+        self.stream.writeln()
+        self.stream.writeln("Time: %.3fs" % float(stopTime - startTime))
+        result.printResult()
+        return result
+
+
+##############################################################################
+# Verbose text UI
+##############################################################################
+
+class _VerboseTextTestResult(TestResult):
+    """A test result class that can print formatted text results to a stream.
+
+    Used by VerboseTextTestRunner.
+    """
+    def __init__(self, stream, descriptions):
+        TestResult.__init__(self)
+        self.stream = stream
+        self.lastFailure = None
+        self.descriptions = descriptions
+
+    def startTest(self, test):
+        TestResult.startTest(self, test)
+        if self.descriptions:
+            self.stream.write(test.shortDescription() or str(test))
+        else:
+            self.stream.write(str(test))
+        self.stream.write(" ... ")
+
+    def stopTest(self, test):
+        TestResult.stopTest(self, test)
+        if self.lastFailure is not test:
+            self.stream.writeln("ok")
+
+    def addError(self, test, err):
+        TestResult.addError(self, test, err)
+        self._printError("ERROR", test, err)
+        self.lastFailure = test
+        if err[0] is KeyboardInterrupt:
+            self.shouldStop = 1
+
+    def addFailure(self, test, err):
+        TestResult.addFailure(self, test, err)
+        self._printError("FAIL", test, err)
+        self.lastFailure = test
+
+    def _printError(self, flavour, test, err):
+        errLines = []
+        separator1 = "\t" + '=' * 70
+        separator2 = "\t" + '-' * 70
+        if not self.lastFailure is test:
+            self.stream.writeln()
+            self.stream.writeln(separator1)
+        self.stream.writeln("\t%s" % flavour)
+        self.stream.writeln(separator2)
+        for line in apply(traceback.format_exception, err):
+            for l in string.split(line,"\n")[:-1]:
+                self.stream.writeln("\t%s" % l)
+        self.stream.writeln(separator1)
+
+
+class VerboseTextTestRunner:
+    """A test runner class that displays results in textual form.
+
+    It prints out the names of tests as they are run, errors as they
+    occur, and a summary of the results at the end of the test run.
+    """
+    def __init__(self, stream=sys.stderr, descriptions=1):
+        self.stream = _WritelnDecorator(stream)
+        self.descriptions = descriptions
+
+    def run(self, test):
+        "Run the given test case or test suite."
+        result = _VerboseTextTestResult(self.stream, self.descriptions)
+        startTime = time.time()
+        test(result)
+        stopTime = time.time()
+        timeTaken = float(stopTime - startTime)
+        self.stream.writeln("-" * 78)
+        run = result.testsRun
+        self.stream.writeln("Ran %d test%s in %.3fs" %
+                            (run, run > 1 and "s" or "", timeTaken))
+        self.stream.writeln()
+        if not result.wasSuccessful():
+            self.stream.write("FAILED (")
+            failed, errored = map(len, (result.failures, result.errors))
+            if failed:
+                self.stream.write("failures=%d" % failed)
+            if errored:
+                if failed: self.stream.write(", ")
+                self.stream.write("errors=%d" % errored)
+            self.stream.writeln(")")
+        else:
+            self.stream.writeln("OK")
+        return result
+
+
+# Which flavour of TextTestRunner is the default?
+TextTestRunner = VerboseTextTestRunner
+
+
+##############################################################################
+# Facilities for running tests from the command line
+##############################################################################
+
+class TestProgram:
+    """A command-line program that runs a set of tests; this is primarily
+       for making test modules conveniently executable.
+    """
+    USAGE = """\
+Usage: %(progName)s [-h|--help] [test[:(casename|prefix-)]] [...]
+
+Examples:
+  %(progName)s                               - run default set of tests
+  %(progName)s MyTestSuite                   - run suite 'MyTestSuite'
+  %(progName)s MyTestCase:checkSomething     - run MyTestCase.checkSomething
+  %(progName)s MyTestCase:check-             - run all 'check*' test methods
+                                               in MyTestCase
+"""
+    def __init__(self, module='__main__', defaultTest=None,
+                 argv=None, testRunner=None):
+        if type(module) == type(''):
+            self.module = __import__(module)
+            for part in string.split(module,'.')[1:]:
+                self.module = getattr(self.module, part)
+        else:
+            self.module = module
+        if argv is None:
+            argv = sys.argv
+        self.defaultTest = defaultTest
+        self.testRunner = testRunner
+        self.progName = os.path.basename(argv[0])
+        self.parseArgs(argv)
+        self.createTests()
+        self.runTests()
+
+    def usageExit(self, msg=None):
+        if msg: print msg
+        print self.USAGE % self.__dict__
+        sys.exit(2)
+
+    def parseArgs(self, argv):
+        import getopt
+        try:
+            options, args = getopt.getopt(argv[1:], 'hH', ['help'])
+            opts = {}
+            for opt, value in options:
+                if opt in ('-h','-H','--help'):
+                    self.usageExit()
+            if len(args) == 0 and self.defaultTest is None:
+                raise getopt.error, "No default test is defined."
+            if len(args) > 0:
+                self.testNames = args
+            else:
+                self.testNames = (self.defaultTest,)
+        except getopt.error, msg:
+            self.usageExit(msg)
+
+    def createTests(self):
+        tests = []
+        for testName in self.testNames:
+            tests.append(createTestInstance(testName, self.module))
+        self.test = TestSuite(tests)
+
+    def runTests(self):
+        if self.testRunner is None:
+            self.testRunner = TextTestRunner()
+        result = self.testRunner.run(self.test)
+        sys.exit(not result.wasSuccessful())
+
+main = TestProgram
+
+
+##############################################################################
+# Executing this module from the command line
+##############################################################################
+
+if __name__ == "__main__":
+    main(module=None)