Code

Make memorydb pass all tests;
authorrichard <richard@57a73879-2fb5-44c3-a270-3262357dd7e2>
Mon, 8 Feb 2010 04:30:48 +0000 (04:30 +0000)
committerrichard <richard@57a73879-2fb5-44c3-a270-3262357dd7e2>
Mon, 8 Feb 2010 04:30:48 +0000 (04:30 +0000)
1. re-base off of anydbm and overwrite persistence bits
2. refactor core DB tests to allow in-memory persistence to work
3. remove bogus testing of file content indexing on import

git-svn-id: http://svn.roundup-tracker.org/svnroot/roundup/roundup/trunk@4453 57a73879-2fb5-44c3-a270-3262357dd7e2

test/db_test_base.py
test/memorydb.py
test/test_memorydb.py [new file with mode: 0644]

index b245ef2a6161c5a31174ea4baf3b85f4f23dbcae..7fd26638fb74eee5182e8925b87b7cdcf8c25166 100644 (file)
@@ -135,9 +135,12 @@ class DBTest(MyTestCase):
         if os.path.exists(config.DATABASE):
             shutil.rmtree(config.DATABASE)
         os.makedirs(config.DATABASE + '/files')
-        self.db = self.module.Database(config, 'admin')
+        self.open_database()
         setupSchema(self.db, 1, self.module)
 
+    def open_database(self):
+        self.db = self.module.Database(config, 'admin')
+
     def testRefresh(self):
         self.db.refresh_database()
 
@@ -147,11 +150,7 @@ class DBTest(MyTestCase):
     def testCreatorProperty(self):
         i = self.db.issue
         id1 = i.create(title='spam')
-        self.db.commit()
-        self.db.close()
-        self.db = self.module.Database(config, 'fred')
-        setupSchema(self.db, 0, self.module)
-        i = self.db.issue
+        self.db.journaltag = 'fred'
         id2 = i.create(title='spam')
         self.assertNotEqual(id1, id2)
         self.assertNotEqual(i.get(id1, 'creator'), i.get(id2, 'creator'))
@@ -159,11 +158,7 @@ class DBTest(MyTestCase):
     def testActorProperty(self):
         i = self.db.issue
         id1 = i.create(title='spam')
-        self.db.commit()
-        self.db.close()
-        self.db = self.module.Database(config, 'fred')
-        setupSchema(self.db, 0, self.module)
-        i = self.db.issue
+        self.db.journaltag = 'fred'
         i.set(id1, title='asfasd')
         self.assertNotEqual(i.get(id1, 'creator'), i.get(id1, 'actor'))
 
@@ -855,6 +850,7 @@ class DBTest(MyTestCase):
         self.assertEquals(self.db.indexer.search([], self.db.issue), {})
         self.assertEquals(self.db.indexer.search(['hello'], self.db.issue),
             {i1: {'files': [f1]}})
+        # content='world' has the wrong content-type and shouldn't be indexed
         self.assertEquals(self.db.indexer.search(['world'], self.db.issue), {})
         self.assertEquals(self.db.indexer.search(['frooz'], self.db.issue),
             {i2: {}})
@@ -963,45 +959,17 @@ class DBTest(MyTestCase):
         self.assertEquals(self.db.indexer.search(['flebble'], self.db.issue),
             {'1': {}})
 
-    def testIndexingOnImport(self):
-        # import a message
-        msgcontent = 'Glrk'
-        msgid = self.db.msg.import_list(['content', 'files', 'recipients'],
-                                        [repr(msgcontent), '[]', '[]'])
-        msg_filename = self.db.filename(self.db.msg.classname, msgid,
-                                        create=1)
-        support.ensureParentsExist(msg_filename)
-        msg_file = open(msg_filename, 'w')
-        msg_file.write(msgcontent)
-        msg_file.close()
-
-        # import a file
-        filecontent = 'Brrk'
-        fileid = self.db.file.import_list(['content'], [repr(filecontent)])
-        file_filename = self.db.filename(self.db.file.classname, fileid,
-                                         create=1)
-        support.ensureParentsExist(file_filename)
-        file_file = open(file_filename, 'w')
-        file_file.write(filecontent)
-        file_file.close()
-
+    def testIndexingPropertiesOnImport(self):
         # import an issue
         title = 'Bzzt'
         nodeid = self.db.issue.import_list(['title', 'messages', 'files',
-            'spam', 'nosy', 'superseder'], [repr(title), repr([msgid]),
-            repr([fileid]), '[]', '[]', '[]'])
+            'spam', 'nosy', 'superseder'], [repr(title), '[]', '[]',
+            '[]', '[]', '[]'])
         self.db.commit()
 
         # Content of title attribute is indexed
         self.assertEquals(self.db.indexer.search([title], self.db.issue),
             {str(nodeid):{}})
