Code

more compliance testing, this time for find()
authorrichard <richard@57a73879-2fb5-44c3-a270-3262357dd7e2>
Tue, 20 Jan 2004 05:55:51 +0000 (05:55 +0000)
committerrichard <richard@57a73879-2fb5-44c3-a270-3262357dd7e2>
Tue, 20 Jan 2004 05:55:51 +0000 (05:55 +0000)
git-svn-id: http://svn.roundup-tracker.org/svnroot/roundup/trunk@2056 57a73879-2fb5-44c3-a270-3262357dd7e2

roundup/backends/back_mysql.py
roundup/backends/rdbms_common.py
test/db_test_base.py

index 7ead40200d1ca568ead60ad320728171f6593339..fa3c6d4ccb97f50d947083bd5159e05066885105 100644 (file)
@@ -411,92 +411,6 @@ class MysqlClass:
         # return the IDs (the first column)
         return [row[0] for row in l]
 
-    # mysql doesn't implement INTERSECT
-    def find(self, **propspec):
-        '''Get the ids of nodes in this class which link to the given nodes.
-
-        'propspec' consists of keyword args propname=nodeid or
-                   propname={nodeid:1, }
-        'propname' must be the name of a property in this class, or a
-                   KeyError is raised.  That property must be a Link or
-                   Multilink property, or a TypeError is raised.
-
-        Any node in this class whose 'propname' property links to any of the
-        nodeids will be returned. Used by the full text indexing, which knows
-        that "foo" occurs in msg1, msg3 and file7, so we have hits on these
-        issues:
-
-            db.issue.find(messages={'1':1,'3':1}, files={'7':1})
-        '''
-        if __debug__:
-            print >>hyperdb.DEBUG, 'find', (self, propspec)
-
-        # shortcut
-        if not propspec:
-            return []
-
-        # validate the args
-        props = self.getprops()
-        propspec = propspec.items()
-        for propname, nodeids in propspec:
-            # check the prop is OK
-            prop = props[propname]
-            if not isinstance(prop, Link) and not isinstance(prop, Multilink):
-                raise TypeError, "'%s' not a Link/Multilink property"%propname
-
-        # first, links
-        a = self.db.arg
-        where = ['__retired__ <> %s'%a]
-        allvalues = (1,)
-        for prop, values in propspec:
-            if not isinstance(props[prop], hyperdb.Link):
-                continue
-            if type(values) is type({}) and len(values) == 1:
-                values = values.keys()[0]
-            if type(values) is type(''):
-                allvalues += (values,)
-                where.append('_%s = %s'%(prop, a))
-            elif values is None:
-                where.append('_%s is NULL'%prop)
-            else:
-                allvalues += tuple(values.keys())
-                where.append('_%s in (%s)'%(prop, ','.join([a]*len(values))))
-        tables = []
-        if where:
-            tables.append('select id as nodeid from _%s where %s'%(
-                self.classname, ' and '.join(where)))
-
-        # now multilinks
-        for prop, values in propspec:
-            if not isinstance(props[prop], hyperdb.Multilink):
-                continue
-            if type(values) is type(''):
-                allvalues += (values,)
-                s = a
-            else:
-                allvalues += tuple(values.keys())
-                s = ','.join([a]*len(values))
-            tables.append('select nodeid from %s_%s where linkid in (%s)'%(
-                self.classname, prop, s))
-
-        raise NotImplemented, "XXX this code's farked"
-        d = {}
-        self.db.sql(sql, allvalues)
-        for result in self.db.sql_fetchall():
-            d[result[0]] = 1
-
-        for query in tables[1:]:
-            self.db.sql(sql, allvalues)
-            for result in self.db.sql_fetchall():
-                if not d.has_key(result[0]):
-                    continue
-
-        if __debug__:
-            print >>hyperdb.DEBUG, 'find ... ', l
-        l = d.keys()
-        l.sort()
-        return l
-
 class Class(MysqlClass, rdbms_common.Class):
     pass
 class IssueClass(MysqlClass, rdbms_common.IssueClass):
