Code

fixed ZRoundup - mostly changes to classic template
[roundup.git] / roundup / backends / back_mysql.py
1 #
2 # Copyright (c) 2003 Martynas Sklyzmantas, Andrey Lebedev <andrey@micro.lt>
3 #
4 # This module is free software, and you may redistribute it and/or modify
5 # under the same terms as Python, so long as this copyright message and
6 # disclaimer are retained in their original form.
7 #
8 # Mysql backend for roundup
9 #
11 from roundup.backends.rdbms_common import *
12 from roundup.backends import rdbms_common
13 import MySQLdb
14 import os, shutil
15 from MySQLdb.constants import ER
17 # Database maintenance functions
18 def db_nuke(config):
19     """Clear all database contents and drop database itself"""
20     db = Database(config, 'admin')
21     try:
22         db.sql_commit()
23         db.sql("DROP DATABASE %s" % config.MYSQL_DBNAME)
24         db.sql("CREATE DATABASE %s" % config.MYSQL_DBNAME)
25     finally:
26         db.close()
27     if os.path.exists(config.DATABASE):
28         shutil.rmtree(config.DATABASE)
30 def db_exists(config):
31     """Check if database already exists"""
32     # Yes, this is a hack, but we must must open connection without
33     # selecting a database to prevent creation of some tables
34     config.MYSQL_DATABASE = (config.MYSQL_DBHOST, config.MYSQL_DBUSER,
35         config.MYSQL_DBPASSWORD)        
36     db = Database(config, 'admin')
37     try:
38         db.conn.select_db(config.MYSQL_DBNAME)
39         config.MYSQL_DATABASE = (config.MYSQL_DBHOST, config.MYSQL_DBUSER,
40             config.MYSQL_DBPASSWORD, config.MYSQL_DBNAME)
41         db.sql("SHOW TABLES")
42         tables = db.sql_fetchall()
43     finally:
44         db.close()
45     if tables or os.path.exists(config.DATABASE):
46         return 1
47     return 0        
49 class Database(Database):
50     arg = '%s'
52     # backend for MySQL to use
53     mysql_backend = 'InnoDB'
54     #mysql_backend = 'BDB'    # much slower, only use if you have no choice
55     
56     def sql_open_connection(self):
57         db = getattr(self.config, 'MYSQL_DATABASE')
58         try:
59             self.conn = MySQLdb.connect(*db)
60         except MySQLdb.OperationalError, message:
61             raise DatabaseError, message
63         self.cursor = self.conn.cursor()
64         # start transaction
65         self.sql("SET AUTOCOMMIT=0")
66         self.sql("BEGIN")
67         try:
68             self.database_schema = self.load_dbschema()
69         except MySQLdb.OperationalError, message:
70             if message[0] != ER.NO_DB_ERROR:
71                 raise
72         except MySQLdb.ProgrammingError, message:
73             if message[0] != ER.NO_SUCH_TABLE:
74                 raise DatabaseError, message
75             self.database_schema = {}
76             self.sql("CREATE TABLE schema (schema TEXT) TYPE=%s"%
77                 self.mysql_backend)
78             # TODO: use AUTO_INCREMENT for generating ids:
79             #       http://www.mysql.com/doc/en/CREATE_TABLE.html
80             self.sql("CREATE TABLE ids (name varchar(255), num INT) TYPE=%s"%
81                 self.mysql_backend)
82             self.sql("CREATE INDEX ids_name_idx on ids(name)")
84     def __repr__(self):
85         return '<myroundsql 0x%x>'%id(self)
87     def sql_fetchone(self):
88         return self.cursor.fetchone()
90     def sql_fetchall(self):
91         return self.cursor.fetchall()
93     def sql_index_exists(self, table_name, index_name):
94         self.cursor.execute('show index from %s'%table_name)
95         for index in self.cursor.fetchall():
96             if index[2] == index_name:
97                 return 1
98         return 0
100     def save_dbschema(self, schema):
101         s = repr(self.database_schema)
102         self.sql('INSERT INTO schema VALUES (%s)', (s,))
103     
104     def load_dbschema(self):
105         self.cursor.execute('SELECT schema FROM schema')
106         schema = self.cursor.fetchone()
107         if schema:
108             return eval(schema[0])
109         return None
111     def save_journal(self, classname, cols, nodeid, journaldate,
112                 journaltag, action, params):
113         params = repr(params)
114         entry = (nodeid, journaldate, journaltag, action, params)
116         a = self.arg
117         sql = 'insert into %s__journal (%s) values (%s,%s,%s,%s,%s)'%(classname,
118                 cols, a, a, a, a, a)
119         if __debug__:
120           print >>hyperdb.DEBUG, 'addjournal', (self, sql, entry)
121         self.cursor.execute(sql, entry)
123     def load_journal(self, classname, cols, nodeid):
124         sql = 'select %s from %s__journal where nodeid=%s'%(cols, classname,
125                 self.arg)
126         if __debug__:
127             print >>hyperdb.DEBUG, 'getjournal', (self, sql, nodeid)
128         self.cursor.execute(sql, (nodeid,))
129         res = []
130         for nodeid, date_stamp, user, action, params in self.cursor.fetchall():
131           params = eval(params)
132           res.append((nodeid, date.Date(date_stamp), user, action, params))
133         return res
135     def create_class_table(self, spec):
136         cols, mls = self.determine_columns(spec.properties.items())
137         cols.append('id')
138         cols.append('__retired__')
139         scols = ',' . join(['`%s` VARCHAR(255)'%x for x in cols])
140         sql = 'CREATE TABLE `_%s` (%s) TYPE=%s'%(spec.classname, scols,
141             self.mysql_backend)
142         if __debug__:
143           print >>hyperdb.DEBUG, 'create_class', (self, sql)
144         self.cursor.execute(sql)
145         self.create_class_table_indexes(spec)
146         return cols, mls
148     def drop_class_table_indexes(self, cn, key):
149         # drop the old table indexes first
150         l = ['_%s_id_idx'%cn, '_%s_retired_idx'%cn]
151         if key:
152             l.append('_%s_%s_idx'%(cn, key))
154         table_name = '_%s'%cn
155         for index_name in l:
156             if not self.sql_index_exists(table_name, index_name):
157                 continue
158             index_sql = 'drop index %s on %s'%(index_name, table_name)
159             if __debug__:
160                 print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
161             self.cursor.execute(index_sql)
163     def create_journal_table(self, spec):
164         cols = ',' . join(['`%s` VARCHAR(255)'%x
165           for x in 'nodeid date tag action params' . split()])
166         sql  = 'CREATE TABLE `%s__journal` (%s) TYPE=%s'%(spec.classname,
167             cols, self.mysql_backend)
168         if __debug__:
169             print >>hyperdb.DEBUG, 'create_class', (self, sql)
170         self.cursor.execute(sql)
171         self.create_journal_table_indexes(spec)
173     def drop_journal_table_indexes(self, classname):
174         index_name = '%s_journ_idx'%classname
175         if not self.sql_index_exists('%s__journal'%classname, index_name):
176             return
177         index_sql = 'drop index %s on %s__journal'%(index_name, classname)
178         if __debug__:
179             print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
180         self.cursor.execute(index_sql)
182     def create_multilink_table(self, spec, ml):
183         sql = '''CREATE TABLE `%s_%s` (linkid VARCHAR(255),
184             nodeid VARCHAR(255)) TYPE=%s'''%(spec.classname, ml,
185                 self.mysql_backend)
186         if __debug__:
187           print >>hyperdb.DEBUG, 'create_class', (self, sql)
188         self.cursor.execute(sql)
189         self.create_multilink_table_indexes(spec, ml)
191     def drop_multilink_table_indexes(self, classname, ml):
192         l = [
193             '%s_%s_l_idx'%(classname, ml),
194             '%s_%s_n_idx'%(classname, ml)
195         ]
196         for index_name in l:
197             if not self.sql_index_exists(table_name, index_name):
198                 continue
199             index_sql = 'drop index %s on %s'%(index_name, table_name)
200             if __debug__:
201                 print >>hyperdb.DEBUG, 'drop_index', (self, index_sql)
202             self.cursor.execute(index_sql)
204 class MysqlClass:
205     # we're overriding this method for ONE missing bit of functionality.
206     # look for "I can't believe it's not a toy RDBMS" below
207     def filter(self, search_matches, filterspec, sort=(None,None),
208             group=(None,None)):
209         ''' Return a list of the ids of the active nodes in this class that
210             match the 'filter' spec, sorted by the group spec and then the
211             sort spec
213             "filterspec" is {propname: value(s)}
214             "sort" and "group" are (dir, prop) where dir is '+', '-' or None
215                                and prop is a prop name or None
216             "search_matches" is {nodeid: marker}
218             The filter must match all properties specificed - but if the
219             property value to match is a list, any one of the values in the
220             list may match for that property to match.
221         '''
222         # just don't bother if the full-text search matched diddly
223         if search_matches == {}:
224             return []
226         cn = self.classname
228         timezone = self.db.getUserTimezone()
229         
230         # figure the WHERE clause from the filterspec
231         props = self.getprops()
232         frum = ['_'+cn]
233         where = []
234         args = []
235         a = self.db.arg
236         for k, v in filterspec.items():
237             propclass = props[k]
238             # now do other where clause stuff
239             if isinstance(propclass, Multilink):
240                 tn = '%s_%s'%(cn, k)
241                 if v in ('-1', ['-1']):
242                     # only match rows that have count(linkid)=0 in the
243                     # corresponding multilink table)
245                     # "I can't believe it's not a toy RDBMS"
246                     # see, even toy RDBMSes like gadfly and sqlite can do
247                     # sub-selects...
248                     self.db.sql('select nodeid from %s'%tn)
249                     s = ','.join([x[0] for x in self.db.sql_fetchall()])
251                     where.append('id not in (%s)'%s)
252                 elif isinstance(v, type([])):
253                     frum.append(tn)
254                     s = ','.join([a for x in v])
255                     where.append('id=%s.nodeid and %s.linkid in (%s)'%(tn,tn,s))
256                     args = args + v
257                 else:
258                     frum.append(tn)
259                     where.append('id=%s.nodeid and %s.linkid=%s'%(tn, tn, a))
260                     args.append(v)
261             elif k == 'id':
262                 if isinstance(v, type([])):
263                     s = ','.join([a for x in v])
264                     where.append('%s in (%s)'%(k, s))
265                     args = args + v
266                 else:
267                     where.append('%s=%s'%(k, a))
268                     args.append(v)
269             elif isinstance(propclass, String):
270                 if not isinstance(v, type([])):
271                     v = [v]
273                 # Quote the bits in the string that need it and then embed
274                 # in a "substring" search. Note - need to quote the '%' so
275                 # they make it through the python layer happily
276                 v = ['%%'+self.db.sql_stringquote(s)+'%%' for s in v]
278                 # now add to the where clause
279                 where.append(' or '.join(["_%s LIKE '%s'"%(k, s) for s in v]))
280                 # note: args are embedded in the query string now
281             elif isinstance(propclass, Link):
282                 if isinstance(v, type([])):
283                     if '-1' in v:
284                         v = v[:]
285                         v.remove('-1')
286                         xtra = ' or _%s is NULL'%k
287                     else:
288                         xtra = ''
289                     if v:
290                         s = ','.join([a for x in v])
291                         where.append('(_%s in (%s)%s)'%(k, s, xtra))
292                         args = args + v
293                     else:
294                         where.append('_%s is NULL'%k)
295                 else:
296                     if v == '-1':
297                         v = None
298                         where.append('_%s is NULL'%k)
299                     else:
300                         where.append('_%s=%s'%(k, a))
301                         args.append(v)
302             elif isinstance(propclass, Date):
303                 if isinstance(v, type([])):
304                     s = ','.join([a for x in v])
305                     where.append('_%s in (%s)'%(k, s))
306                     args = args + [date.Date(x).serialise() for x in v]
307                 else:
308                     try:
309                         # Try to filter on range of dates
310                         date_rng = Range(v, date.Date, offset=timezone)
311                         if (date_rng.from_value):
312                             where.append('_%s >= %s'%(k, a))                            
313                             args.append(date_rng.from_value.serialise())
314                         if (date_rng.to_value):
315                             where.append('_%s <= %s'%(k, a))
316                             args.append(date_rng.to_value.serialise())
317                     except ValueError:
318                         # If range creation fails - ignore that search parameter
319                         pass                        
320             elif isinstance(propclass, Interval):
321                 if isinstance(v, type([])):
322                     s = ','.join([a for x in v])
323                     where.append('_%s in (%s)'%(k, s))
324                     args = args + [date.Interval(x).serialise() for x in v]
325                 else:
326                     try:
327                         # Try to filter on range of intervals
328                         date_rng = Range(v, date.Interval)
329                         if (date_rng.from_value):
330                             where.append('_%s >= %s'%(k, a))
331                             args.append(date_rng.from_value.serialise())
332                         if (date_rng.to_value):
333                             where.append('_%s <= %s'%(k, a))
334                             args.append(date_rng.to_value.serialise())
335                     except ValueError:
336                         # If range creation fails - ignore that search parameter
337                         pass                        
338                     #where.append('_%s=%s'%(k, a))
339                     #args.append(date.Interval(v).serialise())
340             else:
341                 if isinstance(v, type([])):
342                     s = ','.join([a for x in v])
343                     where.append('_%s in (%s)'%(k, s))
344                     args = args + v
345                 else:
346                     where.append('_%s=%s'%(k, a))
347                     args.append(v)
349         # don't match retired nodes
350         where.append('__retired__ <> 1')
352         # add results of full text search
353         if search_matches is not None:
354             v = search_matches.keys()
355             s = ','.join([a for x in v])
356             where.append('id in (%s)'%s)
357             args = args + v
359         # "grouping" is just the first-order sorting in the SQL fetch
360         # can modify it...)
361         orderby = []
362         ordercols = []
363         if group[0] is not None and group[1] is not None:
364             if group[0] != '-':
365                 orderby.append('_'+group[1])
366                 ordercols.append('_'+group[1])
367             else:
368                 orderby.append('_'+group[1]+' desc')
369                 ordercols.append('_'+group[1])
371         # now add in the sorting
372         group = ''
373         if sort[0] is not None and sort[1] is not None:
374             direction, colname = sort
375             if direction != '-':
376                 if colname == 'id':
377                     orderby.append(colname)
378                 else:
379                     orderby.append('_'+colname)
380                     ordercols.append('_'+colname)
381             else:
382                 if colname == 'id':
383                     orderby.append(colname+' desc')
384                     ordercols.append(colname)
385                 else:
386                     orderby.append('_'+colname+' desc')
387                     ordercols.append('_'+colname)
389         # construct the SQL
390         frum = ','.join(frum)
391         if where:
392             where = ' where ' + (' and '.join(where))
393         else:
394             where = ''
395         cols = ['id']
396         if orderby:
397             cols = cols + ordercols
398             order = ' order by %s'%(','.join(orderby))
399         else:
400             order = ''
401         cols = ','.join(cols)
402         sql = 'select %s from %s %s%s%s'%(cols, frum, where, group, order)
403         args = tuple(args)
404         if __debug__:
405             print >>hyperdb.DEBUG, 'filter', (self, sql, args)
406         self.db.cursor.execute(sql, args)
407         l = self.db.cursor.fetchall()
409         # return the IDs (the first column)
410         return [row[0] for row in l]
412 class Class(MysqlClass, rdbms_common.Class):
413     pass
414 class IssueClass(MysqlClass, rdbms_common.IssueClass):
415     pass
416 class FileClass(MysqlClass, rdbms_common.FileClass):
417     pass
419 #vim: set et