server/sources/ldapfeed.py
changeset 10766 d730f91251af
parent 10666 7f6b5f023884
child 10768 99689a5862ea
--- a/server/sources/ldapfeed.py	Fri Oct 02 14:13:26 2015 +0200
+++ b/server/sources/ldapfeed.py	Mon Oct 05 14:51:24 2015 +0200
@@ -17,16 +17,13 @@
 # with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
 """cubicweb ldap feed source"""
 
-from __future__ import division # XXX why?
+from __future__ import division  # XXX why?
 
 from datetime import datetime
 
 from six import string_types
 
-import ldap
-from ldap.ldapobject import ReconnectLDAPObject
-from ldap.filter import filter_format
-from ldapurl import LDAPUrl
+import ldap3
 
 from logilab.common.configuration import merge_options
 
@@ -37,12 +34,12 @@
 from cubicweb import _
 
 # search scopes
-BASE = ldap.SCOPE_BASE
-ONELEVEL = ldap.SCOPE_ONELEVEL
-SUBTREE = ldap.SCOPE_SUBTREE
-LDAP_SCOPES = {'BASE': ldap.SCOPE_BASE,
-               'ONELEVEL': ldap.SCOPE_ONELEVEL,
-               'SUBTREE': ldap.SCOPE_SUBTREE}
+BASE = ldap3.SEARCH_SCOPE_BASE_OBJECT
+ONELEVEL = ldap3.SEARCH_SCOPE_SINGLE_LEVEL
+SUBTREE = ldap3.SEARCH_SCOPE_WHOLE_SUBTREE
+LDAP_SCOPES = {'BASE': BASE,
+               'ONELEVEL': ONELEVEL,
+               'SUBTREE': SUBTREE}
 
 # map ldap protocol to their standard port
 PROTO_PORT = {'ldap': 389,
@@ -51,6 +48,15 @@
               }
 
 
+def replace_filter(s):
+    s = s.replace('*', '\\2A')
+    s = s.replace('(', '\\28')
+    s = s.replace(')', '\\29')
+    s = s.replace('\\', '\\5c')
+    s = s.replace('\0', '\\00')
+    return s
+
+
 class LDAPFeedSource(datafeed.DataFeedSource):
     """LDAP feed source: unlike ldapuser source, this source is copy based and
     will import ldap content (beside passwords for authentication) into the
@@ -63,7 +69,7 @@
         ('auth-mode',
          {'type' : 'choice',
           'default': 'simple',
-          'choices': ('simple', 'cram_md5', 'digest_md5', 'gssapi'),
+          'choices': ('simple', 'digest_md5', 'gssapi'),
           'help': 'authentication mode used to authenticate user to the ldap.',
           'group': 'ldap-source', 'level': 3,
           }),
@@ -186,7 +192,7 @@
         self.user_attrs = {'dn': 'eid', 'modifyTimestamp': 'modification_date'}
         self.user_attrs.update(typedconfig['user-attrs-map'])
         self.user_rev_attrs = dict((v, k) for k, v in self.user_attrs.items())
-        self.base_filters = [filter_format('(%s=%s)', ('objectClass', o))
+        self.base_filters = ['(objectclass=%s)' % replace_filter(o)
                              for o in typedconfig['user-classes']]
         if typedconfig['user-filter']:
             self.base_filters.append(typedconfig['user-filter'])
@@ -196,7 +202,7 @@
         self.group_attrs = {'dn': 'eid', 'modifyTimestamp': 'modification_date'}
         self.group_attrs.update(typedconfig['group-attrs-map'])
         self.group_rev_attrs = dict((v, k) for k, v in self.group_attrs.items())
-        self.group_base_filters = [filter_format('(%s=%s)', ('objectClass', o))
+        self.group_base_filters = ['(objectClass=%s)' % replace_filter(o)
                                    for o in typedconfig['group-classes']]
         if typedconfig['group-filter']:
             self.group_base_filters.append(typedconfig['group-filter'])
@@ -217,9 +223,11 @@
     def connection_info(self):
         assert len(self.urls) == 1, self.urls
         protocol, hostport = self.urls[0].split('://')
-        if protocol != 'ldapi' and not ':' in hostport:
-            hostport = '%s:%s' % (hostport, PROTO_PORT[protocol])
-        return protocol, hostport
+        if protocol != 'ldapi' and ':' in hostport:
+            host, port = hostport.rsplit(':', 1)
+        else:
+            host, port = hostport, PROTO_PORT[protocol]
+        return protocol, host, port
 
     def authenticate(self, cnx, login, password=None, **kwargs):
         """return CWUser eid for the given login/password if this account is
@@ -234,20 +242,20 @@
             # You get Authenticated as: 'NT AUTHORITY\ANONYMOUS LOGON'.
             # we really really don't want that
             raise AuthenticationError()
-        searchfilter = [filter_format('(%s=%s)', (self.user_login_attr, login))]
+        searchfilter = ['(%s=%s)' % (replace_filter(self.user_login_attr), replace_filter(login))]
         searchfilter.extend(self.base_filters)
         searchstr = '(&%s)' % ''.join(searchfilter)
         # first search the user
         try:
             user = self._search(cnx, self.user_base_dn,
                                 self.user_base_scope, searchstr)[0]
-        except (IndexError, ldap.SERVER_DOWN):
+        except IndexError:
             # no such user
             raise AuthenticationError()
         # check password by establishing a (unused) connection
         try:
             self._connect(user, password)
