[req] New method: RequestSessionBase.find().
authorChristophe de Vienne <cdevienne@gmail.com>
Wed, 11 Dec 2013 12:22:52 +0100
changeset 9348 eacd02792332
parent 9347 bd841d6ae723
child 9358 1e0235478403
[req] New method: RequestSessionBase.find(). This method does what find_entities and find_one_entity did, except it returns the resultset itself. In addition, it accepts 'reverse_' arguments and check that the relations actually exists on the entity before executing the query. Also, reimplement find_one_entity and find_entity based on the new function so they benefit from the more complete implementation, and deprecate them. Note: List of values in kwargs are NOT supported in this initial implementation. Closes #3361290
req.py
test/unittest_req.py
--- a/req.py	Wed Dec 11 17:52:54 2013 +0100
+++ b/req.py	Wed Dec 11 12:22:52 2013 +0100
@@ -29,7 +29,10 @@
 from logilab.common.deprecation import deprecated
 from logilab.common.date import ustrftime, strptime, todate, todatetime
 
-from cubicweb import Unauthorized, NoSelectableObject, uilib
+from rql.utils import rqlvar_maker
+
+from cubicweb import (Unauthorized, NoSelectableObject, NoResultError,
+                      MultipleResultsError, uilib)
 from cubicweb.rset import ResultSet
 
 ONESECOND = timedelta(0, 1, 0)
@@ -152,16 +155,16 @@
         cls = self.vreg['etypes'].etype_class(etype)
         return cls.cw_instantiate(self.execute, **kwargs)
 
+    @deprecated('[3.18] use find(etype, **kwargs).entities()')
     def find_entities(self, etype, **kwargs):
         """find entities of the given type and attribute values.
 
         >>> users = find_entities('CWGroup', name=u'users')
         >>> groups = find_entities('CWGroup')
         """
-        parts = ['Any X WHERE X is %s' % etype]
-        parts.extend('X %(attr)s %%(%(attr)s)s' % {'attr': attr} for attr in kwargs)
-        return self.execute(', '.join(parts), kwargs).entities()
+        return self.find(etype, **kwargs).entities()
 
+    @deprecated('[3.18] use find(etype, **kwargs).one()')
     def find_one_entity(self, etype, **kwargs):
         """find one entity of the given type and attribute values.
         raise :exc:`FindEntityError` if can not return one and only one entity.
@@ -170,14 +173,43 @@
         >>> groups = find_one_entity('CWGroup')
         Exception()
         """
+        try:
+            return self.find(etype, **kwargs).one()
+        except (NoResultError, MultipleResultsError) as e:
+            raise FindEntityError("%s: (%s, %s)" % (str(e), etype, kwargs))
+
+    def find(self, etype, **kwargs):
+        """find entities of the given type and attribute values.
+
+        :returns: A :class:`ResultSet`
+
+        >>> users = find('CWGroup', name=u"users").one()
+        >>> groups = find('CWGroup').entities()
+        """
         parts = ['Any X WHERE X is %s' % etype]
-        parts.extend('X %(attr)s %%(%(attr)s)s' % {'attr': attr} for attr in kwargs)
+        varmaker = rqlvar_maker(defined='X')
+        eschema = self.vreg.schema[etype]
+        for attr, value in kwargs.items():
+            if isinstance(value, list) or isinstance(value, tuple):
+                raise NotImplementedError("List of values are not supported")
+            if hasattr(value, 'eid'):
+                kwargs[attr] = value.eid
+            if attr.startswith('reverse_'):
+                attr = attr[8:]
+                assert attr in eschema.objrels, \
+                    '%s not in %s object relations' % (attr, eschema)
+                parts.append(
+                    '%(varname)s %(attr)s X, '
+                    '%(varname)s eid %%(reverse_%(attr)s)s'
+                    % {'attr': attr, 'varname': varmaker.next()})
+            else:
+                assert attr in eschema.subjrels, \
+                    '%s not in %s subject relations' % (attr, eschema)
+                parts.append('X %(attr)s %%(%(attr)s)s' % {'attr': attr})
+
         rql = ', '.join(parts)
