entity.py
brancholdstable
changeset 7676 cc3987eb793c
parent 7509 c69dd872e5d7
child 7514 32081892850e
--- a/entity.py	Mon May 16 16:24:00 2011 +0200
+++ b/entity.py	Wed Jul 20 18:21:47 2011 +0200
@@ -28,7 +28,7 @@
 
 from rql.utils import rqlvar_maker
 
-from cubicweb import Unauthorized, typed_eid
+from cubicweb import Unauthorized, typed_eid, neg_role
 from cubicweb.rset import ResultSet
 from cubicweb.selectors import yes
 from cubicweb.appobject import AppObject
@@ -62,6 +62,23 @@
     return True
 
 
+def remove_ambiguous_rels(attr_set, subjtypes, schema):
+    '''remove from `attr_set` the relations of entity types `subjtypes` that have
+    different entity type sets as target'''
+    for attr in attr_set.copy():
+        rschema = schema.rschema(attr)
+        if rschema.final:
+            continue
+        ttypes = None
+        for subjtype in subjtypes:
+            cur_ttypes = rschema.objects(subjtype)
+            if ttypes is None:
+                ttypes = cur_ttypes
+            elif cur_ttypes != ttypes:
+                attr_set.remove(attr)
+                break
+
+
 class Entity(AppObject):
     """an entity instance has e_schema automagically set on
     the class and instances has access to their issuing cursor.
@@ -91,7 +108,7 @@
     # class attributes that must be set in class definition
     rest_attr = None
     fetch_attrs = None
-    skip_copy_for = ('in_state',)
+    skip_copy_for = ('in_state',) # XXX turn into a set
     # class attributes set automatically at registration time
     e_schema = None
 
@@ -157,6 +174,7 @@
     def fetch_rql(cls, user, restriction=None, fetchattrs=None, mainvar='X',
                   settype=True, ordermethod='fetch_order'):
         """return a rql to fetch all entities of the class type"""
+        # XXX update api and implementation to AST manipulation (see unrelated rql)
         restrictions = restriction or []
         if settype:
             restrictions.append('%s is %s' % (mainvar, cls.__regid__))
@@ -165,6 +183,7 @@
         selection = [mainvar]
         orderby = []
         # start from 26 to avoid possible conflicts with X
+        # XXX not enough to be sure it'll be no conflicts
         varmaker = rqlvar_maker(index=26)
         cls._fetch_restrictions(mainvar, varmaker, fetchattrs, selection,
                                 orderby, restrictions, user, ordermethod)
@@ -202,8 +221,6 @@
             restriction = '%s %s %s' % (mainvar, attr, var)
             restrictions.append(restriction)
             if not rschema.final:
-                # XXX this does not handle several destination types
-                desttype = rschema.objects(eschema.type)[0]
                 card = rdef.cardinality[0]
                 if card not in '?1':
                     cls.warning('bad relation %s specified in fetch attrs for %s',
@@ -216,11 +233,18 @@
                 # that case the relation may still be missing. As we miss this
                 # later information here, systematically add it.
                 restrictions[-1] += '?'
+                targettypes = rschema.objects(eschema.type)
                 # XXX user._cw.vreg iiiirk
-                destcls = user._cw.vreg['etypes'].etype_class(desttype)
-                destcls._fetch_restrictions(var, varmaker, destcls.fetch_attrs,
-                                            selection, orderby, restrictions,
-                                            user, ordermethod, visited=visited)
+                etypecls = user._cw.vreg['etypes'].etype_class(targettypes[0])
+                if len(targettypes) > 1:
+                    # find fetch_attrs common to all destination types
+                    fetchattrs = user._cw.vreg['etypes'].fetch_attrs(targettypes)
+                    remove_ambiguous_rels(fetchattrs, targettypes, user._cw.vreg.schema)
+                else:
+                    fetchattrs = etypecls.fetch_attrs
+                etypecls._fetch_restrictions(var, varmaker, fetchattrs,
+                                             selection, orderby, restrictions,
+                                             user, ordermethod, visited=visited)
             if ordermethod is not None:
                 orderterm = getattr(cls, ordermethod)(attr, var)
                 if orderterm:
@@ -264,6 +288,7 @@
         restrictions = set()
         pending_relations = []
         eschema = cls.e_schema
+        qargs = {}
         for attr, value in kwargs.items():
             if attr.startswith('reverse_'):
                 attr = attr[len('reverse_'):]
@@ -277,10 +302,11 @@
                     value = iter(value).next()
                 else:
                     # prepare IN clause
-                    del kwargs[attr]
-                    pending_relations.append( (attr, value) )
+                    pending_relations.append( (attr, role, value) )
                     continue
-            if hasattr(value, 'eid'): # non final relation
+            if rschema.final: # attribute
+                relations.append('X %s %%(%s)s' % (attr, attr))
+            else:
                 rvar = attr.upper()
                 if role == 'object':
                     relations.append('%s %s X' % (rvar, attr))
@@ -289,21 +315,21 @@
                 restriction = '%s eid %%(%s)s' % (rvar, attr)
                 if not restriction in restrictions:
                     restrictions.add(restriction)
-                kwargs[attr] = value.eid
-            else: # attribute
-                relations.append('X %s %%(%s)s' % (attr, attr))
+                if hasattr(value, 'eid'):
+                    value = value.eid
+            qargs[attr] = value
         if relations:
             rql = '%s: %s' % (rql, ', '.join(relations))
         if restrictions:
             rql = '%s WHERE %s' % (rql, ', '.join(restrictions))
-        created = execute(rql, kwargs).get_entity(0, 0)
-        for attr, values in pending_relations:
-            if attr.startswith('reverse_'):
-                restr = 'Y %s X' % attr[len('reverse_'):]
+        created = execute(rql, qargs).get_entity(0, 0)
+        for attr, role, values in pending_relations:
+            if role == 'object':
+                restr = 'Y %s X' % attr
             else:
                 restr = 'X %s Y' % attr
             execute('SET %s WHERE X eid %%(x)s, Y eid IN (%s)' % (
-                restr, ','.join(str(r.eid) for r in values)),
+                restr, ','.join(str(getattr(r, 'eid', r)) for r in values)),
                     {'x': created.eid}, build_descr=False)
         return created
 
@@ -728,17 +754,17 @@
             else:
                 restriction += ', X is IN (%s)' % ','.join(targettypes)
             card = greater_card(rschema, targettypes, (self.e_schema,), 1)
+        etypecls = self._cw.vreg['etypes'].etype_class(targettypes[0])
         if len(targettypes) > 1:
-            fetchattrs_list = []
-            for ttype in targettypes:
-                etypecls = self._cw.vreg['etypes'].etype_class(ttype)
-                fetchattrs_list.append(set(etypecls.fetch_attrs))
-            fetchattrs = reduce(set.intersection, fetchattrs_list)
-            rql = etypecls.fetch_rql(self._cw.user, [restriction], fetchattrs,
-                                     settype=False)
+            fetchattrs = self._cw.vreg['etypes'].fetch_attrs(targettypes)
+            # XXX we should fetch ambiguous relation objects too but not
+            # recurse on them in _fetch_restrictions; it is easier to remove
+            # them completely for now, as it would require an deeper api rewrite
+            remove_ambiguous_rels(fetchattrs, targettypes, self._cw.vreg.schema)
         else:
-            etypecls = self._cw.vreg['etypes'].etype_class(targettypes[0])
-            rql = etypecls.fetch_rql(self._cw.user, [restriction], settype=False)
+            fetchattrs = etypecls.fetch_attrs
+        rql = etypecls.fetch_rql(self._cw.user, [restriction], fetchattrs,
+                                 settype=False)
         # optimisation: remove ORDERBY if cardinality is 1 or ? (though
         # greater_card return 1 for those both cases)
         if card == '1':
@@ -762,7 +788,7 @@
     # generic vocabulary methods ##############################################
 
     def cw_unrelated_rql(self, rtype, targettype, role, ordermethod=None,
-                      vocabconstraints=True):
+                         vocabconstraints=True):
         """build a rql to fetch `targettype` entities unrelated to this entity
         using (rtype, role) relation.
 
