Code

Make memorydb pass all tests;
[roundup.git] / test / memorydb.py
index 25901215ddba11ea31297aad95ba996ebd05ac6f..9565050bef68405d3c0eebf54969369261a0b0e3 100644 (file)
@@ -1,3 +1,4 @@
+# $Id: test_memorydb.py,v 1.4 2004-11-03 01:34:21 richard Exp $ 
 '''Implement an in-memory hyperdb for testing purposes.
 '''
 
@@ -10,8 +11,10 @@ from roundup import password
 from roundup import configuration
 from roundup.backends import back_anydbm
 from roundup.backends import indexer_dbm
+from roundup.backends import sessions_dbm
 from roundup.backends import indexer_common
 from roundup.hyperdb import *
+from roundup.support import ensureParentsExist
 
 def new_config():
     config = configuration.CoreConfig()
@@ -129,10 +132,10 @@ class BasicDatabase(dict):
     def clean(self):
         pass
 
-class Sessions(BasicDatabase):
+class Sessions(BasicDatabase, sessions_dbm.Sessions):
     name = 'sessions'
 
-class OneTimeKeys(BasicDatabase):
+class OneTimeKeys(BasicDatabase, sessions_dbm.Sessions):
     name = 'otks'
 
 class Indexer(indexer_dbm.Indexer):
@@ -153,8 +156,12 @@ class Indexer(indexer_dbm.Indexer):
 
     def save_index(self):
         pass
+    def force_reindex(self):
+        # TODO I'm concerned that force_reindex may not be tested by
+        # testForcedReindexing if the functionality can just be removed
+        pass
 
