Code

Add config-option "nosy" to messages_to_author setting in [nosy] section
[roundup.git] / roundup / backends / back_mysql.py
index 6aa3a633d0b1c288e46d2a6fdf60b213a0c0db78..875cc53e94d2aa316307e78504ca4ee2c9510fc0 100644 (file)
@@ -12,7 +12,7 @@
 How to implement AUTO_INCREMENT:
 
 mysql> create table foo (num integer auto_increment primary key, name
-varchar(255)) AUTO_INCREMENT=1 type=InnoDB;
+varchar(255)) AUTO_INCREMENT=1 ENGINE=InnoDB;
 
 ql> insert into foo (name) values ('foo5');
 Query OK, 1 row affected (0.00 sec)
@@ -38,15 +38,24 @@ from roundup.backends import rdbms_common
 import MySQLdb
 import os, shutil
 from MySQLdb.constants import ER
+import logging
 
+def connection_dict(config, dbnamestr=None):
+    d = rdbms_common.connection_dict(config, dbnamestr)
+    if d.has_key('password'):
+        d['passwd'] = d['password']
+        del d['password']
+    if d.has_key('port'):
+        d['port'] = int(d['port'])
+    return d
 
 def db_nuke(config):
     """Clear all database contents and drop database itself"""
     if db_exists(config):
-        conn = MySQLdb.connect(config.MYSQL_DBHOST, config.MYSQL_DBUSER,
-            config.MYSQL_DBPASSWORD)
+        kwargs = connection_dict(config)
+        conn = MySQLdb.connect(**kwargs)
         try:
-            conn.select_db(config.MYSQL_DBNAME)
+            conn.select_db(config.RDBMS_NAME)
         except:
             # no, it doesn't exist
             pass
@@ -54,13 +63,14 @@ def db_nuke(config):
             cursor = conn.cursor()
             cursor.execute("SHOW TABLES")
             tables = cursor.fetchall()
+            # stupid MySQL bug requires us to drop all the tables first
             for table in tables:
-                if __debug__:
-                    print >>hyperdb.DEBUG, 'DROP TABLE %s'%table[0]
-                cursor.execute("DROP TABLE %s"%table[0])
-            if __debug__:
-                print >>hyperdb.DEBUG, "DROP DATABASE %s"%config.MYSQL_DBNAME
-            cursor.execute("DROP DATABASE %s"%config.MYSQL_DBNAME)
+                command = 'DROP TABLE `%s`'%table[0]
+                logging.debug(command)
+                cursor.execute(command)
+            command = "DROP DATABASE %s"%config.RDBMS_NAME
+            logging.info(command)
+            cursor.execute(command)
             conn.commit()
         conn.close()
 
@@ -69,42 +79,35 @@ def db_nuke(config):
 
 def db_create(config):
     """Create the database."""
-    conn = MySQLdb.connect(config.MYSQL_DBHOST, config.MYSQL_DBUSER,
-        config.MYSQL_DBPASSWORD)
+    kwargs = connection_dict(config)
+    conn = MySQLdb.connect(**kwargs)
     cursor = conn.cursor()
-    if __debug__:
-        print >>hyperdb.DEBUG, "CREATE DATABASE %s"%config.MYSQL_DBNAME
-    cursor.execute("CREATE DATABASE %s"%config.MYSQL_DBNAME)
+    command = "CREATE DATABASE %s"%config.RDBMS_NAME
+    logging.info(command)
+    cursor.execute(command)
     conn.commit()
     conn.close()
 
 def db_exists(config):
     """Check if database already exists."""
-    conn = MySQLdb.connect(config.MYSQL_DBHOST, config.MYSQL_DBUSER,
-        config.MYSQL_DBPASSWORD)
-#    tables = None
+    kwargs = connection_dict(config)
+    conn = MySQLdb.connect(**kwargs)
     try:
         try:
-            conn.select_db(config.MYSQL_DBNAME)
-#            cursor = conn.cursor()
-#            cursor.execute("SHOW TABLES")
-#            tables = cursor.fetchall()
-#            if __debug__:
-#                print >>hyperdb.DEBUG, "tables %s"%(tables,)
+            conn.select_db(config.RDBMS_NAME)
         except MySQLdb.OperationalError:
-            if __debug__:
-                print >>hyperdb.DEBUG, "no database '%s'"%config.MYSQL_DBNAME
             return 0
     finally:
         conn.close()
-    if __debug__:
-        print >>hyperdb.DEBUG, "database '%s' exists"%config.MYSQL_DBNAME
     return 1
 
 
 class Database(Database):
     arg = '%s'
 
+    # used by some code to switch styles of query
+    implements_intersect = 0
+
     # Backend for MySQL to use.
     # InnoDB is faster, but if you're running <4.0.16 then you'll need to
     # use BDB to pass all unit tests.