@@ -772,58 +798,83 @@
         ordermethod = ordermethod or 'fetch_unrelated_order'
         if isinstance(rtype, basestring):
             rtype = self._cw.vreg.schema.rschema(rtype)
+        rdef = rtype.role_rdef(self.e_schema, targettype, role)
+        rewriter = RQLRewriter(self._cw)
+        # initialize some variables according to the `role` of `self` in the
+        # relation:
+        # * variable for myself (`evar`) and searched entities (`searchvedvar`)
+        # * entity type of the subject (`subjtype`) and of the object
+        #   (`objtype`) of the relation
         if role == 'subject':
             evar, searchedvar = 'S', 'O'
             subjtype, objtype = self.e_schema, targettype
         else:
             searchedvar, evar = 'S', 'O'
             objtype, subjtype = self.e_schema, targettype
+        # initialize some variables according to `self` existance
+        if rdef.role_cardinality(neg_role(role)) in '?1':
+            # if cardinality in '1?', we want a target entity which isn't
+            # already linked using this relation
+            if searchedvar == 'S':
+                restriction = ['NOT S %s ZZ' % rtype]
+            else:
+                restriction = ['NOT ZZ %s O' % rtype]
+        elif self.has_eid():
+            # elif we have an eid, we don't want a target entity which is
+            # already linked to ourself through this relation
+            restriction = ['NOT S %s O' % rtype]
+        else:
+            restriction = []
         if self.has_eid():
