--- a/req.py Tue Oct 13 18:21:24 2009 +0200
+++ b/req.py Thu Oct 15 10:31:54 2009 +0200
@@ -108,31 +108,55 @@
# XXX move to CWEntityManager or even better as factory method (unclear
# where yet...)
- def create_entity(self, etype, *args, **kwargs):
+ def create_entity(self, etype, **kwargs):
"""add a new entity of the given type
Example (in a shell session):
- c = create_entity('Company', name='Logilab')
- create_entity('Person', ('works_for', 'Y'), Y=c.eid, firstname='John', lastname='Doe')
+ c = create_entity('Company', name=u'Logilab')
+ create_entity('Person', works_for=c, firstname=u'John', lastname=u'Doe')
+
"""
rql = 'INSERT %s X' % etype
relations = []
- restrictions = []
+ restrictions = set()
cachekey = []
- for rtype, rvar in args:
- relations.append('X %s %s' % (rtype, rvar))
- restrictions.append('%s eid %%(%s)s' % (rvar, rvar))
- cachekey.append(rvar)
- for attr in kwargs:
- if attr in cachekey:
- continue
- relations.append('X %s %%(%s)s' % (attr, attr))
+ pending_relations = []
+ for attr, value in kwargs.iteritems():
+ if isinstance(value, (tuple, list, set, frozenset)):
+ if len(value) == 1:
+ value = iter(value).next()
+ else:
+ pending_relations.append( (attr, value) )
+ continue
+ if hasattr(value, 'eid'): # non final relation
+ rvar = attr.upper()
+ # XXX safer detection of object relation
+ if attr.startswith('reverse_'):
+ relations.append('%s %s X' % (rvar, attr[len('reverse_'):]))
+ else:
+ relations.append('X %s %s' % (attr, rvar))
+ restriction = '%s eid %%(%s)s' % (rvar, attr)
+ if not restriction in restrictions:
+ restrictions.add(restriction)
+ cachekey.append(attr)
+ kwargs[attr] = value.eid
+ else: # attribute
+ relations.append('X %s %%(%s)s' % (attr, attr))
if relations:
rql = '%s: %s' % (rql, ', '.join(relations))
if restrictions:
rql = '%s WHERE %s' % (rql, ', '.join(restrictions))
- return self.execute(rql, kwargs, cachekey).get_entity(0, 0)
+ created = self.execute(rql, kwargs, cachekey).get_entity(0, 0)
+ for attr, values in pending_relations:
+ if attr.startswith('reverse_'):
+ restr = 'Y %s X' % attr[len('reverse_'):]
+ else:
+ restr = 'X %s Y' % attr
+ self.execute('SET %s WHERE X eid %%(x)s, Y eid IN (%s)' % (
+ restr, ','.join(str(r.eid) for r in values)),
+ {'x': created.eid}, 'x')
+ return created
def ensure_ro_rql(self, rql):
"""raise an exception if the given rql is not a select query"""