EntitySelector base class now understand 'entity' in kwargs, new entity_implements selector tls-sprint
authorsylvain.thenault@logilab.fr
Wed, 08 Apr 2009 20:30:25 +0200
branchtls-sprint
changeset 1301 4596ce9bb4dc
parent 1300 62d2b890a980
child 1302 dd984d682ab0
EntitySelector base class now understand 'entity' in kwargs, new entity_implements selector
selectors.py
--- a/selectors.py	Wed Apr 08 14:11:34 2009 +0200
+++ b/selectors.py	Wed Apr 08 20:30:25 2009 +0200
@@ -113,6 +113,7 @@
 
 
 # abstract selectors ##########################################################
+
 class PartialSelectorMixIn(object):
     """convenience mix-in for selectors that will look into the containing
     class to find missing information.
@@ -123,6 +124,49 @@
         self.complete(cls)
         return super(PartialSelectorMixIn, self).__call__(cls, *args, **kwargs)
 
+
+class ImplementsMixIn(object):
+    """mix-in class for selectors checking implemented interfaces of something
+    """
+    def __init__(self, *expected_ifaces):
+        super(ImplementsMixIn, self).__init__()
+        self.expected_ifaces = expected_ifaces
+
+    def __str__(self):
+        return '%s(%s)' % (self.__class__.__name__,
+                           ','.join(str(s) for s in self.expected_ifaces))
+    
+    def score_interfaces(self, cls_or_inst, cls):
+        score = 0
+        vreg, eschema = cls_or_inst.vreg, cls_or_inst.e_schema
+        for iface in self.expected_ifaces:
+            if isinstance(iface, basestring):
+                # entity type
+                try:
+                    iface = vreg.etype_class(iface)
+                except KeyError:
+                    continue # entity type not in the schema
+            if implements_iface(cls_or_inst, iface):
+                if getattr(iface, '__registry__', None) == 'etypes':
+                    # adjust score if the interface is an entity class
+                    if iface is cls:
+                        score += len(eschema.ancestors()) + 4
+                    else: 
+                        parents = [e.type for e in eschema.ancestors()]
+                        for index, etype in enumerate(reversed(parents)):
+                            basecls = vreg.etype_class(etype)
+                            if iface is basecls:
+                                score += index + 3
+                                break
+                        else: # Any
+                            score += 1
+                else:
+                    # implenting an interface takes precedence other special Any
+                    # interface
+                    score += 2
+        return score
+
+
 class EClassSelector(Selector):
     """abstract class for selectors working on the entity classes of the result
     set. Its __call__ method has the following behaviour:
@@ -173,6 +217,8 @@
     """abstract class for selectors working on the entity instances of the
     result set. Its __call__ method has the following behaviour:
 
+    * if 'entity' find in kwargs, return the score returned by the score_entity
+      method for this entity
     * if row is specified, return the score returned by the score_entity method
       called with the entity instance found in the specified cell
     * else return the sum of score returned by the score_entity method for each
@@ -188,10 +234,12 @@
     
     @lltrace
     def __call__(self, cls, req, rset, row=None, col=0, **kwargs):
-        if not rset:
+        if not rset and not kwargs.get('entity'):
             return 0
         score = 0
-        if row is None:
+        if kwargs.get('entity'):
+            score = self.score_entity(kwargs['entity'])
+        elif row is None:
             for row, rowvalue in enumerate(rset.rows):
                 if rowvalue[col] is None: # outer join
                     continue
@@ -508,8 +556,8 @@
 
 # not so basic selectors ######################################################
 
-class implements(EClassSelector):
-    """accept if entity class found in the result set implements at least one
+class implements(ImplementsMixIn, EClassSelector):
+    """accept if entity classes found in the result set implements at least one
     of the interfaces given as argument. Returned score is the number of
     implemented interfaces.
 
@@ -523,42 +571,8 @@
     note: when interface is an entity class, the score will reflect class
           proximity so the most specific object'll be selected
     """
-    def __init__(self, *expected_ifaces):
-        super(implements, self).__init__()
-        self.expected_ifaces = expected_ifaces
-
-    def __str__(self):
-        return '%s(%s)' % (self.__class__.__name__,
-                           ','.join(str(s) for s in self.expected_ifaces))
-    
     def score_class(self, eclass, req):
-        score = 0
-        for iface in self.expected_ifaces:
-            if isinstance(iface, basestring):
-                # entity type
-                try:
-                    iface = eclass.vreg.etype_class(iface)
-                except KeyError:
-                    continue # entity type not in the schema
-            if implements_iface(eclass, iface):
-                if getattr(iface, '__registry__', None) == 'etypes':
-                    # adjust score if the interface is an entity class
-                    if iface is eclass:
-                        score += len(eclass.e_schema.ancestors()) + 4
-                    else: 
-                        parents = [e.type for e in eclass.e_schema.ancestors()]
-                        for index, etype in enumerate(reversed(parents)):
-                            basecls = eclass.vreg.etype_class(etype)
-                            if iface is basecls:
-                                score += index + 3
-                                break
-                        else: # Any
-                            score += 1
-                else:
-                    # implenting an interface takes precedence other special Any
-                    # interface
-                    score += 2
-        return score
+        return self.score_interfaces(eclass, eclass)
 
 
 class specified_etype_implements(implements):
@@ -587,6 +601,25 @@
         return self.score_class(cls.vreg.etype_class(etype), req)
 
 
+class entity_implements(ImplementsMixIn, EntitySelector):
+    """accept if entity instances found in the result set implements at least one
+    of the interfaces given as argument. Returned score is the number of
+    implemented interfaces.
+
+    See `EntitySelector` documentation for behaviour when row is not specified.
+
+    :param *expected_ifaces: expected interfaces. An interface may be a class
+                             or an entity type (e.g. `basestring`) in which case
+                             the associated class will be searched in the
+                             registry (at selection time)
+                             
+    note: when interface is an entity class, the score will reflect class
+          proximity so the most specific object'll be selected
+    """    
+    def score_entity(self, entity):
+        return self.score_interfaces(entity, entity.__class__)
+
+
 class relation_possible(EClassSelector):
     """accept if entity class found in the result set support the relation.