-            restriction = ['NOT S %s O' % rtype, '%s eid %%(x)s' % evar]
+            restriction += ['%s eid %%(x)s' % evar]
             args = {'x': self.eid}
             if role == 'subject':
-                securitycheck_args = {'fromeid': self.eid}
+                sec_check_args = {'fromeid': self.eid}
             else:
-                securitycheck_args = {'toeid': self.eid}
+                sec_check_args = {'toeid': self.eid}
+            existant = None # instead of 'SO', improve perfs
         else:
-            restriction = []
             args = {}
-            securitycheck_args = {}
-        rdef = rtype.role_rdef(self.e_schema, targettype, role)
-        insertsecurity = (rdef.has_local_role('add') and not
-                          rdef.has_perm(self._cw, 'add', **securitycheck_args))
-        # XXX consider constraint.mainvars to check if constraint apply
+            sec_check_args = {}
+            existant = searchedvar
+        # retreive entity class for targettype to compute base rql
+        etypecls = self._cw.vreg['etypes'].etype_class(targettype)
+        rql = etypecls.fetch_rql(self._cw.user, restriction,
+                                 mainvar=searchedvar, ordermethod=ordermethod)
+        select = self._cw.vreg.parse(self._cw, rql, args).children[0]
+        # insert RQL expressions for schema constraints into the rql syntax tree
         if vocabconstraints:
             # RQLConstraint is a subclass for RQLVocabularyConstraint, so they
             # will be included as well
-            restriction += [cstr.restriction for cstr in rdef.constraints
-                            if isinstance(cstr, RQLVocabularyConstraint)]
+            cstrcls = RQLVocabularyConstraint
         else:
