Code

Implemented proper datatypes in mysql and postgresql backends (well,
[roundup.git] / roundup / backends / back_mysql.py
1 #
2 # Copyright (c) 2003 Martynas Sklyzmantas, Andrey Lebedev <andrey@micro.lt>
3 #
4 # This module is free software, and you may redistribute it and/or modify
5 # under the same terms as Python, so long as this copyright message and
6 # disclaimer are retained in their original form.
7 #
9 '''This module defines a backend implementation for MySQL.
12 How to implement AUTO_INCREMENT:
14 mysql> create table foo (num integer auto_increment primary key, name
15 varchar(255)) AUTO_INCREMENT=1 type=InnoDB;
17 ql> insert into foo (name) values ('foo5');
18 Query OK, 1 row affected (0.00 sec)
20 mysql> SELECT num FROM foo WHERE num IS NULL;
21 +-----+
22 | num |
23 +-----+
24 |   4 |
25 +-----+
26 1 row in set (0.00 sec)
28 mysql> SELECT num FROM foo WHERE num IS NULL;
29 Empty set (0.00 sec)
31 NOTE: we don't need an index on the id column if it's PRIMARY KEY
33 '''
34 __docformat__ = 'restructuredtext'
36 from roundup.backends.rdbms_common import *
37 from roundup.backends import rdbms_common
38 import MySQLdb
39 import os, shutil
40 from MySQLdb.constants import ER
43 def db_nuke(config):
44     """Clear all database contents and drop database itself"""
45     if db_exists(config):
46         conn = MySQLdb.connect(config.MYSQL_DBHOST, config.MYSQL_DBUSER,
47             config.MYSQL_DBPASSWORD)
48         try:
49             conn.select_db(config.MYSQL_DBNAME)
50         except:
51             # no, it doesn't exist
52             pass
53         else:
54             cursor = conn.cursor()
55             cursor.execute("SHOW TABLES")
56             tables = cursor.fetchall()
57             for table in tables:
58                 if __debug__:
59                     print >>hyperdb.DEBUG, 'DROP TABLE %s'%table[0]
60                 cursor.execute("DROP TABLE %s"%table[0])
61             if __debug__:
62                 print >>hyperdb.DEBUG, "DROP DATABASE %s"%config.MYSQL_DBNAME
63             cursor.execute("DROP DATABASE %s"%config.MYSQL_DBNAME)
64             conn.commit()
65         conn.close()
67     if os.path.exists(config.DATABASE):
68         shutil.rmtree(config.DATABASE)
70 def db_create(config):
71     """Create the database."""
72     conn = MySQLdb.connect(config.MYSQL_DBHOST, config.MYSQL_DBUSER,
73         config.MYSQL_DBPASSWORD)
74     cursor = conn.cursor()
75     if __debug__:
76         print >>hyperdb.DEBUG, "CREATE DATABASE %s"%config.MYSQL_DBNAME
77     cursor.execute("CREATE DATABASE %s"%config.MYSQL_DBNAME)
78     conn.commit()
79     conn.close()
81 def db_exists(config):
82     """Check if database already exists."""
83     conn = MySQLdb.connect(config.MYSQL_DBHOST, config.MYSQL_DBUSER,
84         config.MYSQL_DBPASSWORD)
85 #    tables = None
86     try:
87         try:
88             conn.select_db(config.MYSQL_DBNAME)
89 #            cursor = conn.cursor()
90 #            cursor.execute("SHOW TABLES")
91 #            tables = cursor.fetchall()
92 #            if __debug__:
93 #                print >>hyperdb.DEBUG, "tables %s"%(tables,)
94         except MySQLdb.OperationalError:
95             if __debug__:
96                 print >>hyperdb.DEBUG, "no database '%s'"%config.MYSQL_DBNAME
97             return 0
98     finally:
99         conn.close()
100     if __debug__:
101         print >>hyperdb.DEBUG, "database '%s' exists"%config.MYSQL_DBNAME
102     return 1
105 class Database(Database):
106     arg = '%s'
108     # Backend for MySQL to use.
109     # InnoDB is faster, but if you're running <4.0.16 then you'll need to
110     # use BDB to pass all unit tests.
111     mysql_backend = 'InnoDB'
112     #mysql_backend = 'BDB'
114     hyperdb_to_sql_value = {
115         hyperdb.String : str,
116         # no fractional seconds for MySQL
117         hyperdb.Date   : lambda x: x.formal(sep=' '),
118         hyperdb.Link   : int,
119         hyperdb.Interval  : lambda x: x.serialise(),
120         hyperdb.Password  : str,
121         hyperdb.Boolean   : int,
122         hyperdb.Number    : lambda x: x,
123     }
125     def sql_open_connection(self):
126         db = getattr(self.config, 'MYSQL_DATABASE')
127         try:
128             conn = MySQLdb.connect(*db)
129         except MySQLdb.OperationalError, message:
130             raise DatabaseError, message
131         cursor = conn.cursor()
132         cursor.execute("SET AUTOCOMMIT=0")
133         cursor.execute("BEGIN")
134         return (conn, cursor)
135     
136     def open_connection(self):
137         # make sure the database actually exists
138         if not db_exists(self.config):
139             db_create(self.config)
141         self.conn, self.cursor = self.sql_open_connection()
143         try:
144             self.load_dbschema()
145         except MySQLdb.OperationalError, message:
146             if message[0] != ER.NO_DB_ERROR:
147                 raise
148         except MySQLdb.ProgrammingError, message:
149             if message[0] != ER.NO_SUCH_TABLE:
150                 raise DatabaseError, message
151             self.init_dbschema()
152             self.sql("CREATE TABLE schema (schema TEXT) TYPE=%s"%
153                 self.mysql_backend)
154             self.cursor.execute('''CREATE TABLE ids (name VARCHAR(255),
155                 num INTEGER) TYPE=%s'''%self.mysql_backend)
156             self.cursor.execute('create index ids_name_idx on ids(name)')
157             self.create_version_2_tables()
159     def create_version_2_tables(self):
160         # OTK store
161         self.cursor.execute('''CREATE TABLE otks (otk_key VARCHAR(255),
162             otk_value VARCHAR(255), otk_time FLOAT(20))
163             TYPE=%s'''%self.mysql_backend)
164         self.cursor.execute('CREATE INDEX otks_key_idx ON otks(otk_key)')
166         # Sessions store
167         self.cursor.execute('''CREATE TABLE sessions (
168             session_key VARCHAR(255), session_time FLOAT(20),
169             session_value VARCHAR(255)) TYPE=%s'''%self.mysql_backend)
170         self.cursor.execute('''CREATE INDEX sessions_key_idx ON
171             sessions(session_key)''')
173         # full-text indexing store
174         self.cursor.execute('''CREATE TABLE __textids (_class VARCHAR(255),
175             _itemid VARCHAR(255), _prop VARCHAR(255), _textid INT)
176             TYPE=%s'''%self.mysql_backend)
177         self.cursor.execute('''CREATE TABLE __words (_word VARCHAR(30),
178             _textid INT) TYPE=%s'''%self.mysql_backend)
179         self.cursor.execute('CREATE INDEX words_word_ids ON __words(_word)')
180         sql = 'insert into ids (name, num) values (%s,%s)'%(self.arg, self.arg)
181         self.cursor.execute(sql, ('__textids', 1))
183     def add_actor_column(self):
184         ''' While we're adding the actor column, we need to update the
185         tables to have the correct datatypes.'''
186         assert 0, 'FINISH ME!'
188         for spec in self.classes.values():
189             new_has = spec.properties.has_key
190             new_spec = spec.schema()
191             new_spec[1].sort()
192             old_spec[1].sort()
193             if not force and new_spec == old_spec:
194                 # no changes
195                 return 0
197             if __debug__:
198                 print >>hyperdb.DEBUG, 'update_class FIRING'
200             # detect multilinks that have been removed, and drop their table
201             old_has = {}
202             for name,prop in old_spec[1]:
203                 old_has[name] = 1
204                 if new_has(name) or not isinstance(prop, hyperdb.Multilink):
205                     continue
206                 # it's a multilink, and it's been removed - drop the old
207                 # table. First drop indexes.
208                 self.drop_multilink_table_indexes(spec.classname, ml)
209                 sql = 'drop table %s_%s'%(spec.classname, prop)
210                 if __debug__:
211                     print >>hyperdb.DEBUG, 'update_class', (self, sql)
212                 self.cursor.execute(sql)
213             old_has = old_has.has_key
215             # now figure how we populate the new table
216             if adding_actor:
217                 fetch = ['_activity', '_creation', '_creator']
218             else:
219                 fetch = ['_actor', '_activity', '_creation', '_creator']
220             properties = spec.getprops()
221             for propname,x in new_spec[1]:
222                 prop = properties[propname]
223                 if isinstance(prop, hyperdb.Multilink):
224                     if force or not old_has(propname):
225                         # we need to create the new table
226                         self.create_multilink_table(spec, propname)
227                 elif old_has(propname):
228                     # we copy this col over from the old table
229                     fetch.append('_'+propname)
231             # select the data out of the old table
232             fetch.append('id')
233             fetch.append('__retired__')
234             fetchcols = ','.join(fetch)
235             cn = spec.classname
236             sql = 'select %s from _%s'%(fetchcols, cn)
237             if __debug__:
238                 print >>hyperdb.DEBUG, 'update_class', (self, sql)
239             self.cursor.execute(sql)
240             olddata = self.cursor.fetchall()
242             # TODO: update all the other index dropping code
243             self.drop_class_table_indexes(cn, old_spec[0])
245             # drop the old table
246             self.cursor.execute('drop table _%s'%cn)
248             # create the new table
249             self.create_class_table(spec)
251             # do the insert of the old data - the new columns will have
252             # NULL values
253             args = ','.join([self.arg for x in fetch])
254             sql = 'insert into _%s (%s) values (%s)'%(cn, fetchcols, args)
255             if __debug__:
256                 print >>hyperdb.DEBUG, 'update_class', (self, sql, olddata[0])
257             for entry in olddata:
258                 self.cursor.execute(sql, tuple(entry))
260         return 1
262     def __repr__(self):
263         return '<myroundsql 0x%x>'%id(self)
265     def sql_fetchone(self):
266         return self.cursor.fetchone()
268     def sql_fetchall(self):
269         return self.cursor.fetchall()
271     def sql_index_exists(self, table_name, index_name):
272         self.cursor.execute('show index from %s'%table_name)
273         for index in self.cursor.fetchall():
274             if index[2] == index_name:
275                 return 1
276         return 0
278     def save_dbschema(self, schema):
279         s = repr(self.database_schema)
280         self.sql('INSERT INTO schema VALUES (%s)', (s,))
281     
282     def create_class_table(self, spec):
283         cols, mls = self.determine_columns(spec.properties.items())
285         # add on our special columns
286         cols.append(('id', 'INTEGER PRIMARY KEY'))
287         cols.append(('__retired__', 'INTEGER DEFAULT 0'))
289         # create the base table
290         scols = ','.join(['%s %s'%x for x in cols])
291         sql = 'create table _%s (%s) type=%s'%(spec.classname, scols,
292             self.mysql_backend)
293         if __debug__:
294             print >>hyperdb.DEBUG, 'create_class', (self, sql)
295         self.cursor.execute(sql)
297         self.create_class_table_indexes(spec)
298         return cols, mls
300     def drop_class_table_indexes(self, cn, key):
301         # drop the old table indexes first
302         l = ['_%s_id_idx'%cn, '_%s_retired_idx'%cn]
303         if key:
304             l.append('_%s_%s_idx'%(cn, key))
306         table_name = '_%s'%cn
307         for index_name in l:
308             if not self.sql_index_exists(table_name, index_name):
309                 continue
310             index_sql = 'drop index %s on %s'%(index_name, table_name)
311             if __debug__:
312                 print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
313             self.cursor.execute(index_sql)
315     def create_journal_table(self, spec):
316         # journal table
317         cols = ','.join(['%s varchar'%x
318             for x in 'nodeid date tag action params'.split()])
319         sql = '''create table %s__journal (
320             nodeid integer, date timestamp, tag varchar(255),
321             action varchar(255), params varchar(255)) type=%s'''%(
322             spec.classname, self.mysql_backend)
323         if __debug__:
324             print >>hyperdb.DEBUG, 'create_journal_table', (self, sql)
325         self.cursor.execute(sql)
326         self.create_journal_table_indexes(spec)
328     def drop_journal_table_indexes(self, classname):
329         index_name = '%s_journ_idx'%classname
330         if not self.sql_index_exists('%s__journal'%classname, index_name):
331             return
332         index_sql = 'drop index %s on %s__journal'%(index_name, classname)
333         if __debug__:
334             print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
335         self.cursor.execute(index_sql)
337     def create_multilink_table(self, spec, ml):
338         sql = '''CREATE TABLE `%s_%s` (linkid VARCHAR(255),
339             nodeid VARCHAR(255)) TYPE=%s'''%(spec.classname, ml,
340                 self.mysql_backend)
341         if __debug__:
342           print >>hyperdb.DEBUG, 'create_class', (self, sql)
343         self.cursor.execute(sql)
344         self.create_multilink_table_indexes(spec, ml)
346     def drop_multilink_table_indexes(self, classname, ml):
347         l = [
348             '%s_%s_l_idx'%(classname, ml),
349             '%s_%s_n_idx'%(classname, ml)
350         ]
351         table_name = '%s_%s'%(classname, ml)
352         for index_name in l:
353             if not self.sql_index_exists(table_name, index_name):
354                 continue
355             index_sql = 'drop index %s on %s'%(index_name, table_name)
356             if __debug__:
357                 print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
358             self.cursor.execute(index_sql)
360     def drop_class_table_key_index(self, cn, key):
361         table_name = '_%s'%cn
362         index_name = '_%s_%s_idx'%(cn, key)
363         if not self.sql_index_exists(table_name, index_name):
364             return
365         sql = 'drop index %s on %s'%(index_name, table_name)
366         if __debug__:
367             print >>hyperdb.DEBUG, 'drop_index', (self, sql)
368         self.cursor.execute(sql)
370     # old-skool id generation
371     def newid(self, classname):
372         ''' Generate a new id for the given class
373         '''
374         # get the next ID
375         sql = 'select num from ids where name=%s'%self.arg
376         if __debug__:
377             print >>hyperdb.DEBUG, 'newid', (self, sql, classname)
378         self.cursor.execute(sql, (classname, ))
379         newid = int(self.cursor.fetchone()[0])
381         # update the counter
382         sql = 'update ids set num=%s where name=%s'%(self.arg, self.arg)
383         vals = (int(newid)+1, classname)
384         if __debug__:
385             print >>hyperdb.DEBUG, 'newid', (self, sql, vals)
386         self.cursor.execute(sql, vals)
388         # return as string
389         return str(newid)
391     def setid(self, classname, setid):
392         ''' Set the id counter: used during import of database
394         We add one to make it behave like the seqeunces in postgres.
395         '''
396         sql = 'update ids set num=%s where name=%s'%(self.arg, self.arg)
397         vals = (int(setid)+1, classname)
398         if __debug__:
399             print >>hyperdb.DEBUG, 'setid', (self, sql, vals)
400         self.cursor.execute(sql, vals)
402     def create_class(self, spec):
403         rdbms_common.Database.create_class(self, spec)
404         sql = 'insert into ids (name, num) values (%s, %s)'
405         vals = (spec.classname, 1)
406         if __debug__:
407             print >>hyperdb.DEBUG, 'create_class', (self, sql, vals)
408         self.cursor.execute(sql, vals)
410 class MysqlClass:
411     # we're overriding this method for ONE missing bit of functionality.
412     # look for "I can't believe it's not a toy RDBMS" below
413     def filter(self, search_matches, filterspec, sort=(None,None),
414             group=(None,None)):
415         '''Return a list of the ids of the active nodes in this class that
416         match the 'filter' spec, sorted by the group spec and then the
417         sort spec
419         "filterspec" is {propname: value(s)}
421         "sort" and "group" are (dir, prop) where dir is '+', '-' or None
422         and prop is a prop name or None
424         "search_matches" is {nodeid: marker}
426         The filter must match all properties specificed - but if the
427         property value to match is a list, any one of the values in the
428         list may match for that property to match.
429         '''
430         # just don't bother if the full-text search matched diddly
431         if search_matches == {}:
432             return []
434         cn = self.classname
436         timezone = self.db.getUserTimezone()
437         
438         # figure the WHERE clause from the filterspec
439         props = self.getprops()
440         frum = ['_'+cn]
441         where = []
442         args = []
443         a = self.db.arg
444         for k, v in filterspec.items():
445             propclass = props[k]
446             # now do other where clause stuff
447             if isinstance(propclass, Multilink):
448                 tn = '%s_%s'%(cn, k)
449                 if v in ('-1', ['-1']):
450                     # only match rows that have count(linkid)=0 in the
451                     # corresponding multilink table)
453                     # "I can't believe it's not a toy RDBMS"
454                     # see, even toy RDBMSes like gadfly and sqlite can do
455                     # sub-selects...
456                     self.db.sql('select nodeid from %s'%tn)
457                     s = ','.join([x[0] for x in self.db.sql_fetchall()])
459                     where.append('id not in (%s)'%s)
460                 elif isinstance(v, type([])):
461                     frum.append(tn)
462                     s = ','.join([a for x in v])
463                     where.append('id=%s.nodeid and %s.linkid in (%s)'%(tn,tn,s))
464                     args = args + v
465                 else:
466                     frum.append(tn)
467                     where.append('id=%s.nodeid and %s.linkid=%s'%(tn, tn, a))
468                     args.append(v)
469             elif k == 'id':
470                 if isinstance(v, type([])):
471                     s = ','.join([a for x in v])
472                     where.append('%s in (%s)'%(k, s))
473                     args = args + v
474                 else:
475                     where.append('%s=%s'%(k, a))
476                     args.append(v)
477             elif isinstance(propclass, String):
478                 if not isinstance(v, type([])):
479                     v = [v]
481                 # Quote the bits in the string that need it and then embed
482                 # in a "substring" search. Note - need to quote the '%' so
483                 # they make it through the python layer happily
484                 v = ['%%'+self.db.sql_stringquote(s)+'%%' for s in v]
486                 # now add to the where clause
487                 where.append(' or '.join(["_%s LIKE '%s'"%(k, s) for s in v]))
488                 # note: args are embedded in the query string now
489             elif isinstance(propclass, Link):
490                 if isinstance(v, type([])):
491                     if '-1' in v:
492                         v = v[:]
493                         v.remove('-1')
494                         xtra = ' or _%s is NULL'%k
495                     else:
496                         xtra = ''
497                     if v:
498                         s = ','.join([a for x in v])
499                         where.append('(_%s in (%s)%s)'%(k, s, xtra))
500                         args = args + v
501                     else:
502                         where.append('_%s is NULL'%k)
503                 else:
504                     if v == '-1':
505                         v = None
506                         where.append('_%s is NULL'%k)
507                     else:
508                         where.append('_%s=%s'%(k, a))
509                         args.append(v)
510             elif isinstance(propclass, Date):
511                 if isinstance(v, type([])):
512                     s = ','.join([a for x in v])
513                     where.append('_%s in (%s)'%(k, s))
514                     args = args + [date.Date(x).serialise() for x in v]
515                 else:
516                     try:
517                         # Try to filter on range of dates
518                         date_rng = Range(v, date.Date, offset=timezone)
519                         if (date_rng.from_value):
520                             where.append('_%s >= %s'%(k, a))                            
521                             args.append(date_rng.from_value.serialise())
522                         if (date_rng.to_value):
523                             where.append('_%s <= %s'%(k, a))
524                             args.append(date_rng.to_value.serialise())
525                     except ValueError:
526                         # If range creation fails - ignore that search parameter
527                         pass                        
528             elif isinstance(propclass, Interval):
529                 if isinstance(v, type([])):
530                     s = ','.join([a for x in v])
531                     where.append('_%s in (%s)'%(k, s))
532                     args = args + [date.Interval(x).serialise() for x in v]
533                 else:
534                     try:
535                         # Try to filter on range of intervals
536                         date_rng = Range(v, date.Interval)
537                         if (date_rng.from_value):
538                             where.append('_%s >= %s'%(k, a))
539                             args.append(date_rng.from_value.serialise())
540                         if (date_rng.to_value):
541                             where.append('_%s <= %s'%(k, a))
542                             args.append(date_rng.to_value.serialise())
543                     except ValueError:
544                         # If range creation fails - ignore that search parameter
545                         pass                        
546                     #where.append('_%s=%s'%(k, a))
547                     #args.append(date.Interval(v).serialise())
548             else:
549                 if isinstance(v, type([])):
550                     s = ','.join([a for x in v])
551                     where.append('_%s in (%s)'%(k, s))
552                     args = args + v
553                 else:
554                     where.append('_%s=%s'%(k, a))
555                     args.append(v)
557         # don't match retired nodes
558         where.append('__retired__ <> 1')
560         # add results of full text search
561         if search_matches is not None:
562             v = search_matches.keys()
563             s = ','.join([a for x in v])
564             where.append('id in (%s)'%s)
565             args = args + v
567         # "grouping" is just the first-order sorting in the SQL fetch
568         # can modify it...)
569         orderby = []
570         ordercols = []
571         if group[0] is not None and group[1] is not None:
572             if group[0] != '-':
573                 orderby.append('_'+group[1])
574                 ordercols.append('_'+group[1])
575             else:
576                 orderby.append('_'+group[1]+' desc')
577                 ordercols.append('_'+group[1])
579         # now add in the sorting
580         group = ''
581         if sort[0] is not None and sort[1] is not None:
582             direction, colname = sort
583             if direction != '-':
584                 if colname == 'id':
585                     orderby.append(colname)
586                 else:
587                     orderby.append('_'+colname)
588                     ordercols.append('_'+colname)
589             else:
590                 if colname == 'id':
591                     orderby.append(colname+' desc')
592                     ordercols.append(colname)
593                 else:
594                     orderby.append('_'+colname+' desc')
595                     ordercols.append('_'+colname)
597         # construct the SQL
598         frum = ','.join(frum)
599         if where:
600             where = ' where ' + (' and '.join(where))
601         else:
602             where = ''
603         cols = ['id']
604         if orderby:
605             cols = cols + ordercols
606             order = ' order by %s'%(','.join(orderby))
607         else:
608             order = ''
609         cols = ','.join(cols)
610         sql = 'select %s from %s %s%s%s'%(cols, frum, where, group, order)
611         args = tuple(args)
612         if __debug__:
613             print >>hyperdb.DEBUG, 'filter', (self, sql, args)
614         self.db.cursor.execute(sql, args)
615         l = self.db.cursor.fetchall()
617         # return the IDs (the first column)
618         # XXX numeric ids
619         return [str(row[0]) for row in l]
621 class Class(MysqlClass, rdbms_common.Class):
622     pass
623 class IssueClass(MysqlClass, rdbms_common.IssueClass):
624     pass
625 class FileClass(MysqlClass, rdbms_common.FileClass):
626     pass
628 #vim: set et