-        # Content of message is indexed
-        self.assertEquals(self.db.indexer.search([msgcontent], self.db.issue),
-            {str(nodeid):{'messages':[str(msgid)]}})
-        # Content of file is indexed
-        self.assertEquals(self.db.indexer.search([filecontent], self.db.issue),
-            {str(nodeid):{'files':[str(fileid)]}})
-
 
 
     #
@@ -1627,7 +1595,6 @@ class DBTest(MyTestCase):
         self.db = self.module.Database(config, 'admin')
         setupSchema(self.db, 0, self.module)
 
-
     def testImportExport(self):
         # use the filtering setup to create a bunch of items
         ae, filt = self.filteringSetup()
@@ -1897,7 +1864,7 @@ class SchemaTest(MyTestCase):
         os.makedirs(config.DATABASE + '/files')
 
     def test_reservedProperties(self):
-        self.db = self.module.Database(config, 'admin')
+        self.open_database()
         self.assertRaises(ValueError, self.module.Class, self.db, "a",
             creation=String())
         self.assertRaises(ValueError, self.module.Class, self.db, "a",
@@ -1908,13 +1875,13 @@ class SchemaTest(MyTestCase):
             actor=String())
 
     def init_a(self):
-        self.db = self.module.Database(config, 'admin')
+        self.open_database()
         a = self.module.Class(self.db, "a", name=String())
         a.setkey("name")
         self.db.post_init()
 
     def test_fileClassProps(self):
-        self.db = self.module.Database(config, 'admin')
+        self.open_database()
         a = self.module.FileClass(self.db, 'a')
         l = a.getprops().keys()
         l.sort()
@@ -1922,7 +1889,7 @@ class SchemaTest(MyTestCase):
             'creation', 'type'])
 
     def init_ab(self):
