[sources] refactor source creation and options handling
authorSylvain Thénault <sylvain.thenault@logilab.fr>
Mon, 07 Feb 2011 18:19:39 +0100
changeset 6945 28bf94d062a9
parent 6944 0cf10429ad39
child 6946 e350771c23a3
[sources] refactor source creation and options handling * options validation * ease proper update of source's config on configuration change
entities/sources.py
hooks/syncsources.py
misc/migration/3.10.0_Any.py
server/repository.py
server/sources/__init__.py
server/sources/ldapuser.py
server/sources/native.py
server/sources/pyrorql.py
server/test/unittest_ldapuser.py
--- a/entities/sources.py	Mon Feb 07 18:19:36 2011 +0100
+++ b/entities/sources.py	Mon Feb 07 18:19:39 2011 +0100
@@ -97,11 +97,22 @@
                                        cw_schema=schemaentity,
                                        options=options)
 
+    @property
+    def repo_source(self):
+        """repository only property, not available from the web side (eg
+        self._cw is expected to be a server session)
+        """
+        return self._cw.repo.sources_by_eid[self.eid]
+
 
 class CWSourceHostConfig(_CWSourceCfgMixIn, AnyEntity):
     __regid__ = 'CWSourceHostConfig'
     fetch_attrs, fetch_order = fetch_config(['match_host', 'config'])
 
+    @property
+    def cwsource(self):
+        return self.cw_host_config_of[0]
+
     def match(self, hostname):
         return re.match(self.match_host, hostname)
 
--- a/hooks/syncsources.py	Mon Feb 07 18:19:36 2011 +0100
+++ b/hooks/syncsources.py	Mon Feb 07 18:19:39 2011 +0100
@@ -17,10 +17,14 @@
     __select__ = SourceHook.__select__ & is_instance('CWSource')
     events = ('after_add_entity',)
     def __call__(self):
-        if not self.entity.type in SOURCE_TYPES:
+        try:
+            sourcecls = SOURCE_TYPES[self.entity.type]
+        except KeyError:
             msg = self._cw._('unknown source type')
             raise ValidationError(self.entity.eid,
                                   {role_name('type', 'subject'): msg})
+        sourcecls.check_conf_dict(self.entity.eid, self.entity.host_config,
+                                  fail_if_unknown=not self._cw.vreg.config.repairing)
         SourceAddedOp(self._cw, entity=self.entity)
 
 
@@ -37,6 +41,39 @@
             raise ValidationError(self.entity.eid, {None: 'cant remove system source'})
         SourceRemovedOp(self._cw, uri=self.entity.name)
 
+
+class SourceUpdatedOp(hook.DataOperationMixIn, hook.Operation):
+
+    def precommit_event(self):
+        self.__processed = []
+        for source in self.get_data():
+            conf = source.repo_source.check_config(source)
+            self.__processed.append( (source, conf) )
+
+    def postcommit_event(self):
+        for source, conf in self.__processed:
+            source.repo_source.update_config(source, conf)
+
+class SourceUpdatedHook(SourceHook):
+    __regid__ = 'cw.sources.configupdate'
+    __select__ = SourceHook.__select__ & is_instance('CWSource')
+    events = ('after_update_entity',)
+    def __call__(self):
+        if 'config' in self.entity.cw_edited:
+            SourceUpdatedOp.get_instance(self._cw).add_data(self.entity)
+
+class SourceHostConfigUpdatedHook(SourceHook):
+    __regid__ = 'cw.sources.hostconfigupdate'
+    __select__ = SourceHook.__select__ & is_instance('CWSourceHostConfig')
+    events = ('after_add_entity', 'after_update_entity', 'before_delete_entity',)
+    def __call__(self):
+        try:
+            SourceUpdatedOp.get_instance(self._cw).add_data(self.entity.cwsource)
+        except IndexError:
+            # XXX no source linked to the host config yet
+            pass
+
+
 # source mapping synchronization. Expect cw_for_source/cw_schema are immutable
 # relations (i.e. can't change from a source or schema to another).
 
--- a/misc/migration/3.10.0_Any.py	Mon Feb 07 18:19:36 2011 +0100
+++ b/misc/migration/3.10.0_Any.py	Mon Feb 07 18:19:39 2011 +0100
@@ -5,7 +5,7 @@
 for uri, cfg in config.sources().items():
     if uri in ('system', 'admin'):
         continue
-    repo.sources_by_uri[uri] = repo.get_source(cfg['adapter'], uri, cfg)
+    repo.sources_by_uri[uri] = repo.get_source(cfg['adapter'], uri, cfg.copy())
 
 add_entity_type('CWSource')
 add_relation_definition('CWSource', 'cw_source', 'CWSource')
