"""cubicweb ldap user source
this source is for now limited to a read-only EUser source
:organization: Logilab
:copyright: 2003-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
Part of the code is coming form Zope's LDAPUserFolder
Copyright (c) 2004 Jens Vagelpohl.
All Rights Reserved.
This software is subject to the provisions of the Zope Public License,
Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
FOR A PARTICULAR PURPOSE.
"""
from mx.DateTime import now, DateTimeDelta
from logilab.common.textutils import get_csv
from rql.nodes import Relation, VariableRef, Constant, Function
import ldap
from ldap.ldapobject import ReconnectLDAPObject
from ldap.filter import filter_format, escape_filter_chars
from ldapurl import LDAPUrl
from cubicweb.common import AuthenticationError, UnknownEid, RepositoryError
from cubicweb.server.sources import AbstractSource, TrFunc, GlobTrFunc, ConnectionWrapper
from cubicweb.server.utils import cartesian_product
# search scopes
BASE = ldap.SCOPE_BASE
ONELEVEL = ldap.SCOPE_ONELEVEL
SUBTREE = ldap.SCOPE_SUBTREE
# XXX only for edition ??
## password encryption possibilities
#ENCRYPTIONS = ('SHA', 'CRYPT', 'MD5', 'CLEAR') # , 'SSHA'
# mode identifier : (port, protocol)
MODES = {
0: (389, 'ldap'),
1: (636, 'ldaps'),
2: (0, 'ldapi'),
}
class TimedCache(dict):
def __init__(self, ttlm, ttls=0):
# time to live in minutes
self.ttl = DateTimeDelta(0, 0, ttlm, ttls)
def __setitem__(self, key, value):
dict.__setitem__(self, key, (now(), value))
def __getitem__(self, key):
return dict.__getitem__(self, key)[1]
def clear_expired(self):
now_ = now()
ttl = self.ttl
for key, (timestamp, value) in self.items():
if now_ - timestamp > ttl:
del self[key]
class LDAPUserSource(AbstractSource):
"""LDAP read-only EUser source"""
support_entities = {'EUser': False}
port = None
cnx_mode = 0
cnx_dn = ''
cnx_pwd = ''
options = (
('host',
{'type' : 'string',
'default': 'ldap',
'help': 'ldap host',
'group': 'ldap-source', 'inputlevel': 1,
}),
('user-base-dn',
{'type' : 'string',
'default': 'ou=People,dc=logilab,dc=fr',
'help': 'base DN to lookup for users',
'group': 'ldap-source', 'inputlevel': 0,
}),
('user-scope',
{'type' : 'choice',
'default': 'ONELEVEL',
'choices': ('BASE', 'ONELEVEL', 'SUBTREE'),
'help': 'user search scope',
'group': 'ldap-source', 'inputlevel': 1,
}),
('user-classes',
{'type' : 'csv',
'default': ('top', 'posixAccount'),
'help': 'classes of user',
'group': 'ldap-source', 'inputlevel': 1,
}),
('user-login-attr',
{'type' : 'string',
'default': 'uid',
'help': 'attribute used as login on authentication',
'group': 'ldap-source', 'inputlevel': 1,
}),
('user-default-group',
{'type' : 'csv',
'default': ('users',),
'help': 'name of a group in which ldap users will be by default. \
You can set multiple groups by separating them by a comma.',
'group': 'ldap-source', 'inputlevel': 1,
}),
('user-attrs-map',
{'type' : 'named',
'default': {'uid': 'login', 'gecos': 'email'},
'help': 'map from ldap user attributes to cubicweb attributes',
'group': 'ldap-source', 'inputlevel': 1,
}),
('synchronization-interval',
{'type' : 'int',
'default': 24*60*60,
'help': 'interval between synchronization with the ldap \
directory (default to once a day).',
'group': 'ldap-source', 'inputlevel': 2,
}),
('cache-life-time',
{'type' : 'int',
'default': 2*60,
'help': 'life time of query cache in minutes (default to two hours).',
'group': 'ldap-source', 'inputlevel': 2,
}),
)
def __init__(self, repo, appschema, source_config, *args, **kwargs):
AbstractSource.__init__(self, repo, appschema, source_config,
*args, **kwargs)
self.host = source_config['host']
self.user_base_dn = source_config['user-base-dn']
self.user_base_scope = globals()[source_config['user-scope']]
self.user_classes = get_csv(source_config['user-classes'])
self.user_login_attr = source_config['user-login-attr']
self.user_default_groups = get_csv(source_config['user-default-group'])
self.user_attrs = dict(v.split(':', 1) for v in get_csv(source_config['user-attrs-map']))
self.user_rev_attrs = {'eid': 'dn'}
for ldapattr, cwattr in self.user_attrs.items():
self.user_rev_attrs[cwattr] = ldapattr
self.base_filters = [filter_format('(%s=%s)', ('objectClass', o))
for o in self.user_classes]
self._conn = None
self._cache = {}
ttlm = int(source_config.get('cache-life-type', 2*60))
self._query_cache = TimedCache(ttlm)
self._interval = int(source_config.get('synchronization-interval',
24*60*60))
def reset_caches(self):
"""method called during test to reset potential source caches"""
self._query_cache = TimedCache(2*60)
def init(self):
"""method called by the repository once ready to handle request"""
self.repo.looping_task(self._interval, self.synchronize)
self.repo.looping_task(self._query_cache.ttl.seconds/10, self._query_cache.clear_expired)
def synchronize(self):
"""synchronize content known by this repository with content in the
external repository
"""
self.info('synchronizing ldap source %s', self.uri)
try:
ldap_emailattr = self.user_rev_attrs['email']
except KeyError:
return # no email in ldap, we're done
session = self.repo.internal_session()
try:
cursor = session.system_sql("SELECT eid, extid FROM entities WHERE "
"source='%s'" % self.uri)
for eid, extid in cursor.fetchall():
# if no result found, _search automatically delete entity information
res = self._search(session, extid, BASE)
if res:
ldapemailaddr = res[0].get(ldap_emailattr)
if ldapemailaddr:
rset = session.execute('EmailAddress X,A WHERE '
'U use_email X, U eid %(u)s',
{'u': eid})
ldapemailaddr = unicode(ldapemailaddr)
for emaileid, emailaddr in rset:
if emailaddr == ldapemailaddr:
break
else:
self.info('updating email address of user %s to %s',
extid, ldapemailaddr)
if rset:
session.execute('SET X address %(addr)s WHERE '
'U primary_email X, U eid %(u)s',
{'addr': ldapemailaddr, 'u': eid})
else:
# no email found, create it
_insert_email(session, ldapemailaddr, eid)
finally:
session.commit()
session.close()
def get_connection(self):
"""open and return a connection to the source"""
if self._conn is None:
self._connect()
return ConnectionWrapper(self._conn)
def authenticate(self, session, login, password):
"""return EUser eid for the given login/password if this account is
defined in this source, else raise `AuthenticationError`
two queries are needed since passwords are stored crypted, so we have
to fetch the salt first
"""
assert login, 'no login!'
searchfilter = [filter_format('(%s=%s)', (self.user_login_attr, login))]
searchfilter.extend([filter_format('(%s=%s)', ('objectClass', o))
for o in self.user_classes])
searchstr = '(&%s)' % ''.join(searchfilter)
# first search the user
try:
user = self._search(session, self.user_base_dn,
self.user_base_scope, searchstr)[0]
except IndexError:
# no such user
raise AuthenticationError()
# check password by establishing a (unused) connection
try:
self._connect(user['dn'], password)
except:
# Something went wrong, most likely bad credentials
raise AuthenticationError()
return self.extid2eid(user['dn'], 'EUser', session)
def ldap_name(self, var):
if var.stinfo['relations']:
relname = iter(var.stinfo['relations']).next().r_type
return self.user_rev_attrs.get(relname)
return None
def prepare_columns(self, mainvars, rqlst):
"""return two list describin how to build the final results
from the result of an ldap search (ie a list of dictionnary)
"""
columns = []
global_transforms = []
for i, term in enumerate(rqlst.selection):
if isinstance(term, Constant):
columns.append(term)
continue
if isinstance(term, Function): # LOWER, UPPER, COUNT...
var = term.get_nodes(VariableRef)[0]
var = var.variable
try:
mainvar = var.stinfo['attrvar'].name
except AttributeError: # no attrvar set
mainvar = var.name
assert mainvar in mainvars
trname = term.name
ldapname = self.ldap_name(var)
if trname in ('COUNT', 'MIN', 'MAX', 'SUM'):
global_transforms.append(GlobTrFunc(trname, i, ldapname))
columns.append((mainvar, ldapname))
continue
if trname in ('LOWER', 'UPPER'):
columns.append((mainvar, TrFunc(trname, i, ldapname)))
continue
raise NotImplementedError('no support for %s function' % trname)
if term.name in mainvars:
columns.append((term.name, 'dn'))
continue
var = term.variable
mainvar = var.stinfo['attrvar'].name
columns.append((mainvar, self.ldap_name(var)))
#else:
# # probably a bug in rql splitting if we arrive here
# raise NotImplementedError
return columns, global_transforms
def syntax_tree_search(self, session, union,
args=None, cachekey=None, varmap=None, debug=0):
"""return result from this source for a rql query (actually from a rql
syntax tree and a solution dictionary mapping each used variable to a
possible type). If cachekey is given, the query necessary to fetch the
results (but not the results themselves) may be cached using this key.
"""
# XXX not handled : transform/aggregat function, join on multiple users...
assert len(union.children) == 1, 'union not supported'
rqlst = union.children[0]
assert not rqlst.with_, 'subquery not supported'
rqlkey = rqlst.as_string(kwargs=args)
try:
results = self._query_cache[rqlkey]
except KeyError:
results = self.rqlst_search(session, rqlst, args)
self._query_cache[rqlkey] = results
return results
def rqlst_search(self, session, rqlst, args):
mainvars = []
for varname in rqlst.defined_vars:
for sol in rqlst.solutions:
if sol[varname] == 'EUser':
mainvars.append(varname)
break
assert mainvars
columns, globtransforms = self.prepare_columns(mainvars, rqlst)
eidfilters = []
allresults = []
generator = RQL2LDAPFilter(self, session, args, mainvars)
for mainvar in mainvars:
# handle restriction
try:
eidfilters_, ldapfilter = generator.generate(rqlst, mainvar)
except GotDN, ex:
assert ex.dn, 'no dn!'
try:
res = [self._cache[ex.dn]]
except KeyError:
res = self._search(session, ex.dn, BASE)
except UnknownEid, ex:
# raised when we are looking for the dn of an eid which is not
# coming from this source
res = []
else:
eidfilters += eidfilters_
res = self._search(session, self.user_base_dn,
self.user_base_scope, ldapfilter)
allresults.append(res)
# 1. get eid for each dn and filter according to that eid if necessary
for i, res in enumerate(allresults):
filteredres = []
for resdict in res:
# get sure the entity exists in the system table
eid = self.extid2eid(resdict['dn'], 'EUser', session)
for eidfilter in eidfilters:
if not eidfilter(eid):
break
else:
resdict['eid'] = eid
filteredres.append(resdict)
allresults[i] = filteredres
# 2. merge result for each "mainvar": cartesian product
allresults = cartesian_product(allresults)
# 3. build final result according to column definition
result = []
for rawline in allresults:
rawline = dict(zip(mainvars, rawline))
line = []
for varname, ldapname in columns:
if ldapname is None:
value = None # no mapping available
elif ldapname == 'dn':
value = rawline[varname]['eid']
elif isinstance(ldapname, Constant):
if ldapname.type == 'Substitute':
value = args[ldapname.value]
else:
value = ldapname.value
elif isinstance(ldapname, TrFunc):
value = ldapname.apply(rawline[varname])
else:
value = rawline[varname].get(ldapname)
line.append(value)
result.append(line)
for trfunc in globtransforms:
result = trfunc.apply(result)
#print '--> ldap result', result
return result
def _connect(self, userdn=None, userpwd=None):
port, protocol = MODES[self.cnx_mode]
if protocol == 'ldapi':
hostport = self.host
else:
hostport = '%s:%s' % (self.host, self.port or port)
self.info('connecting %s://%s as %s', protocol, hostport,
userdn or 'anonymous')
url = LDAPUrl(urlscheme=protocol, hostport=hostport)
conn = ReconnectLDAPObject(url.initializeUrl())
# Set the protocol version - version 3 is preferred
try:
conn.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION3)
except ldap.LDAPError: # Invalid protocol version, fall back safely
conn.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION2)
# Deny auto-chasing of referrals to be safe, we handle them instead
#try:
# connection.set_option(ldap.OPT_REFERRALS, 0)
#except ldap.LDAPError: # Cannot set referrals, so do nothing
# pass
#conn.set_option(ldap.OPT_NETWORK_TIMEOUT, conn_timeout)
#conn.timeout = op_timeout
# Now bind with the credentials given. Let exceptions propagate out.
if userdn is None:
assert self._conn is None
self._conn = conn
userdn = self.cnx_dn
userpwd = self.cnx_pwd
conn.simple_bind_s(userdn, userpwd)
return conn
def _search(self, session, base, scope,
searchstr='(objectClass=*)', attrs=()):
"""make an ldap query"""
cnx = session.pool.connection(self.uri).cnx
try:
res = cnx.search_s(base, scope, searchstr, attrs)
except ldap.PARTIAL_RESULTS:
res = cnx.result(all=0)[1]
except ldap.NO_SUCH_OBJECT:
eid = self.extid2eid(base, 'EUser', session, insert=False)
if eid:
self.warning('deleting ldap user with eid %s and dn %s',
eid, base)
self.repo.delete_info(session, eid)
self._cache.pop(base, None)
return []
## except ldap.REFERRAL, e:
## cnx = self.handle_referral(e)
## try:
## res = cnx.search_s(base, scope, searchstr, attrs)
## except ldap.PARTIAL_RESULTS:
## res_type, res = cnx.result(all=0)
result = []
for rec_dn, rec_dict in res:
# When used against Active Directory, "rec_dict" may not be
# be a dictionary in some cases (instead, it can be a list)
# An example of a useless "res" entry that can be ignored
# from AD is
# (None, ['ldap://ForestDnsZones.PORTAL.LOCAL/DC=ForestDnsZones,DC=PORTAL,DC=LOCAL'])
# This appears to be some sort of internal referral, but
# we can't handle it, so we need to skip over it.
try:
items = rec_dict.items()
except AttributeError:
# 'items' not found on rec_dict, skip
continue
for key, value in items: # XXX syt: huuum ?
if not isinstance(value, str):
try:
for i in range(len(value)):
value[i] = unicode(value[i], 'utf8')
except:
pass
if isinstance(value, list) and len(value) == 1:
rec_dict[key] = value = value[0]
rec_dict['dn'] = rec_dn
self._cache[rec_dn] = rec_dict
result.append(rec_dict)
#print '--->', result
return result
def before_entity_insertion(self, session, lid, etype, eid):
"""called by the repository when an eid has been attributed for an
entity stored here but the entity has not been inserted in the system
table yet.
This method must return the an Entity instance representation of this
entity.
"""
entity = super(LDAPUserSource, self).before_entity_insertion(session, lid, etype, eid)
res = self._search(session, lid, BASE)[0]
for attr in entity.e_schema.indexable_attributes():
entity[attr] = res[self.user_rev_attrs[attr]]
return entity
def after_entity_insertion(self, session, dn, entity):
"""called by the repository after an entity stored here has been
inserted in the system table.
"""
super(LDAPUserSource, self).after_entity_insertion(session, dn, entity)
for group in self.user_default_groups:
session.execute('SET X in_group G WHERE X eid %(x)s, G name %(group)s',
{'x': entity.eid, 'group': group}, 'x')
# search for existant email first
try:
emailaddr = self._cache[dn][self.user_rev_attrs['email']]
except KeyError:
return
rset = session.execute('EmailAddress X WHERE X address %(addr)s',
{'addr': emailaddr})
if rset:
session.execute('SET U primary_email X WHERE U eid %(u)s, X eid %(x)s',
{'x': rset[0][0], 'u': entity.eid}, 'u')
else:
# not found, create it
_insert_email(session, emailaddr, entity.eid)
def update_entity(self, session, entity):
"""replace an entity in the source"""
raise RepositoryError('this source is read only')
def delete_entity(self, session, etype, eid):
"""delete an entity from the source"""
raise RepositoryError('this source is read only')
def _insert_email(session, emailaddr, ueid):
session.execute('INSERT EmailAddress X: X address %(addr)s, U primary_email X '
'WHERE U eid %(x)s', {'addr': emailaddr, 'x': ueid}, 'x')
class GotDN(Exception):
"""exception used when a dn localizing the searched user has been found"""
def __init__(self, dn):
self.dn = dn
class RQL2LDAPFilter(object):
"""generate an LDAP filter for a rql query"""
def __init__(self, source, session, args=None, mainvars=()):
self.source = source
self._ldap_attrs = source.user_rev_attrs
self._base_filters = source.base_filters
self._session = session
if args is None:
args = {}
self._args = args
self.mainvars = mainvars
def generate(self, selection, mainvarname):
self._filters = res = self._base_filters[:]
self._mainvarname = mainvarname
self._eidfilters = []
self._done_not = set()
restriction = selection.where
if isinstance(restriction, Relation):
# only a single relation, need to append result here (no AND/OR)
filter = restriction.accept(self)
if filter is not None:
res.append(filter)
elif restriction:
restriction.accept(self)
if len(res) > 1:
return self._eidfilters, '(&%s)' % ''.join(res)
return self._eidfilters, res[0]
def visit_and(self, et):
"""generate filter for a AND subtree"""
for c in et.children:
part = c.accept(self)
if part:
self._filters.append(part)
def visit_or(self, ou):
"""generate filter for a OR subtree"""
res = []
for c in ou.children:
part = c.accept(self)
if part:
res.append(part)
if res:
if len(res) > 1:
part = '(|%s)' % ''.join(res)
else:
part = res[0]
self._filters.append(part)
def visit_not(self, node):
"""generate filter for a OR subtree"""
part = node.children[0].accept(self)
if part:
self._filters.append('(!(%s))'% part)
def visit_relation(self, relation):
"""generate filter for a relation"""
rtype = relation.r_type
# don't care of type constraint statement (i.e. relation_type = 'is')
if rtype == 'is':
return ''
lhs, rhs = relation.get_parts()
# attribute relation
if self.source.schema.rschema(rtype).is_final():
# dunno what to do here, don't pretend anything else
if lhs.name != self._mainvarname:
if lhs.name in self.mainvars:
# XXX check we don't have variable as rhs
return
raise NotImplementedError
rhs_vars = rhs.get_nodes(VariableRef)
if rhs_vars:
if len(rhs_vars) > 1:
raise NotImplementedError
# selected variable, nothing to do here
return
# no variables in the RHS
if isinstance(rhs.children[0], Function):
res = rhs.children[0].accept(self)
elif rtype != 'has_text':
res = self._visit_attribute_relation(relation)
else:
raise NotImplementedError(relation)
# regular relation XXX todo: in_group
else:
raise NotImplementedError(relation)
return res
def _visit_attribute_relation(self, relation):
"""generate filter for an attribute relation"""
lhs, rhs = relation.get_parts()
lhsvar = lhs.variable
if relation.r_type == 'eid':
# XXX hack
# skip comparison sign
eid = int(rhs.children[0].accept(self))
if relation.neged(strict=True):
self._done_not.add(relation.parent)
self._eidfilters.append(lambda x: not x == eid)
return
if rhs.operator != '=':
filter = {'>': lambda x: x > eid,
'>=': lambda x: x >= eid,
'<': lambda x: x < eid,
'<=': lambda x: x <= eid,
}[rhs.operator]
self._eidfilters.append(filter)
return
dn = self.source.eid2extid(eid, self._session)
raise GotDN(dn)
try:
filter = '(%s%s)' % (self._ldap_attrs[relation.r_type],
rhs.accept(self))
except KeyError:
# unsupported attribute
self.source.warning('%s source can\'t handle relation %s, no '
'results will be returned from this source',
self.source.uri, relation)
raise UnknownEid # trick to return no result
return filter
def visit_comparison(self, cmp):
"""generate filter for a comparaison"""
return '%s%s'% (cmp.operator, cmp.children[0].accept(self))
def visit_mathexpression(self, mexpr):
"""generate filter for a mathematic expression"""
raise NotImplementedError
def visit_function(self, function):
"""generate filter name for a function"""
if function.name == 'IN':
return self.visit_in(function)
raise NotImplementedError
def visit_in(self, function):
grandpapa = function.parent.parent
ldapattr = self._ldap_attrs[grandpapa.r_type]
res = []
for c in function.children:
part = c.accept(self)
if part:
res.append(part)
if res:
if len(res) > 1:
part = '(|%s)' % ''.join('(%s=%s)' % (ldapattr, v) for v in res)
else:
part = '(%s=%s)' % (ldapattr, res[0])
return part
def visit_constant(self, constant):
"""generate filter name for a constant"""
value = constant.value
if constant.type is None:
raise NotImplementedError
if constant.type == 'Date':
raise NotImplementedError
#value = self.keyword_map[value]()
elif constant.type == 'Substitute':
value = self._args[constant.value]
else:
value = constant.value
if isinstance(value, unicode):
value = value.encode('utf8')
else:
value = str(value)
return escape_filter_chars(value)
def visit_variableref(self, variableref):
"""get the sql name for a variable reference"""
pass