Code

make mysql / postgresql work again. beginnings of otk/session store in rdbmses
[roundup.git] / roundup / backends / back_mysql.py
1 #
2 # Copyright (c) 2003 Martynas Sklyzmantas, Andrey Lebedev <andrey@micro.lt>
3 #
4 # This module is free software, and you may redistribute it and/or modify
5 # under the same terms as Python, so long as this copyright message and
6 # disclaimer are retained in their original form.
7 #
9 '''This module defines a backend implementation for MySQL.'''
10 __docformat__ = 'restructuredtext'
12 from roundup.backends.rdbms_common import *
13 from roundup.backends import rdbms_common
14 import MySQLdb
15 import os, shutil
16 from MySQLdb.constants import ER
19 def db_nuke(config):
20     """Clear all database contents and drop database itself"""
21     if db_exists(config):
22         conn = MySQLdb.connect(config.MYSQL_DBHOST, config.MYSQL_DBUSER,
23             config.MYSQL_DBPASSWORD)
24         try:
25             conn.select_db(config.MYSQL_DBNAME)
26         except:
27             # no, it doesn't exist
28             pass
29         else:
30             cursor = conn.cursor()
31             cursor.execute("SHOW TABLES")
32             tables = cursor.fetchall()
33             for table in tables:
34                 if __debug__:
35                     print >>hyperdb.DEBUG, 'DROP TABLE %s'%table[0]
36                 cursor.execute("DROP TABLE %s"%table[0])
37             if __debug__:
38                 print >>hyperdb.DEBUG, "DROP DATABASE %s"%config.MYSQL_DBNAME
39             cursor.execute("DROP DATABASE %s"%config.MYSQL_DBNAME)
40             conn.commit()
41         conn.close()
43     if os.path.exists(config.DATABASE):
44         shutil.rmtree(config.DATABASE)
46 def db_create(config):
47     """Create the database."""
48     conn = MySQLdb.connect(config.MYSQL_DBHOST, config.MYSQL_DBUSER,
49         config.MYSQL_DBPASSWORD)
50     cursor = conn.cursor()
51     if __debug__:
52         print >>hyperdb.DEBUG, "CREATE DATABASE %s"%config.MYSQL_DBNAME
53     cursor.execute("CREATE DATABASE %s"%config.MYSQL_DBNAME)
54     conn.commit()
55     conn.close()
57 def db_exists(config):
58     """Check if database already exists."""
59     conn = MySQLdb.connect(config.MYSQL_DBHOST, config.MYSQL_DBUSER,
60         config.MYSQL_DBPASSWORD)
61 #    tables = None
62     try:
63         try:
64             conn.select_db(config.MYSQL_DBNAME)
65 #            cursor = conn.cursor()
66 #            cursor.execute("SHOW TABLES")
67 #            tables = cursor.fetchall()
68 #            if __debug__:
69 #                print >>hyperdb.DEBUG, "tables %s"%(tables,)
70         except MySQLdb.OperationalError:
71             if __debug__:
72                 print >>hyperdb.DEBUG, "no database '%s'"%config.MYSQL_DBNAME
73             return 0
74     finally:
75         conn.close()
76     if __debug__:
77         print >>hyperdb.DEBUG, "database '%s' exists"%config.MYSQL_DBNAME
78     return 1
81 class Database(Database):
82     arg = '%s'
84     # Backend for MySQL to use.
85     # InnoDB is faster, but if you're running <4.0.16 then you'll need to
86     # use BDB to pass all unit tests.
87     mysql_backend = 'InnoDB'
88     #mysql_backend = 'BDB'
89     
90     def sql_open_connection(self):
91         # make sure the database actually exists
92         if not db_exists(self.config):
93             db_create(self.config)
95         db = getattr(self.config, 'MYSQL_DATABASE')
96         try:
97             self.conn = MySQLdb.connect(*db)
98         except MySQLdb.OperationalError, message:
99             raise DatabaseError, message
101         self.cursor = self.conn.cursor()
102         # start transaction
103         self.sql("SET AUTOCOMMIT=0")
104         self.sql("BEGIN")
105         try:
106             self.load_dbschema()
107         except MySQLdb.OperationalError, message:
108             if message[0] != ER.NO_DB_ERROR:
109                 raise
110         except MySQLdb.ProgrammingError, message:
111             if message[0] != ER.NO_SUCH_TABLE:
112                 raise DatabaseError, message
113             self.init_dbschema()
114             self.sql("CREATE TABLE schema (schema TEXT) TYPE=%s"%
115                 self.mysql_backend)
116             # TODO: use AUTO_INCREMENT for generating ids:
117             #       http://www.mysql.com/doc/en/CREATE_TABLE.html
118             self.sql("CREATE TABLE ids (name varchar(255), num INT) TYPE=%s"%
119                 self.mysql_backend)
120             self.sql("CREATE INDEX ids_name_idx ON ids(name)")
121             self.create_version_2_tables()
123     def create_version_2_tables(self):
124         self.cursor.execute('CREATE TABLE otks (otk_key VARCHAR(255), '
125             'otk_value VARCHAR(255), otk_time FLOAT(20))')
126         self.cursor.execute('CREATE INDEX otks_key_idx ON otks(otk_key)')
127         self.cursor.execute('CREATE TABLE sessions (s_key VARCHAR(255), '
128             's_last_use FLOAT(20), s_user VARCHAR(255))')
129         self.cursor.execute('CREATE INDEX sessions_key_idx ON sessions(s_key)')
131     def __repr__(self):
132         return '<myroundsql 0x%x>'%id(self)
134     def sql_fetchone(self):
135         return self.cursor.fetchone()
137     def sql_fetchall(self):
138         return self.cursor.fetchall()
140     def sql_index_exists(self, table_name, index_name):
141         self.cursor.execute('show index from %s'%table_name)
142         for index in self.cursor.fetchall():
143             if index[2] == index_name:
144                 return 1
145         return 0
147     def save_dbschema(self, schema):
148         s = repr(self.database_schema)
149         self.sql('INSERT INTO schema VALUES (%s)', (s,))
150     
151     def save_journal(self, classname, cols, nodeid, journaldate,
152                 journaltag, action, params):
153         params = repr(params)
154         entry = (nodeid, journaldate, journaltag, action, params)
156         a = self.arg
157         sql = 'insert into %s__journal (%s) values (%s,%s,%s,%s,%s)'%(classname,
158                 cols, a, a, a, a, a)
159         if __debug__:
160           print >>hyperdb.DEBUG, 'addjournal', (self, sql, entry)
161         self.cursor.execute(sql, entry)
163     def load_journal(self, classname, cols, nodeid):
164         sql = 'select %s from %s__journal where nodeid=%s'%(cols, classname,
165                 self.arg)
166         if __debug__:
167             print >>hyperdb.DEBUG, 'getjournal', (self, sql, nodeid)
168         self.cursor.execute(sql, (nodeid,))
169         res = []
170         for nodeid, date_stamp, user, action, params in self.cursor.fetchall():
171           params = eval(params)
172           res.append((nodeid, date.Date(date_stamp), user, action, params))
173         return res
175     def create_class_table(self, spec):
176         cols, mls = self.determine_columns(spec.properties.items())
177         cols.append('id')
178         cols.append('__retired__')
179         scols = ',' . join(['`%s` VARCHAR(255)'%x for x in cols])
180         sql = 'CREATE TABLE `_%s` (%s) TYPE=%s'%(spec.classname, scols,
181             self.mysql_backend)
182         if __debug__:
183           print >>hyperdb.DEBUG, 'create_class', (self, sql)
184         self.cursor.execute(sql)
185         self.create_class_table_indexes(spec)
186         return cols, mls
188     def drop_class_table_indexes(self, cn, key):
189         # drop the old table indexes first
190         l = ['_%s_id_idx'%cn, '_%s_retired_idx'%cn]
191         if key:
192             l.append('_%s_%s_idx'%(cn, key))
194         table_name = '_%s'%cn
195         for index_name in l:
196             if not self.sql_index_exists(table_name, index_name):
197                 continue
198             index_sql = 'drop index %s on %s'%(index_name, table_name)
199             if __debug__:
200                 print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
201             self.cursor.execute(index_sql)
203     def create_journal_table(self, spec):
204         cols = ',' . join(['`%s` VARCHAR(255)'%x
205           for x in 'nodeid date tag action params' . split()])
206         sql  = 'CREATE TABLE `%s__journal` (%s) TYPE=%s'%(spec.classname,
207             cols, self.mysql_backend)
208         if __debug__:
209             print >>hyperdb.DEBUG, 'create_class', (self, sql)
210         self.cursor.execute(sql)
211         self.create_journal_table_indexes(spec)
213     def drop_journal_table_indexes(self, classname):
214         index_name = '%s_journ_idx'%classname
215         if not self.sql_index_exists('%s__journal'%classname, index_name):
216             return
217         index_sql = 'drop index %s on %s__journal'%(index_name, classname)
218         if __debug__:
219             print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
220         self.cursor.execute(index_sql)
222     def create_multilink_table(self, spec, ml):
223         sql = '''CREATE TABLE `%s_%s` (linkid VARCHAR(255),
224             nodeid VARCHAR(255)) TYPE=%s'''%(spec.classname, ml,
225                 self.mysql_backend)
226         if __debug__:
227           print >>hyperdb.DEBUG, 'create_class', (self, sql)
228         self.cursor.execute(sql)
229         self.create_multilink_table_indexes(spec, ml)
231     def drop_multilink_table_indexes(self, classname, ml):
232         l = [
233             '%s_%s_l_idx'%(classname, ml),
234             '%s_%s_n_idx'%(classname, ml)
235         ]
236         for index_name in l:
237             if not self.sql_index_exists(table_name, index_name):
238                 continue
239             index_sql = 'drop index %s on %s'%(index_name, table_name)
240             if __debug__:
241                 print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
242             self.cursor.execute(index_sql)
244 class MysqlClass:
245     # we're overriding this method for ONE missing bit of functionality.
246     # look for "I can't believe it's not a toy RDBMS" below
247     def filter(self, search_matches, filterspec, sort=(None,None),
248             group=(None,None)):
249         '''Return a list of the ids of the active nodes in this class that
250         match the 'filter' spec, sorted by the group spec and then the
251         sort spec
253         "filterspec" is {propname: value(s)}
255         "sort" and "group" are (dir, prop) where dir is '+', '-' or None
256         and prop is a prop name or None
258         "search_matches" is {nodeid: marker}
260         The filter must match all properties specificed - but if the
261         property value to match is a list, any one of the values in the
262         list may match for that property to match.
263         '''
264         # just don't bother if the full-text search matched diddly
265         if search_matches == {}:
266             return []
268         cn = self.classname
270         timezone = self.db.getUserTimezone()
271         
272         # figure the WHERE clause from the filterspec
273         props = self.getprops()
274         frum = ['_'+cn]
275         where = []
276         args = []
277         a = self.db.arg
278         for k, v in filterspec.items():
279             propclass = props[k]
280             # now do other where clause stuff
281             if isinstance(propclass, Multilink):
282                 tn = '%s_%s'%(cn, k)
283                 if v in ('-1', ['-1']):
284                     # only match rows that have count(linkid)=0 in the
285                     # corresponding multilink table)
287                     # "I can't believe it's not a toy RDBMS"
288                     # see, even toy RDBMSes like gadfly and sqlite can do
289                     # sub-selects...
290                     self.db.sql('select nodeid from %s'%tn)
291                     s = ','.join([x[0] for x in self.db.sql_fetchall()])
293                     where.append('id not in (%s)'%s)
294                 elif isinstance(v, type([])):
295                     frum.append(tn)
296                     s = ','.join([a for x in v])
297                     where.append('id=%s.nodeid and %s.linkid in (%s)'%(tn,tn,s))
298                     args = args + v
299                 else:
300                     frum.append(tn)
301                     where.append('id=%s.nodeid and %s.linkid=%s'%(tn, tn, a))
302                     args.append(v)
303             elif k == 'id':
304                 if isinstance(v, type([])):
305                     s = ','.join([a for x in v])
306                     where.append('%s in (%s)'%(k, s))
307                     args = args + v
308                 else:
309                     where.append('%s=%s'%(k, a))
310                     args.append(v)
311             elif isinstance(propclass, String):
312                 if not isinstance(v, type([])):
313                     v = [v]
315                 # Quote the bits in the string that need it and then embed
316                 # in a "substring" search. Note - need to quote the '%' so
317                 # they make it through the python layer happily
318                 v = ['%%'+self.db.sql_stringquote(s)+'%%' for s in v]
320                 # now add to the where clause
321                 where.append(' or '.join(["_%s LIKE '%s'"%(k, s) for s in v]))
322                 # note: args are embedded in the query string now
323             elif isinstance(propclass, Link):
324                 if isinstance(v, type([])):
325                     if '-1' in v:
326                         v = v[:]
327                         v.remove('-1')
328                         xtra = ' or _%s is NULL'%k
329                     else:
330                         xtra = ''
331                     if v:
332                         s = ','.join([a for x in v])
333                         where.append('(_%s in (%s)%s)'%(k, s, xtra))
334                         args = args + v
335                     else:
336                         where.append('_%s is NULL'%k)
337                 else:
338                     if v == '-1':
339                         v = None
340                         where.append('_%s is NULL'%k)
341                     else:
342                         where.append('_%s=%s'%(k, a))
343                         args.append(v)
344             elif isinstance(propclass, Date):
345                 if isinstance(v, type([])):
346                     s = ','.join([a for x in v])
347                     where.append('_%s in (%s)'%(k, s))
348                     args = args + [date.Date(x).serialise() for x in v]
349                 else:
350                     try:
351                         # Try to filter on range of dates
352                         date_rng = Range(v, date.Date, offset=timezone)
353                         if (date_rng.from_value):
354                             where.append('_%s >= %s'%(k, a))                            
355                             args.append(date_rng.from_value.serialise())
356                         if (date_rng.to_value):
357                             where.append('_%s <= %s'%(k, a))
358                             args.append(date_rng.to_value.serialise())
359                     except ValueError:
360                         # If range creation fails - ignore that search parameter
361                         pass                        
362             elif isinstance(propclass, Interval):
363                 if isinstance(v, type([])):
364                     s = ','.join([a for x in v])
365                     where.append('_%s in (%s)'%(k, s))
366                     args = args + [date.Interval(x).serialise() for x in v]
367                 else:
368                     try:
369                         # Try to filter on range of intervals
370                         date_rng = Range(v, date.Interval)
371                         if (date_rng.from_value):
372                             where.append('_%s >= %s'%(k, a))
373                             args.append(date_rng.from_value.serialise())
374                         if (date_rng.to_value):
375                             where.append('_%s <= %s'%(k, a))
376                             args.append(date_rng.to_value.serialise())
377                     except ValueError:
378                         # If range creation fails - ignore that search parameter
379                         pass                        
380                     #where.append('_%s=%s'%(k, a))
381                     #args.append(date.Interval(v).serialise())
382             else:
383                 if isinstance(v, type([])):
384                     s = ','.join([a for x in v])
385                     where.append('_%s in (%s)'%(k, s))
386                     args = args + v
387                 else:
388                     where.append('_%s=%s'%(k, a))
389                     args.append(v)
391         # don't match retired nodes
392         where.append('__retired__ <> 1')
394         # add results of full text search
395         if search_matches is not None:
396             v = search_matches.keys()
397             s = ','.join([a for x in v])
398             where.append('id in (%s)'%s)
399             args = args + v
401         # "grouping" is just the first-order sorting in the SQL fetch
402         # can modify it...)
403         orderby = []
404         ordercols = []
405         if group[0] is not None and group[1] is not None:
406             if group[0] != '-':
407                 orderby.append('_'+group[1])
408                 ordercols.append('_'+group[1])
409             else:
410                 orderby.append('_'+group[1]+' desc')
411                 ordercols.append('_'+group[1])
413         # now add in the sorting
414         group = ''
415         if sort[0] is not None and sort[1] is not None:
416             direction, colname = sort
417             if direction != '-':
418                 if colname == 'id':
419                     orderby.append(colname)
420                 else:
421                     orderby.append('_'+colname)
422                     ordercols.append('_'+colname)
423             else:
424                 if colname == 'id':
425                     orderby.append(colname+' desc')
426                     ordercols.append(colname)
427                 else:
428                     orderby.append('_'+colname+' desc')
429                     ordercols.append('_'+colname)
431         # construct the SQL
432         frum = ','.join(frum)
433         if where:
434             where = ' where ' + (' and '.join(where))
435         else:
436             where = ''
437         cols = ['id']
438         if orderby:
439             cols = cols + ordercols
440             order = ' order by %s'%(','.join(orderby))
441         else:
442             order = ''
443         cols = ','.join(cols)
444         sql = 'select %s from %s %s%s%s'%(cols, frum, where, group, order)
445         args = tuple(args)
446         if __debug__:
447             print >>hyperdb.DEBUG, 'filter', (self, sql, args)
448         self.db.cursor.execute(sql, args)
449         l = self.db.cursor.fetchall()
451         # return the IDs (the first column)
452         return [row[0] for row in l]
454 class Class(MysqlClass, rdbms_common.Class):
455     pass
456 class IssueClass(MysqlClass, rdbms_common.IssueClass):
457     pass
458 class FileClass(MysqlClass, rdbms_common.FileClass):
459     pass
461 #vim: set et