--- a/server/repository.py	Mon Feb 07 18:19:36 2011 +0100
+++ b/server/repository.py	Mon Feb 07 18:19:39 2011 +0100
@@ -258,8 +258,7 @@
 
     def add_source(self, sourceent, add_to_pools=True):
         source = self.get_source(sourceent.type, sourceent.name,
-                                 sourceent.host_config)
-        source.eid = sourceent.eid
+                                 sourceent.host_config, sourceent.eid)
         self.sources_by_eid[sourceent.eid] = source
         self.sources_by_uri[sourceent.name] = source
         if self.config.source_enabled(source):
@@ -283,12 +282,12 @@
                 pool.remove_source(source)
         self._clear_planning_caches()
 
-    def get_source(self, type, uri, source_config):
+    def get_source(self, type, uri, source_config, eid=None):
         # set uri and type in source config so it's available through
         # source_defs()
         source_config['uri'] = uri
         source_config['type'] = type
-        return sources.get_source(type, source_config, self)
+        return sources.get_source(type, source_config, self, eid)
 
     def set_schema(self, schema, resetvreg=True, rebuildinfered=True):
         if rebuildinfered:
--- a/server/sources/__init__.py	Mon Feb 07 18:19:36 2011 +0100
+++ b/server/sources/__init__.py	Mon Feb 07 18:19:39 2011 +0100
@@ -24,7 +24,11 @@
 from datetime import datetime, timedelta
 from logging import getLogger
 
-from cubicweb import set_log_methods, server
+from logilab.common import configuration
+
+from yams.schema import role_name
+
+from cubicweb import ValidationError, set_log_methods, server
 from cubicweb.schema import VIRTUAL_RTYPES
 from cubicweb.server.sqlutils import SQL_PREFIX
 from cubicweb.server.ssplanner import EditedEntity
@@ -103,15 +107,19 @@
     # force deactivation (configuration error for instance)
     disabled = False
 
-    def __init__(self, repo, source_config, *args, **kwargs):
+    # source configuration options
+    options = ()
+
+    def __init__(self, repo, source_config, eid=None):
         self.repo = repo
-        self.uri = source_config['uri']
-        set_log_methods(self, getLogger('cubicweb.sources.'+self.uri))
         self.set_schema(repo.schema)
         self.support_relations['identity'] = False
-        self.eid = None
+        self.eid = eid
         self.public_config = source_config.copy()
         self.remove_sensitive_information(self.public_config)
+        self.uri = source_config.pop('uri')
+        set_log_methods(self, getLogger('cubicweb.sources.'+self.uri))
+        source_config.pop('type')
 
     def __repr__(self):
         return '<%s source %s @%#x>' % (self.uri, self.eid, id(self))
@@ -136,6 +144,56 @@
         """method called to restore a backup of source's data"""
         pass
 
+    @classmethod
+    def check_conf_dict(cls, eid, confdict, _=unicode, fail_if_unknown=True):
+        """check configuration of source entity. Return config dict properly
+        typed with defaults set.
+        """
+        processed = {}
+        for optname, optdict in cls.options:
+            value = confdict.pop(optname, optdict.get('default'))
+            if value is configuration.REQUIRED:
+                if not fail_if_unknown:
+                    continue
+                msg = _('specifying %s is mandatory' % optname)
+                raise ValidationError(eid, {role_name('config', 'subject'): msg})
+            elif value is not None:
+                # type check
+                try:
+                    value = configuration.convert(value, optdict, optname)
+                except Exception, ex:
+                    msg = unicode(ex) # XXX internationalization
+                    raise ValidationError(eid, {role_name('config', 'subject'): msg})
+            processed[optname] = value
+        # cw < 3.10 bw compat
+        try:
+            processed['adapter'] = confdict['adapter']
+        except:
+            pass
+        # check for unknown options
+        if confdict and not confdict.keys() == ['adapter']:
+            if fail_if_unknown:
+                msg = _('unknown options %s') % ', '.join(confdict)
+                raise ValidationError(eid, {role_name('config', 'subject'): msg})
+            else:
+                logger = getLogger('cubicweb.sources')
+                logger.warning('unknown options %s', ', '.join(confdict))
+                # add options to processed, they may be necessary during migration
+                processed.update(confdict)
+        return processed
+
+    @classmethod
+    def check_config(cls, source_entity):
+        """check configuration of source entity"""
+        return cls.check_conf_dict(source_entity.eid, source_entity.host_config,
+                                    _=source_entity._cw._)
+
+    def update_config(self, source_entity, typedconfig):
+        """update configuration from source entity. `typedconfig` is config
+        properly typed with defaults set
+        """
+        pass
+
     # source initialization / finalization #####################################
 
     def set_schema(self, schema):
