|
1 """cubicweb ldap user source |
|
2 |
|
3 this source is for now limited to a read-only EUser source |
|
4 |
|
5 :organization: Logilab |
|
6 :copyright: 2003-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved. |
|
7 :contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr |
|
8 |
|
9 |
|
10 Part of the code is coming form Zope's LDAPUserFolder |
|
11 |
|
12 Copyright (c) 2004 Jens Vagelpohl. |
|
13 All Rights Reserved. |
|
14 |
|
15 This software is subject to the provisions of the Zope Public License, |
|
16 Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. |
|
17 THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED |
|
18 WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED |
|
19 WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS |
|
20 FOR A PARTICULAR PURPOSE. |
|
21 """ |
|
22 |
|
23 from mx.DateTime import now, DateTimeDelta |
|
24 |
|
25 from logilab.common.textutils import get_csv |
|
26 from rql.nodes import Relation, VariableRef, Constant, Function |
|
27 |
|
28 import ldap |
|
29 from ldap.ldapobject import ReconnectLDAPObject |
|
30 from ldap.filter import filter_format, escape_filter_chars |
|
31 from ldapurl import LDAPUrl |
|
32 |
|
33 from cubicweb.common import AuthenticationError, UnknownEid, RepositoryError |
|
34 from cubicweb.server.sources import AbstractSource, TrFunc, GlobTrFunc, ConnectionWrapper |
|
35 from cubicweb.server.utils import cartesian_product |
|
36 |
|
37 # search scopes |
|
38 BASE = ldap.SCOPE_BASE |
|
39 ONELEVEL = ldap.SCOPE_ONELEVEL |
|
40 SUBTREE = ldap.SCOPE_SUBTREE |
|
41 |
|
42 # XXX only for edition ?? |
|
43 ## password encryption possibilities |
|
44 #ENCRYPTIONS = ('SHA', 'CRYPT', 'MD5', 'CLEAR') # , 'SSHA' |
|
45 |
|
46 # mode identifier : (port, protocol) |
|
47 MODES = { |
|
48 0: (389, 'ldap'), |
|
49 1: (636, 'ldaps'), |
|
50 2: (0, 'ldapi'), |
|
51 } |
|
52 |
|
53 class TimedCache(dict): |
|
54 def __init__(self, ttlm, ttls=0): |
|
55 # time to live in minutes |
|
56 self.ttl = DateTimeDelta(0, 0, ttlm, ttls) |
|
57 |
|
58 def __setitem__(self, key, value): |
|
59 dict.__setitem__(self, key, (now(), value)) |
|
60 |
|
61 def __getitem__(self, key): |
|
62 return dict.__getitem__(self, key)[1] |
|
63 |
|
64 def clear_expired(self): |
|
65 now_ = now() |
|
66 ttl = self.ttl |
|
67 for key, (timestamp, value) in self.items(): |
|
68 if now_ - timestamp > ttl: |
|
69 del self[key] |
|
70 |
|
71 class LDAPUserSource(AbstractSource): |
|
72 """LDAP read-only EUser source""" |
|
73 support_entities = {'EUser': False} |
|
74 |
|
75 port = None |
|
76 |
|
77 cnx_mode = 0 |
|
78 cnx_dn = '' |
|
79 cnx_pwd = '' |
|
80 |
|
81 options = ( |
|
82 ('host', |
|
83 {'type' : 'string', |
|
84 'default': 'ldap', |
|
85 'help': 'ldap host', |
|
86 'group': 'ldap-source', 'inputlevel': 1, |
|
87 }), |
|
88 ('user-base-dn', |
|
89 {'type' : 'string', |
|
90 'default': 'ou=People,dc=logilab,dc=fr', |
|
91 'help': 'base DN to lookup for users', |
|
92 'group': 'ldap-source', 'inputlevel': 0, |
|
93 }), |
|
94 ('user-scope', |
|
95 {'type' : 'choice', |
|
96 'default': 'ONELEVEL', |
|
97 'choices': ('BASE', 'ONELEVEL', 'SUBTREE'), |
|
98 'help': 'user search scope', |
|
99 'group': 'ldap-source', 'inputlevel': 1, |
|
100 }), |
|
101 ('user-classes', |
|
102 {'type' : 'csv', |
|
103 'default': ('top', 'posixAccount'), |
|
104 'help': 'classes of user', |
|
105 'group': 'ldap-source', 'inputlevel': 1, |
|
106 }), |
|
107 ('user-login-attr', |
|
108 {'type' : 'string', |
|
109 'default': 'uid', |
|
110 'help': 'attribute used as login on authentication', |
|
111 'group': 'ldap-source', 'inputlevel': 1, |
|
112 }), |
|
113 ('user-default-group', |
|
114 {'type' : 'csv', |
|
115 'default': ('users',), |
|
116 'help': 'name of a group in which ldap users will be by default. \ |
|
117 You can set multiple groups by separating them by a comma.', |
|
118 'group': 'ldap-source', 'inputlevel': 1, |
|
119 }), |
|
120 ('user-attrs-map', |
|
121 {'type' : 'named', |
|
122 'default': {'uid': 'login', 'gecos': 'email'}, |
|
123 'help': 'map from ldap user attributes to cubicweb attributes', |
|
124 'group': 'ldap-source', 'inputlevel': 1, |
|
125 }), |
|
126 |
|
127 ('synchronization-interval', |
|
128 {'type' : 'int', |
|
129 'default': 24*60*60, |
|
130 'help': 'interval between synchronization with the ldap \ |
|
131 directory (default to once a day).', |
|
132 'group': 'ldap-source', 'inputlevel': 2, |
|
133 }), |
|
134 ('cache-life-time', |
|
135 {'type' : 'int', |
|
136 'default': 2*60, |
|
137 'help': 'life time of query cache in minutes (default to two hours).', |
|
138 'group': 'ldap-source', 'inputlevel': 2, |
|
139 }), |
|
140 |
|
141 ) |
|
142 |
|
143 def __init__(self, repo, appschema, source_config, *args, **kwargs): |
|
144 AbstractSource.__init__(self, repo, appschema, source_config, |
|
145 *args, **kwargs) |
|
146 self.host = source_config['host'] |
|
147 self.user_base_dn = source_config['user-base-dn'] |
|
148 self.user_base_scope = globals()[source_config['user-scope']] |
|
149 self.user_classes = get_csv(source_config['user-classes']) |
|
150 self.user_login_attr = source_config['user-login-attr'] |
|
151 self.user_default_groups = get_csv(source_config['user-default-group']) |
|
152 self.user_attrs = dict(v.split(':', 1) for v in get_csv(source_config['user-attrs-map'])) |
|
153 self.user_rev_attrs = {'eid': 'dn'} |
|
154 for ldapattr, cwattr in self.user_attrs.items(): |
|
155 self.user_rev_attrs[cwattr] = ldapattr |
|
156 self.base_filters = [filter_format('(%s=%s)', ('objectClass', o)) |
|
157 for o in self.user_classes] |
|
158 self._conn = None |
|
159 self._cache = {} |
|
160 ttlm = int(source_config.get('cache-life-type', 2*60)) |
|
161 self._query_cache = TimedCache(ttlm) |
|
162 self._interval = int(source_config.get('synchronization-interval', |
|
163 24*60*60)) |
|
164 |
|
165 def reset_caches(self): |
|
166 """method called during test to reset potential source caches""" |
|
167 self._query_cache = TimedCache(2*60) |
|
168 |
|
169 def init(self): |
|
170 """method called by the repository once ready to handle request""" |
|
171 self.repo.looping_task(self._interval, self.synchronize) |
|
172 self.repo.looping_task(self._query_cache.ttl.seconds/10, self._query_cache.clear_expired) |
|
173 |
|
174 def synchronize(self): |
|
175 """synchronize content known by this repository with content in the |
|
176 external repository |
|
177 """ |
|
178 self.info('synchronizing ldap source %s', self.uri) |
|
179 session = self.repo.internal_session() |
|
180 try: |
|
181 cursor = session.system_sql("SELECT eid, extid FROM entities WHERE " |
|
182 "source='%s'" % self.uri) |
|
183 for eid, extid in cursor.fetchall(): |
|
184 # if no result found, _search automatically delete entity information |
|
185 res = self._search(session, extid, BASE) |
|
186 if res: |
|
187 ldapemailaddr = res[0].get(self.user_rev_attrs['email']) |
|
188 if ldapemailaddr: |
|
189 rset = session.execute('EmailAddress X,A WHERE ' |
|
190 'U use_email X, U eid %(u)s', |
|
191 {'u': eid}) |
|
192 ldapemailaddr = unicode(ldapemailaddr) |
|
193 for emaileid, emailaddr in rset: |
|
194 if emailaddr == ldapemailaddr: |
|
195 break |
|
196 else: |
|
197 self.info('updating email address of user %s to %s', |
|
198 extid, ldapemailaddr) |
|
199 if rset: |
|
200 session.execute('SET X address %(addr)s WHERE ' |
|
201 'U primary_email X, U eid %(u)s', |
|
202 {'addr': ldapemailaddr, 'u': eid}) |
|
203 else: |
|
204 # no email found, create it |
|
205 _insert_email(session, ldapemailaddr, eid) |
|
206 finally: |
|
207 session.commit() |
|
208 session.close() |
|
209 |
|
210 def get_connection(self): |
|
211 """open and return a connection to the source""" |
|
212 if self._conn is None: |
|
213 self._connect() |
|
214 return ConnectionWrapper(self._conn) |
|
215 |
|
216 def authenticate(self, session, login, password): |
|
217 """return EUser eid for the given login/password if this account is |
|
218 defined in this source, else raise `AuthenticationError` |
|
219 |
|
220 two queries are needed since passwords are stored crypted, so we have |
|
221 to fetch the salt first |
|
222 """ |
|
223 assert login, 'no login!' |
|
224 searchfilter = [filter_format('(%s=%s)', (self.user_login_attr, login))] |
|
225 searchfilter.extend([filter_format('(%s=%s)', ('objectClass', o)) |
|
226 for o in self.user_classes]) |
|
227 searchstr = '(&%s)' % ''.join(searchfilter) |
|
228 # first search the user |
|
229 try: |
|
230 user = self._search(session, self.user_base_dn, |
|
231 self.user_base_scope, searchstr)[0] |
|
232 except IndexError: |
|
233 # no such user |
|
234 raise AuthenticationError() |
|
235 # check password by establishing a (unused) connection |
|
236 try: |
|
237 self._connect(user['dn'], password) |
|
238 except: |
|
239 # Something went wrong, most likely bad credentials |
|
240 raise AuthenticationError() |
|
241 return self.extid2eid(user['dn'], 'EUser', session) |
|
242 |
|
243 def ldap_name(self, var): |
|
244 if var.stinfo['relations']: |
|
245 relname = iter(var.stinfo['relations']).next().r_type |
|
246 return self.user_rev_attrs.get(relname) |
|
247 return None |
|
248 |
|
249 def prepare_columns(self, mainvars, rqlst): |
|
250 """return two list describin how to build the final results |
|
251 from the result of an ldap search (ie a list of dictionnary) |
|
252 """ |
|
253 columns = [] |
|
254 global_transforms = [] |
|
255 for i, term in enumerate(rqlst.selection): |
|
256 if isinstance(term, Constant): |
|
257 columns.append(term) |
|
258 continue |
|
259 if isinstance(term, Function): # LOWER, UPPER, COUNT... |
|
260 var = term.get_nodes(VariableRef)[0] |
|
261 var = var.variable |
|
262 try: |
|
263 mainvar = var.stinfo['attrvar'].name |
|
264 except AttributeError: # no attrvar set |
|
265 mainvar = var.name |
|
266 assert mainvar in mainvars |
|
267 trname = term.name |
|
268 ldapname = self.ldap_name(var) |
|
269 if trname in ('COUNT', 'MIN', 'MAX', 'SUM'): |
|
270 global_transforms.append(GlobTrFunc(trname, i, ldapname)) |
|
271 columns.append((mainvar, ldapname)) |
|
272 continue |
|
273 if trname in ('LOWER', 'UPPER'): |
|
274 columns.append((mainvar, TrFunc(trname, i, ldapname))) |
|
275 continue |
|
276 raise NotImplementedError('no support for %s function' % trname) |
|
277 if term.name in mainvars: |
|
278 columns.append((term.name, 'dn')) |
|
279 continue |
|
280 var = term.variable |
|
281 mainvar = var.stinfo['attrvar'].name |
|
282 columns.append((mainvar, self.ldap_name(var))) |
|
283 #else: |
|
284 # # probably a bug in rql splitting if we arrive here |
|
285 # raise NotImplementedError |
|
286 return columns, global_transforms |
|
287 |
|
288 def syntax_tree_search(self, session, union, |
|
289 args=None, cachekey=None, varmap=None, debug=0): |
|
290 """return result from this source for a rql query (actually from a rql |
|
291 syntax tree and a solution dictionary mapping each used variable to a |
|
292 possible type). If cachekey is given, the query necessary to fetch the |
|
293 results (but not the results themselves) may be cached using this key. |
|
294 """ |
|
295 # XXX not handled : transform/aggregat function, join on multiple users... |
|
296 assert len(union.children) == 1, 'union not supported' |
|
297 rqlst = union.children[0] |
|
298 assert not rqlst.with_, 'subquery not supported' |
|
299 rqlkey = rqlst.as_string(kwargs=args) |
|
300 try: |
|
301 results = self._query_cache[rqlkey] |
|
302 except KeyError: |
|
303 results = self.rqlst_search(session, rqlst, args) |
|
304 self._query_cache[rqlkey] = results |
|
305 return results |
|
306 |
|
307 def rqlst_search(self, session, rqlst, args): |
|
308 mainvars = [] |
|
309 for varname in rqlst.defined_vars: |
|
310 for sol in rqlst.solutions: |
|
311 if sol[varname] == 'EUser': |
|
312 mainvars.append(varname) |
|
313 break |
|
314 assert mainvars |
|
315 columns, globtransforms = self.prepare_columns(mainvars, rqlst) |
|
316 eidfilters = [] |
|
317 allresults = [] |
|
318 generator = RQL2LDAPFilter(self, session, args, mainvars) |
|
319 for mainvar in mainvars: |
|
320 # handle restriction |
|
321 try: |
|
322 eidfilters_, ldapfilter = generator.generate(rqlst, mainvar) |
|
323 except GotDN, ex: |
|
324 assert ex.dn, 'no dn!' |
|
325 try: |
|
326 res = [self._cache[ex.dn]] |
|
327 except KeyError: |
|
328 res = self._search(session, ex.dn, BASE) |
|
329 except UnknownEid, ex: |
|
330 # raised when we are looking for the dn of an eid which is not |
|
331 # coming from this source |
|
332 res = [] |
|
333 else: |
|
334 eidfilters += eidfilters_ |
|
335 res = self._search(session, self.user_base_dn, |
|
336 self.user_base_scope, ldapfilter) |
|
337 allresults.append(res) |
|
338 # 1. get eid for each dn and filter according to that eid if necessary |
|
339 for i, res in enumerate(allresults): |
|
340 filteredres = [] |
|
341 for resdict in res: |
|
342 # get sure the entity exists in the system table |
|
343 eid = self.extid2eid(resdict['dn'], 'EUser', session) |
|
344 for eidfilter in eidfilters: |
|
345 if not eidfilter(eid): |
|
346 break |
|
347 else: |
|
348 resdict['eid'] = eid |
|
349 filteredres.append(resdict) |
|
350 allresults[i] = filteredres |
|
351 # 2. merge result for each "mainvar": cartesian product |
|
352 allresults = cartesian_product(allresults) |
|
353 # 3. build final result according to column definition |
|
354 result = [] |
|
355 for rawline in allresults: |
|
356 rawline = dict(zip(mainvars, rawline)) |
|
357 line = [] |
|
358 for varname, ldapname in columns: |
|
359 if ldapname is None: |
|
360 value = None # no mapping available |
|
361 elif ldapname == 'dn': |
|
362 value = rawline[varname]['eid'] |
|
363 elif isinstance(ldapname, Constant): |
|
364 if ldapname.type == 'Substitute': |
|
365 value = args[ldapname.value] |
|
366 else: |
|
367 value = ldapname.value |
|
368 elif isinstance(ldapname, TrFunc): |
|
369 value = ldapname.apply(rawline[varname]) |
|
370 else: |
|
371 value = rawline[varname].get(ldapname) |
|
372 line.append(value) |
|
373 result.append(line) |
|
374 for trfunc in globtransforms: |
|
375 result = trfunc.apply(result) |
|
376 #print '--> ldap result', result |
|
377 return result |
|
378 |
|
379 |
|
380 def _connect(self, userdn=None, userpwd=None): |
|
381 port, protocol = MODES[self.cnx_mode] |
|
382 if protocol == 'ldapi': |
|
383 hostport = self.host |
|
384 else: |
|
385 hostport = '%s:%s' % (self.host, self.port or port) |
|
386 self.info('connecting %s://%s as %s', protocol, hostport, |
|
387 userdn or 'anonymous') |
|
388 url = LDAPUrl(urlscheme=protocol, hostport=hostport) |
|
389 conn = ReconnectLDAPObject(url.initializeUrl()) |
|
390 # Set the protocol version - version 3 is preferred |
|
391 try: |
|
392 conn.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION3) |
|
393 except ldap.LDAPError: # Invalid protocol version, fall back safely |
|
394 conn.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION2) |
|
395 # Deny auto-chasing of referrals to be safe, we handle them instead |
|
396 #try: |
|
397 # connection.set_option(ldap.OPT_REFERRALS, 0) |
|
398 #except ldap.LDAPError: # Cannot set referrals, so do nothing |
|
399 # pass |
|
400 #conn.set_option(ldap.OPT_NETWORK_TIMEOUT, conn_timeout) |
|
401 #conn.timeout = op_timeout |
|
402 # Now bind with the credentials given. Let exceptions propagate out. |
|
403 if userdn is None: |
|
404 assert self._conn is None |
|
405 self._conn = conn |
|
406 userdn = self.cnx_dn |
|
407 userpwd = self.cnx_pwd |
|
408 conn.simple_bind_s(userdn, userpwd) |
|
409 return conn |
|
410 |
|
411 def _search(self, session, base, scope, |
|
412 searchstr='(objectClass=*)', attrs=()): |
|
413 """make an ldap query""" |
|
414 cnx = session.pool.connection(self.uri).cnx |
|
415 try: |
|
416 res = cnx.search_s(base, scope, searchstr, attrs) |
|
417 except ldap.PARTIAL_RESULTS: |
|
418 res = cnx.result(all=0)[1] |
|
419 except ldap.NO_SUCH_OBJECT: |
|
420 eid = self.extid2eid(base, 'EUser', session, insert=False) |
|
421 if eid: |
|
422 self.warning('deleting ldap user with eid %s and dn %s', |
|
423 eid, base) |
|
424 self.repo.delete_info(session, eid) |
|
425 self._cache.pop(base, None) |
|
426 return [] |
|
427 ## except ldap.REFERRAL, e: |
|
428 ## cnx = self.handle_referral(e) |
|
429 ## try: |
|
430 ## res = cnx.search_s(base, scope, searchstr, attrs) |
|
431 ## except ldap.PARTIAL_RESULTS: |
|
432 ## res_type, res = cnx.result(all=0) |
|
433 result = [] |
|
434 for rec_dn, rec_dict in res: |
|
435 # When used against Active Directory, "rec_dict" may not be |
|
436 # be a dictionary in some cases (instead, it can be a list) |
|
437 # An example of a useless "res" entry that can be ignored |
|
438 # from AD is |
|
439 # (None, ['ldap://ForestDnsZones.PORTAL.LOCAL/DC=ForestDnsZones,DC=PORTAL,DC=LOCAL']) |
|
440 # This appears to be some sort of internal referral, but |
|
441 # we can't handle it, so we need to skip over it. |
|
442 try: |
|
443 items = rec_dict.items() |
|
444 except AttributeError: |
|
445 # 'items' not found on rec_dict, skip |
|
446 continue |
|
447 for key, value in items: # XXX syt: huuum ? |
|
448 if not isinstance(value, str): |
|
449 try: |
|
450 for i in range(len(value)): |
|
451 value[i] = unicode(value[i], 'utf8') |
|
452 except: |
|
453 pass |
|
454 if isinstance(value, list) and len(value) == 1: |
|
455 rec_dict[key] = value = value[0] |
|
456 rec_dict['dn'] = rec_dn |
|
457 self._cache[rec_dn] = rec_dict |
|
458 result.append(rec_dict) |
|
459 #print '--->', result |
|
460 return result |
|
461 |
|
462 def before_entity_insertion(self, session, lid, etype, eid): |
|
463 """called by the repository when an eid has been attributed for an |
|
464 entity stored here but the entity has not been inserted in the system |
|
465 table yet. |
|
466 |
|
467 This method must return the an Entity instance representation of this |
|
468 entity. |
|
469 """ |
|
470 entity = super(LDAPUserSource, self).before_entity_insertion(session, lid, etype, eid) |
|
471 res = self._search(session, lid, BASE)[0] |
|
472 for attr in entity.e_schema.indexable_attributes(): |
|
473 entity[attr] = res[self.user_rev_attrs[attr]] |
|
474 return entity |
|
475 |
|
476 def after_entity_insertion(self, session, dn, entity): |
|
477 """called by the repository after an entity stored here has been |
|
478 inserted in the system table. |
|
479 """ |
|
480 super(LDAPUserSource, self).after_entity_insertion(session, dn, entity) |
|
481 for group in self.user_default_groups: |
|
482 session.execute('SET X in_group G WHERE X eid %(x)s, G name %(group)s', |
|
483 {'x': entity.eid, 'group': group}, 'x') |
|
484 # search for existant email first |
|
485 try: |
|
486 emailaddr = self._cache[dn][self.user_rev_attrs['email']] |
|
487 except KeyError: |
|
488 return |
|
489 rset = session.execute('EmailAddress X WHERE X address %(addr)s', |
|
490 {'addr': emailaddr}) |
|
491 if rset: |
|
492 session.execute('SET U primary_email X WHERE U eid %(u)s, X eid %(x)s', |
|
493 {'x': rset[0][0], 'u': entity.eid}, 'u') |
|
494 else: |
|
495 # not found, create it |
|
496 _insert_email(session, emailaddr, entity.eid) |
|
497 |
|
498 def update_entity(self, session, entity): |
|
499 """replace an entity in the source""" |
|
500 raise RepositoryError('this source is read only') |
|
501 |
|
502 def delete_entity(self, session, etype, eid): |
|
503 """delete an entity from the source""" |
|
504 raise RepositoryError('this source is read only') |
|
505 |
|
506 def _insert_email(session, emailaddr, ueid): |
|
507 session.execute('INSERT EmailAddress X: X address %(addr)s, U primary_email X ' |
|
508 'WHERE U eid %(x)s', {'addr': emailaddr, 'x': ueid}, 'x') |
|
509 |
|
510 class GotDN(Exception): |
|
511 """exception used when a dn localizing the searched user has been found""" |
|
512 def __init__(self, dn): |
|
513 self.dn = dn |
|
514 |
|
515 |
|
516 class RQL2LDAPFilter(object): |
|
517 """generate an LDAP filter for a rql query""" |
|
518 def __init__(self, source, session, args=None, mainvars=()): |
|
519 self.source = source |
|
520 self._ldap_attrs = source.user_rev_attrs |
|
521 self._base_filters = source.base_filters |
|
522 self._session = session |
|
523 if args is None: |
|
524 args = {} |
|
525 self._args = args |
|
526 self.mainvars = mainvars |
|
527 |
|
528 def generate(self, selection, mainvarname): |
|
529 self._filters = res = self._base_filters[:] |
|
530 self._mainvarname = mainvarname |
|
531 self._eidfilters = [] |
|
532 self._done_not = set() |
|
533 restriction = selection.where |
|
534 if isinstance(restriction, Relation): |
|
535 # only a single relation, need to append result here (no AND/OR) |
|
536 filter = restriction.accept(self) |
|
537 if filter is not None: |
|
538 res.append(filter) |
|
539 elif restriction: |
|
540 restriction.accept(self) |
|
541 if len(res) > 1: |
|
542 return self._eidfilters, '(&%s)' % ''.join(res) |
|
543 return self._eidfilters, res[0] |
|
544 |
|
545 def visit_and(self, et): |
|
546 """generate filter for a AND subtree""" |
|
547 for c in et.children: |
|
548 part = c.accept(self) |
|
549 if part: |
|
550 self._filters.append(part) |
|
551 |
|
552 def visit_or(self, ou): |
|
553 """generate filter for a OR subtree""" |
|
554 res = [] |
|
555 for c in ou.children: |
|
556 part = c.accept(self) |
|
557 if part: |
|
558 res.append(part) |
|
559 if res: |
|
560 if len(res) > 1: |
|
561 part = '(|%s)' % ''.join(res) |
|
562 else: |
|
563 part = res[0] |
|
564 self._filters.append(part) |
|
565 |
|
566 def visit_not(self, node): |
|
567 """generate filter for a OR subtree""" |
|
568 part = node.children[0].accept(self) |
|
569 if part: |
|
570 self._filters.append('(!(%s))'% part) |
|
571 |
|
572 def visit_relation(self, relation): |
|
573 """generate filter for a relation""" |
|
574 rtype = relation.r_type |
|
575 # don't care of type constraint statement (i.e. relation_type = 'is') |
|
576 if rtype == 'is': |
|
577 return '' |
|
578 lhs, rhs = relation.get_parts() |
|
579 # attribute relation |
|
580 if self.source.schema.rschema(rtype).is_final(): |
|
581 # dunno what to do here, don't pretend anything else |
|
582 if lhs.name != self._mainvarname: |
|
583 if lhs.name in self.mainvars: |
|
584 # XXX check we don't have variable as rhs |
|
585 return |
|
586 raise NotImplementedError |
|
587 rhs_vars = rhs.get_nodes(VariableRef) |
|
588 if rhs_vars: |
|
589 if len(rhs_vars) > 1: |
|
590 raise NotImplementedError |
|
591 # selected variable, nothing to do here |
|
592 return |
|
593 # no variables in the RHS |
|
594 if isinstance(rhs.children[0], Function): |
|
595 res = rhs.children[0].accept(self) |
|
596 elif rtype != 'has_text': |
|
597 res = self._visit_attribute_relation(relation) |
|
598 else: |
|
599 raise NotImplementedError(relation) |
|
600 # regular relation XXX todo: in_group |
|
601 else: |
|
602 raise NotImplementedError(relation) |
|
603 return res |
|
604 |
|
605 def _visit_attribute_relation(self, relation): |
|
606 """generate filter for an attribute relation""" |
|
607 lhs, rhs = relation.get_parts() |
|
608 lhsvar = lhs.variable |
|
609 if relation.r_type == 'eid': |
|
610 # XXX hack |
|
611 # skip comparison sign |
|
612 eid = int(rhs.children[0].accept(self)) |
|
613 if relation.neged(strict=True): |
|
614 self._done_not.add(relation.parent) |
|
615 self._eidfilters.append(lambda x: not x == eid) |
|
616 return |
|
617 if rhs.operator != '=': |
|
618 filter = {'>': lambda x: x > eid, |
|
619 '>=': lambda x: x >= eid, |
|
620 '<': lambda x: x < eid, |
|
621 '<=': lambda x: x <= eid, |
|
622 }[rhs.operator] |
|
623 self._eidfilters.append(filter) |
|
624 return |
|
625 dn = self.source.eid2extid(eid, self._session) |
|
626 raise GotDN(dn) |
|
627 try: |
|
628 filter = '(%s%s)' % (self._ldap_attrs[relation.r_type], |
|
629 rhs.accept(self)) |
|
630 except KeyError: |
|
631 assert relation.r_type == 'password' # 2.38 migration |
|
632 raise UnknownEid # trick to return no result |
|
633 return filter |
|
634 |
|
635 def visit_comparison(self, cmp): |
|
636 """generate filter for a comparaison""" |
|
637 return '%s%s'% (cmp.operator, cmp.children[0].accept(self)) |
|
638 |
|
639 def visit_mathexpression(self, mexpr): |
|
640 """generate filter for a mathematic expression""" |
|
641 raise NotImplementedError |
|
642 |
|
643 def visit_function(self, function): |
|
644 """generate filter name for a function""" |
|
645 if function.name == 'IN': |
|
646 return self.visit_in(function) |
|
647 raise NotImplementedError |
|
648 |
|
649 def visit_in(self, function): |
|
650 grandpapa = function.parent.parent |
|
651 ldapattr = self._ldap_attrs[grandpapa.r_type] |
|
652 res = [] |
|
653 for c in function.children: |
|
654 part = c.accept(self) |
|
655 if part: |
|
656 res.append(part) |
|
657 if res: |
|
658 if len(res) > 1: |
|
659 part = '(|%s)' % ''.join('(%s=%s)' % (ldapattr, v) for v in res) |
|
660 else: |
|
661 part = '(%s=%s)' % (ldapattr, res[0]) |
|
662 return part |
|
663 |
|
664 def visit_constant(self, constant): |
|
665 """generate filter name for a constant""" |
|
666 value = constant.value |
|
667 if constant.type is None: |
|
668 raise NotImplementedError |
|
669 if constant.type == 'Date': |
|
670 raise NotImplementedError |
|
671 #value = self.keyword_map[value]() |
|
672 elif constant.type == 'Substitute': |
|
673 value = self._args[constant.value] |
|
674 else: |
|
675 value = constant.value |
|
676 if isinstance(value, unicode): |
|
677 value = value.encode('utf8') |
|
678 else: |
|
679 value = str(value) |
|
680 return escape_filter_chars(value) |
|
681 |
|
682 def visit_variableref(self, variableref): |
|
683 """get the sql name for a variable reference""" |
|
684 pass |
|
685 |