-class Database(hyperdb.Database, roundupdb.Database):
+class Database(back_anydbm.Database):
     """A database for storing records containing flexible data types.
 
     Transaction stuff TODO:
@@ -170,18 +177,30 @@ class Database(hyperdb.Database, roundupdb.Database):
         self.ids = {}
         self.journals = {}
         self.files = {}
+        self.tx_files = {}
         self.security = security.Security(self)
         self.stats = {'cache_hits': 0, 'cache_misses': 0, 'get_items': 0,
             'filtering': 0}
         self.sessions = Sessions()
         self.otks = OneTimeKeys()
         self.indexer = Indexer(self)
+        self.sessions = Sessions()
+
+        # anydbm bits
+        self.cache = {}         # cache of nodes loaded or created
+        self.dirtynodes = {}    # keep track of the dirty nodes by class
+        self.newnodes = {}      # keep track of the new nodes by class
+        self.destroyednodes = {}# keep track of the destroyed nodes by class
+        self.transactions = []
 
 
     def filename(self, classname, nodeid, property=None, create=0):
         shutil.copyfile(__file__, __file__+'.dummy')
         return __file__+'.dummy'
 
+    def filesize(self, classname, nodeid, property=None, create=0):
+        return len(self.getfile(classname, nodeid, property))
+
     def post_init(self):
         pass
 
@@ -201,13 +220,31 @@ class Database(hyperdb.Database, roundupdb.Database):
         return '<memorydb instance at %x>'%id(self)
 
     def storefile(self, classname, nodeid, property, content):
-        self.files[classname, nodeid, property] = content
+        self.tx_files[classname, nodeid, property] = content
+        self.transactions.append((self.doStoreFile, (classname, nodeid,
+            property)))
 
     def getfile(self, classname, nodeid, property):
+        if (classname, nodeid, property) in self.tx_files:
+            return self.tx_files[classname, nodeid, property]
         return self.files[classname, nodeid, property]
 
+    def doStoreFile(self, classname, nodeid, property, **databases):
+        self.files[classname, nodeid, property] = self.tx_files[classname, nodeid, property]
+        return (classname, nodeid)
+
+    def rollbackStoreFile(self, classname, nodeid, property, **databases):
+        del self.tx_files[classname, nodeid, property]
+
     def numfiles(self):
-        return len(self.files)
+        return len(self.files) + len(self.tx_files)
+
+    def close(self):
+        self.clearCache()
+        self.tx_files = {}
+        # kill the schema too
+        self.classes = {}
+        # just keep the .items
 
     #
     # Classes
@@ -223,8 +260,9 @@ class Database(hyperdb.Database, roundupdb.Database):
         if self.classes.has_key(cn):
             raise ValueError, cn
         self.classes[cn] = cl
-        self.items[cn] = cldb()
-        self.ids[cn] = 0
+        if cn not in self.items:
+            self.items[cn] = cldb()
+            self.ids[cn] = 0
 
         # add default Edit and View permissions
         self.security.addPermission(name="Create", klass=cn,
@@ -256,47 +294,29 @@ class Database(hyperdb.Database, roundupdb.Database):
     def clear(self):
         self.items = {}
 
-    def getclassdb(self, classname):
+    def getclassdb(self, classname, mode='r'):
         """ grab a connection to the class db that will be used for
             multiple actions
         """
         return self.items[classname]
 
+    def getCachedJournalDB(self, classname):
+        return self.journals.setdefault(classname, {})
+
     #
     # Node IDs
     #
     def newid(self, classname):
         self.ids[classname] += 1
         return str(self.ids[classname])
-
-    #
-    # Nodes
-    #
-    def addnode(self, classname, nodeid, node):
-        self.getclassdb(classname)[nodeid] = node
-
-    def setnode(self, classname, nodeid, node):
-        self.getclassdb(classname)[nodeid] = node
-
-    def getnode(self, classname, nodeid, cldb=None):
-        if cldb is not None:
-            return cldb[nodeid]
-        return self.getclassdb(classname)[nodeid]
-
-    def destroynode(self, classname, nodeid):
-        del self.getclassdb(classname)[nodeid]
-
-    def hasnode(self, classname, nodeid):
-        return nodeid in self.getclassdb(classname)
-
-    def countnodes(self, classname, db=None):
-        return len(self.getclassdb(classname))
+    def setid(self, classname, id):
+        self.ids[classname] = int(id)
 
     #
     # Journal
     #
-    def addjournal(self, classname, nodeid, action, params, creator=None,
-            creation=None):
+    def doSaveJournal(self, classname, nodeid, action, params, creator,
+            creation):
         if creator is None:
             creator = self.getuid()
         if creation is None:
@@ -304,32 +324,59 @@ class Database(hyperdb.Database, roundupdb.Database):
         self.journals.setdefault(classname, {}).setdefault(nodeid,
             []).append((nodeid, creation, creator, action, params))
 
-    def setjournal(self, classname, nodeid, journal):
+    def doSetJournal(self, classname, nodeid, journal):
         self.journals.setdefault(classname, {})[nodeid] = journal
 
     def getjournal(self, classname, nodeid):
-        return self.journals.get(classname, {}).get(nodeid, [])
+        # our journal result
+        res = []
+
+        # add any journal entries for transactions not committed to the
+        # database
+        for method, args in self.transactions:
+            if method != self.doSaveJournal:
+                continue
+            (cache_classname, cache_nodeid, cache_action, cache_params,
+                cache_creator, cache_creation) = args
+            if cache_classname == classname and cache_nodeid == nodeid:
+                if not cache_creator:
+                    cache_creator = self.getuid()
+                if not cache_creation:
+                    cache_creation = date.Date()
+                res.append((cache_nodeid, cache_creation, cache_creator,
+                    cache_action, cache_params))
+        try:
+            res += self.journals.get(classname, {})[nodeid]
+        except KeyError:
+            if res: return res
+            raise IndexError, nodeid
+        return res
 
     def pack(self, pack_before):
-        TODO
-
-    #
-    # Basic transaction support
-    #
-    def commit(self, fail_ok=False):
-        pass
-
-    def rollback(self):
-        TODO
-
-    def close(self):
-        pass
+        """ Delete all journal entries except "create" before 'pack_before'.
+        """
+        pack_before = pack_before.serialise()
+        for classname in self.journals:
+            db = self.journals[classname]
+            for key in db:
+                # get the journal for this db entry
+                l = []
+                last_set_entry = None
+                for entry in db[key]:
+                    # unpack the entry
+                    (nodeid, date_stamp, self.journaltag, action,
+                        params) = entry
+                    date_stamp = date_stamp.serialise()
+                    # if the entry is after the pack date, _or_ the initial
+                    # create entry, then it stays
+                    if date_stamp > pack_before or action == 'create':
+                        l.append(entry)
+                db[key] = l
 
 class Class(back_anydbm.Class):
-    def getnodeids(self, db=None, retired=None):
-        return self.db.getclassdb(self.classname).keys()
+    pass
 
-class FileClass(back_anydbm.Class):
+class FileClass(back_anydbm.FileClass):
     def __init__(self, db, classname, **properties):
         if not properties.has_key('content'):
             properties['content'] = hyperdb.String(indexme='yes')
@@ -337,8 +384,27 @@ class FileClass(back_anydbm.Class):
             properties['type'] = hyperdb.String()
         back_anydbm.Class.__init__(self, db, classname, **properties)
 
-    def getnodeids(self, db=None, retired=None):
-        return self.db.getclassdb(self.classname).keys()
+    def export_files(self, dirname, nodeid):
+        dest = self.exportFilename(dirname, nodeid)
+        ensureParentsExist(dest)
+        f = open(dest, 'wb')
+        f.write(self.db.files[self.classname, nodeid, None])
+        f.close()
+
+    def import_files(self, dirname, nodeid):
+        source = self.exportFilename(dirname, nodeid)
+        f = open(source, 'rb')
+        self.db.files[self.classname, nodeid, None] = f.read()
+        f.close()
+        mime_type = None
+        props = self.getprops()
+        if props.has_key('type'):
+            mime_type = self.get(nodeid, 'type')
+        if not mime_type:
+            mime_type = self.default_mime_type
+        if props['content'].indexme:
+            self.db.indexer.add_text((self.classname, nodeid, 'content'),
+                self.get(nodeid, 'content'), mime_type)
 
 # deviation from spec - was called ItemClass
 class IssueClass(Class, roundupdb.IssueClass):