Code

use the upload-supplied content-type if there is one
[roundup.git] / roundup / backends / rdbms_common.py
index 4c5760cdc56d4f3cb4594cb3c8a6ea0395460fc4..d9e2d582b349ce80246c6fabf125a0caa81c3358 100644 (file)
@@ -1,4 +1,4 @@
-# $Id: rdbms_common.py,v 1.66 2003-10-25 22:53:26 richard Exp $
+# $Id: rdbms_common.py,v 1.73 2004-01-20 03:58:38 richard Exp $
 ''' Relational database (SQL) backend common code.
 
 Basics:
@@ -69,13 +69,13 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
         self.lockfile = None
 
         # open a connection to the database, creating the "conn" attribute
-        self.open_connection()
+        self.sql_open_connection()
 
     def clearCache(self):
         self.cache = {}
         self.cache_lru = []
 
-    def open_connection(self):
+    def sql_open_connection(self):
         ''' Open a connection to the database, creating it if necessary
         '''
         raise NotImplemented
@@ -93,7 +93,12 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
     def sql_fetchone(self):
         ''' Fetch a single row. If there's nothing to fetch, return None.
         '''
-        raise NotImplemented
+        return self.cursor.fetchone()
+
+    def sql_fetchall(self):
+        ''' Fetch all rows. If there's nothing to fetch, return [].
+        '''
+        return self.cursor.fetchall()
 
     def sql_stringquote(self, value):
         ''' Quote the string so it's safe to put in the 'sql quotes'
@@ -103,12 +108,14 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
     def save_dbschema(self, schema):
         ''' Save the schema definition that the database currently implements
         '''
-        raise NotImplemented
+        s = repr(self.database_schema)
+        self.sql('insert into schema values (%s)', (s,))
 
     def load_dbschema(self):
         ''' Load the schema definition that the database currently implements
         '''
-        raise NotImplemented
+        self.cursor.execute('select schema from schema')
+        return eval(self.cursor.fetchone()[0])
 
     def post_init(self):
         ''' Called once the schema initialisation has finished.
@@ -148,19 +155,7 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
         self.conn.commit()
 
     def refresh_database(self):
-        # now detect changes in the schema
-        for classname, spec in self.classes.items():
-            dbspec = self.database_schema[classname]
-            self.update_class(spec, dbspec, force=1)
-            self.database_schema[classname] = spec.schema()
-        # update the database version of the schema
-        self.sql('delete from schema')
-        self.save_dbschema(self.database_schema)
-        # reindex the db 
-        self.reindex()
-        # commit
-        self.conn.commit()
-
+        self.post_init()
 
     def reindex(self):
         for klass in self.classes.values():
@@ -190,6 +185,7 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
     def update_class(self, spec, old_spec, force=0):
         ''' Determine the differences between the current spec and the
             database version of the spec, and update where necessary.
+
             If 'force' is true, update the database anyway.
         '''
         new_has = spec.properties.has_key
@@ -207,25 +203,15 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
         old_has = {}
         for name,prop in old_spec[1]:
             old_has[name] = 1
-            if (force or not new_has(name)) and isinstance(prop, Multilink):
-                # it's a multilink, and it's been removed - drop the old
-                # table. First drop indexes.
-                index_sqls = [ 'drop index %s_%s_l_idx'%(spec.classname, ml),
-                               'drop index %s_%s_n_idx'%(spec.classname, ml) ]
-                for index_sql in index_sqls:
-                    if __debug__:
-                        print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
-                    try:
-                        self.cursor.execute(index_sql)
-                    except:
-                        # The database may not actually have any indexes.
-                        # assume the worst.
-                        pass
-                sql = 'drop table %s_%s'%(spec.classname, prop)
-                if __debug__:
-                    print >>hyperdb.DEBUG, 'update_class', (self, sql)
-                self.cursor.execute(sql)
+            if new_has(name) or not isinstance(prop, Multilink):
                 continue
+            # it's a multilink, and it's been removed - drop the old
+            # table. First drop indexes.
+            self.drop_multilink_table_indexes(spec.classname, ml)
+            sql = 'drop table %s_%s'%(spec.classname, prop)
+            if __debug__:
+                print >>hyperdb.DEBUG, 'update_class', (self, sql)
+            self.cursor.execute(sql)
         old_has = old_has.has_key
 
         # now figure how we populate the new table
@@ -252,20 +238,8 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
         self.cursor.execute(sql)
         olddata = self.cursor.fetchall()
 
-        # drop the old table indexes first
-        index_sqls = [ 'drop index _%s_id_idx'%cn,
-                       'drop index _%s_retired_idx'%cn ]
-        if old_spec[0]:
-            index_sqls.append('drop index _%s_%s_idx'%(cn, old_spec[0]))
-        for index_sql in index_sqls:
-            if __debug__:
-                print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
-            try:
-                self.cursor.execute(index_sql)
-            except:
-                # The database may not actually have any indexes.
-                # assume the worst.
-                pass
+        # TODO: update all the other index dropping code
+        self.drop_class_table_indexes(cn, old_spec[0])
 
         # drop the old table
         self.cursor.execute('drop table _%s'%cn)
@@ -300,6 +274,13 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
             print >>hyperdb.DEBUG, 'create_class', (self, sql)
         self.cursor.execute(sql)
 
+        self.create_class_table_indexes(spec)
+
+        return cols, mls
+
+    def create_class_table_indexes(self, spec):
+        ''' create the class table for the given spec
+        '''
         # create id index
         index_sql1 = 'create index _%s_id_idx on _%s(id)'%(
                         spec.classname, spec.classname)
@@ -326,7 +307,22 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
                 print >>hyperdb.DEBUG, 'create_index', (self, index_sql3)
             self.cursor.execute(index_sql3)
 
-        return cols, mls
+    def drop_class_table_indexes(self, cn, key):
+        # drop the old table indexes first
+        l = ['_%s_id_idx'%cn, '_%s_retired_idx'%cn]
+        if key:
+            # key prop too?
+            l.append('_%s_%s_idx'%(cn, key))
+
+        # TODO: update all the other index dropping code
+        table_name = '_%s'%cn
+        for index_name in l:
+            if not self.sql_index_exists(table_name, index_name):
+                continue
+            index_sql = 'drop index '+index_name
+            if __debug__:
+                print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
+            self.cursor.execute(index_sql)
 
     def create_journal_table(self, spec):
         ''' create the journal table for a class given the spec and 
