Code

6aa3a633d0b1c288e46d2a6fdf60b213a0c0db78
[roundup.git] / roundup / backends / back_mysql.py
1 #
2 # Copyright (c) 2003 Martynas Sklyzmantas, Andrey Lebedev <andrey@micro.lt>
3 #
4 # This module is free software, and you may redistribute it and/or modify
5 # under the same terms as Python, so long as this copyright message and
6 # disclaimer are retained in their original form.
7 #
9 '''This module defines a backend implementation for MySQL.
12 How to implement AUTO_INCREMENT:
14 mysql> create table foo (num integer auto_increment primary key, name
15 varchar(255)) AUTO_INCREMENT=1 type=InnoDB;
17 ql> insert into foo (name) values ('foo5');
18 Query OK, 1 row affected (0.00 sec)
20 mysql> SELECT num FROM foo WHERE num IS NULL;
21 +-----+
22 | num |
23 +-----+
24 |   4 |
25 +-----+
26 1 row in set (0.00 sec)
28 mysql> SELECT num FROM foo WHERE num IS NULL;
29 Empty set (0.00 sec)
31 NOTE: we don't need an index on the id column if it's PRIMARY KEY
33 '''
34 __docformat__ = 'restructuredtext'
36 from roundup.backends.rdbms_common import *
37 from roundup.backends import rdbms_common
38 import MySQLdb
39 import os, shutil
40 from MySQLdb.constants import ER
43 def db_nuke(config):
44     """Clear all database contents and drop database itself"""
45     if db_exists(config):
46         conn = MySQLdb.connect(config.MYSQL_DBHOST, config.MYSQL_DBUSER,
47             config.MYSQL_DBPASSWORD)
48         try:
49             conn.select_db(config.MYSQL_DBNAME)
50         except:
51             # no, it doesn't exist
52             pass
53         else:
54             cursor = conn.cursor()
55             cursor.execute("SHOW TABLES")
56             tables = cursor.fetchall()
57             for table in tables:
58                 if __debug__:
59                     print >>hyperdb.DEBUG, 'DROP TABLE %s'%table[0]
60                 cursor.execute("DROP TABLE %s"%table[0])
61             if __debug__:
62                 print >>hyperdb.DEBUG, "DROP DATABASE %s"%config.MYSQL_DBNAME
63             cursor.execute("DROP DATABASE %s"%config.MYSQL_DBNAME)
64             conn.commit()
65         conn.close()
67     if os.path.exists(config.DATABASE):
68         shutil.rmtree(config.DATABASE)
70 def db_create(config):
71     """Create the database."""
72     conn = MySQLdb.connect(config.MYSQL_DBHOST, config.MYSQL_DBUSER,
73         config.MYSQL_DBPASSWORD)
74     cursor = conn.cursor()
75     if __debug__:
76         print >>hyperdb.DEBUG, "CREATE DATABASE %s"%config.MYSQL_DBNAME
77     cursor.execute("CREATE DATABASE %s"%config.MYSQL_DBNAME)
78     conn.commit()
79     conn.close()
81 def db_exists(config):
82     """Check if database already exists."""
83     conn = MySQLdb.connect(config.MYSQL_DBHOST, config.MYSQL_DBUSER,
84         config.MYSQL_DBPASSWORD)
85 #    tables = None
86     try:
87         try:
88             conn.select_db(config.MYSQL_DBNAME)
89 #            cursor = conn.cursor()
90 #            cursor.execute("SHOW TABLES")
91 #            tables = cursor.fetchall()
92 #            if __debug__:
93 #                print >>hyperdb.DEBUG, "tables %s"%(tables,)
94         except MySQLdb.OperationalError:
95             if __debug__:
96                 print >>hyperdb.DEBUG, "no database '%s'"%config.MYSQL_DBNAME
97             return 0
98     finally:
99         conn.close()
100     if __debug__:
101         print >>hyperdb.DEBUG, "database '%s' exists"%config.MYSQL_DBNAME
102     return 1
105 class Database(Database):
106     arg = '%s'
108     # Backend for MySQL to use.
109     # InnoDB is faster, but if you're running <4.0.16 then you'll need to
110     # use BDB to pass all unit tests.
111     mysql_backend = 'InnoDB'
112     #mysql_backend = 'BDB'
114     hyperdb_to_sql_datatypes = {
115         hyperdb.String : 'VARCHAR(255)',
116         hyperdb.Date   : 'DATETIME',
117         hyperdb.Link   : 'INTEGER',
118         hyperdb.Interval  : 'VARCHAR(255)',
119         hyperdb.Password  : 'VARCHAR(255)',
120         hyperdb.Boolean   : 'INTEGER',
121         hyperdb.Number    : 'REAL',
122     }
124     hyperdb_to_sql_value = {
125         hyperdb.String : str,
126         # no fractional seconds for MySQL
127         hyperdb.Date   : lambda x: x.formal(sep=' '),
128         hyperdb.Link   : int,
129         hyperdb.Interval  : lambda x: x.serialise(),
130         hyperdb.Password  : str,
131         hyperdb.Boolean   : int,
132         hyperdb.Number    : lambda x: x,
133     }
135     def sql_open_connection(self):
136         db = getattr(self.config, 'MYSQL_DATABASE')
137         try:
138             conn = MySQLdb.connect(*db)
139         except MySQLdb.OperationalError, message:
140             raise DatabaseError, message
141         cursor = conn.cursor()
142         cursor.execute("SET AUTOCOMMIT=0")
143         cursor.execute("BEGIN")
144         return (conn, cursor)
145     
146     def open_connection(self):
147         # make sure the database actually exists
148         if not db_exists(self.config):
149             db_create(self.config)
151         self.conn, self.cursor = self.sql_open_connection()
153         try:
154             self.load_dbschema()
155         except MySQLdb.OperationalError, message:
156             if message[0] != ER.NO_DB_ERROR:
157                 raise
158         except MySQLdb.ProgrammingError, message:
159             if message[0] != ER.NO_SUCH_TABLE:
160                 raise DatabaseError, message
161             self.init_dbschema()
162             self.sql("CREATE TABLE schema (schema TEXT) TYPE=%s"%
163                 self.mysql_backend)
164             self.cursor.execute('''CREATE TABLE ids (name VARCHAR(255),
165                 num INTEGER) TYPE=%s'''%self.mysql_backend)
166             self.cursor.execute('create index ids_name_idx on ids(name)')
167             self.create_version_2_tables()
169     def create_version_2_tables(self):
170         # OTK store
171         self.cursor.execute('''CREATE TABLE otks (otk_key VARCHAR(255),
172             otk_value VARCHAR(255), otk_time FLOAT(20))
173             TYPE=%s'''%self.mysql_backend)
174         self.cursor.execute('CREATE INDEX otks_key_idx ON otks(otk_key)')
176         # Sessions store
177         self.cursor.execute('''CREATE TABLE sessions (
178             session_key VARCHAR(255), session_time FLOAT(20),
179             session_value VARCHAR(255)) TYPE=%s'''%self.mysql_backend)
180         self.cursor.execute('''CREATE INDEX sessions_key_idx ON
181             sessions(session_key)''')
183         # full-text indexing store
184         self.cursor.execute('''CREATE TABLE __textids (_class VARCHAR(255),
185             _itemid VARCHAR(255), _prop VARCHAR(255), _textid INT)
186             TYPE=%s'''%self.mysql_backend)
187         self.cursor.execute('''CREATE TABLE __words (_word VARCHAR(30),
188             _textid INT) TYPE=%s'''%self.mysql_backend)
189         self.cursor.execute('CREATE INDEX words_word_ids ON __words(_word)')
190         sql = 'insert into ids (name, num) values (%s,%s)'%(self.arg, self.arg)
191         self.cursor.execute(sql, ('__textids', 1))
193     def add_actor_column(self):
194         ''' While we're adding the actor column, we need to update the
195         tables to have the correct datatypes.'''
196         assert 0, 'FINISH ME!'
198         for spec in self.classes.values():
199             new_has = spec.properties.has_key
200             new_spec = spec.schema()
201             new_spec[1].sort()
202             old_spec[1].sort()
203             if not force and new_spec == old_spec:
204                 # no changes
205                 return 0
207             if __debug__:
208                 print >>hyperdb.DEBUG, 'update_class FIRING'
210             # detect multilinks that have been removed, and drop their table
211             old_has = {}
212             for name,prop in old_spec[1]:
213                 old_has[name] = 1
214                 if new_has(name) or not isinstance(prop, hyperdb.Multilink):
215                     continue
216                 # it's a multilink, and it's been removed - drop the old
217                 # table. First drop indexes.
218                 self.drop_multilink_table_indexes(spec.classname, ml)
219                 sql = 'drop table %s_%s'%(spec.classname, prop)
220                 if __debug__:
221                     print >>hyperdb.DEBUG, 'update_class', (self, sql)
222                 self.cursor.execute(sql)
223             old_has = old_has.has_key
225             # now figure how we populate the new table
226             if adding_actor:
227                 fetch = ['_activity', '_creation', '_creator']
228             else:
229                 fetch = ['_actor', '_activity', '_creation', '_creator']
230             properties = spec.getprops()
231             for propname,x in new_spec[1]:
232                 prop = properties[propname]
233                 if isinstance(prop, hyperdb.Multilink):
234                     if force or not old_has(propname):
235                         # we need to create the new table
236                         self.create_multilink_table(spec, propname)
237                 elif old_has(propname):
238                     # we copy this col over from the old table
239                     fetch.append('_'+propname)
241             # select the data out of the old table
242             fetch.append('id')
243             fetch.append('__retired__')
244             fetchcols = ','.join(fetch)
245             cn = spec.classname
246             sql = 'select %s from _%s'%(fetchcols, cn)
247             if __debug__:
248                 print >>hyperdb.DEBUG, 'update_class', (self, sql)
249             self.cursor.execute(sql)
250             olddata = self.cursor.fetchall()
252             # TODO: update all the other index dropping code
253             self.drop_class_table_indexes(cn, old_spec[0])
255             # drop the old table
256             self.cursor.execute('drop table _%s'%cn)
258             # create the new table
259             self.create_class_table(spec)
261             # do the insert of the old data - the new columns will have
262             # NULL values
263             args = ','.join([self.arg for x in fetch])
264             sql = 'insert into _%s (%s) values (%s)'%(cn, fetchcols, args)
265             if __debug__:
266                 print >>hyperdb.DEBUG, 'update_class', (self, sql, olddata[0])
267             for entry in olddata:
268                 self.cursor.execute(sql, tuple(entry))
270         return 1
272     def __repr__(self):
273         return '<myroundsql 0x%x>'%id(self)
275     def sql_fetchone(self):
276         return self.cursor.fetchone()
278     def sql_fetchall(self):
279         return self.cursor.fetchall()
281     def sql_index_exists(self, table_name, index_name):
282         self.cursor.execute('show index from %s'%table_name)
283         for index in self.cursor.fetchall():
284             if index[2] == index_name:
285                 return 1
286         return 0
288     def save_dbschema(self, schema):
289         s = repr(self.database_schema)
290         self.sql('INSERT INTO schema VALUES (%s)', (s,))
291     
292     def create_class_table(self, spec):
293         cols, mls = self.determine_columns(spec.properties.items())
295         # add on our special columns
296         cols.append(('id', 'INTEGER PRIMARY KEY'))
297         cols.append(('__retired__', 'INTEGER DEFAULT 0'))
299         # create the base table
300         scols = ','.join(['%s %s'%x for x in cols])
301         sql = 'create table _%s (%s) type=%s'%(spec.classname, scols,
302             self.mysql_backend)
303         if __debug__:
304             print >>hyperdb.DEBUG, 'create_class', (self, sql)
305         self.cursor.execute(sql)
307         self.create_class_table_indexes(spec)
308         return cols, mls
310     def drop_class_table_indexes(self, cn, key):
311         # drop the old table indexes first
312         l = ['_%s_id_idx'%cn, '_%s_retired_idx'%cn]
313         if key:
314             l.append('_%s_%s_idx'%(cn, key))
316         table_name = '_%s'%cn
317         for index_name in l:
318             if not self.sql_index_exists(table_name, index_name):
319                 continue
320             index_sql = 'drop index %s on %s'%(index_name, table_name)
321             if __debug__:
322                 print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
323             self.cursor.execute(index_sql)
325     def create_journal_table(self, spec):
326         # journal table
327         cols = ','.join(['%s varchar'%x
328             for x in 'nodeid date tag action params'.split()])
329         sql = '''create table %s__journal (
330             nodeid integer, date timestamp, tag varchar(255),
331             action varchar(255), params varchar(255)) type=%s'''%(
332             spec.classname, self.mysql_backend)
333         if __debug__:
334             print >>hyperdb.DEBUG, 'create_journal_table', (self, sql)
335         self.cursor.execute(sql)
336         self.create_journal_table_indexes(spec)
338     def drop_journal_table_indexes(self, classname):
339         index_name = '%s_journ_idx'%classname
340         if not self.sql_index_exists('%s__journal'%classname, index_name):
341             return
342         index_sql = 'drop index %s on %s__journal'%(index_name, classname)
343         if __debug__:
344             print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
345         self.cursor.execute(index_sql)
347     def create_multilink_table(self, spec, ml):
348         sql = '''CREATE TABLE `%s_%s` (linkid VARCHAR(255),
349             nodeid VARCHAR(255)) TYPE=%s'''%(spec.classname, ml,
350                 self.mysql_backend)
351         if __debug__:
352           print >>hyperdb.DEBUG, 'create_class', (self, sql)
353         self.cursor.execute(sql)
354         self.create_multilink_table_indexes(spec, ml)
356     def drop_multilink_table_indexes(self, classname, ml):
357         l = [
358             '%s_%s_l_idx'%(classname, ml),
359             '%s_%s_n_idx'%(classname, ml)
360         ]
361         table_name = '%s_%s'%(classname, ml)
362         for index_name in l:
363             if not self.sql_index_exists(table_name, index_name):
364                 continue
365             index_sql = 'drop index %s on %s'%(index_name, table_name)
366             if __debug__:
367                 print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
368             self.cursor.execute(index_sql)
370     def drop_class_table_key_index(self, cn, key):
371         table_name = '_%s'%cn
372         index_name = '_%s_%s_idx'%(cn, key)
373         if not self.sql_index_exists(table_name, index_name):
374             return
375         sql = 'drop index %s on %s'%(index_name, table_name)
376         if __debug__:
377             print >>hyperdb.DEBUG, 'drop_index', (self, sql)
378         self.cursor.execute(sql)
380     # old-skool id generation
381     def newid(self, classname):
382         ''' Generate a new id for the given class
383         '''
384         # get the next ID
385         sql = 'select num from ids where name=%s'%self.arg
386         if __debug__:
387             print >>hyperdb.DEBUG, 'newid', (self, sql, classname)
388         self.cursor.execute(sql, (classname, ))
389         newid = int(self.cursor.fetchone()[0])
391         # update the counter
392         sql = 'update ids set num=%s where name=%s'%(self.arg, self.arg)
393         vals = (int(newid)+1, classname)
394         if __debug__:
395             print >>hyperdb.DEBUG, 'newid', (self, sql, vals)
396         self.cursor.execute(sql, vals)
398         # return as string
399         return str(newid)
401     def setid(self, classname, setid):
402         ''' Set the id counter: used during import of database
404         We add one to make it behave like the seqeunces in postgres.
405         '''
406         sql = 'update ids set num=%s where name=%s'%(self.arg, self.arg)
407         vals = (int(setid)+1, classname)
408         if __debug__:
409             print >>hyperdb.DEBUG, 'setid', (self, sql, vals)
410         self.cursor.execute(sql, vals)
412     def create_class(self, spec):
413         rdbms_common.Database.create_class(self, spec)
414         sql = 'insert into ids (name, num) values (%s, %s)'
415         vals = (spec.classname, 1)
416         if __debug__:
417             print >>hyperdb.DEBUG, 'create_class', (self, sql, vals)
418         self.cursor.execute(sql, vals)
420 class MysqlClass:
421     # we're overriding this method for ONE missing bit of functionality.
422     # look for "I can't believe it's not a toy RDBMS" below
423     def filter(self, search_matches, filterspec, sort=(None,None),
424             group=(None,None)):
425         '''Return a list of the ids of the active nodes in this class that
426         match the 'filter' spec, sorted by the group spec and then the
427         sort spec
429         "filterspec" is {propname: value(s)}
431         "sort" and "group" are (dir, prop) where dir is '+', '-' or None
432         and prop is a prop name or None
434         "search_matches" is {nodeid: marker}
436         The filter must match all properties specificed - but if the
437         property value to match is a list, any one of the values in the
438         list may match for that property to match.
439         '''
440         # just don't bother if the full-text search matched diddly
441         if search_matches == {}:
442             return []
444         cn = self.classname
446         timezone = self.db.getUserTimezone()
447         
448         # figure the WHERE clause from the filterspec
449         props = self.getprops()
450         frum = ['_'+cn]
451         where = []
452         args = []
453         a = self.db.arg
454         for k, v in filterspec.items():
455             propclass = props[k]
456             # now do other where clause stuff
457             if isinstance(propclass, Multilink):
458                 tn = '%s_%s'%(cn, k)
459                 if v in ('-1', ['-1']):
460                     # only match rows that have count(linkid)=0 in the
461                     # corresponding multilink table)
463                     # "I can't believe it's not a toy RDBMS"
464                     # see, even toy RDBMSes like gadfly and sqlite can do
465                     # sub-selects...
466                     self.db.sql('select nodeid from %s'%tn)
467                     s = ','.join([x[0] for x in self.db.sql_fetchall()])
469                     where.append('id not in (%s)'%s)
470                 elif isinstance(v, type([])):
471                     frum.append(tn)
472                     s = ','.join([a for x in v])
473                     where.append('id=%s.nodeid and %s.linkid in (%s)'%(tn,tn,s))
474                     args = args + v
475                 else:
476                     frum.append(tn)
477                     where.append('id=%s.nodeid and %s.linkid=%s'%(tn, tn, a))
478                     args.append(v)
479             elif k == 'id':
480                 if isinstance(v, type([])):
481                     s = ','.join([a for x in v])
482                     where.append('%s in (%s)'%(k, s))
483                     args = args + v
484                 else:
485                     where.append('%s=%s'%(k, a))
486                     args.append(v)
487             elif isinstance(propclass, String):
488                 if not isinstance(v, type([])):
489                     v = [v]
491                 # Quote the bits in the string that need it and then embed
492                 # in a "substring" search. Note - need to quote the '%' so
493                 # they make it through the python layer happily
494                 v = ['%%'+self.db.sql_stringquote(s)+'%%' for s in v]
496                 # now add to the where clause
497                 where.append(' or '.join(["_%s LIKE '%s'"%(k, s) for s in v]))
498                 # note: args are embedded in the query string now
499             elif isinstance(propclass, Link):
500                 if isinstance(v, type([])):
501                     if '-1' in v:
502                         v = v[:]
503                         v.remove('-1')
504                         xtra = ' or _%s is NULL'%k
505                     else:
506                         xtra = ''
507                     if v:
508                         s = ','.join([a for x in v])
509                         where.append('(_%s in (%s)%s)'%(k, s, xtra))
510                         args = args + v
511                     else:
512                         where.append('_%s is NULL'%k)
513                 else:
514                     if v == '-1':
515                         v = None
516                         where.append('_%s is NULL'%k)
517                     else:
518                         where.append('_%s=%s'%(k, a))
519                         args.append(v)
520             elif isinstance(propclass, Date):
521                 if isinstance(v, type([])):
522                     s = ','.join([a for x in v])
523                     where.append('_%s in (%s)'%(k, s))
524                     args = args + [date.Date(x).serialise() for x in v]
525                 else:
526                     try:
527                         # Try to filter on range of dates
528                         date_rng = Range(v, date.Date, offset=timezone)
529                         if (date_rng.from_value):
530                             where.append('_%s >= %s'%(k, a))                            
531                             args.append(date_rng.from_value.serialise())
532                         if (date_rng.to_value):
533                             where.append('_%s <= %s'%(k, a))
534                             args.append(date_rng.to_value.serialise())
535                     except ValueError:
536                         # If range creation fails - ignore that search parameter
537                         pass                        
538             elif isinstance(propclass, Interval):
539                 if isinstance(v, type([])):
540                     s = ','.join([a for x in v])
541                     where.append('_%s in (%s)'%(k, s))
542                     args = args + [date.Interval(x).serialise() for x in v]
543                 else:
544                     try:
545                         # Try to filter on range of intervals
546                         date_rng = Range(v, date.Interval)
547                         if (date_rng.from_value):
548                             where.append('_%s >= %s'%(k, a))
549                             args.append(date_rng.from_value.serialise())
550                         if (date_rng.to_value):
551                             where.append('_%s <= %s'%(k, a))
552                             args.append(date_rng.to_value.serialise())
553                     except ValueError:
554                         # If range creation fails - ignore that search parameter
555                         pass                        
556                     #where.append('_%s=%s'%(k, a))
557                     #args.append(date.Interval(v).serialise())
558             else:
559                 if isinstance(v, type([])):
560                     s = ','.join([a for x in v])
561                     where.append('_%s in (%s)'%(k, s))
562                     args = args + v
563                 else:
564                     where.append('_%s=%s'%(k, a))
565                     args.append(v)
567         # don't match retired nodes
568         where.append('__retired__ <> 1')
570         # add results of full text search
571         if search_matches is not None:
572             v = search_matches.keys()
573             s = ','.join([a for x in v])
574             where.append('id in (%s)'%s)
575             args = args + v
577         # "grouping" is just the first-order sorting in the SQL fetch
578         # can modify it...)
579         orderby = []
580         ordercols = []
581         if group[0] is not None and group[1] is not None:
582             if group[0] != '-':
583                 orderby.append('_'+group[1])
584                 ordercols.append('_'+group[1])
585             else:
586                 orderby.append('_'+group[1]+' desc')
587                 ordercols.append('_'+group[1])
589         # now add in the sorting
590         group = ''
591         if sort[0] is not None and sort[1] is not None:
592             direction, colname = sort
593             if direction != '-':
594                 if colname == 'id':
595                     orderby.append(colname)
596                 else:
597                     orderby.append('_'+colname)
598                     ordercols.append('_'+colname)
599             else:
600                 if colname == 'id':
601                     orderby.append(colname+' desc')
602                     ordercols.append(colname)
603                 else:
604                     orderby.append('_'+colname+' desc')
605                     ordercols.append('_'+colname)
607         # construct the SQL
608         frum = ','.join(frum)
609         if where:
610             where = ' where ' + (' and '.join(where))
611         else:
612             where = ''
613         cols = ['id']
614         if orderby:
615             cols = cols + ordercols
616             order = ' order by %s'%(','.join(orderby))
617         else:
618             order = ''
619         cols = ','.join(cols)
620         sql = 'select %s from %s %s%s%s'%(cols, frum, where, group, order)
621         args = tuple(args)
622         if __debug__:
623             print >>hyperdb.DEBUG, 'filter', (self, sql, args)
624         self.db.cursor.execute(sql, args)
625         l = self.db.cursor.fetchall()
627         # return the IDs (the first column)
628         # XXX numeric ids
629         return [str(row[0]) for row in l]
631 class Class(MysqlClass, rdbms_common.Class):
632     pass
633 class IssueClass(MysqlClass, rdbms_common.IssueClass):
634     pass
635 class FileClass(MysqlClass, rdbms_common.FileClass):
636     pass
638 #vim: set et