-        rset = self.execute(rql, kwargs)
-        if len(rset) != 1:
-            raise FindEntityError('Found %i entitie(s) when 1 was expected (rql=%s ; %s)'
-                                  % (len(rset), rql, repr(kwargs)))
-        return rset.get_entity(0,0)
+
+        return self.execute(rql, kwargs)
 
     def ensure_ro_rql(self, rql):
         """raise an exception if the given rql is not a select query"""
--- a/test/unittest_req.py	Wed Dec 11 17:52:54 2013 +0100
+++ b/test/unittest_req.py	Wed Dec 11 12:22:52 2013 +0100
@@ -18,7 +18,7 @@
 
 from logilab.common.testlib import TestCase, unittest_main
 from cubicweb import ObjectNotFound
-from cubicweb.req import RequestSessionBase
+from cubicweb.req import RequestSessionBase, FindEntityError
 from cubicweb.devtools.testlib import CubicWebTC
 from cubicweb import Unauthorized
 
@@ -59,5 +59,81 @@
         self.assertEqual(req.view('oneline', rset, 'null'), '')
         self.assertRaises(ObjectNotFound, req.view, 'onelinee', rset, 'null')
 
+    def test_find_one_entity(self):
+        self.request().create_entity(
+            'CWUser', login=u'cdevienne', upassword=u'cdevienne',
+            surname=u'de Vienne', firstname=u'Christophe',
+            in_group=self.request().find('CWGroup', name=u'users').one())
+
+        self.request().create_entity(
+            'CWUser', login=u'adim', upassword='adim', surname=u'di mascio',
+            firstname=u'adrien',
+            in_group=self.request().find('CWGroup', name=u'users').one())
+
+        u = self.request().find_one_entity('CWUser', login=u'cdevienne')
+        self.assertEqual(u.firstname, u"Christophe")
+
+        with self.assertRaises(FindEntityError):
+            self.request().find_one_entity('CWUser', login=u'patanok')
+
+        with self.assertRaises(FindEntityError):
+            self.request().find_one_entity('CWUser')
+
+    def test_find_entities(self):
+        self.request().create_entity(
+            'CWUser', login=u'cdevienne', upassword=u'cdevienne',
+            surname=u'de Vienne', firstname=u'Christophe',
+            in_group=self.request().find('CWGroup', name=u'users').one())
+
+        self.request().create_entity(
+            'CWUser', login=u'adim', upassword='adim', surname=u'di mascio',
+            firstname=u'adrien',
+            in_group=self.request().find('CWGroup', name=u'users').one())
+
+        l = list(self.request().find_entities('CWUser', login=u'cdevienne'))
+        self.assertEqual(1, len(l))
+        self.assertEqual(l[0].firstname, u"Christophe")
+
+        l = list(self.request().find_entities('CWUser', login=u'patanok'))
+        self.assertEqual(0, len(l))
+
+        l = list(self.request().find_entities('CWUser'))
+        self.assertEqual(4, len(l))
+
+    def test_find(self):
+        self.request().create_entity(
+            'CWUser', login=u'cdevienne', upassword=u'cdevienne',
+            surname=u'de Vienne', firstname=u'Christophe',
+            in_group=self.request().find('CWGroup', name=u'users').one())
+
+        self.request().create_entity(
+            'CWUser', login=u'adim', upassword='adim', surname=u'di mascio',
+            firstname=u'adrien',
+            in_group=self.request().find('CWGroup', name=u'users').one())
+
+        u = self.request().find('CWUser', login=u'cdevienne').one()
+        self.assertEqual(u.firstname, u"Christophe")
+
+        users = list(self.request().find('CWUser').entities())
+        self.assertEqual(len(users), 4)
+
+        groups = list(
+            self.request().find('CWGroup', reverse_in_group=u).entities())
+        self.assertEqual(len(groups), 1)
+        self.assertEqual(groups[0].name, u'users')
+
+        users = self.request().find('CWUser', in_group=groups[0]).entities()
+        users = list(users)
+        self.assertEqual(len(users), 2)
+
+        with self.assertRaises(AssertionError):
+            self.request().find('CWUser', chapeau=u"melon")
+
+        with self.assertRaises(AssertionError):
+            self.request().find('CWUser', reverse_buddy=users[0])
+
+        with self.assertRaises(NotImplementedError):
+            self.request().find('CWUser', in_group=[1, 2])
+
 if __name__ == '__main__':
     unittest_main()