@@ -339,7 +335,9 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
         if __debug__:
             print >>hyperdb.DEBUG, 'create_class', (self, sql)
         self.cursor.execute(sql)
+        self.create_journal_table_indexes(spec)
 
+    def create_journal_table_indexes(self, spec):
         # index on nodeid
         index_sql = 'create index %s_journ_idx on %s__journal(nodeid)'%(
                         spec.classname, spec.classname)
@@ -347,6 +345,15 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
             print >>hyperdb.DEBUG, 'create_index', (self, index_sql)
         self.cursor.execute(index_sql)
 
+    def drop_journal_table_indexes(self, classname):
+        index_name = '%s_journ_idx'%classname
+        if not self.sql_index_exists('%s__journal'%classname, index_name):
+            return
+        index_sql = 'drop index '+index_name
+        if __debug__:
+            print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
+        self.cursor.execute(index_sql)
+
     def create_multilink_table(self, spec, ml):
         ''' Create a multilink table for the "ml" property of the class
             given by the spec
@@ -357,7 +364,9 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
         if __debug__:
             print >>hyperdb.DEBUG, 'create_class', (self, sql)
         self.cursor.execute(sql)
+        self.create_multilink_table_indexes(spec, ml)
 
+    def create_multilink_table_indexes(self, spec, ml):
         # create index on linkid
         index_sql = 'create index %s_%s_l_idx on %s_%s(linkid)'%(
                         spec.classname, ml, spec.classname, ml)
@@ -372,6 +381,20 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
             print >>hyperdb.DEBUG, 'create_index', (self, index_sql)
         self.cursor.execute(index_sql)
 
+    def drop_multilink_table_indexes(self, classname, ml):
+        l = [
+            '%s_%s_l_idx'%(classname, ml),
+            '%s_%s_n_idx'%(classname, ml)
+        ]
+        table_name = '%s_%s'%(classname, ml)
+        for index_name in l:
+            if not self.sql_index_exists(table_name, index_name):
+                continue
+            index_sql = 'drop index %s'%index_name
+            if __debug__:
+                print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
+            self.cursor.execute(index_sql)
+
     def create_class(self, spec):
         ''' Create a database table according to the given spec.
         '''
@@ -401,45 +424,23 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
             if isinstance(prop, Multilink):
                 mls.append(propname)
 
-        index_sqls = [ 'drop index _%s_id_idx'%cn,
-                       'drop index _%s_retired_idx'%cn,
-                       'drop index %s_journ_idx'%cn ]
-        if spec[0]:
-            index_sqls.append('drop index _%s_%s_idx'%(cn, spec[0]))
-        for index_sql in index_sqls:
-            if __debug__:
-                print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
-            try:
-                self.cursor.execute(index_sql)
-            except:
-                # The database may not actually have any indexes.
-                # assume the worst.
-                pass
-
+        # drop class table and indexes
+        self.drop_class_table_indexes(cn, spec[0])
         sql = 'drop table _%s'%cn
         if __debug__:
             print >>hyperdb.DEBUG, 'drop_class', (self, sql)
         self.cursor.execute(sql)
 
+        # drop journal table and indexes
+        self.drop_journal_table_indexes(cn)
         sql = 'drop table %s__journal'%cn
         if __debug__:
             print >>hyperdb.DEBUG, 'drop_class', (self, sql)
         self.cursor.execute(sql)
 
         for ml in mls:
-            index_sqls = [ 
-                'drop index %s_%s_n_idx'%(cn, ml),
-                'drop index %s_%s_l_idx'%(cn, ml),
-            ]
-            for index_sql in index_sqls:
-                if __debug__:
-                    print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
-                try:
-                    self.cursor.execute(index_sql)
-                except:
-                    # The database may not actually have any indexes.
-                    # assume the worst.
-                    pass
+            # drop multilink table and indexes
+            self.drop_multilink_table_indexes(cn, ml)
             sql = 'drop table %s_%s'%(spec.classname, ml)
             if __debug__:
                 print >>hyperdb.DEBUG, 'drop_class', (self, sql)
@@ -741,7 +742,7 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
         for col in mls:
             # get the link ids
             sql = 'delete from %s_%s where nodeid=%s'%(classname, col, self.arg)
-            self.cursor.execute(sql, (nodeid,))
+            self.sql(sql, (nodeid,))
 
         # remove journal entries
         sql = 'delete from %s__journal where nodeid=%s'%(classname, self.arg)
@@ -800,8 +801,14 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
                 p = password.Password()
                 p.unpack(v)
                 d[k] = p
-            elif (isinstance(prop, Boolean) or isinstance(prop, Number)) and v is not None:
-                d[k]=float(v)
+            elif isinstance(prop, Boolean) and v is not None:
+                d[k] = int(v)
+            elif isinstance(prop, Number) and v is not None:
+                # try int first, then assume it's a float
+                try:
+                    d[k] = int(v)
+                except ValueError:
+                    d[k] = float(v)
             else:
                 d[k] = v
         return d
@@ -859,12 +866,6 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
         self.save_journal(classname, cols, nodeid, journaldate,
             journaltag, action, params)
 
-    def save_journal(self, classname, cols, nodeid, journaldate,
-            journaltag, action, params):
-        ''' Save the journal entry to the database
-        '''
-        raise NotImplemented
-
     def getjournal(self, classname, nodeid):
         ''' get the journal for id
         '''
@@ -875,10 +876,36 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
         cols = ','.join('nodeid date tag action params'.split())
         return self.load_journal(classname, cols, nodeid)
 
+    def save_journal(self, classname, cols, nodeid, journaldate,
+            journaltag, action, params):
+        ''' Save the journal entry to the database
+        '''
+        # make the params db-friendly
+        params = repr(params)
+        entry = (nodeid, journaldate, journaltag, action, params)
+
+        # do the insert
+        a = self.arg
+        sql = 'insert into %s__journal (%s) values (%s,%s,%s,%s,%s)'%(classname,
+            cols, a, a, a, a, a)
+        if __debug__:
+            print >>hyperdb.DEBUG, 'addjournal', (self, sql, entry)
+        self.cursor.execute(sql, entry)
+
     def load_journal(self, classname, cols, nodeid):
         ''' Load the journal from the database
         '''
-        raise NotImplemented
+        # now get the journal entries
+        sql = 'select %s from %s__journal where nodeid=%s'%(cols, classname,
+            self.arg)
+        if __debug__:
+            print >>hyperdb.DEBUG, 'load_journal', (self, sql, nodeid)
+        self.cursor.execute(sql, (nodeid,))
+        res = []
+        for nodeid, date_stamp, user, action, params in self.cursor.fetchall():
+            params = eval(params)
+            res.append((nodeid, date.Date(date_stamp), user, action, params))
+        return res
 
     def pack(self, pack_before):
         ''' Delete all journal entries except "create" before 'pack_before'.