-            restriction += [cstr.restriction for cstr in rdef.constraints
-                            if isinstance(cstr, RQLConstraint)]
-        etypecls = self._cw.vreg['etypes'].etype_class(targettype)
-        rql = etypecls.fetch_rql(self._cw.user, restriction,
-                                 mainvar=searchedvar, ordermethod=ordermethod)
+            cstrcls = RQLConstraint
+        for cstr in rdef.constraints:
+            # consider constraint.mainvars to check if constraint apply
+            if isinstance(cstr, cstrcls) and searchedvar in cstr.mainvars:
+                if not self.has_eid() and evar in cstr.mainvars:
+                    continue
+                # compute a varmap suitable to RQLRewriter.rewrite argument
+                varmap = dict((v, v) for v in 'SO' if v in select.defined_vars
+                              and v in cstr.mainvars)
+                # rewrite constraint by constraint since we want a AND between
+                # expressions.
+                rewriter.rewrite(select, [(varmap, (cstr,))], select.solutions,
+                                 args, existant)
+        # insert security RQL expressions granting the permission to 'add' the
+        # relation into the rql syntax tree, if necessary
+        rqlexprs = rdef.get_rqlexprs('add')
+        if rqlexprs and not rdef.has_perm(self._cw, 'add', **sec_check_args):
+            # compute a varmap suitable to RQLRewriter.rewrite argument
+            varmap = dict((v, v) for v in 'SO' if v in select.defined_vars)
+            # rewrite all expressions at once since we want a OR between them.
+            rewriter.rewrite(select, [(varmap, rqlexprs)], select.solutions,
+                             args, existant)
         # ensure we have an order defined
-        if not ' ORDERBY ' in rql:
-            before, after = rql.split(' WHERE ', 1)
-            rql = '%s ORDERBY %s WHERE %s' % (before, searchedvar, after)
-        if insertsecurity:
-            rqlexprs = rdef.get_rqlexprs('add')
-            rewriter = RQLRewriter(self._cw)
-            rqlst = self._cw.vreg.parse(self._cw, rql, args)
-            if not self.has_eid():
-                existant = searchedvar
-            else:
-                existant = None # instead of 'SO', improve perfs
-            for select in rqlst.children:
-                varmap = {}
-                for var in 'SO':
-                    if var in select.defined_vars:
-                        varmap[var] = var
-                rewriter.rewrite(select, [(varmap, rqlexprs)],
-                                 select.solutions, args, existant)
-            rql = rqlst.as_string()
+        if not select.orderby:
+            select.add_sort_var(select.defined_vars[searchedvar])
+        # we're done, turn the rql syntax tree as a string
+        rql = select.as_string()
         return rql, args
 
     def unrelated(self, rtype, targettype, role='subject', limit=None,
@@ -835,6 +886,7 @@
             rql, args = self.cw_unrelated_rql(rtype, targettype, role, ordermethod)
         except Unauthorized:
             return self._cw.empty_rset()
+        # XXX should be set in unrelated rql when manipulating the AST
         if limit is not None:
             before, after = rql.split(' WHERE ', 1)
             rql = '%s LIMIT %s WHERE %s' % (before, limit, after)
@@ -930,8 +982,9 @@
         """add relations to the given object. To set a relation where this entity
         is the object of the relation, use 'reverse_'<relation> as argument name.
 
-        Values may be an entity, a list of entities, or None (meaning that all
-        relations of the given type from or to this object should be deleted).
+        Values may be an entity or eid, a list of entities or eids, or None
+        (meaning that all relations of the given type from or to this object
+        should be deleted).
         """
         # XXX update cache
         _check_cw_unsafe(kwargs)
@@ -946,9 +999,17 @@
                 continue
             if not isinstance(values, (tuple, list, set, frozenset)):
                 values = (values,)
+            eids = []
+            for val in values:
+                try:
+                    eids.append(str(val.eid))
+                except AttributeError:
+                    try:
+                        eids.append(str(typed_eid(val)))
+                    except (ValueError, TypeError):
+                        raise Exception('expected an Entity or eid, got %s' % val)
             self._cw.execute('SET %s WHERE X eid %%(x)s, Y eid IN (%s)' % (
-                restr, ','.join(str(r.eid) for r in values)),
-                             {'x': self.eid})
+                    restr, ','.join(eids)), {'x': self.eid})
 
     def cw_delete(self, **kwargs):
         assert self.has_eid(), self.eid