[json] fix json serialization for recent simplejson implementation, and test encoding of entities
authorSylvain Thénault <sylvain.thenault@logilab.fr>
Fri, 11 Jun 2010 07:58:52 +0200
changeset 5726 c3b99606644d
parent 5725 b5d595b66c35
child 5727 29afb9e715bb
[json] fix json serialization for recent simplejson implementation, and test encoding of entities as with earlier simplejson implementation, iterencode internal stuff is a generated function, we can't anymore rely on the _iterencode overriding trick, so move on by stoping isinstance(Entity, dict). This is a much heavier change than expected but it was expected to be done at some point, so let's go that way.
entity.py
server/querier.py
server/repository.py
server/ssplanner.py
test/unittest_utils.py
utils.py
web/views/calendar.py
--- a/entity.py	Fri Jun 11 07:58:49 2010 +0200
+++ b/entity.py	Fri Jun 11 07:58:52 2010 +0200
@@ -19,6 +19,7 @@
 
 __docformat__ = "restructuredtext en"
 
+from copy import copy
 from warnings import warn
 
 from logilab.common import interface
@@ -51,7 +52,7 @@
     return '1'
 
 
-class Entity(AppObject, dict):
+class Entity(AppObject):
     """an entity instance has e_schema automagically set on
     the class and instances has access to their issuing cursor.
 
@@ -287,17 +288,17 @@
 
     def __init__(self, req, rset=None, row=None, col=0):
         AppObject.__init__(self, req, rset=rset, row=row, col=col)
-        dict.__init__(self)
         self._cw_related_cache = {}
         if rset is not None:
             self.eid = rset[row][col]
         else:
             self.eid = None
         self._cw_is_saved = True
+        self.cw_attr_cache = {}
 
     def __repr__(self):
         return '<Entity %s %s %s at %s>' % (
-            self.e_schema, self.eid, self.keys(), id(self))
+            self.e_schema, self.eid, self.cw_attr_cache.keys(), id(self))
 
     def __json_encode__(self):
         """custom json dumps hook to dump the entity's eid
@@ -316,12 +317,18 @@
     def __cmp__(self, other):
         raise NotImplementedError('comparison not implemented for %s' % self.__class__)
 
+    def __contains__(self, key):
+        return key in self.cw_attr_cache
+
+    def __iter__(self):
+        return iter(self.cw_attr_cache)
+
     def __getitem__(self, key):
         if key == 'eid':
             warn('[3.7] entity["eid"] is deprecated, use entity.eid instead',
                  DeprecationWarning, stacklevel=2)
             return self.eid
-        return super(Entity, self).__getitem__(key)
+        return self.cw_attr_cache[key]
 
     def __setitem__(self, attr, value):
         """override __setitem__ to update self.edited_attributes.
@@ -339,7 +346,7 @@
                  DeprecationWarning, stacklevel=2)
             self.eid = value
         else:
-            super(Entity, self).__setitem__(attr, value)
+            self.cw_attr_cache[attr] = value
             # don't add attribute into skip_security if already in edited
             # attributes, else we may accidentaly skip a desired security check
             if hasattr(self, 'edited_attributes') and \
@@ -363,13 +370,16 @@
                 del self.entity['load_left']
 
         """
-        super(Entity, self).__delitem__(attr)
+        del self.cw_attr_cache[attr]
         if hasattr(self, 'edited_attributes'):
             self.edited_attributes.remove(attr)
 
+    def get(self, key, default=None):
+        return self.cw_attr_cache.get(key, default)
+
     def setdefault(self, attr, default):
         """override setdefault to update self.edited_attributes"""
-        super(Entity, self).setdefault(attr, default)
+        self.cw_attr_cache.setdefault(attr, default)
         # don't add attribute into skip_security if already in edited
         # attributes, else we may accidentaly skip a desired security check
         if hasattr(self, 'edited_attributes') and \