@@ -927,6 +954,9 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
         # clear out the transactions
         self.transactions = []
 
+    def sql_rollback(self):
+        self.conn.rollback()
+
     def rollback(self):
         ''' Reverse all actions from the current transaction.
 
@@ -936,8 +966,7 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
         if __debug__:
             print >>hyperdb.DEBUG, 'rollback', (self,)
 
-        # roll back
-        self.conn.rollback()
+        self.sql_rollback()
 
         # roll back "other" transaction stuff
         for method, args in self.transactions:
@@ -955,10 +984,13 @@ class Database(FileStorage, hyperdb.Database, roundupdb.Database):
         # return the classname, nodeid so we reindex this content
         return (classname, nodeid)
 
+    def sql_close(self):
+        self.conn.close()
+
     def close(self):
         ''' Close off the connection.
         '''
-        self.conn.close()
+        self.sql_close()
         if self.lockfile is not None:
             locking.release_lock(self.lockfile)
         if self.lockfile is not None:
@@ -1192,7 +1224,7 @@ class Class(hyperdb.Class):
             elif isinstance(proptype, hyperdb.Password):
                 value = str(value)
             l.append(repr(value))
-        l.append(self.is_retired(nodeid))
+        l.append(repr(self.is_retired(nodeid)))
         return l
 
     def import_list(self, propnames, proplist):
@@ -1249,6 +1281,9 @@ class Class(hyperdb.Class):
         if newid is None:
             newid = self.db.newid(self.classname)
 
+        # add the node and journal
+        self.db.addnode(self.classname, newid, d)
+
         # retire?
         if retire:
             # use the arg for __retired__ to cope with any odd database type
@@ -1259,9 +1294,6 @@ class Class(hyperdb.Class):
                 print >>hyperdb.DEBUG, 'retire', (self, sql, newid)
             self.db.cursor.execute(sql, (1, newid))
 
-        # add the node and journal
-        self.db.addnode(self.classname, newid, d)
-
         # extract the extraneous journalling gumpf and nuke it
         if d.has_key('creator'):
             creator = d['creator']
@@ -1329,16 +1361,6 @@ class Class(hyperdb.Class):
 
         return d[propname]
 
-    def getnode(self, nodeid, cache=1):
-        ''' Return a convenience wrapper for the node.
-
-        'nodeid' must be the id of an existing node of this class or an
-        IndexError is raised.
-
-        'cache' exists for backwards compatibility, and is not used.
-        '''
-        return Node(self, nodeid)
-
     def set(self, nodeid, **propvalues):
         '''Modify a property on an existing node of this class.
         