@@ -112,12 +115,12 @@ class Database(Database):
     #mysql_backend = 'BDB'
 
     hyperdb_to_sql_datatypes = {
-        hyperdb.String : 'VARCHAR(255)',
+        hyperdb.String : 'TEXT',
         hyperdb.Date   : 'DATETIME',
         hyperdb.Link   : 'INTEGER',
         hyperdb.Interval  : 'VARCHAR(255)',
         hyperdb.Password  : 'VARCHAR(255)',
-        hyperdb.Boolean   : 'INTEGER',
+        hyperdb.Boolean   : 'BOOL',
         hyperdb.Number    : 'REAL',
     }
 
@@ -126,23 +129,25 @@ class Database(Database):
         # no fractional seconds for MySQL
         hyperdb.Date   : lambda x: x.formal(sep=' '),
         hyperdb.Link   : int,
-        hyperdb.Interval  : lambda x: x.serialise(),
+        hyperdb.Interval  : str,
         hyperdb.Password  : str,
         hyperdb.Boolean   : int,
         hyperdb.Number    : lambda x: x,
+        hyperdb.Multilink : lambda x: x,    # used in journal marshalling
     }
 
     def sql_open_connection(self):
-        db = getattr(self.config, 'MYSQL_DATABASE')
+        kwargs = connection_dict(self.config, 'db')
+        self.log_info('open database %r'%(kwargs['db'],))
         try:
-            conn = MySQLdb.connect(*db)
+            conn = MySQLdb.connect(**kwargs)
         except MySQLdb.OperationalError, message:
             raise DatabaseError, message
         cursor = conn.cursor()
         cursor.execute("SET AUTOCOMMIT=0")
-        cursor.execute("BEGIN")
+        cursor.execute("START TRANSACTION")
         return (conn, cursor)
-    
+
     def open_connection(self):
         # make sure the database actually exists
         if not db_exists(self.config):
@@ -159,115 +164,206 @@ class Database(Database):
             if message[0] != ER.NO_SUCH_TABLE:
                 raise DatabaseError, message
             self.init_dbschema()