-        except ldap.LDAPError as ex:
+        except ldap3.LDAPException as ex:
             # Something went wrong, most likely bad credentials
             self.info('while trying to authenticate %s: %s', user, ex)
             raise AuthenticationError()
@@ -261,32 +269,16 @@
         return eid
 
     def _connect(self, user=None, userpwd=None):
-        protocol, hostport = self.connection_info()
-        self.info('connecting %s://%s as %s', protocol, hostport,
+        protocol, host, port = self.connection_info()
+        self.info('connecting %s://%s:%s as %s', protocol, host, port,
                   user and user['dn'] or 'anonymous')
-        # don't require server certificate when using ldaps (will
-        # enable self signed certs)
-        ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER)
-        url = LDAPUrl(urlscheme=protocol, hostport=hostport)
-        conn = ReconnectLDAPObject(url.initializeUrl())
-        # Set the protocol version - version 3 is preferred
-        try:
-            conn.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION3)
-        except ldap.LDAPError: # Invalid protocol version, fall back safely
-            conn.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION2)
-        # Deny auto-chasing of referrals to be safe, we handle them instead
-        # Required for AD
-        try:
-            conn.set_option(ldap.OPT_REFERRALS, 0)
-        except ldap.LDAPError: # Cannot set referrals, so do nothing
-            pass
-        #conn.set_option(ldap.OPT_NETWORK_TIMEOUT, conn_timeout)
-        #conn.timeout = op_timeout
+        server = ldap3.Server(host, port=int(port))
+        conn = ldap3.Connection(server, user=user and user['dn'], client_strategy=ldap3.STRATEGY_SYNC_RESTARTABLE, auto_referrals=False)
         # Now bind with the credentials given. Let exceptions propagate out.
         if user is None:
             # XXX always use simple bind for data connection
             if not self.cnx_dn:
-                conn.simple_bind_s(self.cnx_dn, self.cnx_pwd)
+                conn.bind()
             else:
                 self._authenticate(conn, {'dn': self.cnx_dn}, self.cnx_pwd)
         else:
@@ -296,25 +288,22 @@
         return conn
 
     def _auth_simple(self, conn, user, userpwd):
-        conn.simple_bind_s(user['dn'], userpwd)
-
-    def _auth_cram_md5(self, conn, user, userpwd):
-        from ldap import sasl
-        auth_token = sasl.cram_md5(user['dn'], userpwd)
-        conn.sasl_interactive_bind_s('', auth_token)
+        conn.authentication = ldap3.AUTH_SIMPLE
+        conn.user = user['dn']
+        conn.password = userpwd
+        conn.bind()
 
     def _auth_digest_md5(self, conn, user, userpwd):
-        from ldap import sasl
-        auth_token = sasl.digest_md5(user['dn'], userpwd)
-        conn.sasl_interactive_bind_s('', auth_token)
+        conn.authentication = ldap3.AUTH_SASL
+        conn.sasl_mechanism = 'DIGEST-MD5'
+        # realm, user, password, authz-id
+        conn.sasl_credentials = (None, user['dn'], userpwd, None)
+        conn.bind()
 
     def _auth_gssapi(self, conn, user, userpwd):
-        # print XXX not proper sasl/gssapi
-        import kerberos
-        if not kerberos.checkPassword(user[self.user_login_attr], userpwd):
-            raise Exception('BAD login / mdp')
-        #from ldap import sasl
-        #conn.sasl_interactive_bind_s('', sasl.gssapi())
+        conn.authentication = ldap3.AUTH_SASL
+        conn.sasl_mechanism = 'GSSAPI'
+        conn.bind()
 
     def _search(self, cnx, base, scope,
                 searchstr='(objectClass=*)', attrs=()):
@@ -324,37 +313,15 @@
         if self._conn is None:
             self._conn = self._connect()
         ldapcnx = self._conn
-        try:
-            res = ldapcnx.search_s(base, scope, searchstr, attrs)
-        except ldap.PARTIAL_RESULTS:
-            res = ldapcnx.result(all=0)[1]
-        except ldap.NO_SUCH_OBJECT:
-            self.info('ldap NO SUCH OBJECT %s %s %s', base, scope, searchstr)
-            self._process_no_such_object(cnx, base)
+        if not ldapcnx.search(base, searchstr, search_scope=scope, attributes=attrs):
             return []
-        # except ldap.REFERRAL as e:
-        #     ldapcnx = self.handle_referral(e)
-        #     try:
-        #         res = ldapcnx.search_s(base, scope, searchstr, attrs)
-        #     except ldap.PARTIAL_RESULTS:
-        #         res_type, res = ldapcnx.result(all=0)
         result = []
-        for rec_dn, rec_dict in res:
-            # When used against Active Directory, "rec_dict" may not be
-            # be a dictionary in some cases (instead, it can be a list)
-            #
-            # An example of a useless "res" entry that can be ignored
-            # from AD is
-            # (None, ['ldap://ForestDnsZones.PORTAL.LOCAL/DC=ForestDnsZones,DC=PORTAL,DC=LOCAL'])
-            # This appears to be some sort of internal referral, but
-            # we can't handle it, so we need to skip over it.
-            try:
-                items = rec_dict.items()
-            except AttributeError:
+        for rec in ldapcnx.response:
+            if rec['type'] != 'searchResEntry':
                 continue
-            else:
-                itemdict = self._process_ldap_item(rec_dn, items)
-                result.append(itemdict)
+            items = rec['attributes'].items()
+            itemdict = self._process_ldap_item(rec['dn'], items)
+            result.append(itemdict)
         self.debug('ldap built results %s', len(result))
         return result