@@ -1739,8 +1761,8 @@ class Class(hyperdb.Class):
         'propspec' consists of keyword args propname=nodeid or
                    propname={nodeid:1, }
         'propname' must be the name of a property in this class, or a
-        KeyError is raised.  That property must be a Link or Multilink
-        property, or a TypeError is raised.
+                   KeyError is raised.  That property must be a Link or
+                   Multilink property, or a TypeError is raised.
 
         Any node in this class whose 'propname' property links to any of the
         nodeids will be returned. Used by the full text indexing, which knows
@@ -1766,12 +1788,14 @@ class Class(hyperdb.Class):
                 raise TypeError, "'%s' not a Link/Multilink property"%propname
 
         # first, links
-        where = []
-        allvalues = ()
         a = self.db.arg
+        where = ['__retired__ <> %s'%a]
+        allvalues = (1,)
         for prop, values in propspec:
             if not isinstance(props[prop], hyperdb.Link):
                 continue
+            if type(values) is type({}) and len(values) == 1:
+                values = values.keys()[0]
             if type(values) is type(''):
                 allvalues += (values,)
                 where.append('_%s = %s'%(prop, a))
@@ -1797,7 +1821,8 @@ class Class(hyperdb.Class):
                 s = ','.join([a]*len(values))
             tables.append('select nodeid from %s_%s where linkid in (%s)'%(
                 self.classname, prop, s))
-        sql = '\nunion\n'.join(tables)
+
+        sql = '\nintersect\n'.join(tables)
         self.db.sql(sql, allvalues)
         l = [x[0] for x in self.db.sql_fetchall()]
         if __debug__:
@@ -1816,14 +1841,16 @@ class Class(hyperdb.Class):
         args = []
         for propname in requirements.keys():
             prop = self.properties[propname]
-            if isinstance(not prop, String):
+            if not isinstance(prop, String):
                 raise TypeError, "'%s' not a String property"%propname
             where.append(propname)
             args.append(requirements[propname].lower())
 
         # generate the where clause
         s = ' and '.join(['lower(_%s)=%s'%(col, self.db.arg) for col in where])
-        sql = 'select id from _%s where %s'%(self.classname, s)
+        sql = 'select id from _%s where %s and __retired__=%s'%(self.classname,
+            s, self.db.arg)
+        args.append(0)
         self.db.sql(sql, tuple(args))
         l = [x[0] for x in self.db.sql_fetchall()]
         if __debug__:
@@ -2050,13 +2077,15 @@ class Class(hyperdb.Class):
         args = tuple(args)
         if __debug__:
             print >>hyperdb.DEBUG, 'filter', (self, sql, args)
-        self.db.cursor.execute(sql, args)
-        l = self.db.cursor.fetchall()
+        if args:
+            self.db.cursor.execute(sql, args)
+        else:
+            # psycopg doesn't like empty args
+            self.db.cursor.execute(sql)
+        l = self.db.sql_fetchall()
 
         # return the IDs (the first column)
-        # XXX The filter(None, l) bit is sqlite-specific... if there's _NO_
-        # XXX matches to a fetch, it returns NULL instead of nothing!?!
-        return filter(None, [row[0] for row in l])
+        return [row[0] for row in l]
 
     def count(self):
         '''Get the number of nodes in this class.