--- a/selectors.py Wed Feb 18 13:41:07 2009 +0100
+++ b/selectors.py Wed Feb 18 13:41:55 2009 +0100
@@ -54,7 +54,6 @@
from cubicweb import Unauthorized, NoSelectableObject, role
from cubicweb.vregistry import (NoSelectableObject, Selector,
chainall, chainfirst, objectify_selector)
-from cubicweb.cwvreg import DummyCursorError
from cubicweb.cwconfig import CubicWebConfiguration
from cubicweb.schema import split_expression
@@ -70,7 +69,7 @@
# /!\ lltrace decorates pure function or __call__ method, this
# means argument order may be different
if isinstance(cls, Selector):
- selname = cls.__class__.__name__
+ selname = str(cls)
vobj = args[0]
else:
selname = selector.__name__
@@ -128,8 +127,9 @@
- `once_is_enough` is False, in which case if score_class return 0, 0 is
returned
"""
- def __init__(self, once_is_enough=False):
+ def __init__(self, once_is_enough=False, sumscores=True):
self.once_is_enough = once_is_enough
+ self.sumscores = sumscores
@lltrace
def __call__(self, cls, req, rset, row=None, col=0, **kwargs):
@@ -150,7 +150,7 @@
etype = rset.description[row][col]
if etype is not None:
score = self.score(cls, req, etype)
- return score and (score + 1)
+ return score
def score(self, cls, req, etype):
if etype in BASE_TYPES:
@@ -197,7 +197,7 @@
etype = rset.description[row][col]
if etype is not None: # outer join
score = self.score(req, rset, row, col)
- return score and (score + 1)
+ return score
def score(self, req, rset, row, col):
try:
@@ -386,6 +386,10 @@
def __init__(self, *expected):
assert expected, self
self.expected = frozenset(expected)
+
+ def __str__(self):
+ return '%s(%s)' % (self.__class__.__name__,
+ ','.join(sorted(str(s) for s in self.expected)))
@lltrace
def __call__(self, cls, req, rset, row=None, col=0, **kwargs):
@@ -502,6 +506,10 @@
super(implements, self).__init__()
self.expected_ifaces = expected_ifaces
+ def __str__(self):
+ return '%s(%s)' % (self.__class__.__name__,
+ ','.join(str(s) for s in self.expected_ifaces))
+
def score_class(self, eclass, req):
score = 0
for iface in self.expected_ifaces:
@@ -509,21 +517,23 @@
# entity type
iface = eclass.vreg.etype_class(iface)
if implements_iface(eclass, iface):
- score += 1
if getattr(iface, '__registry__', None) == 'etypes':
- score += 1
# adjust score if the interface is an entity class
if iface is eclass:
- score += (len(eclass.e_schema.ancestors()) + 1)
-# print 'is majoration', len(eclass.e_schema.ancestors())
- else:
+ score += len(eclass.e_schema.ancestors()) + 4
+ else:
parents = [e.type for e in eclass.e_schema.ancestors()]
for index, etype in enumerate(reversed(parents)):
basecls = eclass.vreg.etype_class(etype)
if iface is basecls:
- score += index
-# print 'etype majoration', index
+ score += index + 3
break
+ else: # Any
+ score += 1
+ else:
+ # implenting an interface takes precedence other special Any
+ # interface
+ score += 2
return score
--- a/test/unittest_selectors.py Wed Feb 18 13:41:07 2009 +0100
+++ b/test/unittest_selectors.py Wed Feb 18 13:41:55 2009 +0100
@@ -8,6 +8,9 @@
from logilab.common.testlib import TestCase, unittest_main
from cubicweb.vregistry import Selector, AndSelector, OrSelector
+from cubicweb.selectors import implements
+
+from cubicweb.interfaces import IDownloadable
class _1_(Selector):
def __call__(self, *args, **kwargs):
@@ -74,7 +77,26 @@
self.assertEquals(len(selector.selectors), 2)
self.assertEquals(selector(None), 2)
+ def test_search_selectors(self):
+ sel = implements('something')
+ self.assertIs(sel.search_selector(implements), sel)
+ csel = AndSelector(sel, Selector())
+ self.assertIs(csel.search_selector(implements), sel)
+ csel = AndSelector(Selector(), sel)
+ self.assertIs(csel.search_selector(implements), sel)
+
+from cubicweb.devtools.testlib import EnvBasedTC
+class ImplementsSelectorTC(EnvBasedTC):
+ def test_etype_priority(self):
+ req = self.request()
+ cls = self.vreg.etype_class('File')
+ anyscore = implements('Any').score_class(cls, req)
+ idownscore = implements(IDownloadable).score_class(cls, req)
+ self.failUnless(idownscore > anyscore, (idownscore, anyscore))
+ filescore = implements('File').score_class(cls, req)
+ self.failUnless(filescore > idownscore, (filescore, idownscore))
+
if __name__ == '__main__':
unittest_main()