Code

130327bc45921dc92a771bbee17fe5bc0dde4783
[roundup.git] / roundup / backends / back_mysql.py
1 #
2 # Copyright (c) 2003 Martynas Sklyzmantas, Andrey Lebedev <andrey@micro.lt>
3 #
4 # This module is free software, and you may redistribute it and/or modify
5 # under the same terms as Python, so long as this copyright message and
6 # disclaimer are retained in their original form.
7 #
8 # Mysql backend for roundup
9 #
11 from roundup.backends.rdbms_common import *
12 from roundup.backends import rdbms_common
13 import MySQLdb
14 import os, shutil
15 from MySQLdb.constants import ER
17 class Maintenance:
18     """ Database maintenance functions """
19     def db_nuke(self, config):
20         """Clear all database contents and drop database itself"""
21         db = Database(config, 'admin')
22         db.sql_commit()
23         db.sql("DROP DATABASE %s" % config.MYSQL_DBNAME)
24         db.sql("CREATE DATABASE %s" % config.MYSQL_DBNAME)
25         if os.path.exists(config.DATABASE):
26             shutil.rmtree(config.DATABASE)
28     def db_exists(self, config):
29         """Check if database already exists"""
30         # Yes, this is a hack, but we must must open connection without
31         # selecting a database to prevent creation of some tables
32         config.MYSQL_DATABASE = (config.MYSQL_DBHOST, config.MYSQL_DBUSER,
33             config.MYSQL_DBPASSWORD)        
34         db = Database(config, 'admin')
35         db.conn.select_db(config.MYSQL_DBNAME)
36         config.MYSQL_DATABASE = (config.MYSQL_DBHOST, config.MYSQL_DBUSER,
37             config.MYSQL_DBPASSWORD, config.MYSQL_DBNAME)
38         db.sql("SHOW TABLES")
39         tables = db.sql_fetchall()
40         if tables or os.path.exists(config.DATABASE):
41             return 1
42         return 0        
44 class Database(Database):
45     arg = '%s'
47     # backend for MySQL to use
48     mysql_backend = 'InnoDB'
49     #mysql_backend = 'BDB'    # much slower, only use if you have no choice
50     
51     def open_connection(self):
52         db = getattr(self.config, 'MYSQL_DATABASE')
53         try:
54             self.conn = MySQLdb.connect(*db)
55         except MySQLdb.OperationalError, message:
56             raise DatabaseError, message
58         self.cursor = self.conn.cursor()
59         # start transaction
60         self.sql("SET AUTOCOMMIT=0")
61         self.sql("BEGIN")
62         try:
63             self.database_schema = self.load_dbschema()
64         except MySQLdb.OperationalError, message:
65             if message[0] != ER.NO_DB_ERROR:
66                 raise
67         except MySQLdb.ProgrammingError, message:
68             if message[0] != ER.NO_SUCH_TABLE:
69                 raise DatabaseError, message
70             self.database_schema = {}
71             self.sql("CREATE TABLE schema (schema TEXT) TYPE=%s"%
72                 self.mysql_backend)
73             # TODO: use AUTO_INCREMENT for generating ids:
74             #       http://www.mysql.com/doc/en/CREATE_TABLE.html
75             self.sql("CREATE TABLE ids (name varchar(255), num INT) TYPE=%s"%
76                 self.mysql_backend)
77             self.sql("CREATE INDEX ids_name_idx on ids(name)")
79     def close(self):
80         self.conn.close()
82     def __repr__(self):
83         return '<myroundsql 0x%x>'%id(self)
85     def sql_fetchone(self):
86         return self.cursor.fetchone()
88     def sql_fetchall(self):
89         return self.cursor.fetchall()
90     
91     def save_dbschema(self, schema):
92         s = repr(self.database_schema)
93         self.sql('INSERT INTO schema VALUES (%s)', (s,))
94     
95     def load_dbschema(self):
96         self.cursor.execute('SELECT schema FROM schema')
97         schema = self.cursor.fetchone()
98         if schema:
99             return eval(schema[0])
100         return None
102     def save_journal(self, classname, cols, nodeid, journaldate,
103                 journaltag, action, params):
104         params = repr(params)
105         entry = (nodeid, journaldate, journaltag, action, params)
107         a = self.arg
108         sql = 'insert into %s__journal (%s) values (%s,%s,%s,%s,%s)'%(classname,
109                 cols, a, a, a, a, a)
110         if __debug__:
111           print >>hyperdb.DEBUG, 'addjournal', (self, sql, entry)
112         self.cursor.execute(sql, entry)
114     def load_journal(self, classname, cols, nodeid):
115         sql = 'select %s from %s__journal where nodeid=%s'%(cols, classname,
116                 self.arg)
117         if __debug__:
118             print >>hyperdb.DEBUG, 'getjournal', (self, sql, nodeid)
119         self.cursor.execute(sql, (nodeid,))
120         res = []
121         for nodeid, date_stamp, user, action, params in self.cursor.fetchall():
122           params = eval(params)
123           res.append((nodeid, date.Date(date_stamp), user, action, params))
124         return res
126     def create_class_table(self, spec):
127         cols, mls = self.determine_columns(spec.properties.items())
128         cols.append('id')
129         cols.append('__retired__')
130         scols = ',' . join(['`%s` VARCHAR(255)'%x for x in cols])
131         sql = 'CREATE TABLE `_%s` (%s) TYPE=%s'%(spec.classname, scols,
132             self.mysql_backend)
133         if __debug__:
134           print >>hyperdb.DEBUG, 'create_class', (self, sql)
135         self.cursor.execute(sql)
136         return cols, mls
138     def create_journal_table(self, spec):
139         cols = ',' . join(['`%s` VARCHAR(255)'%x
140           for x in 'nodeid date tag action params' . split()])
141         sql  = 'CREATE TABLE `%s__journal` (%s) TYPE=%s'%(spec.classname,
142             cols, self.mysql_backend)
143         if __debug__:
144             print >>hyperdb.DEBUG, 'create_class', (self, sql)
145         self.cursor.execute(sql)
147     def create_multilink_table(self, spec, ml):
148         sql = '''CREATE TABLE `%s_%s` (linkid VARCHAR(255),
149             nodeid VARCHAR(255)) TYPE=%s'''%(spec.classname, ml,
150                 self.mysql_backend)
151         if __debug__:
152           print >>hyperdb.DEBUG, 'create_class', (self, sql)
153         self.cursor.execute(sql)
155     # Static methods
156     nuke = Maintenance().db_nuke
157     exists = Maintenance().db_exists
159 class MysqlClass:
160     # we're overriding this method for ONE missing bit of functionality.
161     # look for "I can't believe it's not a toy RDBMS" below
162     def filter(self, search_matches, filterspec, sort=(None,None),
163             group=(None,None)):
164         ''' Return a list of the ids of the active nodes in this class that
165             match the 'filter' spec, sorted by the group spec and then the
166             sort spec
168             "filterspec" is {propname: value(s)}
169             "sort" and "group" are (dir, prop) where dir is '+', '-' or None
170                                and prop is a prop name or None
171             "search_matches" is {nodeid: marker}
173             The filter must match all properties specificed - but if the
174             property value to match is a list, any one of the values in the
175             list may match for that property to match.
176         '''
177         # just don't bother if the full-text search matched diddly
178         if search_matches == {}:
179             return []
181         cn = self.classname
183         timezone = self.db.getUserTimezone()
184         
185         # figure the WHERE clause from the filterspec
186         props = self.getprops()
187         frum = ['_'+cn]
188         where = []
189         args = []
190         a = self.db.arg
191         for k, v in filterspec.items():
192             propclass = props[k]
193             # now do other where clause stuff
194             if isinstance(propclass, Multilink):
195                 tn = '%s_%s'%(cn, k)
196                 if v in ('-1', ['-1']):
197                     # only match rows that have count(linkid)=0 in the
198                     # corresponding multilink table)
200                     # "I can't believe it's not a toy RDBMS"
201                     # see, even toy RDBMSes like gadfly and sqlite can do
202                     # sub-selects...
203                     self.db.sql('select nodeid from %s'%tn)
204                     s = ','.join([x[0] for x in self.db.sql_fetchall()])
206                     where.append('id not in (%s)'%s)
207                 elif isinstance(v, type([])):
208                     frum.append(tn)
209                     s = ','.join([a for x in v])
210                     where.append('id=%s.nodeid and %s.linkid in (%s)'%(tn,tn,s))
211                     args = args + v
212                 else:
213                     frum.append(tn)
214                     where.append('id=%s.nodeid and %s.linkid=%s'%(tn, tn, a))
215                     args.append(v)
216             elif k == 'id':
217                 if isinstance(v, type([])):
218                     s = ','.join([a for x in v])
219                     where.append('%s in (%s)'%(k, s))
220                     args = args + v
221                 else:
222                     where.append('%s=%s'%(k, a))
223                     args.append(v)
224             elif isinstance(propclass, String):
225                 if not isinstance(v, type([])):
226                     v = [v]
228                 # Quote the bits in the string that need it and then embed
229                 # in a "substring" search. Note - need to quote the '%' so
230                 # they make it through the python layer happily
231                 v = ['%%'+self.db.sql_stringquote(s)+'%%' for s in v]
233                 # now add to the where clause
234                 where.append(' or '.join(["_%s LIKE '%s'"%(k, s) for s in v]))
235                 # note: args are embedded in the query string now
236             elif isinstance(propclass, Link):
237                 if isinstance(v, type([])):
238                     if '-1' in v:
239                         v = v[:]
240                         v.remove('-1')
241                         xtra = ' or _%s is NULL'%k
242                     else:
243                         xtra = ''
244                     if v:
245                         s = ','.join([a for x in v])
246                         where.append('(_%s in (%s)%s)'%(k, s, xtra))
247                         args = args + v
248                     else:
249                         where.append('_%s is NULL'%k)
250                 else:
251                     if v == '-1':
252                         v = None
253                         where.append('_%s is NULL'%k)
254                     else:
255                         where.append('_%s=%s'%(k, a))
256                         args.append(v)
257             elif isinstance(propclass, Date):
258                 if isinstance(v, type([])):
259                     s = ','.join([a for x in v])
260                     where.append('_%s in (%s)'%(k, s))
261                     args = args + [date.Date(x).serialise() for x in v]
262                 else:
263                     try:
264                         # Try to filter on range of dates
265                         date_rng = Range(v, date.Date, offset=timezone)
266                         if (date_rng.from_value):
267                             where.append('_%s >= %s'%(k, a))                            
268                             args.append(date_rng.from_value.serialise())
269                         if (date_rng.to_value):
270                             where.append('_%s <= %s'%(k, a))
271                             args.append(date_rng.to_value.serialise())
272                     except ValueError:
273                         # If range creation fails - ignore that search parameter
274                         pass                        
275             elif isinstance(propclass, Interval):
276                 if isinstance(v, type([])):
277                     s = ','.join([a for x in v])
278                     where.append('_%s in (%s)'%(k, s))
279                     args = args + [date.Interval(x).serialise() for x in v]
280                 else:
281                     try:
282                         # Try to filter on range of intervals
283                         date_rng = Range(v, date.Interval)
284                         if (date_rng.from_value):
285                             where.append('_%s >= %s'%(k, a))
286                             args.append(date_rng.from_value.serialise())
287                         if (date_rng.to_value):
288                             where.append('_%s <= %s'%(k, a))
289                             args.append(date_rng.to_value.serialise())
290                     except ValueError:
291                         # If range creation fails - ignore that search parameter
292                         pass                        
293                     #where.append('_%s=%s'%(k, a))
294                     #args.append(date.Interval(v).serialise())
295             else:
296                 if isinstance(v, type([])):
297                     s = ','.join([a for x in v])
298                     where.append('_%s in (%s)'%(k, s))
299                     args = args + v
300                 else:
301                     where.append('_%s=%s'%(k, a))
302                     args.append(v)
304         # don't match retired nodes
305         where.append('__retired__ <> 1')
307         # add results of full text search
308         if search_matches is not None:
309             v = search_matches.keys()
310             s = ','.join([a for x in v])
311             where.append('id in (%s)'%s)
312             args = args + v
314         # "grouping" is just the first-order sorting in the SQL fetch
315         # can modify it...)
316         orderby = []
317         ordercols = []
318         if group[0] is not None and group[1] is not None:
319             if group[0] != '-':
320                 orderby.append('_'+group[1])
321                 ordercols.append('_'+group[1])
322             else:
323                 orderby.append('_'+group[1]+' desc')
324                 ordercols.append('_'+group[1])
326         # now add in the sorting
327         group = ''
328         if sort[0] is not None and sort[1] is not None:
329             direction, colname = sort
330             if direction != '-':
331                 if colname == 'id':
332                     orderby.append(colname)
333                 else:
334                     orderby.append('_'+colname)
335                     ordercols.append('_'+colname)
336             else:
337                 if colname == 'id':
338                     orderby.append(colname+' desc')
339                     ordercols.append(colname)
340                 else:
341                     orderby.append('_'+colname+' desc')
342                     ordercols.append('_'+colname)
344         # construct the SQL
345         frum = ','.join(frum)
346         if where:
347             where = ' where ' + (' and '.join(where))
348         else:
349             where = ''
350         cols = ['id']
351         if orderby:
352             cols = cols + ordercols
353             order = ' order by %s'%(','.join(orderby))
354         else:
355             order = ''
356         cols = ','.join(cols)
357         sql = 'select %s from %s %s%s%s'%(cols, frum, where, group, order)
358         args = tuple(args)
359         if __debug__:
360             print >>hyperdb.DEBUG, 'filter', (self, sql, args)
361         self.db.cursor.execute(sql, args)
362         l = self.db.cursor.fetchall()
364         # return the IDs (the first column)
365         return [row[0] for row in l]
367 class Class(MysqlClass, rdbms_common.Class):
368     pass
369 class IssueClass(MysqlClass, rdbms_common.IssueClass):
370     pass
371 class FileClass(MysqlClass, rdbms_common.FileClass):
372     pass
374 #vim: set et