server/sources/ldapuser.py
changeset 9015 65b8236e1bb4
parent 9014 dfa4da8a53a0
child 9016 0368b94921ed
child 9146 9b58a6406a64
equal deleted inserted replaced
9014:dfa4da8a53a0 9015:65b8236e1bb4
     1 # copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
       
     2 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
       
     3 #
       
     4 # This file is part of CubicWeb.
       
     5 #
       
     6 # CubicWeb is free software: you can redistribute it and/or modify it under the
       
     7 # terms of the GNU Lesser General Public License as published by the Free
       
     8 # Software Foundation, either version 2.1 of the License, or (at your option)
       
     9 # any later version.
       
    10 #
       
    11 # CubicWeb is distributed in the hope that it will be useful, but WITHOUT
       
    12 # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
       
    13 # FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
       
    14 # details.
       
    15 #
       
    16 # You should have received a copy of the GNU Lesser General Public License along
       
    17 # with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
       
    18 """cubicweb ldap user source
       
    19 
       
    20 this source is for now limited to a read-only CWUser source
       
    21 """
       
    22 from __future__ import division, with_statement
       
    23 from base64 import b64decode
       
    24 
       
    25 import ldap
       
    26 from ldap.filter import escape_filter_chars
       
    27 
       
    28 from rql.nodes import Relation, VariableRef, Constant, Function
       
    29 
       
    30 import warnings
       
    31 from cubicweb import UnknownEid, RepositoryError
       
    32 from cubicweb.server import ldaputils
       
    33 from cubicweb.server.utils import cartesian_product
       
    34 from cubicweb.server.sources import (AbstractSource, TrFunc, GlobTrFunc,
       
    35                                      TimedCache)
       
    36 
       
    37 # search scopes
       
    38 BASE = ldap.SCOPE_BASE
       
    39 ONELEVEL = ldap.SCOPE_ONELEVEL
       
    40 SUBTREE = ldap.SCOPE_SUBTREE
       
    41 
       
    42 # map ldap protocol to their standard port
       
    43 PROTO_PORT = {'ldap': 389,
       
    44               'ldaps': 636,
       
    45               'ldapi': None,
       
    46               }
       
    47 
       
    48 
       
    49 # module is lazily imported
       
    50 warnings.warn('Imminent drop of ldapuser. Switch to ldapfeed now!',
       
    51               DeprecationWarning)
       
    52 
       
    53 
       
    54 class LDAPUserSource(ldaputils.LDAPSourceMixIn, AbstractSource):
       
    55     """LDAP read-only CWUser source"""
       
    56     support_entities = {'CWUser': False}
       
    57 
       
    58     options = ldaputils.LDAPSourceMixIn.options + (
       
    59 
       
    60         ('synchronization-interval',
       
    61          {'type' : 'time',
       
    62           'default': '1d',
       
    63           'help': 'interval between synchronization with the ldap \
       
    64 directory (default to once a day).',
       
    65           'group': 'ldap-source', 'level': 3,
       
    66           }),
       
    67         ('cache-life-time',
       
    68          {'type' : 'time',
       
    69           'default': '2h',
       
    70           'help': 'life time of query cache (default to two hours).',
       
    71           'group': 'ldap-source', 'level': 3,
       
    72           }),
       
    73 
       
    74     )
       
    75 
       
    76     def update_config(self, source_entity, typedconfig):
       
    77         """update configuration from source entity. `typedconfig` is config
       
    78         properly typed with defaults set
       
    79         """
       
    80         super(LDAPUserSource, self).update_config(source_entity, typedconfig)
       
    81         self._interval = typedconfig['synchronization-interval']
       
    82         self._cache_ttl = max(71, typedconfig['cache-life-time'])
       
    83         self.reset_caches()
       
    84         # XXX copy from datafeed source
       
    85         if source_entity is not None:
       
    86             self._entity_update(source_entity)
       
    87         self.config = typedconfig
       
    88         # /end XXX
       
    89 
       
    90     def reset_caches(self):
       
    91         """method called during test to reset potential source caches"""
       
    92         self._cache = {}
       
    93         self._query_cache = TimedCache(self._cache_ttl)
       
    94 
       
    95     def init(self, activated, source_entity):
       
    96         """method called by the repository once ready to handle request"""
       
    97         super(LDAPUserSource, self).init(activated, source_entity)
       
    98         if activated:
       
    99             self.info('ldap init')
       
   100             # set minimum period of 5min 1s (the additional second is to
       
   101             # minimize resonnance effet)
       
   102             if self.user_rev_attrs['email']:
       
   103                 self.repo.looping_task(max(301, self._interval), self.synchronize)
       
   104             self.repo.looping_task(self._cache_ttl // 10,
       
   105                                    self._query_cache.clear_expired)
       
   106 
       
   107     def synchronize(self):
       
   108         with self.repo.internal_session() as session:
       
   109             self.pull_data(session)
       
   110 
       
   111     def pull_data(self, session, force=False, raise_on_error=False):
       
   112         """synchronize content known by this repository with content in the
       
   113         external repository
       
   114         """
       
   115         self.info('synchronizing ldap source %s', self.uri)
       
   116         ldap_emailattr = self.user_rev_attrs['email']
       
   117         assert ldap_emailattr
       
   118         execute = session.execute
       
   119         cursor = session.system_sql("SELECT eid, extid FROM entities WHERE "
       
   120                                     "source='%s'" % self.uri)
       
   121         for eid, b64extid in cursor.fetchall():
       
   122             extid = b64decode(b64extid)
       
   123             self.debug('ldap eid %s', eid)
       
   124             # if no result found, _search automatically delete entity information
       
   125             res = self._search(session, extid, BASE)
       
   126             self.debug('ldap search %s', res)
       
   127             if res:
       
   128                 ldapemailaddr = res[0].get(ldap_emailattr)
       
   129                 if ldapemailaddr:
       
   130                     if isinstance(ldapemailaddr, list):
       
   131                         ldapemailaddr = ldapemailaddr[0] # XXX consider only the first email in the list
       
   132                     rset = execute('Any X,A WHERE '
       
   133                                    'X address A, U use_email X, U eid %(u)s',
       
   134                                    {'u': eid})
       
   135                     ldapemailaddr = unicode(ldapemailaddr)
       
   136                     for emaileid, emailaddr, in rset:
       
   137                         if emailaddr == ldapemailaddr:
       
   138                             break
       
   139                     else:
       
   140                         self.debug('updating email address of user %s to %s',
       
   141                                   extid, ldapemailaddr)
       
   142                         emailrset = execute('EmailAddress A WHERE A address %(addr)s',
       
   143                                             {'addr': ldapemailaddr})
       
   144                         if emailrset:
       
   145                             execute('SET U use_email X WHERE '
       
   146                                     'X eid %(x)s, U eid %(u)s',
       
   147                                     {'x': emailrset[0][0], 'u': eid})
       
   148                         elif rset:
       
   149                             if not execute('SET X address %(addr)s WHERE '
       
   150                                            'U primary_email X, U eid %(u)s',
       
   151                                            {'addr': ldapemailaddr, 'u': eid}):
       
   152                                 execute('SET X address %(addr)s WHERE '
       
   153                                         'X eid %(x)s',
       
   154                                         {'addr': ldapemailaddr, 'x': rset[0][0]})
       
   155                         else:
       
   156                             # no email found, create it
       
   157                             _insert_email(session, ldapemailaddr, eid)
       
   158         session.commit()
       
   159 
       
   160     def ldap_name(self, var):
       
   161         if var.stinfo['relations']:
       
   162             relname = iter(var.stinfo['relations']).next().r_type
       
   163             return self.user_rev_attrs.get(relname)
       
   164         return None
       
   165 
       
   166     def prepare_columns(self, mainvars, rqlst):
       
   167         """return two list describing how to build the final results
       
   168         from the result of an ldap search (ie a list of dictionary)
       
   169         """
       
   170         columns = []
       
   171         global_transforms = []
       
   172         for i, term in enumerate(rqlst.selection):
       
   173             if isinstance(term, Constant):
       
   174                 columns.append(term)
       
   175                 continue
       
   176             if isinstance(term, Function): # LOWER, UPPER, COUNT...
       
   177                 var = term.get_nodes(VariableRef)[0]
       
   178                 var = var.variable
       
   179                 try:
       
   180                     mainvar = var.stinfo['attrvar'].name
       
   181                 except AttributeError: # no attrvar set
       
   182                     mainvar = var.name
       
   183                 assert mainvar in mainvars
       
   184                 trname = term.name
       
   185                 ldapname = self.ldap_name(var)
       
   186                 if trname in ('COUNT', 'MIN', 'MAX', 'SUM'):
       
   187                     global_transforms.append(GlobTrFunc(trname, i, ldapname))
       
   188                     columns.append((mainvar, ldapname))
       
   189                     continue
       
   190                 if trname in ('LOWER', 'UPPER'):
       
   191                     columns.append((mainvar, TrFunc(trname, i, ldapname)))
       
   192                     continue
       
   193                 raise NotImplementedError('no support for %s function' % trname)
       
   194             if term.name in mainvars:
       
   195                 columns.append((term.name, 'dn'))
       
   196                 continue
       
   197             var = term.variable
       
   198             mainvar = var.stinfo['attrvar'].name
       
   199             columns.append((mainvar, self.ldap_name(var)))
       
   200             #else:
       
   201             #    # probably a bug in rql splitting if we arrive here
       
   202             #    raise NotImplementedError
       
   203         return columns, global_transforms
       
   204 
       
   205     def syntax_tree_search(self, session, union,
       
   206                            args=None, cachekey=None, varmap=None, debug=0):
       
   207         """return result from this source for a rql query (actually from a rql
       
   208         syntax tree and a solution dictionary mapping each used variable to a
       
   209         possible type). If cachekey is given, the query necessary to fetch the
       
   210         results (but not the results themselves) may be cached using this key.
       
   211         """
       
   212         self.debug('ldap syntax tree search')
       
   213         # XXX not handled : transform/aggregat function, join on multiple users...
       
   214         assert len(union.children) == 1, 'union not supported'
       
   215         rqlst = union.children[0]
       
   216         assert not rqlst.with_, 'subquery not supported'
       
   217         rqlkey = rqlst.as_string(kwargs=args)
       
   218         try:
       
   219             results = self._query_cache[rqlkey]
       
   220         except KeyError:
       
   221             try:
       
   222                 results = self.rqlst_search(session, rqlst, args)
       
   223                 self._query_cache[rqlkey] = results
       
   224             except ldap.SERVER_DOWN:
       
   225                 # cant connect to server
       
   226                 msg = session._("can't connect to source %s, some data may be missing")
       
   227                 session.set_shared_data('sources_error', msg % self.uri, txdata=True)
       
   228                 return []
       
   229         return results
       
   230 
       
   231     def rqlst_search(self, session, rqlst, args):
       
   232         mainvars = []
       
   233         for varname in rqlst.defined_vars:
       
   234             for sol in rqlst.solutions:
       
   235                 if sol[varname] == 'CWUser':
       
   236                     mainvars.append(varname)
       
   237                     break
       
   238         assert mainvars, rqlst
       
   239         columns, globtransforms = self.prepare_columns(mainvars, rqlst)
       
   240         eidfilters = [lambda x: x > 0]
       
   241         allresults = []
       
   242         generator = RQL2LDAPFilter(self, session, args, mainvars)
       
   243         for mainvar in mainvars:
       
   244             # handle restriction
       
   245             try:
       
   246                 eidfilters_, ldapfilter = generator.generate(rqlst, mainvar)
       
   247             except GotDN as ex:
       
   248                 assert ex.dn, 'no dn!'
       
   249                 try:
       
   250                     res = [self._cache[ex.dn]]
       
   251                 except KeyError:
       
   252                     res = self._search(session, ex.dn, BASE)
       
   253             except UnknownEid as ex:
       
   254                 # raised when we are looking for the dn of an eid which is not
       
   255                 # coming from this source
       
   256                 res = []
       
   257             else:
       
   258                 eidfilters += eidfilters_
       
   259                 res = self._search(session, self.user_base_dn,
       
   260                                    self.user_base_scope, ldapfilter)
       
   261             allresults.append(res)
       
   262         # 1. get eid for each dn and filter according to that eid if necessary
       
   263         for i, res in enumerate(allresults):
       
   264             filteredres = []
       
   265             for resdict in res:
       
   266                 # get sure the entity exists in the system table
       
   267                 eid = self.repo.extid2eid(self, resdict['dn'], 'CWUser', session)
       
   268                 for eidfilter in eidfilters:
       
   269                     if not eidfilter(eid):
       
   270                         break
       
   271                 else:
       
   272                     resdict['eid'] = eid
       
   273                     filteredres.append(resdict)
       
   274             allresults[i] = filteredres
       
   275         # 2. merge result for each "mainvar": cartesian product
       
   276         allresults = cartesian_product(allresults)
       
   277         # 3. build final result according to column definition
       
   278         result = []
       
   279         for rawline in allresults:
       
   280             rawline = dict(zip(mainvars, rawline))
       
   281             line = []
       
   282             for varname, ldapname in columns:
       
   283                 if ldapname is None:
       
   284                     value = None # no mapping available
       
   285                 elif ldapname == 'dn':
       
   286                     value = rawline[varname]['eid']
       
   287                 elif isinstance(ldapname, Constant):
       
   288                     if ldapname.type == 'Substitute':
       
   289                         value = args[ldapname.value]
       
   290                     else:
       
   291                         value = ldapname.value
       
   292                 elif isinstance(ldapname, TrFunc):
       
   293                     value = ldapname.apply(rawline[varname])
       
   294                 else:
       
   295                     value = rawline[varname].get(ldapname)
       
   296                 line.append(value)
       
   297             result.append(line)
       
   298         for trfunc in globtransforms:
       
   299             result = trfunc.apply(result)
       
   300         #print '--> ldap result', result
       
   301         return result
       
   302 
       
   303     def _process_ldap_item(self, dn, iterator):
       
   304         itemdict = super(LDAPUserSource, self)._process_ldap_item(dn, iterator)
       
   305         self._cache[dn] = itemdict
       
   306         return itemdict
       
   307 
       
   308     def _process_no_such_object(self, session, dn):
       
   309         eid = self.repo.extid2eid(self, dn, 'CWUser', session, insert=False)
       
   310         if eid:
       
   311             self.warning('deleting ldap user with eid %s and dn %s', eid, dn)
       
   312             entity = session.entity_from_eid(eid, 'CWUser')
       
   313             self.repo.delete_info(session, entity, self.uri)
       
   314             self.reset_caches()
       
   315 
       
   316     def before_entity_insertion(self, session, lid, etype, eid, sourceparams):
       
   317         """called by the repository when an eid has been attributed for an
       
   318         entity stored here but the entity has not been inserted in the system
       
   319         table yet.
       
   320 
       
   321         This method must return the an Entity instance representation of this
       
   322         entity.
       
   323         """
       
   324         self.debug('ldap before entity insertion')
       
   325         entity = super(LDAPUserSource, self).before_entity_insertion(
       
   326             session, lid, etype, eid, sourceparams)
       
   327         res = self._search(session, lid, BASE)[0]
       
   328         for attr in entity.e_schema.indexable_attributes():
       
   329             entity.cw_edited[attr] = res[self.user_rev_attrs[attr]]
       
   330         return entity
       
   331 
       
   332     def after_entity_insertion(self, session, lid, entity, sourceparams):
       
   333         """called by the repository after an entity stored here has been
       
   334         inserted in the system table.
       
   335         """
       
   336         self.debug('ldap after entity insertion')
       
   337         super(LDAPUserSource, self).after_entity_insertion(
       
   338             session, lid, entity, sourceparams)
       
   339         for group in self.user_default_groups:
       
   340             session.execute('SET X in_group G WHERE X eid %(x)s, G name %(group)s',
       
   341                             {'x': entity.eid, 'group': group})
       
   342         # search for existant email first
       
   343         try:
       
   344             # lid = dn
       
   345             emailaddr = self._cache[lid][self.user_rev_attrs['email']]
       
   346         except KeyError:
       
   347             return
       
   348         if isinstance(emailaddr, list):
       
   349             emailaddr = emailaddr[0] # XXX consider only the first email in the list
       
   350         rset = session.execute('EmailAddress X WHERE X address %(addr)s',
       
   351                                {'addr': emailaddr})
       
   352         if rset:
       
   353             session.execute('SET U primary_email X WHERE U eid %(u)s, X eid %(x)s',
       
   354                             {'x': rset[0][0], 'u': entity.eid})
       
   355         else:
       
   356             # not found, create it
       
   357             _insert_email(session, emailaddr, entity.eid)
       
   358 
       
   359     def update_entity(self, session, entity):
       
   360         """replace an entity in the source"""
       
   361         raise RepositoryError('this source is read only')
       
   362 
       
   363     def delete_entity(self, session, entity):
       
   364         """delete an entity from the source"""
       
   365         raise RepositoryError('this source is read only')
       
   366 
       
   367 
       
   368 def _insert_email(session, emailaddr, ueid):
       
   369     session.execute('INSERT EmailAddress X: X address %(addr)s, U primary_email X '
       
   370                     'WHERE U eid %(x)s', {'addr': emailaddr, 'x': ueid})
       
   371 
       
   372 class GotDN(Exception):
       
   373     """exception used when a dn localizing the searched user has been found"""
       
   374     def __init__(self, dn):
       
   375         self.dn = dn
       
   376 
       
   377 
       
   378 class RQL2LDAPFilter(object):
       
   379     """generate an LDAP filter for a rql query"""
       
   380     def __init__(self, source, session, args=None, mainvars=()):
       
   381         self.source = source
       
   382         self.repo = source.repo
       
   383         self._ldap_attrs = source.user_rev_attrs
       
   384         self._base_filters = source.base_filters
       
   385         self._session = session
       
   386         if args is None:
       
   387             args = {}
       
   388         self._args = args
       
   389         self.mainvars = mainvars
       
   390 
       
   391     def generate(self, selection, mainvarname):
       
   392         self._filters = res = self._base_filters[:]
       
   393         self._mainvarname = mainvarname
       
   394         self._eidfilters = []
       
   395         self._done_not = set()
       
   396         restriction = selection.where
       
   397         if isinstance(restriction, Relation):
       
   398             # only a single relation, need to append result here (no AND/OR)
       
   399             filter = restriction.accept(self)
       
   400             if filter is not None:
       
   401                 res.append(filter)
       
   402         elif restriction:
       
   403             restriction.accept(self)
       
   404         if len(res) > 1:
       
   405             return self._eidfilters, '(&%s)' % ''.join(res)
       
   406         return self._eidfilters, res[0]
       
   407 
       
   408     def visit_and(self, et):
       
   409         """generate filter for a AND subtree"""
       
   410         for c in et.children:
       
   411             part = c.accept(self)
       
   412             if part:
       
   413                 self._filters.append(part)
       
   414 
       
   415     def visit_or(self, ou):
       
   416         """generate filter for a OR subtree"""
       
   417         res = []
       
   418         for c in ou.children:
       
   419             part = c.accept(self)
       
   420             if part:
       
   421                 res.append(part)
       
   422         if res:
       
   423             if len(res) > 1:
       
   424                 part = '(|%s)' % ''.join(res)
       
   425             else:
       
   426                 part = res[0]
       
   427             self._filters.append(part)
       
   428 
       
   429     def visit_not(self, node):
       
   430         """generate filter for a OR subtree"""
       
   431         part = node.children[0].accept(self)
       
   432         if part:
       
   433             self._filters.append('(!(%s))'% part)
       
   434 
       
   435     def visit_relation(self, relation):
       
   436         """generate filter for a relation"""
       
   437         rtype = relation.r_type
       
   438         # don't care of type constraint statement (i.e. relation_type = 'is')
       
   439         if rtype == 'is':
       
   440             return ''
       
   441         lhs, rhs = relation.get_parts()
       
   442         # attribute relation
       
   443         if self.source.schema.rschema(rtype).final:
       
   444             # dunno what to do here, don't pretend anything else
       
   445             if lhs.name != self._mainvarname:
       
   446                 if lhs.name in self.mainvars:
       
   447                     # XXX check we don't have variable as rhs
       
   448                     return
       
   449                 raise NotImplementedError
       
   450             rhs_vars = rhs.get_nodes(VariableRef)
       
   451             if rhs_vars:
       
   452                 if len(rhs_vars) > 1:
       
   453                     raise NotImplementedError
       
   454                 # selected variable, nothing to do here
       
   455                 return
       
   456             # no variables in the RHS
       
   457             if isinstance(rhs.children[0], Function):
       
   458                 res = rhs.children[0].accept(self)
       
   459             elif rtype != 'has_text':
       
   460                 res = self._visit_attribute_relation(relation)
       
   461             else:
       
   462                 raise NotImplementedError(relation)
       
   463         # regular relation XXX todo: in_group
       
   464         else:
       
   465             raise NotImplementedError(relation)
       
   466         return res
       
   467 
       
   468     def _visit_attribute_relation(self, relation):
       
   469         """generate filter for an attribute relation"""
       
   470         lhs, rhs = relation.get_parts()
       
   471         lhsvar = lhs.variable
       
   472         if relation.r_type == 'eid':
       
   473             # XXX hack
       
   474             # skip comparison sign
       
   475             eid = int(rhs.children[0].accept(self))
       
   476             if relation.neged(strict=True):
       
   477                 self._done_not.add(relation.parent)
       
   478                 self._eidfilters.append(lambda x: not x == eid)
       
   479                 return
       
   480             if rhs.operator != '=':
       
   481                 filter = {'>': lambda x: x > eid,
       
   482                           '>=': lambda x: x >= eid,
       
   483                           '<': lambda x: x < eid,
       
   484                           '<=': lambda x: x <= eid,
       
   485                           }[rhs.operator]
       
   486                 self._eidfilters.append(filter)
       
   487                 return
       
   488             dn = self.repo.eid2extid(self.source, eid, self._session)
       
   489             raise GotDN(dn)
       
   490         try:
       
   491             filter = '(%s%s)' % (self._ldap_attrs[relation.r_type],
       
   492                                  rhs.accept(self))
       
   493         except KeyError:
       
   494             # unsupported attribute
       
   495             self.source.warning('%s source can\'t handle relation %s, no '
       
   496                                 'results will be returned from this source',
       
   497                                 self.source.uri, relation)
       
   498             raise UnknownEid # trick to return no result
       
   499         return filter
       
   500 
       
   501     def visit_comparison(self, cmp):
       
   502         """generate filter for a comparaison"""
       
   503         return '%s%s'% (cmp.operator, cmp.children[0].accept(self))
       
   504 
       
   505     def visit_mathexpression(self, mexpr):
       
   506         """generate filter for a mathematic expression"""
       
   507         raise NotImplementedError
       
   508 
       
   509     def visit_function(self, function):
       
   510         """generate filter name for a function"""
       
   511         if function.name == 'IN':
       
   512             return self.visit_in(function)
       
   513         raise NotImplementedError
       
   514 
       
   515     def visit_in(self, function):
       
   516         grandpapa = function.parent.parent
       
   517         ldapattr = self._ldap_attrs[grandpapa.r_type]
       
   518         res = []
       
   519         for c in function.children:
       
   520             part = c.accept(self)
       
   521             if part:
       
   522                 res.append(part)
       
   523         if res:
       
   524             if len(res) > 1:
       
   525                 part = '(|%s)' % ''.join('(%s=%s)' % (ldapattr, v) for v in res)
       
   526             else:
       
   527                 part = '(%s=%s)' % (ldapattr, res[0])
       
   528         return part
       
   529 
       
   530     def visit_constant(self, constant):
       
   531         """generate filter name for a constant"""
       
   532         value = constant.value
       
   533         if constant.type is None:
       
   534             raise NotImplementedError
       
   535         if constant.type == 'Date':
       
   536             raise NotImplementedError
       
   537             #value = self.keyword_map[value]()
       
   538         elif constant.type == 'Substitute':
       
   539             value = self._args[constant.value]
       
   540         else:
       
   541             value = constant.value
       
   542         if isinstance(value, unicode):
       
   543             value = value.encode('utf8')
       
   544         else:
       
   545             value = str(value)
       
   546         return escape_filter_chars(value)
       
   547 
       
   548     def visit_variableref(self, variableref):
       
   549         """get the sql name for a variable reference"""
       
   550         pass
       
   551