server/sources/ldapuser.py
branchtls-sprint
changeset 1802 d628defebc17
parent 1398 5fe84a5f7035
child 1952 8e19c813750d
equal deleted inserted replaced
1801:672acc730ce5 1802:d628defebc17
    50     }
    50     }
    51 
    51 
    52 
    52 
    53 class LDAPUserSource(AbstractSource):
    53 class LDAPUserSource(AbstractSource):
    54     """LDAP read-only CWUser source"""
    54     """LDAP read-only CWUser source"""
    55     support_entities = {'CWUser': False} 
    55     support_entities = {'CWUser': False}
    56 
    56 
    57     port = None
    57     port = None
    58     
    58 
    59     cnx_mode = 0
    59     cnx_mode = 0
    60     cnx_dn = ''
    60     cnx_dn = ''
    61     cnx_pwd = ''
    61     cnx_pwd = ''
    62     
    62 
    63     options = (
    63     options = (
    64         ('host',
    64         ('host',
    65          {'type' : 'string',
    65          {'type' : 'string',
    66           'default': 'ldap',
    66           'default': 'ldap',
    67           'help': 'ldap host',
    67           'help': 'ldap host',
   117          {'type' : 'int',
   117          {'type' : 'int',
   118           'default': 2*60,
   118           'default': 2*60,
   119           'help': 'life time of query cache in minutes (default to two hours).',
   119           'help': 'life time of query cache in minutes (default to two hours).',
   120           'group': 'ldap-source', 'inputlevel': 2,
   120           'group': 'ldap-source', 'inputlevel': 2,
   121           }),
   121           }),
   122         
   122 
   123     )
   123     )
   124             
   124 
   125     def __init__(self, repo, appschema, source_config, *args, **kwargs):
   125     def __init__(self, repo, appschema, source_config, *args, **kwargs):
   126         AbstractSource.__init__(self, repo, appschema, source_config,
   126         AbstractSource.__init__(self, repo, appschema, source_config,
   127                                 *args, **kwargs)
   127                                 *args, **kwargs)
   128         self.host = source_config['host']
   128         self.host = source_config['host']
   129         self.user_base_dn = source_config['user-base-dn']
   129         self.user_base_dn = source_config['user-base-dn']
   148         """method called during test to reset potential source caches"""
   148         """method called during test to reset potential source caches"""
   149         self._query_cache = TimedCache(2*60)
   149         self._query_cache = TimedCache(2*60)
   150 
   150 
   151     def init(self):
   151     def init(self):
   152         """method called by the repository once ready to handle request"""
   152         """method called by the repository once ready to handle request"""
   153         self.repo.looping_task(self._interval, self.synchronize) 
   153         self.repo.looping_task(self._interval, self.synchronize)
   154         self.repo.looping_task(self._query_cache.ttl.seconds/10, self._query_cache.clear_expired) 
   154         self.repo.looping_task(self._query_cache.ttl.seconds/10, self._query_cache.clear_expired)
   155 
   155 
   156     def synchronize(self):
   156     def synchronize(self):
   157         """synchronize content known by this repository with content in the
   157         """synchronize content known by this repository with content in the
   158         external repository
   158         external repository
   159         """
   159         """
   167             cursor = session.system_sql("SELECT eid, extid FROM entities WHERE "
   167             cursor = session.system_sql("SELECT eid, extid FROM entities WHERE "
   168                                         "source='%s'" % self.uri)
   168                                         "source='%s'" % self.uri)
   169             for eid, extid in cursor.fetchall():
   169             for eid, extid in cursor.fetchall():
   170                 # if no result found, _search automatically delete entity information
   170                 # if no result found, _search automatically delete entity information
   171                 res = self._search(session, extid, BASE)
   171                 res = self._search(session, extid, BASE)
   172                 if res: 
   172                 if res:
   173                     ldapemailaddr = res[0].get(ldap_emailattr)
   173                     ldapemailaddr = res[0].get(ldap_emailattr)
   174                     if ldapemailaddr:
   174                     if ldapemailaddr:
   175                         rset = session.execute('EmailAddress X,A WHERE '
   175                         rset = session.execute('EmailAddress X,A WHERE '
   176                                                'U use_email X, U eid %(u)s',
   176                                                'U use_email X, U eid %(u)s',
   177                                                {'u': eid})
   177                                                {'u': eid})
   190                                 # no email found, create it
   190                                 # no email found, create it
   191                                 _insert_email(session, ldapemailaddr, eid)
   191                                 _insert_email(session, ldapemailaddr, eid)
   192         finally:
   192         finally:
   193             session.commit()
   193             session.commit()
   194             session.close()
   194             session.close()
   195             
   195 
   196     def get_connection(self):
   196     def get_connection(self):
   197         """open and return a connection to the source"""
   197         """open and return a connection to the source"""
   198         if self._conn is None:
   198         if self._conn is None:
   199             self._connect()
   199             self._connect()
   200         return ConnectionWrapper(self._conn)
   200         return ConnectionWrapper(self._conn)
   201     
   201 
   202     def authenticate(self, session, login, password):
   202     def authenticate(self, session, login, password):
   203         """return CWUser eid for the given login/password if this account is
   203         """return CWUser eid for the given login/password if this account is
   204         defined in this source, else raise `AuthenticationError`
   204         defined in this source, else raise `AuthenticationError`
   205 
   205 
   206         two queries are needed since passwords are stored crypted, so we have
   206         two queries are needed since passwords are stored crypted, so we have
   229     def ldap_name(self, var):
   229     def ldap_name(self, var):
   230         if var.stinfo['relations']:
   230         if var.stinfo['relations']:
   231             relname = iter(var.stinfo['relations']).next().r_type
   231             relname = iter(var.stinfo['relations']).next().r_type
   232             return self.user_rev_attrs.get(relname)
   232             return self.user_rev_attrs.get(relname)
   233         return None
   233         return None
   234         
   234 
   235     def prepare_columns(self, mainvars, rqlst):
   235     def prepare_columns(self, mainvars, rqlst):
   236         """return two list describin how to build the final results
   236         """return two list describin how to build the final results
   237         from the result of an ldap search (ie a list of dictionnary)
   237         from the result of an ldap search (ie a list of dictionnary)
   238         """
   238         """
   239         columns = []
   239         columns = []
   268             columns.append((mainvar, self.ldap_name(var)))
   268             columns.append((mainvar, self.ldap_name(var)))
   269             #else:
   269             #else:
   270             #    # probably a bug in rql splitting if we arrive here
   270             #    # probably a bug in rql splitting if we arrive here
   271             #    raise NotImplementedError
   271             #    raise NotImplementedError
   272         return columns, global_transforms
   272         return columns, global_transforms
   273     
   273 
   274     def syntax_tree_search(self, session, union,
   274     def syntax_tree_search(self, session, union,
   275                            args=None, cachekey=None, varmap=None, debug=0):
   275                            args=None, cachekey=None, varmap=None, debug=0):
   276         """return result from this source for a rql query (actually from a rql 
   276         """return result from this source for a rql query (actually from a rql
   277         syntax tree and a solution dictionary mapping each used variable to a 
   277         syntax tree and a solution dictionary mapping each used variable to a
   278         possible type). If cachekey is given, the query necessary to fetch the
   278         possible type). If cachekey is given, the query necessary to fetch the
   279         results (but not the results themselves) may be cached using this key.
   279         results (but not the results themselves) may be cached using this key.
   280         """
   280         """
   281         # XXX not handled : transform/aggregat function, join on multiple users...
   281         # XXX not handled : transform/aggregat function, join on multiple users...
   282         assert len(union.children) == 1, 'union not supported'
   282         assert len(union.children) == 1, 'union not supported'
   359             result.append(line)
   359             result.append(line)
   360         for trfunc in globtransforms:
   360         for trfunc in globtransforms:
   361             result = trfunc.apply(result)
   361             result = trfunc.apply(result)
   362         #print '--> ldap result', result
   362         #print '--> ldap result', result
   363         return result
   363         return result
   364                 
   364 
   365     
   365 
   366     def _connect(self, userdn=None, userpwd=None):
   366     def _connect(self, userdn=None, userpwd=None):
   367         port, protocol = MODES[self.cnx_mode]
   367         port, protocol = MODES[self.cnx_mode]
   368         if protocol == 'ldapi':
   368         if protocol == 'ldapi':
   369             hostport = self.host
   369             hostport = self.host
   370         else:
   370         else:
   442             rec_dict['dn'] = rec_dn
   442             rec_dict['dn'] = rec_dn
   443             self._cache[rec_dn] = rec_dict
   443             self._cache[rec_dn] = rec_dict
   444             result.append(rec_dict)
   444             result.append(rec_dict)
   445         #print '--->', result
   445         #print '--->', result
   446         return result
   446         return result
   447     
   447 
   448     def before_entity_insertion(self, session, lid, etype, eid):
   448     def before_entity_insertion(self, session, lid, etype, eid):
   449         """called by the repository when an eid has been attributed for an
   449         """called by the repository when an eid has been attributed for an
   450         entity stored here but the entity has not been inserted in the system
   450         entity stored here but the entity has not been inserted in the system
   451         table yet.
   451         table yet.
   452         
   452 
   453         This method must return the an Entity instance representation of this
   453         This method must return the an Entity instance representation of this
   454         entity.
   454         entity.
   455         """
   455         """
   456         entity = super(LDAPUserSource, self).before_entity_insertion(session, lid, etype, eid)
   456         entity = super(LDAPUserSource, self).before_entity_insertion(session, lid, etype, eid)
   457         res = self._search(session, lid, BASE)[0]
   457         res = self._search(session, lid, BASE)[0]
   458         for attr in entity.e_schema.indexable_attributes():
   458         for attr in entity.e_schema.indexable_attributes():
   459             entity[attr] = res[self.user_rev_attrs[attr]]
   459             entity[attr] = res[self.user_rev_attrs[attr]]
   460         return entity
   460         return entity
   461     
   461 
   462     def after_entity_insertion(self, session, dn, entity):
   462     def after_entity_insertion(self, session, dn, entity):
   463         """called by the repository after an entity stored here has been
   463         """called by the repository after an entity stored here has been
   464         inserted in the system table.
   464         inserted in the system table.
   465         """
   465         """
   466         super(LDAPUserSource, self).after_entity_insertion(session, dn, entity)
   466         super(LDAPUserSource, self).after_entity_insertion(session, dn, entity)
   490         raise RepositoryError('this source is read only')
   490         raise RepositoryError('this source is read only')
   491 
   491 
   492 def _insert_email(session, emailaddr, ueid):
   492 def _insert_email(session, emailaddr, ueid):
   493     session.execute('INSERT EmailAddress X: X address %(addr)s, U primary_email X '
   493     session.execute('INSERT EmailAddress X: X address %(addr)s, U primary_email X '
   494                     'WHERE U eid %(x)s', {'addr': emailaddr, 'x': ueid}, 'x')
   494                     'WHERE U eid %(x)s', {'addr': emailaddr, 'x': ueid}, 'x')
   495     
   495 
   496 class GotDN(Exception):
   496 class GotDN(Exception):
   497     """exception used when a dn localizing the searched user has been found"""
   497     """exception used when a dn localizing the searched user has been found"""
   498     def __init__(self, dn):
   498     def __init__(self, dn):
   499         self.dn = dn
   499         self.dn = dn
   500 
   500 
   501         
   501 
   502 class RQL2LDAPFilter(object):
   502 class RQL2LDAPFilter(object):
   503     """generate an LDAP filter for a rql query"""
   503     """generate an LDAP filter for a rql query"""
   504     def __init__(self, source, session, args=None, mainvars=()):
   504     def __init__(self, source, session, args=None, mainvars=()):
   505         self.source = source
   505         self.source = source
   506         self._ldap_attrs = source.user_rev_attrs
   506         self._ldap_attrs = source.user_rev_attrs
   508         self._session = session
   508         self._session = session
   509         if args is None:
   509         if args is None:
   510             args = {}
   510             args = {}
   511         self._args = args
   511         self._args = args
   512         self.mainvars = mainvars
   512         self.mainvars = mainvars
   513         
   513 
   514     def generate(self, selection, mainvarname):
   514     def generate(self, selection, mainvarname):
   515         self._filters = res = self._base_filters[:]
   515         self._filters = res = self._base_filters[:]
   516         self._mainvarname = mainvarname
   516         self._mainvarname = mainvarname
   517         self._eidfilters = []
   517         self._eidfilters = []
   518         self._done_not = set()
   518         self._done_not = set()
   525         elif restriction:
   525         elif restriction:
   526             restriction.accept(self)
   526             restriction.accept(self)
   527         if len(res) > 1:
   527         if len(res) > 1:
   528             return self._eidfilters, '(&%s)' % ''.join(res)
   528             return self._eidfilters, '(&%s)' % ''.join(res)
   529         return self._eidfilters, res[0]
   529         return self._eidfilters, res[0]
   530     
   530 
   531     def visit_and(self, et):
   531     def visit_and(self, et):
   532         """generate filter for a AND subtree"""
   532         """generate filter for a AND subtree"""
   533         for c in et.children:
   533         for c in et.children:
   534             part = c.accept(self)
   534             part = c.accept(self)
   535             if part:
   535             if part:
   585                 raise NotImplementedError(relation)
   585                 raise NotImplementedError(relation)
   586         # regular relation XXX todo: in_group
   586         # regular relation XXX todo: in_group
   587         else:
   587         else:
   588             raise NotImplementedError(relation)
   588             raise NotImplementedError(relation)
   589         return res
   589         return res
   590         
   590 
   591     def _visit_attribute_relation(self, relation):
   591     def _visit_attribute_relation(self, relation):
   592         """generate filter for an attribute relation"""
   592         """generate filter for an attribute relation"""
   593         lhs, rhs = relation.get_parts()
   593         lhs, rhs = relation.get_parts()
   594         lhsvar = lhs.variable
   594         lhsvar = lhs.variable
   595         if relation.r_type == 'eid':
   595         if relation.r_type == 'eid':
   621             raise UnknownEid # trick to return no result
   621             raise UnknownEid # trick to return no result
   622         return filter
   622         return filter
   623 
   623 
   624     def visit_comparison(self, cmp):
   624     def visit_comparison(self, cmp):
   625         """generate filter for a comparaison"""
   625         """generate filter for a comparaison"""
   626         return '%s%s'% (cmp.operator, cmp.children[0].accept(self))            
   626         return '%s%s'% (cmp.operator, cmp.children[0].accept(self))
   627 
   627 
   628     def visit_mathexpression(self, mexpr):
   628     def visit_mathexpression(self, mexpr):
   629         """generate filter for a mathematic expression"""
   629         """generate filter for a mathematic expression"""
   630         raise NotImplementedError
   630         raise NotImplementedError
   631         
   631 
   632     def visit_function(self, function):
   632     def visit_function(self, function):
   633         """generate filter name for a function"""
   633         """generate filter name for a function"""
   634         if function.name == 'IN':
   634         if function.name == 'IN':
   635             return self.visit_in(function)
   635             return self.visit_in(function)
   636         raise NotImplementedError
   636         raise NotImplementedError
   637         
   637 
   638     def visit_in(self, function):
   638     def visit_in(self, function):
   639         grandpapa = function.parent.parent
   639         grandpapa = function.parent.parent
   640         ldapattr = self._ldap_attrs[grandpapa.r_type]
   640         ldapattr = self._ldap_attrs[grandpapa.r_type]
   641         res = []
   641         res = []
   642         for c in function.children:
   642         for c in function.children:
   647             if len(res) > 1:
   647             if len(res) > 1:
   648                 part = '(|%s)' % ''.join('(%s=%s)' % (ldapattr, v) for v in res)
   648                 part = '(|%s)' % ''.join('(%s=%s)' % (ldapattr, v) for v in res)
   649             else:
   649             else:
   650                 part = '(%s=%s)' % (ldapattr, res[0])
   650                 part = '(%s=%s)' % (ldapattr, res[0])
   651         return part
   651         return part
   652         
   652 
   653     def visit_constant(self, constant):
   653     def visit_constant(self, constant):
   654         """generate filter name for a constant"""
   654         """generate filter name for a constant"""
   655         value = constant.value
   655         value = constant.value
   656         if constant.type is None:
   656         if constant.type is None:
   657             raise NotImplementedError
   657             raise NotImplementedError
   665         if isinstance(value, unicode):
   665         if isinstance(value, unicode):
   666             value = value.encode('utf8')
   666             value = value.encode('utf8')
   667         else:
   667         else:
   668             value = str(value)
   668             value = str(value)
   669         return escape_filter_chars(value)
   669         return escape_filter_chars(value)
   670         
   670 
   671     def visit_variableref(self, variableref):
   671     def visit_variableref(self, variableref):
   672         """get the sql name for a variable reference"""
   672         """get the sql name for a variable reference"""
   673         pass
   673         pass
   674 
   674