@@ -503,8 +561,8 @@
     except KeyError:
         raise RuntimeError('Unknown source type %r' % source_type)
 
-def get_source(type, source_config, repo):
-    """return a source adapter according to the adapter field in the
-    source's configuration
+def get_source(type, source_config, repo, eid):
+    """return a source adapter according to the adapter field in the source's
+    configuration
     """
-    return source_adapter(type)(repo, source_config)
+    return source_adapter(type)(repo, source_config, eid)
--- a/server/sources/ldapuser.py	Mon Feb 07 18:19:36 2011 +0100
+++ b/server/sources/ldapuser.py	Mon Feb 07 18:19:39 2011 +0100
@@ -34,15 +34,13 @@
 from __future__ import division
 from base64 import b64decode
 
-from logilab.common.textutils import splitstrip
-from rql.nodes import Relation, VariableRef, Constant, Function
-
 import ldap
 from ldap.ldapobject import ReconnectLDAPObject
 from ldap.filter import filter_format, escape_filter_chars
 from ldapurl import LDAPUrl
 
-from logilab.common.configuration import time_validator
+from rql.nodes import Relation, VariableRef, Constant, Function
+
 from cubicweb import AuthenticationError, UnknownEid, RepositoryError
 from cubicweb.server.utils import cartesian_product
 from cubicweb.server.sources import (AbstractSource, TrFunc, GlobTrFunc,
@@ -168,44 +166,36 @@
 
     )
 
-    def __init__(self, repo, source_config, *args, **kwargs):
-        AbstractSource.__init__(self, repo, source_config, *args, **kwargs)
-        self.host = source_config['host']
-        self.protocol = source_config.get('protocol', 'ldap')
-        self.authmode = source_config.get('auth-mode', 'simple')
+    def __init__(self, repo, source_config, eid=None):
+        AbstractSource.__init__(self, repo, source_config, eid)
+        self.update_config(None, self.check_conf_dict(eid, source_config))
+        self._conn = None
+
+    def update_config(self, source_entity, typedconfig):
+        """update configuration from source entity. `typedconfig` is config
+        properly typed with defaults set
+        """
+        self.host = typedconfig['host']
+        self.protocol = typedconfig['protocol']
+        self.authmode = typedconfig['auth-mode']
         self._authenticate = getattr(self, '_auth_%s' % self.authmode)
-        self.cnx_dn = source_config.get('data-cnx-dn') or ''
-        self.cnx_pwd = source_config.get('data-cnx-password') or ''
-        self.user_base_scope = globals()[source_config['user-scope']]
-        self.user_base_dn = str(source_config['user-base-dn'])
-        self.user_base_scope = globals()[source_config['user-scope']]
-        self.user_classes = splitstrip(source_config['user-classes'])
-        self.user_login_attr = source_config['user-login-attr']
-        self.user_default_groups = splitstrip(source_config['user-default-group'])
-        self.user_attrs = dict(v.split(':', 1) for v in splitstrip(source_config['user-attrs-map']))
-        self.user_filter = source_config.get('user-filter')
+        self.cnx_dn = typedconfig['data-cnx-dn']
+        self.cnx_pwd = typedconfig['data-cnx-password']
+        self.user_base_dn = str(typedconfig['user-base-dn'])
+        self.user_base_scope = globals()[typedconfig['user-scope']]
+        self.user_login_attr = typedconfig['user-login-attr']
+        self.user_default_groups = typedconfig['user-default-group']
+        self.user_attrs = typedconfig['user-attrs-map']
         self.user_rev_attrs = {'eid': 'dn'}
         for ldapattr, cwattr in self.user_attrs.items():
             self.user_rev_attrs[cwattr] = ldapattr
-        self.base_filters = self._make_base_filters()
-        self._conn = None
-        self._cache = {}
-        # ttlm is in minutes!
-        self._cache_ttl = time_validator(None, None,
-                              source_config.get('cache-life-time', 2*60*60))
-        self._cache_ttl = max(71, self._cache_ttl)
-        self._query_cache = TimedCache(self._cache_ttl)
-        # interval is in seconds !
-        self._interval = time_validator(None, None,
-                                    source_config.get('synchronization-interval',
-                                                      24*60*60))
-
-    def _make_base_filters(self):
-        filters =  [filter_format('(%s=%s)', ('objectClass', o))
-                              for o in self.user_classes] 
-        if self.user_filter:
-            filters += [self.user_filter]
-        return filters
+        self.base_filters = [filter_format('(%s=%s)', ('objectClass', o))
+                             for o in typedconfig['user-classes']]
+        if typedconfig['user-filter']:
+            self.base_filters.append(typedconfig['user-filter'])
+        self._interval = typedconfig['synchronization-interval']
+        self._cache_ttl = max(71, typedconfig['cache-life-time'])
+        self.reset_caches()
 
     def reset_caches(self):
         """method called during test to reset potential source caches"""
@@ -300,7 +290,7 @@
             # we really really don't want that
             raise AuthenticationError()
         searchfilter = [filter_format('(%s=%s)', (self.user_login_attr, login))]
-        searchfilter.extend(self._make_base_filters())
+        searchfilter.extend(self.base_filters)
         searchstr = '(&%s)' % ''.join(searchfilter)
         # first search the user
         try:
--- a/server/sources/native.py	Mon Feb 07 18:19:36 2011 +0100
+++ b/server/sources/native.py	Mon Feb 07 18:19:39 2011 +0100
@@ -45,8 +45,10 @@
 from logilab.database import get_db_helper
 
 from yams import schema2sql as y2sql
+from yams.schema import role_name
 
-from cubicweb import UnknownEid, AuthenticationError, ValidationError, Binary, UniqueTogetherError
+from cubicweb import (UnknownEid, AuthenticationError, ValidationError, Binary,
+                      UniqueTogetherError)
 from cubicweb import transaction as tx, server, neg_role
 from cubicweb.schema import VIRTUAL_RTYPES
 from cubicweb.cwconfig import CubicWebNoAppConfiguration
@@ -310,6 +312,13 @@
         #      consuming, find another way
         return SQLAdapterMixIn.get_connection(self)
 
+    def check_config(self, source_entity):
+        """check configuration of source entity"""
+        if source_entity.host_config:
+            msg = source_entity._cw._('the system source has its configuration '
+                                      'stored on the file-system')
+            raise ValidationError(source_entity.eid, {role_name('config', 'subject'): msg})
+
     def add_authentifier(self, authentifier):
         self.authentifiers.append(authentifier)
         authentifier.source = self
--- a/server/sources/pyrorql.py	Mon Feb 07 18:19:36 2011 +0100
+++ b/server/sources/pyrorql.py	Mon Feb 07 18:19:39 2011 +0100
@@ -123,13 +123,10 @@
     PUBLIC_KEYS = AbstractSource.PUBLIC_KEYS + ('base-url',)
     _conn = None
 
-    def __init__(self, repo, source_config, *args, **kwargs):
-        AbstractSource.__init__(self, repo, source_config, *args, **kwargs)
-        # XXX get it through pyro if unset
-        baseurl = source_config.get('base-url')
-        if baseurl and not baseurl.endswith('/'):
-            source_config['base-url'] += '/'
-        self.config = source_config
+    def __init__(self, repo, source_config, eid=None):
+        AbstractSource.__init__(self, repo, source_config, eid)
+        self.update_config(None, self.check_conf_dict(eid, source_config,
+                                                      fail_if_unknown=False))
         myoptions = (('%s.latest-update-time' % self.uri,
                       {'type' : 'int', 'sitewide': True,
                        'default': 0,
@@ -138,8 +135,15 @@
                        }),)
         register_persistent_options(myoptions)
         self._query_cache = TimedCache(1800)
-        self._skip_externals = check_yn(None, 'skip-external-entities',
-                                        source_config.get('skip-external-entities', False))
+
+    def update_config(self, source_entity, processed_config):
+        """update configuration from source entity"""
+        # XXX get it through pyro if unset
+        baseurl = processed_config.get('base-url')
+        if baseurl and not baseurl.endswith('/'):
+            processed_config['base-url'] += '/'
+        self.config = processed_config
+        self._skip_externals = processed_config['skip-external-entities']
 
     def reset_caches(self):
         """method called during test to reset potential source caches"""
--- a/server/test/unittest_ldapuser.py	Mon Feb 07 18:19:36 2011 +0100
+++ b/server/test/unittest_ldapuser.py	Mon Feb 07 18:19:39 2011 +0100
@@ -49,8 +49,7 @@
     """
     assert login, 'no login!'
     searchfilter = [filter_format('(%s=%s)', (self.user_login_attr, login))]
-    searchfilter.extend([filter_format('(%s=%s)', ('objectClass', o))
-                         for o in self.user_classes])
+    searchfilter.extend(self.base_filters)
     searchstr = '(&%s)' % ''.join(searchfilter)
     # first search the user
     try:
@@ -456,8 +455,7 @@
         self.pool = repo._get_pool()
         session = mock_object(pool=self.pool)
         self.o = RQL2LDAPFilter(ldapsource, session)
-        self.ldapclasses = ''.join('(objectClass=%s)' % ldapcls
-                                   for ldapcls in ldapsource.user_classes)
+        self.ldapclasses = ''.join(ldapsource.base_filters)
 
     def tearDown(self):
         repo._free_pool(self.pool)