-        self.db = self.module.Database(config, 'admin')
+        self.open_database()
         a = self.module.Class(self.db, "a", name=String())
         a.setkey("name")
         b = self.module.Class(self.db, "b", name=String(),
@@ -1960,7 +1927,7 @@ class SchemaTest(MyTestCase):
         self.db.getjournal('b', bid)
 
     def init_amod(self):
-        self.db = self.module.Database(config, 'admin')
+        self.open_database()
         a = self.module.Class(self.db, "a", name=String(), newstr=String(),
             newint=Interval(), newnum=Number(), newbool=Boolean(),
             newdate=Date())
@@ -2004,7 +1971,7 @@ class SchemaTest(MyTestCase):
         self.db.getjournal('a', aid2)
 
     def init_amodkey(self):
-        self.db = self.module.Database(config, 'admin')
+        self.open_database()
         a = self.module.Class(self.db, "a", name=String(), newstr=String())
         a.setkey("newstr")
         b = self.module.Class(self.db, "b", name=String())
@@ -2047,7 +2014,7 @@ class SchemaTest(MyTestCase):
 
 
     def init_amodml(self):
-        self.db = self.module.Database(config, 'admin')
+        self.open_database()
         a = self.module.Class(self.db, "a", name=String(),
             newml=Multilink('a'))
         a.setkey('name')
index ea76be7dd458e5d735766245ec452615430591d5..9565050bef68405d3c0eebf54969369261a0b0e3 100644 (file)
@@ -11,8 +11,10 @@ from roundup import password
 from roundup import configuration
 from roundup.backends import back_anydbm
 from roundup.backends import indexer_dbm
+from roundup.backends import sessions_dbm
 from roundup.backends import indexer_common
 from roundup.hyperdb import *
+from roundup.support import ensureParentsExist
 
 def new_config():
     config = configuration.CoreConfig()
@@ -130,10 +132,10 @@ class BasicDatabase(dict):
     def clean(self):
         pass
 
-class Sessions(BasicDatabase):
+class Sessions(BasicDatabase, sessions_dbm.Sessions):
     name = 'sessions'
 
-class OneTimeKeys(BasicDatabase):
+class OneTimeKeys(BasicDatabase, sessions_dbm.Sessions):
     name = 'otks'
 
 class Indexer(indexer_dbm.Indexer):
@@ -154,8 +156,12 @@ class Indexer(indexer_dbm.Indexer):
 
     def save_index(self):
         pass
+    def force_reindex(self):
+        # TODO I'm concerned that force_reindex may not be tested by
+        # testForcedReindexing if the functionality can just be removed
+        pass
 
-class Database(hyperdb.Database, roundupdb.Database):
+class Database(back_anydbm.Database):
     """A database for storing records containing flexible data types.
 
     Transaction stuff TODO:
@@ -171,12 +177,21 @@ class Database(hyperdb.Database, roundupdb.Database):
         self.ids = {}
         self.journals = {}
         self.files = {}
+        self.tx_files = {}
         self.security = security.Security(self)
         self.stats = {'cache_hits': 0, 'cache_misses': 0, 'get_items': 0,
             'filtering': 0}
         self.sessions = Sessions()
         self.otks = OneTimeKeys()
         self.indexer = Indexer(self)
+        self.sessions = Sessions()
+
+        # anydbm bits
+        self.cache = {}         # cache of nodes loaded or created
+        self.dirtynodes = {}    # keep track of the dirty nodes by class
+        self.newnodes = {}      # keep track of the new nodes by class
+        self.destroyednodes = {}# keep track of the destroyed nodes by class
+        self.transactions = []
 
 
     def filename(self, classname, nodeid, property=None, create=0):
@@ -184,7 +199,7 @@ class Database(hyperdb.Database, roundupdb.Database):
         return __file__+'.dummy'
 
     def filesize(self, classname, nodeid, property=None, create=0):
-        return len(self.getnode(classname, nodeid)[property or 'content'])
+        return len(self.getfile(classname, nodeid, property))
 
     def post_init(self):
         pass
@@ -205,13 +220,31 @@ class Database(hyperdb.Database, roundupdb.Database):
         return '<memorydb instance at %x>'%id(self)
 
     def storefile(self, classname, nodeid, property, content):
-        self.files[classname, nodeid, property] = content
+        self.tx_files[classname, nodeid, property] = content
+        self.transactions.append((self.doStoreFile, (classname, nodeid,
+            property)))
 
     def getfile(self, classname, nodeid, property):
+        if (classname, nodeid, property) in self.tx_files:
+            return self.tx_files[classname, nodeid, property]
         return self.files[classname, nodeid, property]
 
+    def doStoreFile(self, classname, nodeid, property, **databases):
+        self.files[classname, nodeid, property] = self.tx_files[classname, nodeid, property]
+        return (classname, nodeid)
+
+    def rollbackStoreFile(self, classname, nodeid, property, **databases):
+        del self.tx_files[classname, nodeid, property]
+
     def numfiles(self):
-        return len(self.files)
+        return len(self.files) + len(self.tx_files)
+
+    def close(self):
+        self.clearCache()
+        self.tx_files = {}
+        # kill the schema too
+        self.classes = {}
+        # just keep the .items
 
     #
     # Classes
@@ -227,8 +260,9 @@ class Database(hyperdb.Database, roundupdb.Database):
         if self.classes.has_key(cn):
             raise ValueError, cn
         self.classes[cn] = cl
-        self.items[cn] = cldb()
-        self.ids[cn] = 0
+        if cn not in self.items:
+            self.items[cn] = cldb()
+            self.ids[cn] = 0
 
         # add default Edit and View permissions
         self.security.addPermission(name="Create", klass=cn,
@@ -260,12 +294,15 @@ class Database(hyperdb.Database, roundupdb.Database):
     def clear(self):
         self.items = {}
 
-    def getclassdb(self, classname):
+    def getclassdb(self, classname, mode='r'):
         """ grab a connection to the class db that will be used for
             multiple actions
         """
         return self.items[classname]
 
+    def getCachedJournalDB(self, classname):
+        return self.journals.setdefault(classname, {})
+
     #
     # Node IDs
     #
@@ -273,39 +310,13 @@ class Database(hyperdb.Database, roundupdb.Database):
         self.ids[classname] += 1
         return str(self.ids[classname])
     def setid(self, classname, id):
-        self.ids[classname] = id
-
-    #
-    # Nodes
-    #
-    def addnode(self, classname, nodeid, node):
-        self.getclassdb(classname)[nodeid] = node
-
-    def setnode(self, classname, nodeid, node):
-        self.getclassdb(classname)[nodeid] = node
-
-    def getnode(self, classname, nodeid, db=None):
-        if db is not None:
-            return db[nodeid]
-        d = self.getclassdb(classname)
-        if nodeid not in d:
-            raise IndexError(nodeid)
-        return d[nodeid]
-
-    def destroynode(self, classname, nodeid):
-        del self.getclassdb(classname)[nodeid]
-
-    def hasnode(self, classname, nodeid):
-        return nodeid in self.getclassdb(classname)
-
-    def countnodes(self, classname, db=None):
-        return len(self.getclassdb(classname))
+        self.ids[classname] = int(id)
 
     #
     # Journal
     #
-    def addjournal(self, classname, nodeid, action, params, creator=None,
-            creation=None):
+    def doSaveJournal(self, classname, nodeid, action, params, creator,
+            creation):
         if creator is None:
             creator = self.getuid()
         if creation is None:
@@ -313,35 +324,59 @@ class Database(hyperdb.Database, roundupdb.Database):
         self.journals.setdefault(classname, {}).setdefault(nodeid,
             []).append((nodeid, creation, creator, action, params))
 
-    def setjournal(self, classname, nodeid, journal):
+    def doSetJournal(self, classname, nodeid, journal):
         self.journals.setdefault(classname, {})[nodeid] = journal
 
     def getjournal(self, classname, nodeid):
-        return self.journals.get(classname, {}).get(nodeid, [])
+        # our journal result
+        res = []
+
+        # add any journal entries for transactions not committed to the
+        # database
+        for method, args in self.transactions:
+            if method != self.doSaveJournal:
+                continue
+            (cache_classname, cache_nodeid, cache_action, cache_params,
+                cache_creator, cache_creation) = args
+            if cache_classname == classname and cache_nodeid == nodeid:
+                if not cache_creator:
+                    cache_creator = self.getuid()
+                if not cache_creation:
+                    cache_creation = date.Date()
+                res.append((cache_nodeid, cache_creation, cache_creator,
+                    cache_action, cache_params))
+        try:
+            res += self.journals.get(classname, {})[nodeid]
+        except KeyError:
+            if res: return res
+            raise IndexError, nodeid
+        return res
 
     def pack(self, pack_before):
-        TODO
-
-    #
-    # Basic transaction support
-    #
-    def commit(self, fail_ok=False):
-        pass
-
-    def rollback(self):
-        TODO
-
-    def close(self):
-        pass
+        """ Delete all journal entries except "create" before 'pack_before'.
+        """
+        pack_before = pack_before.serialise()
+        for classname in self.journals:
+            db = self.journals[classname]
+            for key in db:
+                # get the journal for this db entry
+                l = []
+                last_set_entry = None
+                for entry in db[key]:
+                    # unpack the entry
+                    (nodeid, date_stamp, self.journaltag, action,
+                        params) = entry
+                    date_stamp = date_stamp.serialise()
+                    # if the entry is after the pack date, _or_ the initial
+                    # create entry, then it stays
+                    if date_stamp > pack_before or action == 'create':
+                        l.append(entry)
+                db[key] = l
 
 class Class(back_anydbm.Class):
-    def getnodeids(self, db=None, retired=None):
-        d = self.db.getclassdb(self.classname)
-        if retired is None:
-            return d.keys()
-        return [k for k in d if d[k].get(self.db.RETIRED_FLAG, False) == retired]
+    pass
 
-class FileClass(back_anydbm.Class):
+class FileClass(back_anydbm.FileClass):
     def __init__(self, db, classname, **properties):
         if not properties.has_key('content'):
             properties['content'] = hyperdb.String(indexme='yes')
@@ -349,8 +384,27 @@ class FileClass(back_anydbm.Class):
             properties['type'] = hyperdb.String()
         back_anydbm.Class.__init__(self, db, classname, **properties)
 
-    def getnodeids(self, db=None, retired=None):
-        return self.db.getclassdb(self.classname).keys()
+    def export_files(self, dirname, nodeid):
+        dest = self.exportFilename(dirname, nodeid)
+        ensureParentsExist(dest)
+        f = open(dest, 'wb')
+        f.write(self.db.files[self.classname, nodeid, None])
+        f.close()
+
+    def import_files(self, dirname, nodeid):
+        source = self.exportFilename(dirname, nodeid)
+        f = open(source, 'rb')
+        self.db.files[self.classname, nodeid, None] = f.read()
+        f.close()
+        mime_type = None
+        props = self.getprops()
+        if props.has_key('type'):
+            mime_type = self.get(nodeid, 'type')
+        if not mime_type:
+            mime_type = self.default_mime_type
+        if props['content'].indexme:
+            self.db.indexer.add_text((self.classname, nodeid, 'content'),
+                self.get(nodeid, 'content'), mime_type)
 
 # deviation from spec - was called ItemClass
 class IssueClass(Class, roundupdb.IssueClass):
diff --git a/test/test_memorydb.py b/test/test_memorydb.py
new file mode 100644 (file)
index 0000000..2ceebec
--- /dev/null
@@ -0,0 +1,71 @@
+# $Id: test_memorydb.py,v 1.4 2004-11-03 01:34:21 richard Exp $ 
+
+import unittest, os, shutil, time
+
+from roundup import hyperdb
+
+from db_test_base import DBTest, ROTest, SchemaTest, config, setupSchema
+import memorydb
+
+class memorydbOpener:
+    module = memorydb
+
+    def nuke_database(self):
+        # really kill it
+        self.db = None
+
+    db = None
+    def open_database(self):
+        if self.db is None:
+            self.db = self.module.Database(config, 'admin')
+        return self.db
+
+    def setUp(self):
+        self.open_database()
+        setupSchema(self.db, 1, self.module)
+
+    def tearDown(self):
+        if self.db is not None:
+            self.db.close()
+
+    # nuke and re-create db for restore
+    def nukeAndCreate(self):
+        self.db.close()
+        self.nuke_database()
+        self.db = self.module.Database(config, 'admin')
+        setupSchema(self.db, 0, self.module)
+
+class memorydbDBTest(memorydbOpener, DBTest):
+    pass
+
+class memorydbROTest(memorydbOpener, ROTest):
+    def setUp(self):
+        self.db = self.module.Database(config)
+        setupSchema(self.db, 0, self.module)
+
+class memorydbSchemaTest(memorydbOpener, SchemaTest):
+    pass
+
+from session_common import DBMTest
+class memorydbSessionTest(memorydbOpener, DBMTest):
+    def setUp(self):
+        self.db = self.module.Database(config, 'admin')
+        setupSchema(self.db, 1, self.module)
+        self.sessions = self.db.sessions
+
+def test_suite():
+    suite = unittest.TestSuite()
+    print 'Including memorydb tests'
+    suite.addTest(unittest.makeSuite(memorydbDBTest))
+    suite.addTest(unittest.makeSuite(memorydbROTest))
+    suite.addTest(unittest.makeSuite(memorydbSchemaTest))
+    suite.addTest(unittest.makeSuite(memorydbSessionTest))
+    return suite
+
+if __name__ == '__main__':
+    runner = unittest.TextTestRunner()
+    unittest.main(testRunner=runner)
+
+
+# vim: set filetype=python ts=4 sw=4 et si
+