Code

more fixes to search permissions:
[roundup.git] / test / memorydb.py
index ea76be7dd458e5d735766245ec452615430591d5..cd830cbd8acee32bbbe307c25cc7c959da535ddd 100644 (file)
@@ -11,20 +11,24 @@ from roundup import password
 from roundup import configuration
 from roundup.backends import back_anydbm
 from roundup.backends import indexer_dbm
+from roundup.backends import sessions_dbm
 from roundup.backends import indexer_common
 from roundup.hyperdb import *
+from roundup.support import ensureParentsExist
 
-def new_config():
+def new_config(debug=False):
     config = configuration.CoreConfig()
     config.DATABASE = "db"
     #config.logging = MockNull()
     # these TRACKER_WEB and MAIL_DOMAIN values are used in mailgw tests
+    if debug:
+        config.LOGGING_LEVEL = "DEBUG"
     config.MAIL_DOMAIN = "your.tracker.email.domain.example"
     config.TRACKER_WEB = "http://tracker.example/cgi-bin/roundup.cgi/bugs/"
     return config
 
-def create(journaltag, create=True):
-    db = Database(new_config(), journaltag)
+def create(journaltag, create=True, debug=False):
+    db = Database(new_config(debug), journaltag)
 
     # load standard schema
     schema = os.path.join(os.path.dirname(__file__),
@@ -114,6 +118,8 @@ class BasicDatabase(dict):
     def get(self, infoid, value, default=None):
         return self[infoid].get(value, default)
     def getall(self, infoid):
+        if infoid not in self:
+            raise KeyError(infoid)
         return self[infoid]
     def set(self, infoid, **newvalues):
         self[infoid].update(newvalues)
@@ -130,10 +136,10 @@ class BasicDatabase(dict):
     def clean(self):
         pass
 
-class Sessions(BasicDatabase):
+class Sessions(BasicDatabase, sessions_dbm.Sessions):
     name = 'sessions'
 
-class OneTimeKeys(BasicDatabase):
+class OneTimeKeys(BasicDatabase, sessions_dbm.Sessions):
     name = 'otks'
 
 class Indexer(indexer_dbm.Indexer):
@@ -154,8 +160,12 @@ class Indexer(indexer_dbm.Indexer):
 
     def save_index(self):
         pass
+    def force_reindex(self):
+        # TODO I'm concerned that force_reindex may not be tested by
+        # testForcedReindexing if the functionality can just be removed
+        pass
 
-class Database(hyperdb.Database, roundupdb.Database):
+class Database(back_anydbm.Database):
     """A database for storing records containing flexible data types.
 
     Transaction stuff TODO:
@@ -171,6 +181,7 @@ class Database(hyperdb.Database, roundupdb.Database):
         self.ids = {}
         self.journals = {}
         self.files = {}
+        self.tx_files = {}
         self.security = security.Security(self)
         self.stats = {'cache_hits': 0, 'cache_misses': 0, 'get_items': 0,
             'filtering': 0}
@@ -178,13 +189,19 @@ class Database(hyperdb.Database, roundupdb.Database):
         self.otks = OneTimeKeys()
         self.indexer = Indexer(self)
 
+        # anydbm bits
+        self.cache = {}         # cache of nodes loaded or created
+        self.dirtynodes = {}    # keep track of the dirty nodes by class
+        self.newnodes = {}      # keep track of the new nodes by class
+        self.destroyednodes = {}# keep track of the destroyed nodes by class
+        self.transactions = []
 
     def filename(self, classname, nodeid, property=None, create=0):
         shutil.copyfile(__file__, __file__+'.dummy')
         return __file__+'.dummy'
 
     def filesize(self, classname, nodeid, property=None, create=0):
-        return len(self.getnode(classname, nodeid)[property or 'content'])
+        return len(self.getfile(classname, nodeid, property))
 
     def post_init(self):
         pass
@@ -205,13 +222,31 @@ class Database(hyperdb.Database, roundupdb.Database):
         return '<memorydb instance at %x>'%id(self)
 
     def storefile(self, classname, nodeid, property, content):
-        self.files[classname, nodeid, property] = content
+        self.tx_files[classname, nodeid, property] = content
+        self.transactions.append((self.doStoreFile, (classname, nodeid,
+            property)))
 
     def getfile(self, classname, nodeid, property):
+        if (classname, nodeid, property) in self.tx_files:
+            return self.tx_files[classname, nodeid, property]
         return self.files[classname, nodeid, property]
 
+    def doStoreFile(self, classname, nodeid, property, **databases):
+        self.files[classname, nodeid, property] = self.tx_files[classname, nodeid, property]
+        return (classname, nodeid)
+
+    def rollbackStoreFile(self, classname, nodeid, property, **databases):
+        del self.tx_files[classname, nodeid, property]
+
     def numfiles(self):
-        return len(self.files)
+        return len(self.files) + len(self.tx_files)
+
+    def close(self):
+        self.clearCache()
+        self.tx_files = {}
+        # kill the schema too
+        self.classes = {}
+        # just keep the .items
 
     #
     # Classes
@@ -227,8 +262,9 @@ class Database(hyperdb.Database, roundupdb.Database):
         if self.classes.has_key(cn):
             raise ValueError, cn
         self.classes[cn] = cl
-        self.items[cn] = cldb()
-        self.ids[cn] = 0
+        if cn not in self.items:
+            self.items[cn] = cldb()
+            self.ids[cn] = 0
 
         # add default Edit and View permissions
         self.security.addPermission(name="Create", klass=cn,
@@ -260,12 +296,15 @@ class Database(hyperdb.Database, roundupdb.Database):
     def clear(self):
         self.items = {}
 
-    def getclassdb(self, classname):
+    def getclassdb(self, classname, mode='r'):
         """ grab a connection to the class db that will be used for
             multiple actions
         """
         return self.items[classname]
 
+    def getCachedJournalDB(self, classname):
+        return self.journals.setdefault(classname, {})
+
     #
     # Node IDs
     #
@@ -273,39 +312,13 @@ class Database(hyperdb.Database, roundupdb.Database):
         self.ids[classname] += 1
         return str(self.ids[classname])
     def setid(self, classname, id):
-        self.ids[classname] = id
-
-    #
-    # Nodes
-    #
-    def addnode(self, classname, nodeid, node):
-        self.getclassdb(classname)[nodeid] = node
-
-    def setnode(self, classname, nodeid, node):
-        self.getclassdb(classname)[nodeid] = node
-
-    def getnode(self, classname, nodeid, db=None):
-        if db is not None:
-            return db[nodeid]
-        d = self.getclassdb(classname)
-        if nodeid not in d:
-            raise IndexError(nodeid)
-        return d[nodeid]
-
-    def destroynode(self, classname, nodeid):
-        del self.getclassdb(classname)[nodeid]
-
-    def hasnode(self, classname, nodeid):
-        return nodeid in self.getclassdb(classname)
-
-    def countnodes(self, classname, db=None):
-        return len(self.getclassdb(classname))
+        self.ids[classname] = int(id)
 
     #
     # Journal
     #
-    def addjournal(self, classname, nodeid, action, params, creator=None,
-            creation=None):
+    def doSaveJournal(self, classname, nodeid, action, params, creator,
+            creation):
         if creator is None:
             creator = self.getuid()
         if creation is None:
@@ -313,35 +326,59 @@ class Database(hyperdb.Database, roundupdb.Database):
         self.journals.setdefault(classname, {}).setdefault(nodeid,
             []).append((nodeid, creation, creator, action, params))
 
-    def setjournal(self, classname, nodeid, journal):
+    def doSetJournal(self, classname, nodeid, journal):
         self.journals.setdefault(classname, {})[nodeid] = journal
 
     def getjournal(self, classname, nodeid):
-        return self.journals.get(classname, {}).get(nodeid, [])
+        # our journal result
+        res = []
+
+        # add any journal entries for transactions not committed to the
+        # database
+        for method, args in self.transactions:
+            if method != self.doSaveJournal:
+                continue
+            (cache_classname, cache_nodeid, cache_action, cache_params,
+                cache_creator, cache_creation) = args
+            if cache_classname == classname and cache_nodeid == nodeid:
+                if not cache_creator:
+                    cache_creator = self.getuid()
+                if not cache_creation:
+                    cache_creation = date.Date()
+                res.append((cache_nodeid, cache_creation, cache_creator,
+                    cache_action, cache_params))
+        try:
+            res += self.journals.get(classname, {})[nodeid]
+        except KeyError:
+            if res: return res
+            raise IndexError, nodeid
+        return res
 
     def pack(self, pack_before):
-        TODO
-
-    #
-    # Basic transaction support
-    #
-    def commit(self, fail_ok=False):
-        pass
-
-    def rollback(self):
-        TODO
-
-    def close(self):
-        pass
+        """ Delete all journal entries except "create" before 'pack_before'.
+        """
+        pack_before = pack_before.serialise()
+        for classname in self.journals:
+            db = self.journals[classname]
+            for key in db:
+                # get the journal for this db entry
+                l = []
+                last_set_entry = None
+                for entry in db[key]:
+                    # unpack the entry
+                    (nodeid, date_stamp, self.journaltag, action,
+                        params) = entry
+                    date_stamp = date_stamp.serialise()
+                    # if the entry is after the pack date, _or_ the initial
+                    # create entry, then it stays
+                    if date_stamp > pack_before or action == 'create':
+                        l.append(entry)
+                db[key] = l
 
 class Class(back_anydbm.Class):
-    def getnodeids(self, db=None, retired=None):
-        d = self.db.getclassdb(self.classname)
-        if retired is None:
-            return d.keys()
-        return [k for k in d if d[k].get(self.db.RETIRED_FLAG, False) == retired]
+    pass
 
-class FileClass(back_anydbm.Class):
+class FileClass(back_anydbm.FileClass):
     def __init__(self, db, classname, **properties):
         if not properties.has_key('content'):
             properties['content'] = hyperdb.String(indexme='yes')
@@ -349,8 +386,27 @@ class FileClass(back_anydbm.Class):
             properties['type'] = hyperdb.String()
         back_anydbm.Class.__init__(self, db, classname, **properties)
 
-    def getnodeids(self, db=None, retired=None):
-        return self.db.getclassdb(self.classname).keys()
+    def export_files(self, dirname, nodeid):
+        dest = self.exportFilename(dirname, nodeid)
+        ensureParentsExist(dest)
+        f = open(dest, 'wb')
+        f.write(self.db.files[self.classname, nodeid, None])
+        f.close()
+
+    def import_files(self, dirname, nodeid):
+        source = self.exportFilename(dirname, nodeid)
+        f = open(source, 'rb')
+        self.db.files[self.classname, nodeid, None] = f.read()
+        f.close()
+        mime_type = None
+        props = self.getprops()
+        if props.has_key('type'):
+            mime_type = self.get(nodeid, 'type')
+        if not mime_type:
+            mime_type = self.default_mime_type
+        if props['content'].indexme:
+            self.db.indexer.add_text((self.classname, nodeid, 'content'),
+                self.get(nodeid, 'content'), mime_type)
 
 # deviation from spec - was called ItemClass
 class IssueClass(Class, roundupdb.IssueClass):