req.py
changeset 3674 387d51af966d
parent 3659 993997b4b41d
child 3720 5376aaadd16b
--- a/req.py	Tue Oct 13 18:21:24 2009 +0200
+++ b/req.py	Thu Oct 15 10:31:54 2009 +0200
@@ -108,31 +108,55 @@
 
     # XXX move to CWEntityManager or even better as factory method (unclear
     # where yet...)
-    def create_entity(self, etype, *args, **kwargs):
+    def create_entity(self, etype, **kwargs):
         """add a new entity of the given type
 
         Example (in a shell session):
 
-        c = create_entity('Company', name='Logilab')
-        create_entity('Person', ('works_for', 'Y'), Y=c.eid, firstname='John', lastname='Doe')
+        c = create_entity('Company', name=u'Logilab')
+        create_entity('Person', works_for=c, firstname=u'John', lastname=u'Doe')
+
         """
         rql = 'INSERT %s X' % etype
         relations = []
-        restrictions = []
+        restrictions = set()
         cachekey = []
-        for rtype, rvar in args:
-            relations.append('X %s %s' % (rtype, rvar))
-            restrictions.append('%s eid %%(%s)s' % (rvar, rvar))
-            cachekey.append(rvar)
-        for attr in kwargs:
-            if attr in cachekey:
-                continue
-            relations.append('X %s %%(%s)s' % (attr, attr))
+        pending_relations = []
+        for attr, value in kwargs.iteritems():
+            if isinstance(value, (tuple, list, set, frozenset)):
+                if len(value) == 1:
+                    value = iter(value).next()
+                else:
+                    pending_relations.append( (attr, value) )
+                    continue
+            if hasattr(value, 'eid'): # non final relation
+                rvar = attr.upper()
+                # XXX safer detection of object relation
+                if attr.startswith('reverse_'):
+                    relations.append('%s %s X' % (rvar, attr[len('reverse_'):]))
+                else:
+                    relations.append('X %s %s' % (attr, rvar))
+                restriction = '%s eid %%(%s)s' % (rvar, attr)
+                if not restriction in restrictions:
+                    restrictions.add(restriction)
+                cachekey.append(attr)
+                kwargs[attr] = value.eid
+            else: # attribute
+                relations.append('X %s %%(%s)s' % (attr, attr))
         if relations:
             rql = '%s: %s' % (rql, ', '.join(relations))
         if restrictions:
             rql = '%s WHERE %s' % (rql, ', '.join(restrictions))
-        return self.execute(rql, kwargs, cachekey).get_entity(0, 0)
+        created = self.execute(rql, kwargs, cachekey).get_entity(0, 0)
+        for attr, values in pending_relations:
+            if attr.startswith('reverse_'):
+                restr = 'Y %s X' % attr[len('reverse_'):]
+            else:
+                restr = 'X %s Y' % attr
+            self.execute('SET %s WHERE X eid %%(x)s, Y eid IN (%s)' % (
+                restr, ','.join(str(r.eid) for r in values)),
+                         {'x': created.eid}, 'x')
+        return created
 
     def ensure_ro_rql(self, rql):
         """raise an exception if the given rql is not a select query"""