@@ -382,9 +392,9 @@
         undesired changes introduced in the entity's dict. See `__delitem__`
         """
         if default is _marker:
-            value = super(Entity, self).pop(attr)
+            value = self.cw_attr_cache.pop(attr)
         else:
-            value = super(Entity, self).pop(attr, default)
+            value = self.cw_attr_cache.pop(attr, default)
         if hasattr(self, 'edited_attributes') and attr in self.edited_attributes:
             self.edited_attributes.remove(attr)
         return value
@@ -556,6 +566,12 @@
 
     # entity cloning ##########################################################
 
+    def cw_copy(self):
+        thecopy = copy(self)
+        thecopy.cw_attr_cache = copy(self.cw_attr_cache)
+        thecopy._cw_related_cache = {}
+        return thecopy
+
     def copy_relations(self, ceid): # XXX cw_copy_relations
         """copy relations of the object with the given eid on this
         object (this method is called on the newly created copy, and
@@ -668,7 +684,7 @@
         selected = []
         for attr in (attributes or self._cw_to_complete_attributes(skip_bytes, skip_pwd)):
             # if attribute already in entity, nothing to do
-            if self.has_key(attr):
+            if self.cw_attr_cache.has_key(attr):
                 continue
             # case where attribute must be completed, but is not yet in entity
             var = varmaker.next()
@@ -727,7 +743,7 @@
         :param name: name of the attribute to get
         """
         try:
-            value = self[name]
+            value = self.cw_attr_cache[name]
         except KeyError:
             if not self.cw_is_saved():
                 return None
@@ -952,7 +968,7 @@
         # clear attributes cache
         haseid = 'eid' in self
         self._cw_completed = False
-        self.clear()
+        self.cw_attr_cache.clear()
         # clear relations cache
         self.cw_clear_relation_cache()
         # rest path unique cache
@@ -1020,7 +1036,7 @@
 
         This method is for internal use, you should not use it.
         """
-        super(Entity, self).__setitem__(attr, value)
+        self.cw_attr_cache[attr] = value
 
     def _cw_clear_local_perm_cache(self, action):
         for rqlexpr in self.e_schema.get_rqlexprs(action):
@@ -1037,7 +1053,7 @@
     def _cw_set_defaults(self):
         """set default values according to the schema"""
         for attr, value in self.e_schema.defaults():
-            if not self.has_key(attr):
+            if not self.cw_attr_cache.has_key(attr):
                 self[str(attr)] = value
 
     def _cw_check(self, creation=False):
--- a/server/querier.py	Fri Jun 11 07:58:49 2010 +0200
+++ b/server/querier.py	Fri Jun 11 07:58:52 2010 +0200
@@ -17,8 +17,8 @@
 # with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
 """Helper classes to execute RQL queries on a set of sources, performing
 security checking and data aggregation.
+"""
 
-"""
 from __future__ import with_statement
 
 __docformat__ = "restructuredtext en"
--- a/server/repository.py	Fri Jun 11 07:58:49 2010 +0200
+++ b/server/repository.py	Fri Jun 11 07:58:52 2010 +0200
@@ -910,7 +910,7 @@
             self._extid_cache[cachekey] = eid
             self._type_source_cache[eid] = (etype, source.uri, extid)
             entity = source.before_entity_insertion(session, extid, etype, eid)
-            entity.edited_attributes = set(entity)
+            entity.edited_attributes = set(entity.cw_attr_cache)
             if source.should_call_hooks:
                 self.hm.call_hooks('before_add_entity', session, entity=entity)
             # XXX call add_info with complete=False ?
@@ -1021,7 +1021,7 @@
         """
         # init edited_attributes before calling before_add_entity hooks
         entity._cw_is_saved = False # entity has an eid but is not yet saved
-        entity.edited_attributes = set(entity) # XXX cw_edited_attributes
+        entity.edited_attributes = set(entity.cw_attr_cache) # XXX cw_edited_attributes
         eschema = entity.e_schema
         source = self.locate_etype_source(entity.__regid__)
         # allocate an eid to the entity before calling hooks
@@ -1036,7 +1036,7 @@
         # XXX use entity.keys here since edited_attributes is not updated for
         # inline relations XXX not true, right? (see edited_attributes
         # affectation above)
-        for attr in entity.iterkeys():
+        for attr in entity.cw_attr_cache.iterkeys():
             rschema = eschema.subjrels[attr]
             if not rschema.final: # inlined relation
                 relations.append((attr, entity[attr]))
--- a/server/ssplanner.py	Fri Jun 11 07:58:49 2010 +0200
+++ b/server/ssplanner.py	Fri Jun 11 07:58:52 2010 +0200
@@ -22,8 +22,6 @@
 
 __docformat__ = "restructuredtext en"
 
-from copy import copy
-
 from rql.stmts import Union, Select
 from rql.nodes import Constant, Relation
 
@@ -479,7 +477,7 @@
             result = [[]]
         for row in result:
             # get a new entity definition for this row
-            edef = copy(base_edef)
+            edef = base_edef.cw_copy()
             # complete this entity def using row values
             index = 0
             for rtype, rorder, value in self.rdefs:
--- a/test/unittest_utils.py	Fri Jun 11 07:58:49 2010 +0200
+++ b/test/unittest_utils.py	Fri Jun 11 07:58:52 2010 +0200
@@ -15,16 +15,16 @@
 #
 # You should have received a copy of the GNU Lesser General Public License along
 # with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
-"""unit tests for module cubicweb.utils
-
-"""
+"""unit tests for module cubicweb.utils"""
 
 import re
 import decimal
 import datetime
 
 from logilab.common.testlib import TestCase, unittest_main
+
 from cubicweb.utils import make_uid, UStringIO, SizeConstrainedList, RepeatList
+from cubicweb.entity import Entity
 
 try:
     from cubicweb.utils import CubicWebJsonEncoder, json
@@ -99,6 +99,7 @@
         l.pop(2)
         self.assertEquals(l, [(1, 3)]*2)
 
+
 class SizeConstrainedListTC(TestCase):
 
     def test_append(self):
@@ -117,6 +118,7 @@
             l.extend(extension)
             yield self.assertEquals, l, expected
 
+
 class JSONEncoderTC(TestCase):
     def setUp(self):
         if json is None:
@@ -136,6 +138,20 @@
     def test_encoding_decimal(self):
         self.assertEquals(self.encode(decimal.Decimal('1.2')), '1.2')
 
+    def test_encoding_bare_entity(self):
+        e = Entity(None)
+        e['pouet'] = 'hop'
+        e.eid = 2
+        self.assertEquals(json.loads(self.encode(e)),
+                          {'pouet': 'hop', 'eid': 2})
+
+    def test_encoding_entity_in_list(self):
+        e = Entity(None)
+        e['pouet'] = 'hop'
+        e.eid = 2
+        self.assertEquals(json.loads(self.encode([e])),
+                          [{'pouet': 'hop', 'eid': 2}])
+
     def test_encoding_unknown_stuff(self):
         self.assertEquals(self.encode(TestCase), 'null')
 
--- a/utils.py	Fri Jun 11 07:58:49 2010 +0200
+++ b/utils.py	Fri Jun 11 07:58:52 2010 +0200
@@ -335,21 +335,11 @@
     class CubicWebJsonEncoder(json.JSONEncoder):
         """define a json encoder to be able to encode yams std types"""
 
-        # _iterencode is the only entry point I've found to use a custom encode
-        # hook early enough: .default() is called if nothing else matched before,
-        # .iterencode() is called once on the main structure to encode and then
-        # never gets called again.
-        # For the record, our main use case is in FormValidateController with:
-        #   json.dumps((status, args, entity), cls=CubicWebJsonEncoder)
-        # where we want all the entity attributes, including eid, to be part
-        # of the json object dumped.
-        # This would have once more been easier if Entity didn't extend dict.
-        def _iterencode(self, obj, markers=None):
-            if hasattr(obj, '__json_encode__'):
-                obj = obj.__json_encode__()
-            return json.JSONEncoder._iterencode(self, obj, markers)
-
         def default(self, obj):
+            if hasattr(obj, 'eid'):
+                d = obj.cw_attr_cache.copy()
+                d['eid'] = obj.eid
+                return d
             if isinstance(obj, datetime.datetime):
                 return obj.strftime('%Y/%m/%d %H:%M:%S')
             elif isinstance(obj, datetime.date):
--- a/web/views/calendar.py	Fri Jun 11 07:58:49 2010 +0200
+++ b/web/views/calendar.py	Fri Jun 11 07:58:52 2010 +0200
@@ -395,12 +395,12 @@
         # colors here are class names defined in cubicweb.css
         colors = [ "col%x" % i for i in range(12) ]
         next_color_index = 0
-        done_tasks = []
+        done_tasks = set()
         for row in xrange(self.cw_rset.rowcount):
             task = self.cw_rset.get_entity(row, 0)
-            if task in done_tasks:
+            if task.eid in done_tasks:
                 continue
-            done_tasks.append(task)
+            done_tasks.add(task.eid)
             the_dates = []
             icalendarable = task.cw_adapt_to('ICalendarable')
             tstart = icalendarable.start