server/sources/ldapuser.py
changeset 257 4c7d3af7e94d
child 938 a69188963ccb
equal deleted inserted replaced
256:3dbee583526c 257:4c7d3af7e94d
       
     1 """cubicweb ldap user source
       
     2 
       
     3 this source is for now limited to a read-only EUser source
       
     4 
       
     5 :organization: Logilab
       
     6 :copyright: 2003-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
       
     7 :contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
       
     8 
       
     9 
       
    10 Part of the code is coming form Zope's LDAPUserFolder
       
    11 
       
    12 Copyright (c) 2004 Jens Vagelpohl.
       
    13 All Rights Reserved.
       
    14 
       
    15 This software is subject to the provisions of the Zope Public License,
       
    16 Version 2.1 (ZPL).  A copy of the ZPL should accompany this distribution.
       
    17 THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
       
    18 WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
       
    19 WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
       
    20 FOR A PARTICULAR PURPOSE.
       
    21 """
       
    22 
       
    23 from mx.DateTime import now, DateTimeDelta
       
    24 
       
    25 from logilab.common.textutils import get_csv
       
    26 from rql.nodes import Relation, VariableRef, Constant, Function
       
    27 
       
    28 import ldap
       
    29 from ldap.ldapobject import ReconnectLDAPObject
       
    30 from ldap.filter import filter_format, escape_filter_chars
       
    31 from ldapurl import LDAPUrl
       
    32 
       
    33 from cubicweb.common import AuthenticationError, UnknownEid, RepositoryError
       
    34 from cubicweb.server.sources import AbstractSource, TrFunc, GlobTrFunc, ConnectionWrapper
       
    35 from cubicweb.server.utils import cartesian_product
       
    36 
       
    37 # search scopes
       
    38 BASE = ldap.SCOPE_BASE
       
    39 ONELEVEL = ldap.SCOPE_ONELEVEL
       
    40 SUBTREE = ldap.SCOPE_SUBTREE
       
    41 
       
    42 # XXX only for edition ??
       
    43 ## password encryption possibilities
       
    44 #ENCRYPTIONS = ('SHA', 'CRYPT', 'MD5', 'CLEAR') # , 'SSHA'
       
    45 
       
    46 # mode identifier : (port, protocol)
       
    47 MODES = {
       
    48     0: (389, 'ldap'),
       
    49     1: (636, 'ldaps'),
       
    50     2: (0,   'ldapi'),
       
    51     }
       
    52 
       
    53 class TimedCache(dict):
       
    54     def __init__(self, ttlm, ttls=0):
       
    55         # time to live in minutes
       
    56         self.ttl = DateTimeDelta(0, 0, ttlm, ttls)
       
    57         
       
    58     def __setitem__(self, key, value):
       
    59         dict.__setitem__(self, key, (now(), value))
       
    60         
       
    61     def __getitem__(self, key):
       
    62         return dict.__getitem__(self, key)[1]
       
    63     
       
    64     def clear_expired(self):
       
    65         now_ = now()
       
    66         ttl = self.ttl
       
    67         for key, (timestamp, value) in self.items():
       
    68             if now_ - timestamp > ttl:
       
    69                 del self[key]
       
    70                 
       
    71 class LDAPUserSource(AbstractSource):
       
    72     """LDAP read-only EUser source"""
       
    73     support_entities = {'EUser': False} 
       
    74 
       
    75     port = None
       
    76     
       
    77     cnx_mode = 0
       
    78     cnx_dn = ''
       
    79     cnx_pwd = ''
       
    80     
       
    81     options = (
       
    82         ('host',
       
    83          {'type' : 'string',
       
    84           'default': 'ldap',
       
    85           'help': 'ldap host',
       
    86           'group': 'ldap-source', 'inputlevel': 1,
       
    87           }),
       
    88         ('user-base-dn',
       
    89          {'type' : 'string',
       
    90           'default': 'ou=People,dc=logilab,dc=fr',
       
    91           'help': 'base DN to lookup for users',
       
    92           'group': 'ldap-source', 'inputlevel': 0,
       
    93           }),
       
    94         ('user-scope',
       
    95          {'type' : 'choice',
       
    96           'default': 'ONELEVEL',
       
    97           'choices': ('BASE', 'ONELEVEL', 'SUBTREE'),
       
    98           'help': 'user search scope',
       
    99           'group': 'ldap-source', 'inputlevel': 1,
       
   100           }),
       
   101         ('user-classes',
       
   102          {'type' : 'csv',
       
   103           'default': ('top', 'posixAccount'),
       
   104           'help': 'classes of user',
       
   105           'group': 'ldap-source', 'inputlevel': 1,
       
   106           }),
       
   107         ('user-login-attr',
       
   108          {'type' : 'string',
       
   109           'default': 'uid',
       
   110           'help': 'attribute used as login on authentication',
       
   111           'group': 'ldap-source', 'inputlevel': 1,
       
   112           }),
       
   113         ('user-default-group',
       
   114          {'type' : 'csv',
       
   115           'default': ('users',),
       
   116           'help': 'name of a group in which ldap users will be by default. \
       
   117 You can set multiple groups by separating them by a comma.',
       
   118           'group': 'ldap-source', 'inputlevel': 1,
       
   119           }),
       
   120         ('user-attrs-map',
       
   121          {'type' : 'named',
       
   122           'default': {'uid': 'login', 'gecos': 'email'},
       
   123           'help': 'map from ldap user attributes to cubicweb attributes',
       
   124           'group': 'ldap-source', 'inputlevel': 1,
       
   125           }),
       
   126 
       
   127         ('synchronization-interval',
       
   128          {'type' : 'int',
       
   129           'default': 24*60*60,
       
   130           'help': 'interval between synchronization with the ldap \
       
   131 directory (default to once a day).',
       
   132           'group': 'ldap-source', 'inputlevel': 2,
       
   133           }),
       
   134         ('cache-life-time',
       
   135          {'type' : 'int',
       
   136           'default': 2*60,
       
   137           'help': 'life time of query cache in minutes (default to two hours).',
       
   138           'group': 'ldap-source', 'inputlevel': 2,
       
   139           }),
       
   140         
       
   141     )
       
   142             
       
   143     def __init__(self, repo, appschema, source_config, *args, **kwargs):
       
   144         AbstractSource.__init__(self, repo, appschema, source_config,
       
   145                                 *args, **kwargs)
       
   146         self.host = source_config['host']
       
   147         self.user_base_dn = source_config['user-base-dn']
       
   148         self.user_base_scope = globals()[source_config['user-scope']]
       
   149         self.user_classes = get_csv(source_config['user-classes'])
       
   150         self.user_login_attr = source_config['user-login-attr']
       
   151         self.user_default_groups = get_csv(source_config['user-default-group'])
       
   152         self.user_attrs = dict(v.split(':', 1) for v in get_csv(source_config['user-attrs-map']))
       
   153         self.user_rev_attrs = {'eid': 'dn'}
       
   154         for ldapattr, cwattr in self.user_attrs.items():
       
   155             self.user_rev_attrs[cwattr] = ldapattr
       
   156         self.base_filters = [filter_format('(%s=%s)', ('objectClass', o))
       
   157                               for o in self.user_classes]
       
   158         self._conn = None
       
   159         self._cache = {}
       
   160         ttlm = int(source_config.get('cache-life-type', 2*60))
       
   161         self._query_cache = TimedCache(ttlm)
       
   162         self._interval = int(source_config.get('synchronization-interval',
       
   163                                                24*60*60))
       
   164 
       
   165     def reset_caches(self):
       
   166         """method called during test to reset potential source caches"""
       
   167         self._query_cache = TimedCache(2*60)
       
   168 
       
   169     def init(self):
       
   170         """method called by the repository once ready to handle request"""
       
   171         self.repo.looping_task(self._interval, self.synchronize) 
       
   172         self.repo.looping_task(self._query_cache.ttl.seconds/10, self._query_cache.clear_expired) 
       
   173 
       
   174     def synchronize(self):
       
   175         """synchronize content known by this repository with content in the
       
   176         external repository
       
   177         """
       
   178         self.info('synchronizing ldap source %s', self.uri)
       
   179         session = self.repo.internal_session()
       
   180         try:
       
   181             cursor = session.system_sql("SELECT eid, extid FROM entities WHERE "
       
   182                                         "source='%s'" % self.uri)
       
   183             for eid, extid in cursor.fetchall():
       
   184                 # if no result found, _search automatically delete entity information
       
   185                 res = self._search(session, extid, BASE)
       
   186                 if res: 
       
   187                     ldapemailaddr = res[0].get(self.user_rev_attrs['email'])
       
   188                     if ldapemailaddr:
       
   189                         rset = session.execute('EmailAddress X,A WHERE '
       
   190                                                'U use_email X, U eid %(u)s',
       
   191                                                {'u': eid})
       
   192                         ldapemailaddr = unicode(ldapemailaddr)
       
   193                         for emaileid, emailaddr in rset:
       
   194                             if emailaddr == ldapemailaddr:
       
   195                                 break
       
   196                         else:
       
   197                             self.info('updating email address of user %s to %s',
       
   198                                       extid, ldapemailaddr)
       
   199                             if rset:
       
   200                                 session.execute('SET X address %(addr)s WHERE '
       
   201                                                 'U primary_email X, U eid %(u)s',
       
   202                                                 {'addr': ldapemailaddr, 'u': eid})
       
   203                             else:
       
   204                                 # no email found, create it
       
   205                                 _insert_email(session, ldapemailaddr, eid)
       
   206         finally:
       
   207             session.commit()
       
   208             session.close()
       
   209             
       
   210     def get_connection(self):
       
   211         """open and return a connection to the source"""
       
   212         if self._conn is None:
       
   213             self._connect()
       
   214         return ConnectionWrapper(self._conn)
       
   215     
       
   216     def authenticate(self, session, login, password):
       
   217         """return EUser eid for the given login/password if this account is
       
   218         defined in this source, else raise `AuthenticationError`
       
   219 
       
   220         two queries are needed since passwords are stored crypted, so we have
       
   221         to fetch the salt first
       
   222         """
       
   223         assert login, 'no login!'
       
   224         searchfilter = [filter_format('(%s=%s)', (self.user_login_attr, login))]
       
   225         searchfilter.extend([filter_format('(%s=%s)', ('objectClass', o))
       
   226                              for o in self.user_classes])
       
   227         searchstr = '(&%s)' % ''.join(searchfilter)
       
   228         # first search the user
       
   229         try:
       
   230             user = self._search(session, self.user_base_dn,
       
   231                                 self.user_base_scope, searchstr)[0]
       
   232         except IndexError:
       
   233             # no such user
       
   234             raise AuthenticationError()
       
   235         # check password by establishing a (unused) connection
       
   236         try:
       
   237             self._connect(user['dn'], password)
       
   238         except:
       
   239             # Something went wrong, most likely bad credentials
       
   240             raise AuthenticationError()
       
   241         return self.extid2eid(user['dn'], 'EUser', session)
       
   242 
       
   243     def ldap_name(self, var):
       
   244         if var.stinfo['relations']:
       
   245             relname = iter(var.stinfo['relations']).next().r_type
       
   246             return self.user_rev_attrs.get(relname)
       
   247         return None
       
   248         
       
   249     def prepare_columns(self, mainvars, rqlst):
       
   250         """return two list describin how to build the final results
       
   251         from the result of an ldap search (ie a list of dictionnary)
       
   252         """
       
   253         columns = []
       
   254         global_transforms = []
       
   255         for i, term in enumerate(rqlst.selection):
       
   256             if isinstance(term, Constant):
       
   257                 columns.append(term)
       
   258                 continue
       
   259             if isinstance(term, Function): # LOWER, UPPER, COUNT...
       
   260                 var = term.get_nodes(VariableRef)[0]
       
   261                 var = var.variable
       
   262                 try:
       
   263                     mainvar = var.stinfo['attrvar'].name
       
   264                 except AttributeError: # no attrvar set
       
   265                     mainvar = var.name
       
   266                 assert mainvar in mainvars
       
   267                 trname = term.name
       
   268                 ldapname = self.ldap_name(var)
       
   269                 if trname in ('COUNT', 'MIN', 'MAX', 'SUM'):
       
   270                     global_transforms.append(GlobTrFunc(trname, i, ldapname))
       
   271                     columns.append((mainvar, ldapname))
       
   272                     continue
       
   273                 if trname in ('LOWER', 'UPPER'):
       
   274                     columns.append((mainvar, TrFunc(trname, i, ldapname)))
       
   275                     continue
       
   276                 raise NotImplementedError('no support for %s function' % trname)
       
   277             if term.name in mainvars:
       
   278                 columns.append((term.name, 'dn'))
       
   279                 continue
       
   280             var = term.variable
       
   281             mainvar = var.stinfo['attrvar'].name
       
   282             columns.append((mainvar, self.ldap_name(var)))
       
   283             #else:
       
   284             #    # probably a bug in rql splitting if we arrive here
       
   285             #    raise NotImplementedError
       
   286         return columns, global_transforms
       
   287     
       
   288     def syntax_tree_search(self, session, union,
       
   289                            args=None, cachekey=None, varmap=None, debug=0):
       
   290         """return result from this source for a rql query (actually from a rql 
       
   291         syntax tree and a solution dictionary mapping each used variable to a 
       
   292         possible type). If cachekey is given, the query necessary to fetch the
       
   293         results (but not the results themselves) may be cached using this key.
       
   294         """
       
   295         # XXX not handled : transform/aggregat function, join on multiple users...
       
   296         assert len(union.children) == 1, 'union not supported'
       
   297         rqlst = union.children[0]
       
   298         assert not rqlst.with_, 'subquery not supported'
       
   299         rqlkey = rqlst.as_string(kwargs=args)
       
   300         try:
       
   301             results = self._query_cache[rqlkey]
       
   302         except KeyError:
       
   303             results = self.rqlst_search(session, rqlst, args)
       
   304             self._query_cache[rqlkey] = results
       
   305         return results
       
   306 
       
   307     def rqlst_search(self, session, rqlst, args):
       
   308         mainvars = []
       
   309         for varname in rqlst.defined_vars:
       
   310             for sol in rqlst.solutions:
       
   311                 if sol[varname] == 'EUser':
       
   312                     mainvars.append(varname)
       
   313                     break
       
   314         assert mainvars
       
   315         columns, globtransforms = self.prepare_columns(mainvars, rqlst)
       
   316         eidfilters = []
       
   317         allresults = []
       
   318         generator = RQL2LDAPFilter(self, session, args, mainvars)
       
   319         for mainvar in mainvars:
       
   320             # handle restriction
       
   321             try:
       
   322                 eidfilters_, ldapfilter = generator.generate(rqlst, mainvar)
       
   323             except GotDN, ex:
       
   324                 assert ex.dn, 'no dn!'
       
   325                 try:
       
   326                     res = [self._cache[ex.dn]]
       
   327                 except KeyError:
       
   328                     res = self._search(session, ex.dn, BASE)
       
   329             except UnknownEid, ex:
       
   330                 # raised when we are looking for the dn of an eid which is not
       
   331                 # coming from this source
       
   332                 res = []
       
   333             else:
       
   334                 eidfilters += eidfilters_
       
   335                 res = self._search(session, self.user_base_dn,
       
   336                                    self.user_base_scope, ldapfilter)
       
   337             allresults.append(res)
       
   338         # 1. get eid for each dn and filter according to that eid if necessary
       
   339         for i, res in enumerate(allresults):
       
   340             filteredres = []
       
   341             for resdict in res:
       
   342                 # get sure the entity exists in the system table
       
   343                 eid = self.extid2eid(resdict['dn'], 'EUser', session)
       
   344                 for eidfilter in eidfilters:
       
   345                     if not eidfilter(eid):
       
   346                         break
       
   347                 else:
       
   348                     resdict['eid'] = eid
       
   349                     filteredres.append(resdict)
       
   350             allresults[i] = filteredres
       
   351         # 2. merge result for each "mainvar": cartesian product
       
   352         allresults = cartesian_product(allresults)
       
   353         # 3. build final result according to column definition
       
   354         result = []
       
   355         for rawline in allresults:
       
   356             rawline = dict(zip(mainvars, rawline))
       
   357             line = []
       
   358             for varname, ldapname in columns:
       
   359                 if ldapname is None:
       
   360                     value = None # no mapping available
       
   361                 elif ldapname == 'dn':
       
   362                     value = rawline[varname]['eid']
       
   363                 elif isinstance(ldapname, Constant):
       
   364                     if ldapname.type == 'Substitute':
       
   365                         value = args[ldapname.value]
       
   366                     else:
       
   367                         value = ldapname.value
       
   368                 elif isinstance(ldapname, TrFunc):
       
   369                     value = ldapname.apply(rawline[varname])
       
   370                 else:
       
   371                     value = rawline[varname].get(ldapname)
       
   372                 line.append(value)
       
   373             result.append(line)
       
   374         for trfunc in globtransforms:
       
   375             result = trfunc.apply(result)
       
   376         #print '--> ldap result', result
       
   377         return result
       
   378                 
       
   379     
       
   380     def _connect(self, userdn=None, userpwd=None):
       
   381         port, protocol = MODES[self.cnx_mode]
       
   382         if protocol == 'ldapi':
       
   383             hostport = self.host
       
   384         else:
       
   385             hostport = '%s:%s' % (self.host, self.port or port)
       
   386         self.info('connecting %s://%s as %s', protocol, hostport,
       
   387                   userdn or 'anonymous')
       
   388         url = LDAPUrl(urlscheme=protocol, hostport=hostport)
       
   389         conn = ReconnectLDAPObject(url.initializeUrl())
       
   390         # Set the protocol version - version 3 is preferred
       
   391         try:
       
   392             conn.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION3)
       
   393         except ldap.LDAPError: # Invalid protocol version, fall back safely
       
   394             conn.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION2)
       
   395         # Deny auto-chasing of referrals to be safe, we handle them instead
       
   396         #try:
       
   397         #    connection.set_option(ldap.OPT_REFERRALS, 0)
       
   398         #except ldap.LDAPError: # Cannot set referrals, so do nothing
       
   399         #    pass
       
   400         #conn.set_option(ldap.OPT_NETWORK_TIMEOUT, conn_timeout)
       
   401         #conn.timeout = op_timeout
       
   402         # Now bind with the credentials given. Let exceptions propagate out.
       
   403         if userdn is None:
       
   404             assert self._conn is None
       
   405             self._conn = conn
       
   406             userdn = self.cnx_dn
       
   407             userpwd = self.cnx_pwd
       
   408         conn.simple_bind_s(userdn, userpwd)
       
   409         return conn
       
   410 
       
   411     def _search(self, session, base, scope,
       
   412                 searchstr='(objectClass=*)', attrs=()):
       
   413         """make an ldap query"""
       
   414         cnx = session.pool.connection(self.uri).cnx
       
   415         try:
       
   416             res = cnx.search_s(base, scope, searchstr, attrs)
       
   417         except ldap.PARTIAL_RESULTS:
       
   418             res = cnx.result(all=0)[1]
       
   419         except ldap.NO_SUCH_OBJECT:
       
   420             eid = self.extid2eid(base, 'EUser', session, insert=False)
       
   421             if eid:
       
   422                 self.warning('deleting ldap user with eid %s and dn %s',
       
   423                              eid, base)
       
   424                 self.repo.delete_info(session, eid)
       
   425                 self._cache.pop(base, None)
       
   426             return []
       
   427 ##         except ldap.REFERRAL, e:
       
   428 ##             cnx = self.handle_referral(e)
       
   429 ##             try:
       
   430 ##                 res = cnx.search_s(base, scope, searchstr, attrs)
       
   431 ##             except ldap.PARTIAL_RESULTS:
       
   432 ##                 res_type, res = cnx.result(all=0)
       
   433         result = []
       
   434         for rec_dn, rec_dict in res:
       
   435             # When used against Active Directory, "rec_dict" may not be
       
   436             # be a dictionary in some cases (instead, it can be a list)
       
   437             # An example of a useless "res" entry that can be ignored
       
   438             # from AD is
       
   439             # (None, ['ldap://ForestDnsZones.PORTAL.LOCAL/DC=ForestDnsZones,DC=PORTAL,DC=LOCAL'])
       
   440             # This appears to be some sort of internal referral, but
       
   441             # we can't handle it, so we need to skip over it.
       
   442             try:
       
   443                 items =  rec_dict.items()
       
   444             except AttributeError:
       
   445                 # 'items' not found on rec_dict, skip
       
   446                 continue
       
   447             for key, value in items: # XXX syt: huuum ?
       
   448                 if not isinstance(value, str):
       
   449                     try:
       
   450                         for i in range(len(value)):
       
   451                             value[i] = unicode(value[i], 'utf8')
       
   452                     except:
       
   453                         pass
       
   454                 if isinstance(value, list) and len(value) == 1:
       
   455                     rec_dict[key] = value = value[0]
       
   456             rec_dict['dn'] = rec_dn
       
   457             self._cache[rec_dn] = rec_dict
       
   458             result.append(rec_dict)
       
   459         #print '--->', result
       
   460         return result
       
   461     
       
   462     def before_entity_insertion(self, session, lid, etype, eid):
       
   463         """called by the repository when an eid has been attributed for an
       
   464         entity stored here but the entity has not been inserted in the system
       
   465         table yet.
       
   466         
       
   467         This method must return the an Entity instance representation of this
       
   468         entity.
       
   469         """
       
   470         entity = super(LDAPUserSource, self).before_entity_insertion(session, lid, etype, eid)
       
   471         res = self._search(session, lid, BASE)[0]
       
   472         for attr in entity.e_schema.indexable_attributes():
       
   473             entity[attr] = res[self.user_rev_attrs[attr]]
       
   474         return entity
       
   475     
       
   476     def after_entity_insertion(self, session, dn, entity):
       
   477         """called by the repository after an entity stored here has been
       
   478         inserted in the system table.
       
   479         """
       
   480         super(LDAPUserSource, self).after_entity_insertion(session, dn, entity)
       
   481         for group in self.user_default_groups:
       
   482             session.execute('SET X in_group G WHERE X eid %(x)s, G name %(group)s',
       
   483                             {'x': entity.eid, 'group': group}, 'x')
       
   484         # search for existant email first
       
   485         try:
       
   486             emailaddr = self._cache[dn][self.user_rev_attrs['email']]
       
   487         except KeyError:
       
   488             return
       
   489         rset = session.execute('EmailAddress X WHERE X address %(addr)s',
       
   490                                {'addr': emailaddr})
       
   491         if rset:
       
   492             session.execute('SET U primary_email X WHERE U eid %(u)s, X eid %(x)s',
       
   493                             {'x': rset[0][0], 'u': entity.eid}, 'u')
       
   494         else:
       
   495             # not found, create it
       
   496             _insert_email(session, emailaddr, entity.eid)
       
   497 
       
   498     def update_entity(self, session, entity):
       
   499         """replace an entity in the source"""
       
   500         raise RepositoryError('this source is read only')
       
   501 
       
   502     def delete_entity(self, session, etype, eid):
       
   503         """delete an entity from the source"""
       
   504         raise RepositoryError('this source is read only')
       
   505 
       
   506 def _insert_email(session, emailaddr, ueid):
       
   507     session.execute('INSERT EmailAddress X: X address %(addr)s, U primary_email X '
       
   508                     'WHERE U eid %(x)s', {'addr': emailaddr, 'x': ueid}, 'x')
       
   509     
       
   510 class GotDN(Exception):
       
   511     """exception used when a dn localizing the searched user has been found"""
       
   512     def __init__(self, dn):
       
   513         self.dn = dn
       
   514 
       
   515         
       
   516 class RQL2LDAPFilter(object):
       
   517     """generate an LDAP filter for a rql query"""
       
   518     def __init__(self, source, session, args=None, mainvars=()):
       
   519         self.source = source
       
   520         self._ldap_attrs = source.user_rev_attrs
       
   521         self._base_filters = source.base_filters
       
   522         self._session = session
       
   523         if args is None:
       
   524             args = {}
       
   525         self._args = args
       
   526         self.mainvars = mainvars
       
   527         
       
   528     def generate(self, selection, mainvarname):
       
   529         self._filters = res = self._base_filters[:]
       
   530         self._mainvarname = mainvarname
       
   531         self._eidfilters = []
       
   532         self._done_not = set()
       
   533         restriction = selection.where
       
   534         if isinstance(restriction, Relation):
       
   535             # only a single relation, need to append result here (no AND/OR)
       
   536             filter = restriction.accept(self)
       
   537             if filter is not None:
       
   538                 res.append(filter)
       
   539         elif restriction:
       
   540             restriction.accept(self)
       
   541         if len(res) > 1:
       
   542             return self._eidfilters, '(&%s)' % ''.join(res)
       
   543         return self._eidfilters, res[0]
       
   544     
       
   545     def visit_and(self, et):
       
   546         """generate filter for a AND subtree"""
       
   547         for c in et.children:
       
   548             part = c.accept(self)
       
   549             if part:
       
   550                 self._filters.append(part)
       
   551 
       
   552     def visit_or(self, ou):
       
   553         """generate filter for a OR subtree"""
       
   554         res = []
       
   555         for c in ou.children:
       
   556             part = c.accept(self)
       
   557             if part:
       
   558                 res.append(part)
       
   559         if res:
       
   560             if len(res) > 1:
       
   561                 part = '(|%s)' % ''.join(res)
       
   562             else:
       
   563                 part = res[0]
       
   564             self._filters.append(part)
       
   565 
       
   566     def visit_not(self, node):
       
   567         """generate filter for a OR subtree"""
       
   568         part = node.children[0].accept(self)
       
   569         if part:
       
   570             self._filters.append('(!(%s))'% part)
       
   571 
       
   572     def visit_relation(self, relation):
       
   573         """generate filter for a relation"""
       
   574         rtype = relation.r_type
       
   575         # don't care of type constraint statement (i.e. relation_type = 'is')
       
   576         if rtype == 'is':
       
   577             return ''
       
   578         lhs, rhs = relation.get_parts()
       
   579         # attribute relation
       
   580         if self.source.schema.rschema(rtype).is_final():
       
   581             # dunno what to do here, don't pretend anything else
       
   582             if lhs.name != self._mainvarname:
       
   583                 if lhs.name in self.mainvars:
       
   584                     # XXX check we don't have variable as rhs
       
   585                     return
       
   586                 raise NotImplementedError
       
   587             rhs_vars = rhs.get_nodes(VariableRef)
       
   588             if rhs_vars:
       
   589                 if len(rhs_vars) > 1:
       
   590                     raise NotImplementedError
       
   591                 # selected variable, nothing to do here
       
   592                 return
       
   593             # no variables in the RHS
       
   594             if isinstance(rhs.children[0], Function):
       
   595                 res = rhs.children[0].accept(self)
       
   596             elif rtype != 'has_text':
       
   597                 res = self._visit_attribute_relation(relation)
       
   598             else:
       
   599                 raise NotImplementedError(relation)
       
   600         # regular relation XXX todo: in_group
       
   601         else:
       
   602             raise NotImplementedError(relation)
       
   603         return res
       
   604         
       
   605     def _visit_attribute_relation(self, relation):
       
   606         """generate filter for an attribute relation"""
       
   607         lhs, rhs = relation.get_parts()
       
   608         lhsvar = lhs.variable
       
   609         if relation.r_type == 'eid':
       
   610             # XXX hack
       
   611             # skip comparison sign
       
   612             eid = int(rhs.children[0].accept(self))
       
   613             if relation.neged(strict=True):
       
   614                 self._done_not.add(relation.parent)
       
   615                 self._eidfilters.append(lambda x: not x == eid)
       
   616                 return
       
   617             if rhs.operator != '=':
       
   618                 filter = {'>': lambda x: x > eid,
       
   619                           '>=': lambda x: x >= eid,
       
   620                           '<': lambda x: x < eid,
       
   621                           '<=': lambda x: x <= eid,
       
   622                           }[rhs.operator]
       
   623                 self._eidfilters.append(filter)
       
   624                 return
       
   625             dn = self.source.eid2extid(eid, self._session)
       
   626             raise GotDN(dn)
       
   627         try:
       
   628             filter = '(%s%s)' % (self._ldap_attrs[relation.r_type],
       
   629                                  rhs.accept(self))
       
   630         except KeyError:
       
   631             assert relation.r_type == 'password' # 2.38 migration
       
   632             raise UnknownEid # trick to return no result
       
   633         return filter
       
   634 
       
   635     def visit_comparison(self, cmp):
       
   636         """generate filter for a comparaison"""
       
   637         return '%s%s'% (cmp.operator, cmp.children[0].accept(self))            
       
   638 
       
   639     def visit_mathexpression(self, mexpr):
       
   640         """generate filter for a mathematic expression"""
       
   641         raise NotImplementedError
       
   642         
       
   643     def visit_function(self, function):
       
   644         """generate filter name for a function"""
       
   645         if function.name == 'IN':
       
   646             return self.visit_in(function)
       
   647         raise NotImplementedError
       
   648         
       
   649     def visit_in(self, function):
       
   650         grandpapa = function.parent.parent
       
   651         ldapattr = self._ldap_attrs[grandpapa.r_type]
       
   652         res = []
       
   653         for c in function.children:
       
   654             part = c.accept(self)
       
   655             if part:
       
   656                 res.append(part)
       
   657         if res:
       
   658             if len(res) > 1:
       
   659                 part = '(|%s)' % ''.join('(%s=%s)' % (ldapattr, v) for v in res)
       
   660             else:
       
   661                 part = '(%s=%s)' % (ldapattr, res[0])
       
   662         return part
       
   663         
       
   664     def visit_constant(self, constant):
       
   665         """generate filter name for a constant"""
       
   666         value = constant.value
       
   667         if constant.type is None:
       
   668             raise NotImplementedError
       
   669         if constant.type == 'Date':
       
   670             raise NotImplementedError
       
   671             #value = self.keyword_map[value]()
       
   672         elif constant.type == 'Substitute':
       
   673             value = self._args[constant.value]
       
   674         else:
       
   675             value = constant.value
       
   676         if isinstance(value, unicode):
       
   677             value = value.encode('utf8')
       
   678         else:
       
   679             value = str(value)
       
   680         return escape_filter_chars(value)
       
   681         
       
   682     def visit_variableref(self, variableref):
       
   683         """get the sql name for a variable reference"""
       
   684         pass
       
   685