changeset 9542 79b9bf88be28
parent 9540 43b4895a150f
parent 9541 e8040107b97e
child 9543 39f981482e34
child 9558 1a719ca9c585
equal deleted inserted replaced
9540:43b4895a150f 9542:79b9bf88be28
     1 # copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
     2 # contact --
     3 #
     4 # This file is part of CubicWeb.
     5 #
     6 # CubicWeb is free software: you can redistribute it and/or modify it under the
     7 # terms of the GNU Lesser General Public License as published by the Free
     8 # Software Foundation, either version 2.1 of the License, or (at your option)
     9 # any later version.
    10 #
    11 # CubicWeb is distributed in the hope that it will be useful, but WITHOUT
    12 # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
    13 # FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
    14 # details.
    15 #
    16 # You should have received a copy of the GNU Lesser General Public License along
    17 # with CubicWeb.  If not, see <>.
    18 """cubicweb ldap user source
    20 this source is for now limited to a read-only CWUser source
    21 """
    22 from __future__ import division, with_statement
    23 from base64 import b64decode
    25 import ldap
    26 from ldap.filter import escape_filter_chars
    28 from rql.nodes import Relation, VariableRef, Constant, Function
    30 import warnings
    31 from cubicweb import UnknownEid, RepositoryError
    32 from cubicweb.server import ldaputils
    33 from cubicweb.server.utils import cartesian_product
    34 from cubicweb.server.sources import (AbstractSource, TrFunc, GlobTrFunc,
    35                                      TimedCache)
    37 # search scopes
    38 BASE = ldap.SCOPE_BASE
    42 # map ldap protocol to their standard port
    43 PROTO_PORT = {'ldap': 389,
    44               'ldaps': 636,
    45               'ldapi': None,
    46               }
    49 # module is lazily imported
    50 warnings.warn('Imminent drop of ldapuser. Switch to ldapfeed now!',
    51               DeprecationWarning)
    54 class LDAPUserSource(ldaputils.LDAPSourceMixIn, AbstractSource):
    55     """LDAP read-only CWUser source"""
    56     support_entities = {'CWUser': False}
    58     options = ldaputils.LDAPSourceMixIn.options + (
    60         ('synchronization-interval',
    61          {'type' : 'time',
    62           'default': '1d',
    63           'help': 'interval between synchronization with the ldap \
    64 directory (default to once a day).',
    65           'group': 'ldap-source', 'level': 3,
    66           }),
    67         ('cache-life-time',
    68          {'type' : 'time',
    69           'default': '2h',
    70           'help': 'life time of query cache (default to two hours).',
    71           'group': 'ldap-source', 'level': 3,
    72           }),
    74     )
    76     def update_config(self, source_entity, typedconfig):
    77         """update configuration from source entity. `typedconfig` is config
    78         properly typed with defaults set
    79         """
    80         super(LDAPUserSource, self).update_config(source_entity, typedconfig)
    81         self._interval = typedconfig['synchronization-interval']
    82         self._cache_ttl = max(71, typedconfig['cache-life-time'])
    83         self.reset_caches()
    84         # XXX copy from datafeed source
    85         if source_entity is not None:
    86             self._entity_update(source_entity)
    87         self.config = typedconfig
    88         # /end XXX
    90     def reset_caches(self):
    91         """method called during test to reset potential source caches"""
    92         self._cache = {}
    93         self._query_cache = TimedCache(self._cache_ttl)
    95     def init(self, activated, source_entity):
    96         """method called by the repository once ready to handle request"""
    97         super(LDAPUserSource, self).init(activated, source_entity)
    98         if activated:
    99   'ldap init')
   100             # set minimum period of 5min 1s (the additional second is to
   101             # minimize resonnance effet)
   102             if self.user_rev_attrs['email']:
   103                 self.repo.looping_task(max(301, self._interval), self.synchronize)
   104             self.repo.looping_task(self._cache_ttl // 10,
   105                                    self._query_cache.clear_expired)
   107     def synchronize(self):
   108         with self.repo.internal_session() as session:
   109             self.pull_data(session)
   111     def pull_data(self, session, force=False, raise_on_error=False):
   112         """synchronize content known by this repository with content in the
   113         external repository
   114         """
   115'synchronizing ldap source %s', self.uri)
   116         ldap_emailattr = self.user_rev_attrs['email']
   117         assert ldap_emailattr
   118         execute = session.execute
   119         cursor = session.system_sql("SELECT eid, extid FROM entities WHERE "
   120                                     "source='%s'" % self.uri)
   121         for eid, b64extid in cursor.fetchall():
   122             extid = b64decode(b64extid)
   123             self.debug('ldap eid %s', eid)
   124             # if no result found, _search automatically delete entity information
   125             res = self._search(session, extid, BASE)
   126             self.debug('ldap search %s', res)
   127             if res:
   128                 ldapemailaddr = res[0].get(ldap_emailattr)
   129                 if ldapemailaddr:
   130                     if isinstance(ldapemailaddr, list):
   131                         ldapemailaddr = ldapemailaddr[0] # XXX consider only the first email in the list
   132                     rset = execute('Any X,A WHERE '
   133                                    'X address A, U use_email X, U eid %(u)s',
   134                                    {'u': eid})
   135                     ldapemailaddr = unicode(ldapemailaddr)
   136                     for emaileid, emailaddr, in rset:
   137                         if emailaddr == ldapemailaddr:
   138                             break
   139                     else:
   140                         self.debug('updating email address of user %s to %s',
   141                                   extid, ldapemailaddr)
   142                         emailrset = execute('EmailAddress A WHERE A address %(addr)s',
   143                                             {'addr': ldapemailaddr})
   144                         if emailrset:
   145                             execute('SET U use_email X WHERE '
   146                                     'X eid %(x)s, U eid %(u)s',
   147                                     {'x': emailrset[0][0], 'u': eid})
   148                         elif rset:
   149                             if not execute('SET X address %(addr)s WHERE '
   150                                            'U primary_email X, U eid %(u)s',
   151                                            {'addr': ldapemailaddr, 'u': eid}):
   152                                 execute('SET X address %(addr)s WHERE '
   153                                         'X eid %(x)s',
   154                                         {'addr': ldapemailaddr, 'x': rset[0][0]})
   155                         else:
   156                             # no email found, create it
   157                             _insert_email(session, ldapemailaddr, eid)
   158         session.commit()
   160     def ldap_name(self, var):
   161         if var.stinfo['relations']:
   162             relname = iter(var.stinfo['relations']).next().r_type
   163             return self.user_rev_attrs.get(relname)
   164         return None
   166     def prepare_columns(self, mainvars, rqlst):
   167         """return two list describing how to build the final results
   168         from the result of an ldap search (ie a list of dictionary)
   169         """
   170         columns = []
   171         global_transforms = []
   172         for i, term in enumerate(rqlst.selection):
   173             if isinstance(term, Constant):
   174                 columns.append(term)
   175                 continue
   176             if isinstance(term, Function): # LOWER, UPPER, COUNT...
   177                 var = term.get_nodes(VariableRef)[0]
   178                 var = var.variable
   179                 try:
   180                     mainvar = var.stinfo['attrvar'].name
   181                 except AttributeError: # no attrvar set
   182                     mainvar =
   183                 assert mainvar in mainvars
   184                 trname =
   185                 ldapname = self.ldap_name(var)
   186                 if trname in ('COUNT', 'MIN', 'MAX', 'SUM'):
   187                     global_transforms.append(GlobTrFunc(trname, i, ldapname))
   188                     columns.append((mainvar, ldapname))
   189                     continue
   190                 if trname in ('LOWER', 'UPPER'):
   191                     columns.append((mainvar, TrFunc(trname, i, ldapname)))
   192                     continue
   193                 raise NotImplementedError('no support for %s function' % trname)
   194             if in mainvars:
   195                 columns.append((, 'dn'))
   196                 continue
   197             var = term.variable
   198             mainvar = var.stinfo['attrvar'].name
   199             columns.append((mainvar, self.ldap_name(var)))
   200             #else:
   201             #    # probably a bug in rql splitting if we arrive here
   202             #    raise NotImplementedError
   203         return columns, global_transforms
   205     def syntax_tree_search(self, session, union,
   206                            args=None, cachekey=None, varmap=None, debug=0):
   207         """return result from this source for a rql query (actually from a rql
   208         syntax tree and a solution dictionary mapping each used variable to a
   209         possible type). If cachekey is given, the query necessary to fetch the
   210         results (but not the results themselves) may be cached using this key.
   211         """
   212         self.debug('ldap syntax tree search')
   213         # XXX not handled : transform/aggregat function, join on multiple users...
   214         assert len(union.children) == 1, 'union not supported'
   215         rqlst = union.children[0]
   216         assert not rqlst.with_, 'subquery not supported'
   217         rqlkey = rqlst.as_string(kwargs=args)
   218         try:
   219             results = self._query_cache[rqlkey]
   220         except KeyError:
   221             try:
   222                 results = self.rqlst_search(session, rqlst, args)
   223                 self._query_cache[rqlkey] = results
   224             except ldap.SERVER_DOWN:
   225                 # cant connect to server
   226                 msg = session._("can't connect to source %s, some data may be missing")
   227                 session.set_shared_data('sources_error', msg % self.uri, txdata=True)
   228                 return []
   229         return results
   231     def rqlst_search(self, session, rqlst, args):
   232         mainvars = []
   233         for varname in rqlst.defined_vars:
   234             for sol in
   235                 if sol[varname] == 'CWUser':
   236                     mainvars.append(varname)
   237                     break
   238         assert mainvars, rqlst
   239         columns, globtransforms = self.prepare_columns(mainvars, rqlst)
   240         eidfilters = [lambda x: x > 0]
   241         allresults = []
   242         generator = RQL2LDAPFilter(self, session, args, mainvars)
   243         for mainvar in mainvars:
   244             # handle restriction
   245             try:
   246                 eidfilters_, ldapfilter = generator.generate(rqlst, mainvar)
   247             except GotDN as ex:
   248                 assert ex.dn, 'no dn!'
   249                 try:
   250                     res = [self._cache[ex.dn]]
   251                 except KeyError:
   252                     res = self._search(session, ex.dn, BASE)
   253             except UnknownEid as ex:
   254                 # raised when we are looking for the dn of an eid which is not
   255                 # coming from this source
   256                 res = []
   257             else:
   258                 eidfilters += eidfilters_
   259                 res = self._search(session, self.user_base_dn,
   260                                    self.user_base_scope, ldapfilter)
   261             allresults.append(res)
   262         # 1. get eid for each dn and filter according to that eid if necessary
   263         for i, res in enumerate(allresults):
   264             filteredres = []
   265             for resdict in res:
   266                 # get sure the entity exists in the system table
   267                 eid = self.repo.extid2eid(self, resdict['dn'], 'CWUser', session)
   268                 for eidfilter in eidfilters:
   269                     if not eidfilter(eid):
   270                         break
   271                 else:
   272                     resdict['eid'] = eid
   273                     filteredres.append(resdict)
   274             allresults[i] = filteredres
   275         # 2. merge result for each "mainvar": cartesian product
   276         allresults = cartesian_product(allresults)
   277         # 3. build final result according to column definition
   278         result = []
   279         for rawline in allresults:
   280             rawline = dict(zip(mainvars, rawline))
   281             line = []
   282             for varname, ldapname in columns:
   283                 if ldapname is None:
   284                     value = None # no mapping available
   285                 elif ldapname == 'dn':
   286                     value = rawline[varname]['eid']
   287                 elif isinstance(ldapname, Constant):
   288                     if ldapname.type == 'Substitute':
   289                         value = args[ldapname.value]
   290                     else:
   291                         value = ldapname.value
   292                 elif isinstance(ldapname, TrFunc):
   293                     value = ldapname.apply(rawline[varname])
   294                 else:
   295                     value = rawline[varname].get(ldapname)
   296                 line.append(value)
   297             result.append(line)
   298         for trfunc in globtransforms:
   299             result = trfunc.apply(result)
   300         #print '--> ldap result', result
   301         return result
   303     def _process_ldap_item(self, dn, iterator):
   304         itemdict = super(LDAPUserSource, self)._process_ldap_item(dn, iterator)
   305         self._cache[dn] = itemdict
   306         return itemdict
   308     def _process_no_such_object(self, session, dn):
   309         eid = self.repo.extid2eid(self, dn, 'CWUser', session, insert=False)
   310         if eid:
   311             self.warning('deleting ldap user with eid %s and dn %s', eid, dn)
   312             entity = session.entity_from_eid(eid, 'CWUser')
   313             self.repo.delete_info(session, entity, self.uri)
   314             self.reset_caches()
   316     def before_entity_insertion(self, session, lid, etype, eid, sourceparams):
   317         """called by the repository when an eid has been attributed for an
   318         entity stored here but the entity has not been inserted in the system
   319         table yet.
   321         This method must return the an Entity instance representation of this
   322         entity.
   323         """
   324         self.debug('ldap before entity insertion')
   325         entity = super(LDAPUserSource, self).before_entity_insertion(
   326             session, lid, etype, eid, sourceparams)
   327         res = self._search(session, lid, BASE)[0]
   328         for attr in entity.e_schema.indexable_attributes():
   329             entity.cw_edited[attr] = res[self.user_rev_attrs[attr]]
   330         return entity
   332     def after_entity_insertion(self, session, lid, entity, sourceparams):
   333         """called by the repository after an entity stored here has been
   334         inserted in the system table.
   335         """
   336         self.debug('ldap after entity insertion')
   337         super(LDAPUserSource, self).after_entity_insertion(
   338             session, lid, entity, sourceparams)
   339         for group in self.user_default_groups:
   340             session.execute('SET X in_group G WHERE X eid %(x)s, G name %(group)s',
   341                             {'x': entity.eid, 'group': group})
   342         # search for existant email first
   343         try:
   344             # lid = dn
   345             emailaddr = self._cache[lid][self.user_rev_attrs['email']]
   346         except KeyError:
   347             return
   348         if isinstance(emailaddr, list):
   349             emailaddr = emailaddr[0] # XXX consider only the first email in the list
   350         rset = session.execute('EmailAddress X WHERE X address %(addr)s',
   351                                {'addr': emailaddr})
   352         if rset:
   353             session.execute('SET U primary_email X WHERE U eid %(u)s, X eid %(x)s',
   354                             {'x': rset[0][0], 'u': entity.eid})
   355         else:
   356             # not found, create it
   357             _insert_email(session, emailaddr, entity.eid)
   359     def update_entity(self, session, entity):
   360         """replace an entity in the source"""
   361         raise RepositoryError('this source is read only')
   363     def delete_entity(self, session, entity):
   364         """delete an entity from the source"""
   365         raise RepositoryError('this source is read only')
   368 def _insert_email(session, emailaddr, ueid):
   369     session.execute('INSERT EmailAddress X: X address %(addr)s, U primary_email X '
   370                     'WHERE U eid %(x)s', {'addr': emailaddr, 'x': ueid})
   372 class GotDN(Exception):
   373     """exception used when a dn localizing the searched user has been found"""
   374     def __init__(self, dn):
   375         self.dn = dn
   378 class RQL2LDAPFilter(object):
   379     """generate an LDAP filter for a rql query"""
   380     def __init__(self, source, session, args=None, mainvars=()):
   381         self.source = source
   382         self.repo = source.repo
   383         self._ldap_attrs = source.user_rev_attrs
   384         self._base_filters = source.base_filters
   385         self._session = session
   386         if args is None:
   387             args = {}
   388         self._args = args
   389         self.mainvars = mainvars
   391     def generate(self, selection, mainvarname):
   392         self._filters = res = self._base_filters[:]
   393         self._mainvarname = mainvarname
   394         self._eidfilters = []
   395         self._done_not = set()
   396         restriction = selection.where
   397         if isinstance(restriction, Relation):
   398             # only a single relation, need to append result here (no AND/OR)
   399             filter = restriction.accept(self)
   400             if filter is not None:
   401                 res.append(filter)
   402         elif restriction:
   403             restriction.accept(self)
   404         if len(res) > 1:
   405             return self._eidfilters, '(&%s)' % ''.join(res)
   406         return self._eidfilters, res[0]
   408     def visit_and(self, et):
   409         """generate filter for a AND subtree"""
   410         for c in et.children:
   411             part = c.accept(self)
   412             if part:
   413                 self._filters.append(part)
   415     def visit_or(self, ou):
   416         """generate filter for a OR subtree"""
   417         res = []
   418         for c in ou.children:
   419             part = c.accept(self)
   420             if part:
   421                 res.append(part)
   422         if res:
   423             if len(res) > 1:
   424                 part = '(|%s)' % ''.join(res)
   425             else:
   426                 part = res[0]
   427             self._filters.append(part)
   429     def visit_not(self, node):
   430         """generate filter for a OR subtree"""
   431         part = node.children[0].accept(self)
   432         if part:
   433             self._filters.append('(!(%s))'% part)
   435     def visit_relation(self, relation):
   436         """generate filter for a relation"""
   437         rtype = relation.r_type
   438         # don't care of type constraint statement (i.e. relation_type = 'is')
   439         if rtype == 'is':
   440             return ''
   441         lhs, rhs = relation.get_parts()
   442         # attribute relation
   443         if self.source.schema.rschema(rtype).final:
   444             # dunno what to do here, don't pretend anything else
   445             if != self._mainvarname:
   446                 if in self.mainvars:
   447                     # XXX check we don't have variable as rhs
   448                     return
   449                 raise NotImplementedError
   450             rhs_vars = rhs.get_nodes(VariableRef)
   451             if rhs_vars:
   452                 if len(rhs_vars) > 1:
   453                     raise NotImplementedError
   454                 # selected variable, nothing to do here
   455                 return
   456             # no variables in the RHS
   457             if isinstance(rhs.children[0], Function):
   458                 res = rhs.children[0].accept(self)
   459             elif rtype != 'has_text':
   460                 res = self._visit_attribute_relation(relation)
   461             else:
   462                 raise NotImplementedError(relation)
   463         # regular relation XXX todo: in_group
   464         else:
   465             raise NotImplementedError(relation)
   466         return res
   468     def _visit_attribute_relation(self, relation):
   469         """generate filter for an attribute relation"""
   470         lhs, rhs = relation.get_parts()
   471         lhsvar = lhs.variable
   472         if relation.r_type == 'eid':
   473             # XXX hack
   474             # skip comparison sign
   475             eid = int(rhs.children[0].accept(self))
   476             if relation.neged(strict=True):
   477                 self._done_not.add(relation.parent)
   478                 self._eidfilters.append(lambda x: not x == eid)
   479                 return
   480             if rhs.operator != '=':
   481                 filter = {'>': lambda x: x > eid,
   482                           '>=': lambda x: x >= eid,
   483                           '<': lambda x: x < eid,
   484                           '<=': lambda x: x <= eid,
   485                           }[rhs.operator]
   486                 self._eidfilters.append(filter)
   487                 return
   488             dn = self.repo.eid2extid(self.source, eid, self._session)
   489             raise GotDN(dn)
   490         try:
   491             filter = '(%s%s)' % (self._ldap_attrs[relation.r_type],
   492                                  rhs.accept(self))
   493         except KeyError:
   494             # unsupported attribute
   495             self.source.warning('%s source can\'t handle relation %s, no '
   496                                 'results will be returned from this source',
   497                                 self.source.uri, relation)
   498             raise UnknownEid # trick to return no result
   499         return filter
   501     def visit_comparison(self, cmp):
   502         """generate filter for a comparaison"""
   503         return '%s%s'% (cmp.operator, cmp.children[0].accept(self))
   505     def visit_mathexpression(self, mexpr):
   506         """generate filter for a mathematic expression"""
   507         raise NotImplementedError
   509     def visit_function(self, function):
   510         """generate filter name for a function"""
   511         if == 'IN':
   512             return self.visit_in(function)
   513         raise NotImplementedError
   515     def visit_in(self, function):
   516         grandpapa = function.parent.parent
   517         ldapattr = self._ldap_attrs[grandpapa.r_type]
   518         res = []
   519         for c in function.children:
   520             part = c.accept(self)
   521             if part:
   522                 res.append(part)
   523         if res:
   524             if len(res) > 1:
   525                 part = '(|%s)' % ''.join('(%s=%s)' % (ldapattr, v) for v in res)
   526             else:
   527                 part = '(%s=%s)' % (ldapattr, res[0])
   528         return part
   530     def visit_constant(self, constant):
   531         """generate filter name for a constant"""
   532         value = constant.value
   533         if constant.type is None:
   534             raise NotImplementedError
   535         if constant.type == 'Date':
   536             raise NotImplementedError
   537             #value = self.keyword_map[value]()
   538         elif constant.type == 'Substitute':
   539             value = self._args[constant.value]
   540         else:
   541             value = constant.value
   542         if isinstance(value, unicode):
   543             value = value.encode('utf8')
   544         else:
   545             value = str(value)
   546         return escape_filter_chars(value)
   548     def visit_variableref(self, variableref):
   549         """get the sql name for a variable reference"""
   550         pass