Code

1bf33f6a726eba1d1a4752016b775f31cbe9c14a
[roundup.git] / roundup / backends / back_mysql.py
1 #
2 # Copyright (c) 2003 Martynas Sklyzmantas, Andrey Lebedev <andrey@micro.lt>
3 #
4 # This module is free software, and you may redistribute it and/or modify
5 # under the same terms as Python, so long as this copyright message and
6 # disclaimer are retained in their original form.
7 #
8 # Mysql backend for roundup
9 #
11 from roundup.backends.rdbms_common import *
12 from roundup.backends import rdbms_common
13 import MySQLdb
14 import os, shutil
15 from MySQLdb.constants import ER
17 # Database maintenance functions
18 def db_nuke(config):
19     """Clear all database contents and drop database itself"""
20     db = Database(config, 'admin')
21     try:
22         db.sql_commit()
23         db.sql("DROP DATABASE %s" % config.MYSQL_DBNAME)
24         db.sql("CREATE DATABASE %s" % config.MYSQL_DBNAME)
25     finally:
26         db.close()
27     if os.path.exists(config.DATABASE):
28         shutil.rmtree(config.DATABASE)
30 def db_exists(config):
31     """Check if database already exists"""
32     # Yes, this is a hack, but we must must open connection without
33     # selecting a database to prevent creation of some tables
34     config.MYSQL_DATABASE = (config.MYSQL_DBHOST, config.MYSQL_DBUSER,
35         config.MYSQL_DBPASSWORD)        
36     db = Database(config, 'admin')
37     try:
38         db.conn.select_db(config.MYSQL_DBNAME)
39         config.MYSQL_DATABASE = (config.MYSQL_DBHOST, config.MYSQL_DBUSER,
40             config.MYSQL_DBPASSWORD, config.MYSQL_DBNAME)
41         db.sql("SHOW TABLES")
42         tables = db.sql_fetchall()
43     finally:
44         db.close()
45     if tables or os.path.exists(config.DATABASE):
46         return 1
47     return 0        
49 class Database(Database):
50     arg = '%s'
52     # Backend for MySQL to use.
53     # InnoDB is faster, but has a bug in its rollback machinery that causes
54     # some selects in subsequent transactions to fail. BDB does not have
55     # this bug, but is apparently much slower.
56     #mysql_backend = 'InnoDB'
57     mysql_backend = 'BDB'
58     
59     def sql_open_connection(self):
60         db = getattr(self.config, 'MYSQL_DATABASE')
61         try:
62             self.conn = MySQLdb.connect(*db)
63         except MySQLdb.OperationalError, message:
64             raise DatabaseError, message
66         self.cursor = self.conn.cursor()
67         # start transaction
68         self.sql("SET AUTOCOMMIT=0")
69         self.sql("BEGIN")
70         try:
71             self.database_schema = self.load_dbschema()
72         except MySQLdb.OperationalError, message:
73             if message[0] != ER.NO_DB_ERROR:
74                 raise
75         except MySQLdb.ProgrammingError, message:
76             if message[0] != ER.NO_SUCH_TABLE:
77                 raise DatabaseError, message
78             self.database_schema = {}
79             self.sql("CREATE TABLE schema (schema TEXT) TYPE=%s"%
80                 self.mysql_backend)
81             # TODO: use AUTO_INCREMENT for generating ids:
82             #       http://www.mysql.com/doc/en/CREATE_TABLE.html
83             self.sql("CREATE TABLE ids (name varchar(255), num INT) TYPE=%s"%
84                 self.mysql_backend)
85             self.sql("CREATE INDEX ids_name_idx on ids(name)")
87     def __repr__(self):
88         return '<myroundsql 0x%x>'%id(self)
90     def sql_fetchone(self):
91         return self.cursor.fetchone()
93     def sql_fetchall(self):
94         return self.cursor.fetchall()
96     def sql_index_exists(self, table_name, index_name):
97         self.cursor.execute('show index from %s'%table_name)
98         for index in self.cursor.fetchall():
99             if index[2] == index_name:
100                 return 1
101         return 0
103     def save_dbschema(self, schema):
104         s = repr(self.database_schema)
105         self.sql('INSERT INTO schema VALUES (%s)', (s,))
106     
107     def load_dbschema(self):
108         self.cursor.execute('SELECT schema FROM schema')
109         schema = self.cursor.fetchone()
110         if schema:
111             return eval(schema[0])
112         return None
114     def save_journal(self, classname, cols, nodeid, journaldate,
115                 journaltag, action, params):
116         params = repr(params)
117         entry = (nodeid, journaldate, journaltag, action, params)
119         a = self.arg
120         sql = 'insert into %s__journal (%s) values (%s,%s,%s,%s,%s)'%(classname,
121                 cols, a, a, a, a, a)
122         if __debug__:
123           print >>hyperdb.DEBUG, 'addjournal', (self, sql, entry)
124         self.cursor.execute(sql, entry)
126     def load_journal(self, classname, cols, nodeid):
127         sql = 'select %s from %s__journal where nodeid=%s'%(cols, classname,
128                 self.arg)
129         if __debug__:
130             print >>hyperdb.DEBUG, 'getjournal', (self, sql, nodeid)
131         self.cursor.execute(sql, (nodeid,))
132         res = []
133         for nodeid, date_stamp, user, action, params in self.cursor.fetchall():
134           params = eval(params)
135           res.append((nodeid, date.Date(date_stamp), user, action, params))
136         return res
138     def create_class_table(self, spec):
139         cols, mls = self.determine_columns(spec.properties.items())
140         cols.append('id')
141         cols.append('__retired__')
142         scols = ',' . join(['`%s` VARCHAR(255)'%x for x in cols])
143         sql = 'CREATE TABLE `_%s` (%s) TYPE=%s'%(spec.classname, scols,
144             self.mysql_backend)
145         if __debug__:
146           print >>hyperdb.DEBUG, 'create_class', (self, sql)
147         self.cursor.execute(sql)
148         self.create_class_table_indexes(spec)
149         return cols, mls
151     def drop_class_table_indexes(self, cn, key):
152         # drop the old table indexes first
153         l = ['_%s_id_idx'%cn, '_%s_retired_idx'%cn]
154         if key:
155             l.append('_%s_%s_idx'%(cn, key))
157         table_name = '_%s'%cn
158         for index_name in l:
159             if not self.sql_index_exists(table_name, index_name):
160                 continue
161             index_sql = 'drop index %s on %s'%(index_name, table_name)
162             if __debug__:
163                 print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
164             self.cursor.execute(index_sql)
166     def create_journal_table(self, spec):
167         cols = ',' . join(['`%s` VARCHAR(255)'%x
168           for x in 'nodeid date tag action params' . split()])
169         sql  = 'CREATE TABLE `%s__journal` (%s) TYPE=%s'%(spec.classname,
170             cols, self.mysql_backend)
171         if __debug__:
172             print >>hyperdb.DEBUG, 'create_class', (self, sql)
173         self.cursor.execute(sql)
174         self.create_journal_table_indexes(spec)
176     def drop_journal_table_indexes(self, classname):
177         index_name = '%s_journ_idx'%classname
178         if not self.sql_index_exists('%s__journal'%classname, index_name):
179             return
180         index_sql = 'drop index %s on %s__journal'%(index_name, classname)
181         if __debug__:
182             print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
183         self.cursor.execute(index_sql)
185     def create_multilink_table(self, spec, ml):
186         sql = '''CREATE TABLE `%s_%s` (linkid VARCHAR(255),
187             nodeid VARCHAR(255)) TYPE=%s'''%(spec.classname, ml,
188                 self.mysql_backend)
189         if __debug__:
190           print >>hyperdb.DEBUG, 'create_class', (self, sql)
191         self.cursor.execute(sql)
192         self.create_multilink_table_indexes(spec, ml)
194     def drop_multilink_table_indexes(self, classname, ml):
195         l = [
196             '%s_%s_l_idx'%(classname, ml),
197             '%s_%s_n_idx'%(classname, ml)
198         ]
199         for index_name in l:
200             if not self.sql_index_exists(table_name, index_name):
201                 continue
202             index_sql = 'drop index %s on %s'%(index_name, table_name)
203             if __debug__:
204                 print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
205             self.cursor.execute(index_sql)
207 class MysqlClass:
208     # we're overriding this method for ONE missing bit of functionality.
209     # look for "I can't believe it's not a toy RDBMS" below
210     def filter(self, search_matches, filterspec, sort=(None,None),
211             group=(None,None)):
212         ''' Return a list of the ids of the active nodes in this class that
213             match the 'filter' spec, sorted by the group spec and then the
214             sort spec
216             "filterspec" is {propname: value(s)}
217             "sort" and "group" are (dir, prop) where dir is '+', '-' or None
218                                and prop is a prop name or None
219             "search_matches" is {nodeid: marker}
221             The filter must match all properties specificed - but if the
222             property value to match is a list, any one of the values in the
223             list may match for that property to match.
224         '''
225         # just don't bother if the full-text search matched diddly
226         if search_matches == {}:
227             return []
229         cn = self.classname
231         timezone = self.db.getUserTimezone()
232         
233         # figure the WHERE clause from the filterspec
234         props = self.getprops()
235         frum = ['_'+cn]
236         where = []
237         args = []
238         a = self.db.arg
239         for k, v in filterspec.items():
240             propclass = props[k]
241             # now do other where clause stuff
242             if isinstance(propclass, Multilink):
243                 tn = '%s_%s'%(cn, k)
244                 if v in ('-1', ['-1']):
245                     # only match rows that have count(linkid)=0 in the
246                     # corresponding multilink table)
248                     # "I can't believe it's not a toy RDBMS"
249                     # see, even toy RDBMSes like gadfly and sqlite can do
250                     # sub-selects...
251                     self.db.sql('select nodeid from %s'%tn)
252                     s = ','.join([x[0] for x in self.db.sql_fetchall()])
254                     where.append('id not in (%s)'%s)
255                 elif isinstance(v, type([])):
256                     frum.append(tn)
257                     s = ','.join([a for x in v])
258                     where.append('id=%s.nodeid and %s.linkid in (%s)'%(tn,tn,s))
259                     args = args + v
260                 else:
261                     frum.append(tn)
262                     where.append('id=%s.nodeid and %s.linkid=%s'%(tn, tn, a))
263                     args.append(v)
264             elif k == 'id':
265                 if isinstance(v, type([])):
266                     s = ','.join([a for x in v])
267                     where.append('%s in (%s)'%(k, s))
268                     args = args + v
269                 else:
270                     where.append('%s=%s'%(k, a))
271                     args.append(v)
272             elif isinstance(propclass, String):
273                 if not isinstance(v, type([])):
274                     v = [v]
276                 # Quote the bits in the string that need it and then embed
277                 # in a "substring" search. Note - need to quote the '%' so
278                 # they make it through the python layer happily
279                 v = ['%%'+self.db.sql_stringquote(s)+'%%' for s in v]
281                 # now add to the where clause
282                 where.append(' or '.join(["_%s LIKE '%s'"%(k, s) for s in v]))
283                 # note: args are embedded in the query string now
284             elif isinstance(propclass, Link):
285                 if isinstance(v, type([])):
286                     if '-1' in v:
287                         v = v[:]
288                         v.remove('-1')
289                         xtra = ' or _%s is NULL'%k
290                     else:
291                         xtra = ''
292                     if v:
293                         s = ','.join([a for x in v])
294                         where.append('(_%s in (%s)%s)'%(k, s, xtra))
295                         args = args + v
296                     else:
297                         where.append('_%s is NULL'%k)
298                 else:
299                     if v == '-1':
300                         v = None
301                         where.append('_%s is NULL'%k)
302                     else:
303                         where.append('_%s=%s'%(k, a))
304                         args.append(v)
305             elif isinstance(propclass, Date):
306                 if isinstance(v, type([])):
307                     s = ','.join([a for x in v])
308                     where.append('_%s in (%s)'%(k, s))
309                     args = args + [date.Date(x).serialise() for x in v]
310                 else:
311                     try:
312                         # Try to filter on range of dates
313                         date_rng = Range(v, date.Date, offset=timezone)
314                         if (date_rng.from_value):
315                             where.append('_%s >= %s'%(k, a))                            
316                             args.append(date_rng.from_value.serialise())
317                         if (date_rng.to_value):
318                             where.append('_%s <= %s'%(k, a))
319                             args.append(date_rng.to_value.serialise())
320                     except ValueError:
321                         # If range creation fails - ignore that search parameter
322                         pass                        
323             elif isinstance(propclass, Interval):
324                 if isinstance(v, type([])):
325                     s = ','.join([a for x in v])
326                     where.append('_%s in (%s)'%(k, s))
327                     args = args + [date.Interval(x).serialise() for x in v]
328                 else:
329                     try:
330                         # Try to filter on range of intervals
331                         date_rng = Range(v, date.Interval)
332                         if (date_rng.from_value):
333                             where.append('_%s >= %s'%(k, a))
334                             args.append(date_rng.from_value.serialise())
335                         if (date_rng.to_value):
336                             where.append('_%s <= %s'%(k, a))
337                             args.append(date_rng.to_value.serialise())
338                     except ValueError:
339                         # If range creation fails - ignore that search parameter
340                         pass                        
341                     #where.append('_%s=%s'%(k, a))
342                     #args.append(date.Interval(v).serialise())
343             else:
344                 if isinstance(v, type([])):
345                     s = ','.join([a for x in v])
346                     where.append('_%s in (%s)'%(k, s))
347                     args = args + v
348                 else:
349                     where.append('_%s=%s'%(k, a))
350                     args.append(v)
352         # don't match retired nodes
353         where.append('__retired__ <> 1')
355         # add results of full text search
356         if search_matches is not None:
357             v = search_matches.keys()
358             s = ','.join([a for x in v])
359             where.append('id in (%s)'%s)
360             args = args + v
362         # "grouping" is just the first-order sorting in the SQL fetch
363         # can modify it...)
364         orderby = []
365         ordercols = []
366         if group[0] is not None and group[1] is not None:
367             if group[0] != '-':
368                 orderby.append('_'+group[1])
369                 ordercols.append('_'+group[1])
370             else:
371                 orderby.append('_'+group[1]+' desc')
372                 ordercols.append('_'+group[1])
374         # now add in the sorting
375         group = ''
376         if sort[0] is not None and sort[1] is not None:
377             direction, colname = sort
378             if direction != '-':
379                 if colname == 'id':
380                     orderby.append(colname)
381                 else:
382                     orderby.append('_'+colname)
383                     ordercols.append('_'+colname)
384             else:
385                 if colname == 'id':
386                     orderby.append(colname+' desc')
387                     ordercols.append(colname)
388                 else:
389                     orderby.append('_'+colname+' desc')
390                     ordercols.append('_'+colname)
392         # construct the SQL
393         frum = ','.join(frum)
394         if where:
395             where = ' where ' + (' and '.join(where))
396         else:
397             where = ''
398         cols = ['id']
399         if orderby:
400             cols = cols + ordercols
401             order = ' order by %s'%(','.join(orderby))
402         else:
403             order = ''
404         cols = ','.join(cols)
405         sql = 'select %s from %s %s%s%s'%(cols, frum, where, group, order)
406         args = tuple(args)
407         if __debug__:
408             print >>hyperdb.DEBUG, 'filter', (self, sql, args)
409         self.db.cursor.execute(sql, args)
410         l = self.db.cursor.fetchall()
412         # return the IDs (the first column)
413         return [row[0] for row in l]
415 class Class(MysqlClass, rdbms_common.Class):
416     pass
417 class IssueClass(MysqlClass, rdbms_common.IssueClass):
418     pass
419 class FileClass(MysqlClass, rdbms_common.FileClass):
420     pass
422 #vim: set et