Code

- that's the postgresql backend in (cleaned up doc, unit testing harness and
[roundup.git] / roundup / backends / back_postgresql.py
1 #
2 # Copyright (c) 2003 Martynas Sklyzmantas, Andrey Lebedev <andrey@micro.lt>
3 #
4 # This module is free software, and you may redistribute it and/or modify
5 # under the same terms as Python, so long as this copyright message and
6 # disclaimer are retained in their original form.
7 #
8 # psycopg backend for roundup
9 #
11 from roundup.backends.rdbms_common import *
12 from roundup.backends import rdbms_common
13 import psycopg
14 import os, shutil, popen2
16 class Database(Database):
17     arg = '%s'
19     def open_connection(self):
20         db = getattr(self.config, 'POSTGRESQL_DATABASE')
21         try:
22             self.conn = psycopg.connect(**db)
23         except psycopg.OperationalError, message:
24             raise DatabaseError, message
26         self.cursor = self.conn.cursor()
28         try:
29             self.database_schema = self.load_dbschema()
30         except:
31             self.rollback()
32             self.database_schema = {}
33             self.sql("CREATE TABLE schema (schema TEXT)")
34             self.sql("CREATE TABLE ids (name VARCHAR(255), num INT4)")
36     def close(self):
37         self.conn.close()
39     def __repr__(self):
40         return '<roundpsycopgsql 0x%x>' % id(self)
42     def sql_fetchone(self):
43         return self.cursor.fetchone()
45     def sql_fetchall(self):
46         return self.cursor.fetchall()
48     def sql_stringquote(self, value):
49         ''' psycopg.QuotedString returns a "buffer" object with the
50             single-quotes around it... '''
51         return str(psycopg.QuotedString(str(value)))[1:-1]
53     def sql_index_exists(self, table_name, index_name):
54         sql = 'select count(*) from pg_indexes where ' \
55             'tablename=%s and indexname=%s'%(self.arg, self.arg)
56         self.cursor.execute(sql, (table_name, index_name))
57         return self.cursor.fetchone()[0]
59     def save_dbschema(self, schema):
60         s = repr(self.database_schema)
61         self.sql('INSERT INTO schema VALUES (%s)', (s,))
62     
63     def load_dbschema(self):
64         self.cursor.execute('SELECT schema FROM schema')
65         schema = self.cursor.fetchone()
66         if schema:
67             return eval(schema[0])
69     def save_journal(self, classname, cols, nodeid, journaldate,
70                      journaltag, action, params):
71         params = repr(params)
72         entry = (nodeid, journaldate, journaltag, action, params)
74         a = self.arg
75         sql = 'INSERT INTO %s__journal (%s) values (%s, %s, %s, %s, %s)'%(
76             classname, cols, a, a, a, a, a)
78         if __debug__:
79           print >>hyperdb.DEBUG, 'addjournal', (self, sql, entry)
81         self.cursor.execute(sql, entry)
83     def load_journal(self, classname, cols, nodeid):
84         sql = 'SELECT %s FROM %s__journal WHERE nodeid = %s' % (
85             cols, classname, self.arg)
86         
87         if __debug__:
88             print >>hyperdb.DEBUG, 'getjournal', (self, sql, nodeid)
90         self.cursor.execute(sql, (nodeid,))
91         res = []
92         for nodeid, date_stamp, user, action, params in self.cursor.fetchall():
93             params = eval(params)
94             res.append((nodeid, date.Date(date_stamp), user, action, params))
95         return res
97     def create_class_table(self, spec):
98         cols, mls = self.determine_columns(spec.properties.items())
99         cols.append('id')
100         cols.append('__retired__')
101         scols = ',' . join(['"%s" VARCHAR(255)' % x for x in cols])
102         sql = 'CREATE TABLE "_%s" (%s)' % (spec.classname, scols)
104         if __debug__:
105             print >>hyperdb.DEBUG, 'create_class', (self, sql)
107         self.cursor.execute(sql)
108         return cols, mls
110     def create_journal_table(self, spec):
111         cols = ',' . join(['"%s" VARCHAR(255)' % x
112                            for x in 'nodeid date tag action params' . split()])
113         sql  = 'CREATE TABLE "%s__journal" (%s)'%(spec.classname, cols)
114         
115         if __debug__:
116             print >>hyperdb.DEBUG, 'create_class', (self, sql)
118         self.cursor.execute(sql)
120     def create_multilink_table(self, spec, ml):
121         sql = '''CREATE TABLE "%s_%s" (linkid VARCHAR(255),
122                    nodeid VARCHAR(255))''' % (spec.classname, ml)
124         if __debug__:
125             print >>hyperdb.DEBUG, 'create_class', (self, sql)
127         self.cursor.execute(sql)
129 class PsycopgClass:
130     def find(self, **propspec):
131         """Get the ids of nodes in this class which link to the given nodes."""
132         
133         if __debug__:
134             print >>hyperdb.DEBUG, 'find', (self, propspec)
136         # shortcut
137         if not propspec:
138             return []
140         # validate the args
141         props = self.getprops()
142         propspec = propspec.items()
143         for propname, nodeids in propspec:
144             # check the prop is OK
145             prop = props[propname]
146             if not isinstance(prop, Link) and not isinstance(prop, Multilink):
147                 raise TypeError, "'%s' not a Link/Multilink property"%propname
149         # first, links
150         l = []
151         where = []
152         allvalues = ()
153         a = self.db.arg
154         for prop, values in propspec:
155             if not isinstance(props[prop], hyperdb.Link):
156                 continue
157             if type(values) is type(''):
158                 allvalues += (values,)
159                 where.append('_%s = %s' % (prop, a))
160             elif values is None:
161                 where.append('_%s is NULL'%prop)
162             else:
163                 allvalues += tuple(values.keys())
164                 where.append('_%s in (%s)' % (prop, ','.join([a]*len(values))))
165         tables = []
166         if where:
167             self.db.sql('SELECT id AS nodeid FROM _%s WHERE %s' % (
168                 self.classname, ' and '.join(where)), allvalues)
169             l += [x[0] for x in self.db.sql_fetchall()]
171         # now multilinks
172         for prop, values in propspec:
173             vals = ()
174             if not isinstance(props[prop], hyperdb.Multilink):
175                 continue
176             if type(values) is type(''):
177                 vals = (values,)
178                 s = a
179             else:
180                 vals = tuple(values.keys())
181                 s = ','.join([a]*len(values))
182             query = 'SELECT nodeid FROM %s_%s WHERE linkid IN (%s)'%(
183                 self.classname, prop, s)
184             self.db.sql(query, vals)
185             l += [x[0] for x in self.db.sql_fetchall()]
186             
187         if __debug__:
188             print >>hyperdb.DEBUG, 'find ... ', l
190         # Remove duplicated ids
191         d = {}
192         for k in l:
193             d[k] = 1
194         return d.keys()
196         return l
198 class Class(PsycopgClass, rdbms_common.Class):
199     pass
200 class IssueClass(PsycopgClass, rdbms_common.IssueClass):
201     pass
202 class FileClass(PsycopgClass, rdbms_common.FileClass):
203     pass