Code

- using Zope3's test runner now, allowing GC checks, nicer controls and
[roundup.git] / roundup / backends / back_mysql.py
1 #
2 # Copyright (c) 2003 Martynas Sklyzmantas, Andrey Lebedev <andrey@micro.lt>
3 #
4 # This module is free software, and you may redistribute it and/or modify
5 # under the same terms as Python, so long as this copyright message and
6 # disclaimer are retained in their original form.
7 #
8 # Mysql backend for roundup
9 #
11 from roundup.backends.rdbms_common import *
12 from roundup.backends import rdbms_common
13 import MySQLdb
14 import os, shutil
15 from MySQLdb.constants import ER
17 # Database maintenance functions
18 def db_nuke(config):
19     """Clear all database contents and drop database itself"""
20     db = Database(config, 'admin')
21     try:
22         db.sql_commit()
23         db.sql("DROP DATABASE %s" % config.MYSQL_DBNAME)
24         db.sql("CREATE DATABASE %s" % config.MYSQL_DBNAME)
25     finally:
26         db.close()
27     if os.path.exists(config.DATABASE):
28         shutil.rmtree(config.DATABASE)
30 def db_exists(config):
31     """Check if database already exists"""
32     # Yes, this is a hack, but we must must open connection without
33     # selecting a database to prevent creation of some tables
34     config.MYSQL_DATABASE = (config.MYSQL_DBHOST, config.MYSQL_DBUSER,
35         config.MYSQL_DBPASSWORD)        
36     db = Database(config, 'admin')
37     try:
38         db.conn.select_db(config.MYSQL_DBNAME)
39         config.MYSQL_DATABASE = (config.MYSQL_DBHOST, config.MYSQL_DBUSER,
40             config.MYSQL_DBPASSWORD, config.MYSQL_DBNAME)
41         db.sql("SHOW TABLES")
42         tables = db.sql_fetchall()
43     finally:
44         db.close()
45     if tables or os.path.exists(config.DATABASE):
46         return 1
47     return 0        
49 class Database(Database):
50     arg = '%s'
52     # backend for MySQL to use
53     mysql_backend = 'InnoDB'
54     #mysql_backend = 'BDB'    # much slower, only use if you have no choice
55     
56     def open_connection(self):
57         db = getattr(self.config, 'MYSQL_DATABASE')
58         try:
59             self.conn = MySQLdb.connect(*db)
60         except MySQLdb.OperationalError, message:
61             raise DatabaseError, message
63         self.cursor = self.conn.cursor()
64         # start transaction
65         self.sql("SET AUTOCOMMIT=0")
66         self.sql("BEGIN")
67         try:
68             self.database_schema = self.load_dbschema()
69         except MySQLdb.OperationalError, message:
70             if message[0] != ER.NO_DB_ERROR:
71                 raise
72         except MySQLdb.ProgrammingError, message:
73             if message[0] != ER.NO_SUCH_TABLE:
74                 raise DatabaseError, message
75             self.database_schema = {}
76             self.sql("CREATE TABLE schema (schema TEXT) TYPE=%s"%
77                 self.mysql_backend)
78             # TODO: use AUTO_INCREMENT for generating ids:
79             #       http://www.mysql.com/doc/en/CREATE_TABLE.html
80             self.sql("CREATE TABLE ids (name varchar(255), num INT) TYPE=%s"%
81                 self.mysql_backend)
82             self.sql("CREATE INDEX ids_name_idx on ids(name)")
84     def close(self):
85         self.conn.close()
87     def __repr__(self):
88         return '<myroundsql 0x%x>'%id(self)
90     def sql_fetchone(self):
91         return self.cursor.fetchone()
93     def sql_fetchall(self):
94         return self.cursor.fetchall()
95     
96     def save_dbschema(self, schema):
97         s = repr(self.database_schema)
98         self.sql('INSERT INTO schema VALUES (%s)', (s,))
99     
100     def load_dbschema(self):
101         self.cursor.execute('SELECT schema FROM schema')
102         schema = self.cursor.fetchone()
103         if schema:
104             return eval(schema[0])
105         return None
107     def save_journal(self, classname, cols, nodeid, journaldate,
108                 journaltag, action, params):
109         params = repr(params)
110         entry = (nodeid, journaldate, journaltag, action, params)
112         a = self.arg
113         sql = 'insert into %s__journal (%s) values (%s,%s,%s,%s,%s)'%(classname,
114                 cols, a, a, a, a, a)
115         if __debug__:
116           print >>hyperdb.DEBUG, 'addjournal', (self, sql, entry)
117         self.cursor.execute(sql, entry)
119     def load_journal(self, classname, cols, nodeid):
120         sql = 'select %s from %s__journal where nodeid=%s'%(cols, classname,
121                 self.arg)
122         if __debug__:
123             print >>hyperdb.DEBUG, 'getjournal', (self, sql, nodeid)
124         self.cursor.execute(sql, (nodeid,))
125         res = []
126         for nodeid, date_stamp, user, action, params in self.cursor.fetchall():
127           params = eval(params)
128           res.append((nodeid, date.Date(date_stamp), user, action, params))
129         return res
131     def create_class_table(self, spec):
132         cols, mls = self.determine_columns(spec.properties.items())
133         cols.append('id')
134         cols.append('__retired__')
135         scols = ',' . join(['`%s` VARCHAR(255)'%x for x in cols])
136         sql = 'CREATE TABLE `_%s` (%s) TYPE=%s'%(spec.classname, scols,
137             self.mysql_backend)
138         if __debug__:
139           print >>hyperdb.DEBUG, 'create_class', (self, sql)
140         self.cursor.execute(sql)
141         return cols, mls
143     def create_journal_table(self, spec):
144         cols = ',' . join(['`%s` VARCHAR(255)'%x
145           for x in 'nodeid date tag action params' . split()])
146         sql  = 'CREATE TABLE `%s__journal` (%s) TYPE=%s'%(spec.classname,
147             cols, self.mysql_backend)
148         if __debug__:
149             print >>hyperdb.DEBUG, 'create_class', (self, sql)
150         self.cursor.execute(sql)
152     def create_multilink_table(self, spec, ml):
153         sql = '''CREATE TABLE `%s_%s` (linkid VARCHAR(255),
154             nodeid VARCHAR(255)) TYPE=%s'''%(spec.classname, ml,
155                 self.mysql_backend)
156         if __debug__:
157           print >>hyperdb.DEBUG, 'create_class', (self, sql)
158         self.cursor.execute(sql)
160 class MysqlClass:
161     # we're overriding this method for ONE missing bit of functionality.
162     # look for "I can't believe it's not a toy RDBMS" below
163     def filter(self, search_matches, filterspec, sort=(None,None),
164             group=(None,None)):
165         ''' Return a list of the ids of the active nodes in this class that
166             match the 'filter' spec, sorted by the group spec and then the
167             sort spec
169             "filterspec" is {propname: value(s)}
170             "sort" and "group" are (dir, prop) where dir is '+', '-' or None
171                                and prop is a prop name or None
172             "search_matches" is {nodeid: marker}
174             The filter must match all properties specificed - but if the
175             property value to match is a list, any one of the values in the
176             list may match for that property to match.
177         '''
178         # just don't bother if the full-text search matched diddly
179         if search_matches == {}:
180             return []
182         cn = self.classname
184         timezone = self.db.getUserTimezone()
185         
186         # figure the WHERE clause from the filterspec
187         props = self.getprops()
188         frum = ['_'+cn]
189         where = []
190         args = []
191         a = self.db.arg
192         for k, v in filterspec.items():
193             propclass = props[k]
194             # now do other where clause stuff
195             if isinstance(propclass, Multilink):
196                 tn = '%s_%s'%(cn, k)
197                 if v in ('-1', ['-1']):
198                     # only match rows that have count(linkid)=0 in the
199                     # corresponding multilink table)
201                     # "I can't believe it's not a toy RDBMS"
202                     # see, even toy RDBMSes like gadfly and sqlite can do
203                     # sub-selects...
204                     self.db.sql('select nodeid from %s'%tn)
205                     s = ','.join([x[0] for x in self.db.sql_fetchall()])
207                     where.append('id not in (%s)'%s)
208                 elif isinstance(v, type([])):
209                     frum.append(tn)
210                     s = ','.join([a for x in v])
211                     where.append('id=%s.nodeid and %s.linkid in (%s)'%(tn,tn,s))
212                     args = args + v
213                 else:
214                     frum.append(tn)
215                     where.append('id=%s.nodeid and %s.linkid=%s'%(tn, tn, a))
216                     args.append(v)
217             elif k == 'id':
218                 if isinstance(v, type([])):
219                     s = ','.join([a for x in v])
220                     where.append('%s in (%s)'%(k, s))
221                     args = args + v
222                 else:
223                     where.append('%s=%s'%(k, a))
224                     args.append(v)
225             elif isinstance(propclass, String):
226                 if not isinstance(v, type([])):
227                     v = [v]
229                 # Quote the bits in the string that need it and then embed
230                 # in a "substring" search. Note - need to quote the '%' so
231                 # they make it through the python layer happily
232                 v = ['%%'+self.db.sql_stringquote(s)+'%%' for s in v]
234                 # now add to the where clause
235                 where.append(' or '.join(["_%s LIKE '%s'"%(k, s) for s in v]))
236                 # note: args are embedded in the query string now
237             elif isinstance(propclass, Link):
238                 if isinstance(v, type([])):
239                     if '-1' in v:
240                         v = v[:]
241                         v.remove('-1')
242                         xtra = ' or _%s is NULL'%k
243                     else:
244                         xtra = ''
245                     if v:
246                         s = ','.join([a for x in v])
247                         where.append('(_%s in (%s)%s)'%(k, s, xtra))
248                         args = args + v
249                     else:
250                         where.append('_%s is NULL'%k)
251                 else:
252                     if v == '-1':
253                         v = None
254                         where.append('_%s is NULL'%k)
255                     else:
256                         where.append('_%s=%s'%(k, a))
257                         args.append(v)
258             elif isinstance(propclass, Date):
259                 if isinstance(v, type([])):
260                     s = ','.join([a for x in v])
261                     where.append('_%s in (%s)'%(k, s))
262                     args = args + [date.Date(x).serialise() for x in v]
263                 else:
264                     try:
265                         # Try to filter on range of dates
266                         date_rng = Range(v, date.Date, offset=timezone)
267                         if (date_rng.from_value):
268                             where.append('_%s >= %s'%(k, a))                            
269                             args.append(date_rng.from_value.serialise())
270                         if (date_rng.to_value):
271                             where.append('_%s <= %s'%(k, a))
272                             args.append(date_rng.to_value.serialise())
273                     except ValueError:
274                         # If range creation fails - ignore that search parameter
275                         pass                        
276             elif isinstance(propclass, Interval):
277                 if isinstance(v, type([])):
278                     s = ','.join([a for x in v])
279                     where.append('_%s in (%s)'%(k, s))
280                     args = args + [date.Interval(x).serialise() for x in v]
281                 else:
282                     try:
283                         # Try to filter on range of intervals
284                         date_rng = Range(v, date.Interval)
285                         if (date_rng.from_value):
286                             where.append('_%s >= %s'%(k, a))
287                             args.append(date_rng.from_value.serialise())
288                         if (date_rng.to_value):
289                             where.append('_%s <= %s'%(k, a))
290                             args.append(date_rng.to_value.serialise())
291                     except ValueError:
292                         # If range creation fails - ignore that search parameter
293                         pass                        
294                     #where.append('_%s=%s'%(k, a))
295                     #args.append(date.Interval(v).serialise())
296             else:
297                 if isinstance(v, type([])):
298                     s = ','.join([a for x in v])
299                     where.append('_%s in (%s)'%(k, s))
300                     args = args + v
301                 else:
302                     where.append('_%s=%s'%(k, a))
303                     args.append(v)
305         # don't match retired nodes
306         where.append('__retired__ <> 1')
308         # add results of full text search
309         if search_matches is not None:
310             v = search_matches.keys()
311             s = ','.join([a for x in v])
312             where.append('id in (%s)'%s)
313             args = args + v
315         # "grouping" is just the first-order sorting in the SQL fetch
316         # can modify it...)
317         orderby = []
318         ordercols = []
319         if group[0] is not None and group[1] is not None:
320             if group[0] != '-':
321                 orderby.append('_'+group[1])
322                 ordercols.append('_'+group[1])
323             else:
324                 orderby.append('_'+group[1]+' desc')
325                 ordercols.append('_'+group[1])
327         # now add in the sorting
328         group = ''
329         if sort[0] is not None and sort[1] is not None:
330             direction, colname = sort
331             if direction != '-':
332                 if colname == 'id':
333                     orderby.append(colname)
334                 else:
335                     orderby.append('_'+colname)
336                     ordercols.append('_'+colname)
337             else:
338                 if colname == 'id':
339                     orderby.append(colname+' desc')
340                     ordercols.append(colname)
341                 else:
342                     orderby.append('_'+colname+' desc')
343                     ordercols.append('_'+colname)
345         # construct the SQL
346         frum = ','.join(frum)
347         if where:
348             where = ' where ' + (' and '.join(where))
349         else:
350             where = ''
351         cols = ['id']
352         if orderby:
353             cols = cols + ordercols
354             order = ' order by %s'%(','.join(orderby))
355         else:
356             order = ''
357         cols = ','.join(cols)
358         sql = 'select %s from %s %s%s%s'%(cols, frum, where, group, order)
359         args = tuple(args)
360         if __debug__:
361             print >>hyperdb.DEBUG, 'filter', (self, sql, args)
362         self.db.cursor.execute(sql, args)
363         l = self.db.cursor.fetchall()
365         # return the IDs (the first column)
366         return [row[0] for row in l]
368 class Class(MysqlClass, rdbms_common.Class):
369     pass
370 class IssueClass(MysqlClass, rdbms_common.IssueClass):
371     pass
372 class FileClass(MysqlClass, rdbms_common.FileClass):
373     pass
375 #vim: set et