Code

use the upload-supplied content-type if there is one
[roundup.git] / roundup / backends / back_mysql.py
1 #
2 # Copyright (c) 2003 Martynas Sklyzmantas, Andrey Lebedev <andrey@micro.lt>
3 #
4 # This module is free software, and you may redistribute it and/or modify
5 # under the same terms as Python, so long as this copyright message and
6 # disclaimer are retained in their original form.
7 #
8 # Mysql backend for roundup
9 #
11 from roundup.backends.rdbms_common import *
12 from roundup.backends import rdbms_common
13 import MySQLdb
14 import os, shutil
15 from MySQLdb.constants import ER
17 # Database maintenance functions
18 def db_nuke(config):
19     """Clear all database contents and drop database itself"""
20     db = Database(config, 'admin')
21     try:
22         db.sql_commit()
23         db.sql("DROP DATABASE %s" % config.MYSQL_DBNAME)
24         db.sql("CREATE DATABASE %s" % config.MYSQL_DBNAME)
25     finally:
26         db.close()
27     if os.path.exists(config.DATABASE):
28         shutil.rmtree(config.DATABASE)
30 def db_exists(config):
31     """Check if database already exists"""
32     # Yes, this is a hack, but we must must open connection without
33     # selecting a database to prevent creation of some tables
34     config.MYSQL_DATABASE = (config.MYSQL_DBHOST, config.MYSQL_DBUSER,
35         config.MYSQL_DBPASSWORD)        
36     db = Database(config, 'admin')
37     try:
38         db.conn.select_db(config.MYSQL_DBNAME)
39         config.MYSQL_DATABASE = (config.MYSQL_DBHOST, config.MYSQL_DBUSER,
40             config.MYSQL_DBPASSWORD, config.MYSQL_DBNAME)
41         db.sql("SHOW TABLES")
42         tables = db.sql_fetchall()
43     finally:
44         db.close()
45     if tables or os.path.exists(config.DATABASE):
46         return 1
47     return 0        
49 class Database(Database):
50     arg = '%s'
52     # Backend for MySQL to use.
53     # InnoDB is faster, but if you're running <4.0.16 then you'll need to
54     # use BDB to pass all unit tests.
55     mysql_backend = 'InnoDB'
56     #mysql_backend = 'BDB'
57     
58     def sql_open_connection(self):
59         db = getattr(self.config, 'MYSQL_DATABASE')
60         try:
61             self.conn = MySQLdb.connect(*db)
62         except MySQLdb.OperationalError, message:
63             raise DatabaseError, message
65         self.cursor = self.conn.cursor()
66         # start transaction
67         self.sql("SET AUTOCOMMIT=0")
68         self.sql("BEGIN")
69         try:
70             self.database_schema = self.load_dbschema()
71         except MySQLdb.OperationalError, message:
72             if message[0] != ER.NO_DB_ERROR:
73                 raise
74         except MySQLdb.ProgrammingError, message:
75             if message[0] != ER.NO_SUCH_TABLE:
76                 raise DatabaseError, message
77             self.database_schema = {}
78             self.sql("CREATE TABLE schema (schema TEXT) TYPE=%s"%
79                 self.mysql_backend)
80             # TODO: use AUTO_INCREMENT for generating ids:
81             #       http://www.mysql.com/doc/en/CREATE_TABLE.html
82             self.sql("CREATE TABLE ids (name varchar(255), num INT) TYPE=%s"%
83                 self.mysql_backend)
84             self.sql("CREATE INDEX ids_name_idx on ids(name)")
86     def __repr__(self):
87         return '<myroundsql 0x%x>'%id(self)
89     def sql_fetchone(self):
90         return self.cursor.fetchone()
92     def sql_fetchall(self):
93         return self.cursor.fetchall()
95     def sql_index_exists(self, table_name, index_name):
96         self.cursor.execute('show index from %s'%table_name)
97         for index in self.cursor.fetchall():
98             if index[2] == index_name:
99                 return 1
100         return 0
102     def save_dbschema(self, schema):
103         s = repr(self.database_schema)
104         self.sql('INSERT INTO schema VALUES (%s)', (s,))
105     
106     def load_dbschema(self):
107         self.cursor.execute('SELECT schema FROM schema')
108         schema = self.cursor.fetchone()
109         if schema:
110             return eval(schema[0])
111         return None
113     def save_journal(self, classname, cols, nodeid, journaldate,
114                 journaltag, action, params):
115         params = repr(params)
116         entry = (nodeid, journaldate, journaltag, action, params)
118         a = self.arg
119         sql = 'insert into %s__journal (%s) values (%s,%s,%s,%s,%s)'%(classname,
120                 cols, a, a, a, a, a)
121         if __debug__:
122           print >>hyperdb.DEBUG, 'addjournal', (self, sql, entry)
123         self.cursor.execute(sql, entry)
125     def load_journal(self, classname, cols, nodeid):
126         sql = 'select %s from %s__journal where nodeid=%s'%(cols, classname,
127                 self.arg)
128         if __debug__:
129             print >>hyperdb.DEBUG, 'getjournal', (self, sql, nodeid)
130         self.cursor.execute(sql, (nodeid,))
131         res = []
132         for nodeid, date_stamp, user, action, params in self.cursor.fetchall():
133           params = eval(params)
134           res.append((nodeid, date.Date(date_stamp), user, action, params))
135         return res
137     def create_class_table(self, spec):
138         cols, mls = self.determine_columns(spec.properties.items())
139         cols.append('id')
140         cols.append('__retired__')
141         scols = ',' . join(['`%s` VARCHAR(255)'%x for x in cols])
142         sql = 'CREATE TABLE `_%s` (%s) TYPE=%s'%(spec.classname, scols,
143             self.mysql_backend)
144         if __debug__:
145           print >>hyperdb.DEBUG, 'create_class', (self, sql)
146         self.cursor.execute(sql)
147         self.create_class_table_indexes(spec)
148         return cols, mls
150     def drop_class_table_indexes(self, cn, key):
151         # drop the old table indexes first
152         l = ['_%s_id_idx'%cn, '_%s_retired_idx'%cn]
153         if key:
154             l.append('_%s_%s_idx'%(cn, key))
156         table_name = '_%s'%cn
157         for index_name in l:
158             if not self.sql_index_exists(table_name, index_name):
159                 continue
160             index_sql = 'drop index %s on %s'%(index_name, table_name)
161             if __debug__:
162                 print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
163             self.cursor.execute(index_sql)
165     def create_journal_table(self, spec):
166         cols = ',' . join(['`%s` VARCHAR(255)'%x
167           for x in 'nodeid date tag action params' . split()])
168         sql  = 'CREATE TABLE `%s__journal` (%s) TYPE=%s'%(spec.classname,
169             cols, self.mysql_backend)
170         if __debug__:
171             print >>hyperdb.DEBUG, 'create_class', (self, sql)
172         self.cursor.execute(sql)
173         self.create_journal_table_indexes(spec)
175     def drop_journal_table_indexes(self, classname):
176         index_name = '%s_journ_idx'%classname
177         if not self.sql_index_exists('%s__journal'%classname, index_name):
178             return
179         index_sql = 'drop index %s on %s__journal'%(index_name, classname)
180         if __debug__:
181             print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
182         self.cursor.execute(index_sql)
184     def create_multilink_table(self, spec, ml):
185         sql = '''CREATE TABLE `%s_%s` (linkid VARCHAR(255),
186             nodeid VARCHAR(255)) TYPE=%s'''%(spec.classname, ml,
187                 self.mysql_backend)
188         if __debug__:
189           print >>hyperdb.DEBUG, 'create_class', (self, sql)
190         self.cursor.execute(sql)
191         self.create_multilink_table_indexes(spec, ml)
193     def drop_multilink_table_indexes(self, classname, ml):
194         l = [
195             '%s_%s_l_idx'%(classname, ml),
196             '%s_%s_n_idx'%(classname, ml)
197         ]
198         for index_name in l:
199             if not self.sql_index_exists(table_name, index_name):
200                 continue
201             index_sql = 'drop index %s on %s'%(index_name, table_name)
202             if __debug__:
203                 print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
204             self.cursor.execute(index_sql)
206 class MysqlClass:
207     # we're overriding this method for ONE missing bit of functionality.
208     # look for "I can't believe it's not a toy RDBMS" below
209     def filter(self, search_matches, filterspec, sort=(None,None),
210             group=(None,None)):
211         ''' Return a list of the ids of the active nodes in this class that
212             match the 'filter' spec, sorted by the group spec and then the
213             sort spec
215             "filterspec" is {propname: value(s)}
216             "sort" and "group" are (dir, prop) where dir is '+', '-' or None
217                                and prop is a prop name or None
218             "search_matches" is {nodeid: marker}
220             The filter must match all properties specificed - but if the
221             property value to match is a list, any one of the values in the
222             list may match for that property to match.
223         '''
224         # just don't bother if the full-text search matched diddly
225         if search_matches == {}:
226             return []
228         cn = self.classname
230         timezone = self.db.getUserTimezone()
231         
232         # figure the WHERE clause from the filterspec
233         props = self.getprops()
234         frum = ['_'+cn]
235         where = []
236         args = []
237         a = self.db.arg
238         for k, v in filterspec.items():
239             propclass = props[k]
240             # now do other where clause stuff
241             if isinstance(propclass, Multilink):
242                 tn = '%s_%s'%(cn, k)
243                 if v in ('-1', ['-1']):
244                     # only match rows that have count(linkid)=0 in the
245                     # corresponding multilink table)
247                     # "I can't believe it's not a toy RDBMS"
248                     # see, even toy RDBMSes like gadfly and sqlite can do
249                     # sub-selects...
250                     self.db.sql('select nodeid from %s'%tn)
251                     s = ','.join([x[0] for x in self.db.sql_fetchall()])
253                     where.append('id not in (%s)'%s)
254                 elif isinstance(v, type([])):
255                     frum.append(tn)
256                     s = ','.join([a for x in v])
257                     where.append('id=%s.nodeid and %s.linkid in (%s)'%(tn,tn,s))
258                     args = args + v
259                 else:
260                     frum.append(tn)
261                     where.append('id=%s.nodeid and %s.linkid=%s'%(tn, tn, a))
262                     args.append(v)
263             elif k == 'id':
264                 if isinstance(v, type([])):
265                     s = ','.join([a for x in v])
266                     where.append('%s in (%s)'%(k, s))
267                     args = args + v
268                 else:
269                     where.append('%s=%s'%(k, a))
270                     args.append(v)
271             elif isinstance(propclass, String):
272                 if not isinstance(v, type([])):
273                     v = [v]
275                 # Quote the bits in the string that need it and then embed
276                 # in a "substring" search. Note - need to quote the '%' so
277                 # they make it through the python layer happily
278                 v = ['%%'+self.db.sql_stringquote(s)+'%%' for s in v]
280                 # now add to the where clause
281                 where.append(' or '.join(["_%s LIKE '%s'"%(k, s) for s in v]))
282                 # note: args are embedded in the query string now
283             elif isinstance(propclass, Link):
284                 if isinstance(v, type([])):
285                     if '-1' in v:
286                         v = v[:]
287                         v.remove('-1')
288                         xtra = ' or _%s is NULL'%k
289                     else:
290                         xtra = ''
291                     if v:
292                         s = ','.join([a for x in v])
293                         where.append('(_%s in (%s)%s)'%(k, s, xtra))
294                         args = args + v
295                     else:
296                         where.append('_%s is NULL'%k)
297                 else:
298                     if v == '-1':
299                         v = None
300                         where.append('_%s is NULL'%k)
301                     else:
302                         where.append('_%s=%s'%(k, a))
303                         args.append(v)
304             elif isinstance(propclass, Date):
305                 if isinstance(v, type([])):
306                     s = ','.join([a for x in v])
307                     where.append('_%s in (%s)'%(k, s))
308                     args = args + [date.Date(x).serialise() for x in v]
309                 else:
310                     try:
311                         # Try to filter on range of dates
312                         date_rng = Range(v, date.Date, offset=timezone)
313                         if (date_rng.from_value):
314                             where.append('_%s >= %s'%(k, a))                            
315                             args.append(date_rng.from_value.serialise())
316                         if (date_rng.to_value):
317                             where.append('_%s <= %s'%(k, a))
318                             args.append(date_rng.to_value.serialise())
319                     except ValueError:
320                         # If range creation fails - ignore that search parameter
321                         pass                        
322             elif isinstance(propclass, Interval):
323                 if isinstance(v, type([])):
324                     s = ','.join([a for x in v])
325                     where.append('_%s in (%s)'%(k, s))
326                     args = args + [date.Interval(x).serialise() for x in v]
327                 else:
328                     try:
329                         # Try to filter on range of intervals
330                         date_rng = Range(v, date.Interval)
331                         if (date_rng.from_value):
332                             where.append('_%s >= %s'%(k, a))
333                             args.append(date_rng.from_value.serialise())
334                         if (date_rng.to_value):
335                             where.append('_%s <= %s'%(k, a))
336                             args.append(date_rng.to_value.serialise())
337                     except ValueError:
338                         # If range creation fails - ignore that search parameter
339                         pass                        
340                     #where.append('_%s=%s'%(k, a))
341                     #args.append(date.Interval(v).serialise())
342             else:
343                 if isinstance(v, type([])):
344                     s = ','.join([a for x in v])
345                     where.append('_%s in (%s)'%(k, s))
346                     args = args + v
347                 else:
348                     where.append('_%s=%s'%(k, a))
349                     args.append(v)
351         # don't match retired nodes
352         where.append('__retired__ <> 1')
354         # add results of full text search
355         if search_matches is not None:
356             v = search_matches.keys()
357             s = ','.join([a for x in v])
358             where.append('id in (%s)'%s)
359             args = args + v
361         # "grouping" is just the first-order sorting in the SQL fetch
362         # can modify it...)
363         orderby = []
364         ordercols = []
365         if group[0] is not None and group[1] is not None:
366             if group[0] != '-':
367                 orderby.append('_'+group[1])
368                 ordercols.append('_'+group[1])
369             else:
370                 orderby.append('_'+group[1]+' desc')
371                 ordercols.append('_'+group[1])
373         # now add in the sorting
374         group = ''
375         if sort[0] is not None and sort[1] is not None:
376             direction, colname = sort
377             if direction != '-':
378                 if colname == 'id':
379                     orderby.append(colname)
380                 else:
381                     orderby.append('_'+colname)
382                     ordercols.append('_'+colname)
383             else:
384                 if colname == 'id':
385                     orderby.append(colname+' desc')
386                     ordercols.append(colname)
387                 else:
388                     orderby.append('_'+colname+' desc')
389                     ordercols.append('_'+colname)
391         # construct the SQL
392         frum = ','.join(frum)
393         if where:
394             where = ' where ' + (' and '.join(where))
395         else:
396             where = ''
397         cols = ['id']
398         if orderby:
399             cols = cols + ordercols
400             order = ' order by %s'%(','.join(orderby))
401         else:
402             order = ''
403         cols = ','.join(cols)
404         sql = 'select %s from %s %s%s%s'%(cols, frum, where, group, order)
405         args = tuple(args)
406         if __debug__:
407             print >>hyperdb.DEBUG, 'filter', (self, sql, args)
408         self.db.cursor.execute(sql, args)
409         l = self.db.cursor.fetchall()
411         # return the IDs (the first column)
412         return [row[0] for row in l]
414     # mysql doesn't implement INTERSECT
415     def find(self, **propspec):
416         '''Get the ids of nodes in this class which link to the given nodes.
418         'propspec' consists of keyword args propname=nodeid or
419                    propname={nodeid:1, }
420         'propname' must be the name of a property in this class, or a
421                    KeyError is raised.  That property must be a Link or
422                    Multilink property, or a TypeError is raised.
424         Any node in this class whose 'propname' property links to any of the
425         nodeids will be returned. Used by the full text indexing, which knows
426         that "foo" occurs in msg1, msg3 and file7, so we have hits on these
427         issues:
429             db.issue.find(messages={'1':1,'3':1}, files={'7':1})
430         '''
431         if __debug__:
432             print >>hyperdb.DEBUG, 'find', (self, propspec)
434         # shortcut
435         if not propspec:
436             return []
438         # validate the args
439         props = self.getprops()
440         propspec = propspec.items()
441         for propname, nodeids in propspec:
442             # check the prop is OK
443             prop = props[propname]
444             if not isinstance(prop, Link) and not isinstance(prop, Multilink):
445                 raise TypeError, "'%s' not a Link/Multilink property"%propname
447         # first, links
448         a = self.db.arg
449         where = ['__retired__ <> %s'%a]
450         allvalues = (1,)
451         for prop, values in propspec:
452             if not isinstance(props[prop], hyperdb.Link):
453                 continue
454             if type(values) is type({}) and len(values) == 1:
455                 values = values.keys()[0]
456             if type(values) is type(''):
457                 allvalues += (values,)
458                 where.append('_%s = %s'%(prop, a))
459             elif values is None:
460                 where.append('_%s is NULL'%prop)
461             else:
462                 allvalues += tuple(values.keys())
463                 where.append('_%s in (%s)'%(prop, ','.join([a]*len(values))))
464         tables = []
465         if where:
466             tables.append('select id as nodeid from _%s where %s'%(
467                 self.classname, ' and '.join(where)))
469         # now multilinks
470         for prop, values in propspec:
471             if not isinstance(props[prop], hyperdb.Multilink):
472                 continue
473             if type(values) is type(''):
474                 allvalues += (values,)
475                 s = a
476             else:
477                 allvalues += tuple(values.keys())
478                 s = ','.join([a]*len(values))
479             tables.append('select nodeid from %s_%s where linkid in (%s)'%(
480                 self.classname, prop, s))
482         raise NotImplemented, "XXX this code's farked"
483         d = {}
484         self.db.sql(sql, allvalues)
485         for result in self.db.sql_fetchall():
486             d[result[0]] = 1
488         for query in tables[1:]:
489             self.db.sql(sql, allvalues)
490             for result in self.db.sql_fetchall():
491                 if not d.has_key(result[0]):
492                     continue
494         if __debug__:
495             print >>hyperdb.DEBUG, 'find ... ', l
496         l = d.keys()
497         l.sort()
498         return l
500 class Class(MysqlClass, rdbms_common.Class):
501     pass
502 class IssueClass(MysqlClass, rdbms_common.IssueClass):
503     pass
504 class FileClass(MysqlClass, rdbms_common.FileClass):
505     pass
507 #vim: set et