server/ssplanner.py
changeset 6142 8bc6eac1fac1
parent 5821 656c974961c4
child 6426 541659c39f6a
--- a/server/ssplanner.py	Wed Aug 25 10:29:07 2010 +0200
+++ b/server/ssplanner.py	Wed Aug 25 10:29:18 2010 +0200
@@ -21,6 +21,8 @@
 
 __docformat__ = "restructuredtext en"
 
+from copy import copy
+
 from rql.stmts import Union, Select
 from rql.nodes import Constant, Relation
 
@@ -55,11 +57,11 @@
             if isinstance(rhs, Constant) and not rhs.uid:
                 # add constant values to entity def
                 value = rhs.eval(plan.args)
-                eschema = edef.e_schema
+                eschema = edef.entity.e_schema
                 attrtype = eschema.subjrels[rtype].objects(eschema)[0]
                 if attrtype == 'Password' and isinstance(value, unicode):
                     value = value.encode('UTF8')
-                edef[rtype] = value
+                edef.edited_attribute(rtype, value)
             elif to_build.has_key(str(rhs)):
                 # create a relation between two newly created variables
                 plan.add_relation_def((edef, rtype, to_build[rhs.name]))
@@ -126,6 +128,132 @@
     return select
 
 
+_MARKER = object()
+
+class dict_protocol_catcher(object):
+    def __init__(self, entity):
+        self.__entity = entity
+    def __getitem__(self, attr):
+        return self.__entity.cw_edited[attr]
+    def __setitem__(self, attr, value):
+        self.__entity.cw_edited[attr] = value
+    def __getattr__(self, attr):
+        return getattr(self.__entity, attr)
+
+
+class EditedEntity(dict):
+    """encapsulate entities attributes being written by an RQL query"""
+    def __init__(self, entity, **kwargs):
+        dict.__init__(self, **kwargs)
+        self.entity = entity
+        self.skip_security = set()
+        self.querier_pending_relations = {}
+        self.saved = False
+
+    def __hash__(self):
+        # dict|set keyable
+        return hash(id(self))
+
+    def __cmp__(self, other):
+        # we don't want comparison by value inherited from dict
+        return cmp(id(self), id(other))
+
+    def __setitem__(self, attr, value):
+        assert attr != 'eid'
+        # don't add attribute into skip_security if already in edited
+        # attributes, else we may accidentaly skip a desired security check
+        if attr not in self:
+            self.skip_security.add(attr)
+        self.edited_attribute(attr, value)
+
+    def __delitem__(self, attr):
+        assert not self.saved, 'too late to modify edited attributes'
+        super(EditedEntity, self).__delitem__(attr)
+        self.entity.cw_attr_cache.pop(attr, None)
+
+    def pop(self, attr, *args):
+        # don't update skip_security by design (think to storage api)
+        assert not self.saved, 'too late to modify edited attributes'
+        value = super(EditedEntity, self).pop(attr, *args)
+        self.entity.cw_attr_cache.pop(attr, *args)
+        return value
+
+    def setdefault(self, attr, default):
+        assert attr != 'eid'
+        # don't add attribute into skip_security if already in edited
+        # attributes, else we may accidentaly skip a desired security check
+        if attr not in self:
+            self[attr] = default
+        return self[attr]
+
+    def update(self, values, skipsec=True):
+        if skipsec:
+            setitem = self.__setitem__
+        else:
+            setitem = self.edited_attribute
+        for attr, value in values.iteritems():
+            setitem(attr, value)
+
+    def edited_attribute(self, attr, value):
+        """attribute being edited by a rql query: should'nt be added to
+        skip_security
+        """
+        assert not self.saved, 'too late to modify edited attributes'
+        super(EditedEntity, self).__setitem__(attr, value)
+        self.entity.cw_attr_cache[attr] = value
+
+    def oldnewvalue(self, attr):
+        """returns the couple (old attr value, new attr value)
+
+        NOTE: will only work in a before_update_entity hook
+        """
+        assert not self.saved, 'too late to get the old value'
+        # get new value and remove from local dict to force a db query to
+        # fetch old value
+        newvalue = self.entity.cw_attr_cache.pop(attr, _MARKER)
+        oldvalue = getattr(self.entity, attr)
+        if newvalue is not _MARKER:
+            self.entity.cw_attr_cache[attr] = newvalue
+        else:
+            newvalue = oldvalue
+        return oldvalue, newvalue
+
+    def set_defaults(self):
+        """set default values according to the schema"""
+        for attr, value in self.entity.e_schema.defaults():
+            if not attr in self:
+                self[str(attr)] = value
+
+    def check(self, creation=False):
+        """check the entity edition against its schema. Only final relation
+        are checked here, constraint on actual relations are checked in hooks
+        """
+        entity = self.entity
+        if creation:
+            # on creations, we want to check all relations, especially
+            # required attributes
+            relations = [rschema for rschema in entity.e_schema.subject_relations()
+                         if rschema.final and rschema.type != 'eid']
+        else:
+            relations = [entity._cw.vreg.schema.rschema(rtype)
+                         for rtype in self]
+        from yams import ValidationError
+        try:
+            entity.e_schema.check(dict_protocol_catcher(entity),
+                                  creation=creation, _=entity._cw._,
+                                  relations=relations)
+        except ValidationError, ex:
+            ex.entity = self.entity
+            raise
+
+    def clone(self):
+        thecopy = EditedEntity(copy(self.entity))
+        thecopy.entity.cw_attr_cache = copy(self.entity.cw_attr_cache)
+        thecopy.entity._cw_related_cache = {}
+        thecopy.update(self, skipsec=False)
+        return thecopy
+
+
 class SSPlanner(object):
     """SingleSourcePlanner: build execution plan for rql queries
 
@@ -162,7 +290,7 @@
         etype_class = session.vreg['etypes'].etype_class
         for etype, var in rqlst.main_variables:
             # need to do this since entity class is shared w. web client code !
-            to_build[var.name] = etype_class(etype)(session)
+            to_build[var.name] = EditedEntity(etype_class(etype)(session))
             plan.add_entity_def(to_build[var.name])
         # add constant values to entity def, mark variables to be selected
         to_select = _extract_const_attributes(plan, rqlst, to_build)
@@ -177,7 +305,7 @@
         for edef, rdefs in to_select.items():
             # create a select rql st to fetch needed data
             select = Select()
-            eschema = edef.e_schema
+            eschema = edef.entity.e_schema
             for i, (rtype, term, reverse) in enumerate(rdefs):
                 if getattr(term, 'variable', None) in eidconsts:
                     value = eidconsts[term.variable]
@@ -284,10 +412,8 @@
                 rhsinfo = selectedidx[rhskey][:-1] + (None,)
             rschema = getrschema(relation.r_type)
             updatedefs.append( (lhsinfo, rhsinfo, rschema) )
-            if rschema.final or rschema.inlined:
-                attributes.add(relation.r_type)
         # the update step
-        step = UpdateStep(plan, updatedefs, attributes)
+        step = UpdateStep(plan, updatedefs)
         # when necessary add substep to fetch yet unknown values
         select = _build_substep_query(select, rqlst)
         if select is not None:
@@ -476,7 +602,7 @@
             result = [[]]
         for row in result:
             # get a new entity definition for this row
-            edef = base_edef.cw_copy()
+            edef = base_edef.clone()
             # complete this entity def using row values
             index = 0
             for rtype, rorder, value in self.rdefs:
@@ -484,7 +610,7 @@
                     value = row[index]
                     index += 1
                 if rorder == InsertRelationsStep.FINAL:
-                    edef._cw_rql_set_value(rtype, value)
+                    edef.edited_attribute(rtype, value)
                 elif rorder == InsertRelationsStep.RELATION:
                     self.plan.add_relation_def( (edef, rtype, value) )
                     edef.querier_pending_relations[(rtype, 'subject')] = value
@@ -495,6 +621,7 @@
         self.plan.substitute_entity_def(base_edef, edefs)
         return result
 
+
 class InsertStep(Step):
     """step consisting in inserting new entities / relations"""
 
@@ -555,10 +682,9 @@
     definitions and from results fetched in previous step
     """
 
-    def __init__(self, plan, updatedefs, attributes):
+    def __init__(self, plan, updatedefs):
         Step.__init__(self, plan)
         self.updatedefs = updatedefs
-        self.attributes = attributes
 
     def execute(self):
         """execute this step"""
@@ -578,16 +704,17 @@
                 if rschema.final or rschema.inlined:
                     eid = typed_eid(lhsval)
                     try:
-                        edef = edefs[eid]
+                        edited = edefs[eid]
                     except KeyError:
-                        edefs[eid] = edef = session.entity_from_eid(eid)
-                    edef._cw_rql_set_value(str(rschema), rhsval)
+                        edef = session.entity_from_eid(eid)
+                        edefs[eid] = edited = EditedEntity(edef)
+                    edited.edited_attribute(str(rschema), rhsval)
                 else:
                     repo.glob_add_relation(session, lhsval, str(rschema), rhsval)
             result[i] = newrow
         # update entities
-        for eid, edef in edefs.iteritems():
-            repo.glob_update_entity(session, edef, set(self.attributes))
+        for eid, edited in edefs.iteritems():
+            repo.glob_update_entity(session, edited)
         return result
 
 def _handle_relterm(info, row, newrow):