index d9e2d582b349ce80246c6fabf125a0caa81c3358..aebd2594fcde4a48d6f10051b35eae18346d73e4 100644 (file)
@@ -1,4 +1,4 @@
-# $Id: rdbms_common.py,v 1.73 2004-01-20 03:58:38 richard Exp $
+# $Id: rdbms_common.py,v 1.74 2004-01-20 05:55:51 richard Exp $
 ''' Relational database (SQL) backend common code.
 
 Basics:
@@ -1789,8 +1789,9 @@ class Class(hyperdb.Class):
 
         # first, links
         a = self.db.arg
-        where = ['__retired__ <> %s'%a]
         allvalues = (1,)
+        o = []
+        where = []
         for prop, values in propspec:
             if not isinstance(props[prop], hyperdb.Link):
                 continue
@@ -1804,25 +1805,34 @@ class Class(hyperdb.Class):
             else:
                 allvalues += tuple(values.keys())
                 where.append('_%s in (%s)'%(prop, ','.join([a]*len(values))))
-        tables = []
+        tables = ['_%s'%self.classname]
         if where:
-            tables.append('select id as nodeid from _%s where %s'%(
-                self.classname, ' and '.join(where)))
+            o.append('(' + ' and '.join(where) + ')')
 
         # now multilinks
         for prop, values in propspec:
             if not isinstance(props[prop], hyperdb.Multilink):
                 continue
+            if not values:
+                continue
             if type(values) is type(''):
                 allvalues += (values,)
                 s = a
             else:
                 allvalues += tuple(values.keys())
                 s = ','.join([a]*len(values))
-            tables.append('select nodeid from %s_%s where linkid in (%s)'%(
-                self.classname, prop, s))
+            tn = '%s_%s'%(self.classname, prop)
+            tables.append(tn)
+            o.append('(id=%s.nodeid and %s.linkid in (%s))'%(tn, tn, s))
 
-        sql = '\nintersect\n'.join(tables)
+        if not o:
+            return []
+        elif len(o) > 1:
+            o = '(' + ' or '.join(['(%s)'%i for i in o]) + ')'
+        else:
+            o = o[0]
+        t = ', '.join(tables)
+        sql = 'select distinct(id) from %s where __retired__ <> %s and %s'%(t, a, o)
         self.db.sql(sql, allvalues)
         l = [x[0] for x in self.db.sql_fetchall()]
         if __debug__:
index 9945ef4ded0568b2f9312f56af28bc7c529324d5..ed021c408aabb84bf422f2469895dadd62fe0929 100644 (file)
@@ -15,7 +15,7 @@
 # BASIS, AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE,
 # SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
 # 
-# $Id: db_test_base.py,v 1.13 2004-01-20 03:58:38 richard Exp $ 
+# $Id: db_test_base.py,v 1.14 2004-01-20 05:55:51 richard Exp $ 
 
 import unittest, os, shutil, errno, imp, sys, time, pprint
 
@@ -621,54 +621,88 @@ class DBTest(MyTestCase):
     #
     # searching tests follow
     #
-    def testFind(self):
+    def testFindIncorrectProperty(self):
         self.assertRaises(TypeError, self.db.issue.find, title='fubar')
 
-        self.db.user.create(username='test')
-        ids = []
-        ids.append(self.db.issue.create(status="1", nosy=['1']))
-        oddid = self.db.issue.create(status="2", nosy=['2'], assignedto='2')
-        ids.append(self.db.issue.create(status="1", nosy=['1','2']))
-        self.db.issue.create(status="3", nosy=['1'], assignedto='1')
-        ids.sort()
-
-        # should match first and third
+    def _find_test_setup(self):
+        self.db.file.create(content='')
+        self.db.file.create(content='')
+        self.db.user.create(username='')
+        one = self.db.issue.create(status="1", nosy=['1'])
+        two = self.db.issue.create(status="2", nosy=['2'], files=['1'],
+            assignedto='2')
+        three = self.db.issue.create(status="1", nosy=['1','2'])
+        four = self.db.issue.create(status="3", assignedto='1',
+            files=['1','2'])
+        return one, two, three, four
+
+    def testFindLink(self):
+        one, two, three, four = self._find_test_setup()
         got = self.db.issue.find(status='1')
         got.sort()
-        self.assertEqual(got, ids)
+        self.assertEqual(got, [one, three])
         got = self.db.issue.find(status={'1':1})
         got.sort()
-        self.assertEqual(got, ids)
+        self.assertEqual(got, [one, three])
 
-        # none
+    def testFindLinkFail(self):
+        self._find_test_setup()
         self.assertEqual(self.db.issue.find(status='4'), [])
         self.assertEqual(self.db.issue.find(status={'4':1}), [])
 
-        # should match first and third
+    def testFindLinkUnset(self):
+        one, two, three, four = self._find_test_setup()
         got = self.db.issue.find(assignedto=None)
         got.sort()
-        self.assertEqual(got, ids)
+        self.assertEqual(got, [one, three])
         got = self.db.issue.find(assignedto={None:1})
         got.sort()
-        self.assertEqual(got, ids)
+        self.assertEqual(got, [one, three])
+
+    def testFindMultilink(self):
+        one, two, three, four = self._find_test_setup()
+        got = self.db.issue.find(nosy='2')
+        got.sort()
+        self.assertEqual(got, [two, three])
+        got = self.db.issue.find(nosy={'2':1})
+        got.sort()
+        self.assertEqual(got, [two, three])
+        got = self.db.issue.find(nosy={'2':1}, files={})
+        got.sort()
+        self.assertEqual(got, [two, three])
 
-        # should match first three
+    def testFindMultiMultilink(self):
+        one, two, three, four = self._find_test_setup()
+        got = self.db.issue.find(nosy='2', files='1')
+        got.sort()
+        self.assertEqual(got, [two, three, four])
+        got = self.db.issue.find(nosy={'2':1}, files={'1':1})
+        got.sort()
+        self.assertEqual(got, [two, three, four])
+
+    def testFindMultilinkFail(self):
+        self._find_test_setup()
+        self.assertEqual(self.db.issue.find(nosy='3'), [])
+        self.assertEqual(self.db.issue.find(nosy={'3':1}), [])
+
+    def testFindMultilinkUnset(self):
+        self._find_test_setup()
+        self.assertEqual(self.db.issue.find(nosy={}), [])
+
+    def testFindLinkAndMultilink(self):
+        one, two, three, four = self._find_test_setup()
         got = self.db.issue.find(status='1', nosy='2')
         got.sort()
-        ids.append(oddid)
-        ids.sort()
-        self.assertEqual(got, ids)
+        self.assertEqual(got, [one, two, three])
         got = self.db.issue.find(status={'1':1}, nosy={'2':1})
         got.sort()
-        self.assertEqual(got, ids)
-
-        # none
-        self.assertEqual(self.db.issue.find(status='4', nosy='3'), [])
-        self.assertEqual(self.db.issue.find(status={'4':1}, nosy={'3':1}), [])
+        self.assertEqual(got, [one, two, three])
 
-        # test retiring a node
-        self.db.issue.retire(ids[0])
-        self.assertEqual(len(self.db.issue.find(status='1', nosy='2')), 2)
+    def testFindRetired(self):
+        one, two, three, four = self._find_test_setup()
+        self.assertEqual(len(self.db.issue.find(status='1')), 2)
+        self.db.issue.retire(one)
+        self.assertEqual(len(self.db.issue.find(status='1')), 1)
 
     def testStringFind(self):
         self.assertRaises(TypeError, self.db.issue.stringFind, status='1')