-            self.sql("CREATE TABLE schema (schema TEXT) TYPE=%s"%
+            self.sql("CREATE TABLE `schema` (`schema` TEXT) ENGINE=%s"%
                 self.mysql_backend)
-            self.cursor.execute('''CREATE TABLE ids (name VARCHAR(255),
-                num INTEGER) TYPE=%s'''%self.mysql_backend)
-            self.cursor.execute('create index ids_name_idx on ids(name)')
+            self.sql('''CREATE TABLE ids (name VARCHAR(255),
+                num INTEGER) ENGINE=%s'''%self.mysql_backend)
+            self.sql('create index ids_name_idx on ids(name)')
             self.create_version_2_tables()
 
+    def load_dbschema(self):
+        ''' Load the schema definition that the database currently implements
+        '''
+        self.cursor.execute('select `schema` from `schema`')
+        schema = self.cursor.fetchone()
+        if schema:
+            self.database_schema = eval(schema[0])
+        else:
+            self.database_schema = {}
+
+    def save_dbschema(self):
+        ''' Save the schema definition that the database currently implements
+        '''
+        s = repr(self.database_schema)
+        self.sql('delete from `schema`')
+        self.sql('insert into `schema` values (%s)', (s,))
+
     def create_version_2_tables(self):
         # OTK store
-        self.cursor.execute('''CREATE TABLE otks (otk_key VARCHAR(255),
-            otk_value VARCHAR(255), otk_time FLOAT(20))
-            TYPE=%s'''%self.mysql_backend)
-        self.cursor.execute('CREATE INDEX otks_key_idx ON otks(otk_key)')
+        self.sql('''CREATE TABLE otks (otk_key VARCHAR(255),
+            otk_value TEXT, otk_time FLOAT(20))
+            ENGINE=%s'''%self.mysql_backend)
+        self.sql('CREATE INDEX otks_key_idx ON otks(otk_key)')
 
         # Sessions store
-        self.cursor.execute('''CREATE TABLE sessions (
-            session_key VARCHAR(255), session_time FLOAT(20),
-            session_value VARCHAR(255)) TYPE=%s'''%self.mysql_backend)
-        self.cursor.execute('''CREATE INDEX sessions_key_idx ON
+        self.sql('''CREATE TABLE sessions (session_key VARCHAR(255),
+            session_time FLOAT(20), session_value TEXT)
+            ENGINE=%s'''%self.mysql_backend)
+        self.sql('''CREATE INDEX sessions_key_idx ON
             sessions(session_key)''')
 
         # full-text indexing store
-        self.cursor.execute('''CREATE TABLE __textids (_class VARCHAR(255),
+        self.sql('''CREATE TABLE __textids (_class VARCHAR(255),
             _itemid VARCHAR(255), _prop VARCHAR(255), _textid INT)
-            TYPE=%s'''%self.mysql_backend)
-        self.cursor.execute('''CREATE TABLE __words (_word VARCHAR(30),
-            _textid INT) TYPE=%s'''%self.mysql_backend)
-        self.cursor.execute('CREATE INDEX words_word_ids ON __words(_word)')
+            ENGINE=%s'''%self.mysql_backend)
+        self.sql('''CREATE TABLE __words (_word VARCHAR(30),
+            _textid INT) ENGINE=%s'''%self.mysql_backend)
+        self.sql('CREATE INDEX words_word_ids ON __words(_word)')
+        self.sql('CREATE INDEX words_by_id ON __words (_textid)')
+        self.sql('CREATE UNIQUE INDEX __textids_by_props ON '
+                 '__textids (_class, _itemid, _prop)')
         sql = 'insert into ids (name, num) values (%s,%s)'%(self.arg, self.arg)
-        self.cursor.execute(sql, ('__textids', 1))
+        self.sql(sql, ('__textids', 1))
 
-    def add_actor_column(self):
-        ''' While we're adding the actor column, we need to update the
+    def add_new_columns_v2(self):
+        '''While we're adding the actor column, we need to update the
         tables to have the correct datatypes.'''
-        assert 0, 'FINISH ME!'
-
-        for spec in self.classes.values():
-            new_has = spec.properties.has_key
-            new_spec = spec.schema()
-            new_spec[1].sort()
-            old_spec[1].sort()
-            if not force and new_spec == old_spec:
-                # no changes
-                return 0
-
-            if __debug__:
-                print >>hyperdb.DEBUG, 'update_class FIRING'
-
-            # detect multilinks that have been removed, and drop their table
-            old_has = {}
-            for name,prop in old_spec[1]:
-                old_has[name] = 1
-                if new_has(name) or not isinstance(prop, hyperdb.Multilink):
+        for klass in self.classes.values():
+            cn = klass.classname
+            properties = klass.getprops()
+            old_spec = self.database_schema['tables'][cn]
+
+            # figure the non-Multilink properties to copy over
+            propnames = ['activity', 'creation', 'creator']
+
+            # figure actions based on data type
+            for name, s_prop in old_spec[1]:
+                # s_prop is a repr() string of a hyperdb type object
+                if s_prop.find('Multilink') == -1:
+                    if properties.has_key(name):
+                        propnames.append(name)
                     continue
-                # it's a multilink, and it's been removed - drop the old
-                # table. First drop indexes.
-                self.drop_multilink_table_indexes(spec.classname, ml)
-                sql = 'drop table %s_%s'%(spec.classname, prop)
-                if __debug__:
-                    print >>hyperdb.DEBUG, 'update_class', (self, sql)
-                self.cursor.execute(sql)
-            old_has = old_has.has_key
-
-            # now figure how we populate the new table
-            if adding_actor:
-                fetch = ['_activity', '_creation', '_creator']
-            else:
-                fetch = ['_actor', '_activity', '_creation', '_creator']
-            properties = spec.getprops()
-            for propname,x in new_spec[1]:
-                prop = properties[propname]
-                if isinstance(prop, hyperdb.Multilink):
-                    if force or not old_has(propname):
-                        # we need to create the new table
-                        self.create_multilink_table(spec, propname)
-                elif old_has(propname):
-                    # we copy this col over from the old table
-                    fetch.append('_'+propname)
+                tn = '%s_%s'%(cn, name)
+
+                if properties.has_key(name):
+                    # grabe the current values
+                    sql = 'select linkid, nodeid from %s'%tn
+                    self.sql(sql)
+                    rows = self.cursor.fetchall()
+
+                # drop the old table
+                self.drop_multilink_table_indexes(cn, name)
+                sql = 'drop table %s'%tn
+                self.sql(sql)
+
+                if properties.has_key(name):
+                    # re-create and populate the new table
+                    self.create_multilink_table(klass, name)
+                    sql = '''insert into %s (linkid, nodeid) values
+                        (%s, %s)'''%(tn, self.arg, self.arg)
+                    for linkid, nodeid in rows:
+                        self.sql(sql, (int(linkid), int(nodeid)))
+
+            # figure the column names to fetch
+            fetch = ['_%s'%name for name in propnames]
 
             # select the data out of the old table
             fetch.append('id')
             fetch.append('__retired__')
             fetchcols = ','.join(fetch)
-            cn = spec.classname
             sql = 'select %s from _%s'%(fetchcols, cn)
-            if __debug__:
-                print >>hyperdb.DEBUG, 'update_class', (self, sql)
-            self.cursor.execute(sql)
-            olddata = self.cursor.fetchall()
+            self.sql(sql)
+
+            # unserialise the old data
+            olddata = []
+            propnames = propnames + ['id', '__retired__']
+            cols = []
+            first = 1
+            for entry in self.cursor.fetchall():
+                l = []
+                olddata.append(l)
+                for i in range(len(propnames)):
+                    name = propnames[i]
+                    v = entry[i]
+
+                    if name in ('id', '__retired__'):
+                        if first:
+                            cols.append(name)
+                        l.append(int(v))
+                        continue
+                    if first:
+                        cols.append('_' + name)
+                    prop = properties[name]
+                    if isinstance(prop, Date) and v is not None:
+                        v = date.Date(v)
+                    elif isinstance(prop, Interval) and v is not None:
+                        v = date.Interval(v)
+                    elif isinstance(prop, Password) and v is not None:
+                        v = password.Password(encrypted=v)
+                    elif (isinstance(prop, Boolean) or
+                            isinstance(prop, Number)) and v is not None:
+                        v = float(v)
+
+                    # convert to new MySQL data type
+                    prop = properties[name]
+                    if v is not None:
+                        e = self.to_sql_value(prop.__class__)(v)
+                    else:
+                        e = None
+                    l.append(e)
+
+                    # Intervals store the seconds value too
+                    if isinstance(prop, Interval):
+                        if first:
+                            cols.append('__' + name + '_int__')
+                        if v is not None:
+                            l.append(v.as_seconds())
+                        else:
+                            l.append(e)
+                first = 0
 
-            # TODO: update all the other index dropping code
             self.drop_class_table_indexes(cn, old_spec[0])
 
             # drop the old table
-            self.cursor.execute('drop table _%s'%cn)
+            self.sql('drop table _%s'%cn)
 
             # create the new table
-            self.create_class_table(spec)
-
-            # do the insert of the old data - the new columns will have
-            # NULL values
-            args = ','.join([self.arg for x in fetch])
-            sql = 'insert into _%s (%s) values (%s)'%(cn, fetchcols, args)
-            if __debug__:
-                print >>hyperdb.DEBUG, 'update_class', (self, sql, olddata[0])
-            for entry in olddata:
-                self.cursor.execute(sql, tuple(entry))
+            self.create_class_table(klass)
 
-        return 1
+            # do the insert of the old data
+            args = ','.join([self.arg for x in cols])
+            cols = ','.join(cols)
+            sql = 'insert into _%s (%s) values (%s)'%(cn, cols, args)
+            for entry in olddata:
+                self.sql(sql, tuple(entry))
+
+            # now load up the old journal data to migrate it
+            cols = ','.join('nodeid date tag action params'.split())
+            sql = 'select %s from %s__journal'%(cols, cn)
+            self.sql(sql)
+
+            # data conversions
+            olddata = []
+            for nodeid, journaldate, journaltag, action, params in \
+                    self.cursor.fetchall():
+                #nodeid = int(nodeid)
+                journaldate = date.Date(journaldate)
+                #params = eval(params)
+                olddata.append((nodeid, journaldate, journaltag, action,
+                    params))
+
+            # drop journal table and indexes
+            self.drop_journal_table_indexes(cn)
+            sql = 'drop table %s__journal'%cn
+            self.sql(sql)
+
+            # re-create journal table
+            self.create_journal_table(klass)
+            dc = self.to_sql_value(hyperdb.Date)
+            for nodeid, journaldate, journaltag, action, params in olddata:
+                self.save_journal(cn, cols, nodeid, dc(journaldate),
+                    journaltag, action, params)
+
+            # make sure the normal schema update code doesn't try to
+            # change things
+            self.database_schema['tables'][cn] = klass.schema()
+
+    def fix_version_2_tables(self):
+        # Convert journal date column to TIMESTAMP, params column to TEXT
+        self._convert_journal_tables()
+
+        # Convert all String properties to TEXT
+        self._convert_string_properties()
 
     def __repr__(self):
         return '<myroundsql 0x%x>'%id(self)
@@ -279,17 +375,13 @@ class Database(Database):
         return self.cursor.fetchall()
 
     def sql_index_exists(self, table_name, index_name):
-        self.cursor.execute('show index from %s'%table_name)
+        self.sql('show index from %s'%table_name)
         for index in self.cursor.fetchall():
             if index[2] == index_name:
                 return 1
         return 0
 
-    def save_dbschema(self, schema):
-        s = repr(self.database_schema)
-        self.sql('INSERT INTO schema VALUES (%s)', (s,))
-    
-    def create_class_table(self, spec):
+    def create_class_table(self, spec, create_sequence=1):
         cols, mls = self.determine_columns(spec.properties.items())
 
         # add on our special columns
@@ -298,15 +390,55 @@ class Database(Database):
 
         # create the base table
         scols = ','.join(['%s %s'%x for x in cols])
-        sql = 'create table _%s (%s) type=%s'%(spec.classname, scols,
+        sql = 'create table _%s (%s) ENGINE=%s'%(spec.classname, scols,
             self.mysql_backend)
-        if __debug__:
-            print >>hyperdb.DEBUG, 'create_class', (self, sql)
-        self.cursor.execute(sql)
+        self.sql(sql)
 
         self.create_class_table_indexes(spec)
         return cols, mls
 
+    def create_class_table_indexes(self, spec):
+        ''' create the class table for the given spec
+        '''
+        # create __retired__ index
+        index_sql2 = 'create index _%s_retired_idx on _%s(__retired__)'%(
+                        spec.classname, spec.classname)
+        self.sql(index_sql2)
+
+        # create index for key property
+        if spec.key:
+            if isinstance(spec.properties[spec.key], String):
+                idx = spec.key + '(255)'
+            else:
+                idx = spec.key
+            index_sql3 = 'create index _%s_%s_idx on _%s(_%s)'%(
+                        spec.classname, spec.key,
+                        spec.classname, idx)
+            self.sql(index_sql3)
+
+        # TODO: create indexes on (selected?) Link property columns, as
+        # they're more likely to be used for lookup
+
+    def add_class_key_required_unique_constraint(self, cn, key):
+        # mysql requires sizes on TEXT indexes
+        prop = self.classes[cn].getprops()[key]
+        if isinstance(prop, String):
+            sql = '''create unique index _%s_key_retired_idx
+                on _%s(__retired__, _%s(255))'''%(cn, cn, key)
+        else:
+            sql = '''create unique index _%s_key_retired_idx
+                on _%s(__retired__, _%s)'''%(cn, cn, key)
+        self.sql(sql)
+
+    def create_class_table_key_index(self, cn, key):
+        # mysql requires sizes on TEXT indexes
+        prop = self.classes[cn].getprops()[key]
+        if isinstance(prop, String):
+            sql = 'create index _%s_%s_idx on _%s(_%s(255))'%(cn, key, cn, key)
+        else:
+            sql = 'create index _%s_%s_idx on _%s(_%s)'%(cn, key, cn, key)
+        self.sql(sql)
+
     def drop_class_table_indexes(self, cn, key):
         # drop the old table indexes first
         l = ['_%s_id_idx'%cn, '_%s_retired_idx'%cn]
@@ -318,21 +450,20 @@ class Database(Database):
             if not self.sql_index_exists(table_name, index_name):
                 continue
             index_sql = 'drop index %s on %s'%(index_name, table_name)
-            if __debug__:
-                print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
-            self.cursor.execute(index_sql)
+            self.sql(index_sql)
 
     def create_journal_table(self, spec):
+        ''' create the journal table for a class given the spec and
+            already-determined cols
+        '''
         # journal table
         cols = ','.join(['%s varchar'%x
             for x in 'nodeid date tag action params'.split()])
         sql = '''create table %s__journal (
-            nodeid integer, date timestamp, tag varchar(255),
-            action varchar(255), params varchar(255)) type=%s'''%(
+            nodeid integer, date datetime, tag varchar(255),
+            action varchar(255), params text) ENGINE=%s'''%(
             spec.classname, self.mysql_backend)
-        if __debug__:
-            print >>hyperdb.DEBUG, 'create_journal_table', (self, sql)
-        self.cursor.execute(sql)
+        self.sql(sql)
         self.create_journal_table_indexes(spec)
 
     def drop_journal_table_indexes(self, classname):
@@ -340,17 +471,13 @@ class Database(Database):
         if not self.sql_index_exists('%s__journal'%classname, index_name):
             return
         index_sql = 'drop index %s on %s__journal'%(index_name, classname)
-        if __debug__:
-            print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
-        self.cursor.execute(index_sql)
+        self.sql(index_sql)
 
     def create_multilink_table(self, spec, ml):
         sql = '''CREATE TABLE `%s_%s` (linkid VARCHAR(255),
-            nodeid VARCHAR(255)) TYPE=%s'''%(spec.classname, ml,
+            nodeid VARCHAR(255)) ENGINE=%s'''%(spec.classname, ml,
                 self.mysql_backend)
-        if __debug__:
-          print >>hyperdb.DEBUG, 'create_class', (self, sql)
-        self.cursor.execute(sql)
+        self.sql(sql)
         self.create_multilink_table_indexes(spec, ml)
 
     def drop_multilink_table_indexes(self, classname, ml):
@@ -362,10 +489,8 @@ class Database(Database):
         for index_name in l:
             if not self.sql_index_exists(table_name, index_name):
                 continue
-            index_sql = 'drop index %s on %s'%(index_name, table_name)
-            if __debug__:
-                print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
-            self.cursor.execute(index_sql)
+            sql = 'drop index %s on %s'%(index_name, table_name)
+            self.sql(sql)
 
     def drop_class_table_key_index(self, cn, key):
         table_name = '_%s'%cn
@@ -373,27 +498,21 @@ class Database(Database):
         if not self.sql_index_exists(table_name, index_name):
             return
         sql = 'drop index %s on %s'%(index_name, table_name)
-        if __debug__:
-            print >>hyperdb.DEBUG, 'drop_index', (self, sql)
-        self.cursor.execute(sql)
+        self.sql(sql)
 
     # old-skool id generation
     def newid(self, classname):
         ''' Generate a new id for the given class
         '''
-        # get the next ID
-        sql = 'select num from ids where name=%s'%self.arg
-        if __debug__:
-            print >>hyperdb.DEBUG, 'newid', (self, sql, classname)
-        self.cursor.execute(sql, (classname, ))
+        # get the next ID - "FOR UPDATE" will lock the row for us
+        sql = 'select num from ids where name=%s FOR UPDATE'%self.arg
+        self.sql(sql, (classname, ))
         newid = int(self.cursor.fetchone()[0])
 
         # update the counter
         sql = 'update ids set num=%s where name=%s'%(self.arg, self.arg)
         vals = (int(newid)+1, classname)
-        if __debug__:
-            print >>hyperdb.DEBUG, 'newid', (self, sql, vals)
-        self.cursor.execute(sql, vals)
+        self.sql(sql, vals)
 
         # return as string
         return str(newid)
@@ -405,228 +524,122 @@ class Database(Database):
         '''
         sql = 'update ids set num=%s where name=%s'%(self.arg, self.arg)
         vals = (int(setid)+1, classname)
-        if __debug__:
-            print >>hyperdb.DEBUG, 'setid', (self, sql, vals)
-        self.cursor.execute(sql, vals)
+        self.sql(sql, vals)
+
+    def clear(self):
+        rdbms_common.Database.clear(self)
+
+        # set the id counters to 0 (setid adds one) so we start at 1
+        for cn in self.classes.keys():
+            self.setid(cn, 0)
 
     def create_class(self, spec):
         rdbms_common.Database.create_class(self, spec)
         sql = 'insert into ids (name, num) values (%s, %s)'
         vals = (spec.classname, 1)
-        if __debug__:
-            print >>hyperdb.DEBUG, 'create_class', (self, sql, vals)
-        self.cursor.execute(sql, vals)
+        self.sql(sql, vals)
 
-class MysqlClass:
-    # we're overriding this method for ONE missing bit of functionality.
-    # look for "I can't believe it's not a toy RDBMS" below
-    def filter(self, search_matches, filterspec, sort=(None,None),
-            group=(None,None)):
-        '''Return a list of the ids of the active nodes in this class that
-        match the 'filter' spec, sorted by the group spec and then the
-        sort spec
+    def sql_commit(self, fail_ok=False):
+        ''' Actually commit to the database.
+        '''
+        self.log_info('commit')
+
+        # MySQL commits don't seem to ever fail, the latest update winning.
+        # makes you wonder why they have transactions...
+        self.conn.commit()
+
+        # open a new cursor for subsequent work
+        self.cursor = self.conn.cursor()
 
-        "filterspec" is {propname: value(s)}
+        # make sure we're in a new transaction and not autocommitting
+        self.sql("SET AUTOCOMMIT=0")
+        self.sql("START TRANSACTION")
 
-        "sort" and "group" are (dir, prop) where dir is '+', '-' or None
-        and prop is a prop name or None
+    def sql_close(self):
+        self.log_info('close')
+        try:
+            self.conn.close()
+        except MySQLdb.ProgrammingError, message:
+            if str(message) != 'closing a closed connection':
+                raise
+
+class MysqlClass:
 
-        "search_matches" is {nodeid: marker}
+    def supports_subselects(self):
+        # TODO: AFAIK its version dependent for MySQL
+        return False
 
-        The filter must match all properties specificed - but if the
-        property value to match is a list, any one of the values in the
-        list may match for that property to match.
+    def _subselect(self, classname, multilink_table):
+        ''' "I can't believe it's not a toy RDBMS"
+           see, even toy RDBMSes like gadfly and sqlite can do sub-selects...
         '''
-        # just don't bother if the full-text search matched diddly
-        if search_matches == {}:
-            return []
+        self.db.sql('select nodeid from %s'%multilink_table)
+        s = ','.join([x[0] for x in self.db.sql_fetchall()])
+        return '_%s.id not in (%s)'%(classname, s)
 
-        cn = self.classname
+    def create_inner(self, **propvalues):
+        try:
+            return rdbms_common.Class.create_inner(self, **propvalues)
+        except MySQLdb.IntegrityError, e:
+            self._handle_integrity_error(e, propvalues)
 
-        timezone = self.db.getUserTimezone()
+    def set_inner(self, nodeid, **propvalues):
+        try:
+            return rdbms_common.Class.set_inner(self, nodeid,
+                                                **propvalues)
+        except MySQLdb.IntegrityError, e:
+            self._handle_integrity_error(e, propvalues)
+
+    def _handle_integrity_error(self, e, propvalues):
+        ''' Handle a MySQL IntegrityError.
+
+        If the error is recognized, then it may be converted into an
+        alternative exception.  Otherwise, it is raised unchanged from
+        this function.'''
+
+        # There are checks in create_inner/set_inner to see if a node
+        # is being created with the same key as an existing node.
+        # But, there is a race condition -- we may pass those checks,
+        # only to find out that a parallel session has created the
+        # node by by the time we actually issue the SQL command to
+        # create the node.  Fortunately, MySQL gives us a unique error
+        # code for this situation, so we can detect it here and handle
+        # it appropriately.
+        # 
+        # The details of the race condition are as follows, where
+        # "X" is a classname, and the term "thread" is meant to
+        # refer generically to both threads and processes:
+        #
+        # Thread A                    Thread B
+        # --------                    --------
+        #                             read table for X
+        # create new X object
+        # commit
+        #                             create new X object
+        #
+        # In Thread B, the check in create_inner does not notice that
+        # the new X object is a duplicate of that committed in Thread
+        # A because MySQL's default "consistent nonlocking read"
+        # behavior means that Thread B sees a snapshot of the database
+        # at the point at which its transaction began -- which was
+        # before Thread A created the object.  However, the attempt
+        # to *write* to the table for X, creating a duplicate entry,
+        # triggers an error at the point of the write.
+        #
+        # If both A and B's transaction begins with creating a new X
+        # object, then this bug cannot occur because creating the
+        # object requires getting a new ID, and newid() locks the id
+        # table until the transaction is committed or rolledback.  So,
+        # B will block until A's commit is complete, and will not
+        # actually get its snapshot until A's transaction completes.
+        # But, if the transaction has begun prior to calling newid,
+        # then the snapshot has already been established.
+        if e[0] == ER.DUP_ENTRY:
+            key = propvalues[self.key]
+            raise ValueError, 'node with key "%s" exists' % key
+        # We don't know what this exception is; reraise it.
+        raise
         
-        # figure the WHERE clause from the filterspec
-        props = self.getprops()
-        frum = ['_'+cn]
-        where = []
-        args = []
-        a = self.db.arg
-        for k, v in filterspec.items():
-            propclass = props[k]
-            # now do other where clause stuff
-            if isinstance(propclass, Multilink):
-                tn = '%s_%s'%(cn, k)
-                if v in ('-1', ['-1']):
-                    # only match rows that have count(linkid)=0 in the
-                    # corresponding multilink table)
-
-                    # "I can't believe it's not a toy RDBMS"
-                    # see, even toy RDBMSes like gadfly and sqlite can do
-                    # sub-selects...
-                    self.db.sql('select nodeid from %s'%tn)
-                    s = ','.join([x[0] for x in self.db.sql_fetchall()])
-
-                    where.append('id not in (%s)'%s)
-                elif isinstance(v, type([])):
-                    frum.append(tn)
-                    s = ','.join([a for x in v])
-                    where.append('id=%s.nodeid and %s.linkid in (%s)'%(tn,tn,s))
-                    args = args + v
-                else:
-                    frum.append(tn)
-                    where.append('id=%s.nodeid and %s.linkid=%s'%(tn, tn, a))
-                    args.append(v)
-            elif k == 'id':
-                if isinstance(v, type([])):
-                    s = ','.join([a for x in v])
-                    where.append('%s in (%s)'%(k, s))
-                    args = args + v
-                else:
-                    where.append('%s=%s'%(k, a))
-                    args.append(v)
-            elif isinstance(propclass, String):
-                if not isinstance(v, type([])):
-                    v = [v]
-
-                # Quote the bits in the string that need it and then embed
-                # in a "substring" search. Note - need to quote the '%' so
-                # they make it through the python layer happily
-                v = ['%%'+self.db.sql_stringquote(s)+'%%' for s in v]
-
-                # now add to the where clause
-                where.append(' or '.join(["_%s LIKE '%s'"%(k, s) for s in v]))
-                # note: args are embedded in the query string now
-            elif isinstance(propclass, Link):
-                if isinstance(v, type([])):
-                    if '-1' in v:
-                        v = v[:]
-                        v.remove('-1')
-                        xtra = ' or _%s is NULL'%k
-                    else:
-                        xtra = ''
-                    if v:
-                        s = ','.join([a for x in v])
-                        where.append('(_%s in (%s)%s)'%(k, s, xtra))
-                        args = args + v
-                    else:
-                        where.append('_%s is NULL'%k)
-                else:
-                    if v == '-1':
-                        v = None
-                        where.append('_%s is NULL'%k)
-                    else:
-                        where.append('_%s=%s'%(k, a))
-                        args.append(v)
-            elif isinstance(propclass, Date):
-                if isinstance(v, type([])):
-                    s = ','.join([a for x in v])
-                    where.append('_%s in (%s)'%(k, s))
-                    args = args + [date.Date(x).serialise() for x in v]
-                else:
-                    try:
-                        # Try to filter on range of dates
-                        date_rng = Range(v, date.Date, offset=timezone)
-                        if (date_rng.from_value):
-                            where.append('_%s >= %s'%(k, a))                            
-                            args.append(date_rng.from_value.serialise())
-                        if (date_rng.to_value):
-                            where.append('_%s <= %s'%(k, a))
-                            args.append(date_rng.to_value.serialise())
-                    except ValueError:
-                        # If range creation fails - ignore that search parameter
-                        pass                        
-            elif isinstance(propclass, Interval):
-                if isinstance(v, type([])):
-                    s = ','.join([a for x in v])
-                    where.append('_%s in (%s)'%(k, s))
-                    args = args + [date.Interval(x).serialise() for x in v]
-                else:
-                    try:
-                        # Try to filter on range of intervals
-                        date_rng = Range(v, date.Interval)
-                        if (date_rng.from_value):
-                            where.append('_%s >= %s'%(k, a))
-                            args.append(date_rng.from_value.serialise())
-                        if (date_rng.to_value):
-                            where.append('_%s <= %s'%(k, a))
-                            args.append(date_rng.to_value.serialise())
-                    except ValueError:
-                        # If range creation fails - ignore that search parameter
-                        pass                        
-                    #where.append('_%s=%s'%(k, a))
-                    #args.append(date.Interval(v).serialise())
-            else:
-                if isinstance(v, type([])):
-                    s = ','.join([a for x in v])
-                    where.append('_%s in (%s)'%(k, s))
-                    args = args + v
-                else:
-                    where.append('_%s=%s'%(k, a))
-                    args.append(v)
-
-        # don't match retired nodes
-        where.append('__retired__ <> 1')
-
-        # add results of full text search
-        if search_matches is not None:
-            v = search_matches.keys()
-            s = ','.join([a for x in v])
-            where.append('id in (%s)'%s)
-            args = args + v
-
-        # "grouping" is just the first-order sorting in the SQL fetch
-        # can modify it...)
-        orderby = []
-        ordercols = []
-        if group[0] is not None and group[1] is not None:
-            if group[0] != '-':
-                orderby.append('_'+group[1])
-                ordercols.append('_'+group[1])
-            else:
-                orderby.append('_'+group[1]+' desc')
-                ordercols.append('_'+group[1])
-
-        # now add in the sorting
-        group = ''
-        if sort[0] is not None and sort[1] is not None:
-            direction, colname = sort
-            if direction != '-':
-                if colname == 'id':
-                    orderby.append(colname)
-                else:
-                    orderby.append('_'+colname)
-                    ordercols.append('_'+colname)
-            else:
-                if colname == 'id':
-                    orderby.append(colname+' desc')
-                    ordercols.append(colname)
-                else:
-                    orderby.append('_'+colname+' desc')
-                    ordercols.append('_'+colname)
-
-        # construct the SQL
-        frum = ','.join(frum)
-        if where:
-            where = ' where ' + (' and '.join(where))
-        else:
-            where = ''
-        cols = ['id']
-        if orderby:
-            cols = cols + ordercols
-            order = ' order by %s'%(','.join(orderby))
-        else:
-            order = ''
-        cols = ','.join(cols)
-        sql = 'select %s from %s %s%s%s'%(cols, frum, where, group, order)
-        args = tuple(args)
-        if __debug__:
-            print >>hyperdb.DEBUG, 'filter', (self, sql, args)
-        self.db.cursor.execute(sql, args)
-        l = self.db.cursor.fetchall()
-
-        # return the IDs (the first column)
-        # XXX numeric ids
-        return [str(row[0]) for row in l]
 
 class Class(MysqlClass, rdbms_common.Class):
     pass
@@ -635,4 +648,4 @@ class IssueClass(MysqlClass, rdbms_common.IssueClass):
 class FileClass(MysqlClass, rdbms_common.FileClass):
     pass
 
-#vim: set et
+# vim: set et sts=4 sw=4 :