server/sources/pyrorql.py
changeset 6724 24bf6f181d0e
parent 6672 2008fd2f628c
child 6941 9ed02daa7dbb
--- a/server/sources/pyrorql.py	Wed Dec 01 17:09:19 2010 +0100
+++ b/server/sources/pyrorql.py	Wed Dec 01 17:11:35 2010 +0100
@@ -45,34 +45,6 @@
     select, col = union.locate_subquery(col, etype, args)
     return getattr(select.selection[col], 'uidtype', None)
 
-def load_mapping_file(mappingfile):
-    mapping = {}
-    execfile(mappingfile, mapping)
-    for junk in ('__builtins__', '__doc__'):
-        mapping.pop(junk, None)
-    mapping.setdefault('support_relations', {})
-    mapping.setdefault('dont_cross_relations', set())
-    mapping.setdefault('cross_relations', set())
-
-    # do some basic checks of the mapping content
-    assert 'support_entities' in mapping, \
-           'mapping file should at least define support_entities'
-    assert isinstance(mapping['support_entities'], dict)
-    assert isinstance(mapping['support_relations'], dict)
-    assert isinstance(mapping['dont_cross_relations'], set)
-    assert isinstance(mapping['cross_relations'], set)
-    unknown = set(mapping) - set( ('support_entities', 'support_relations',
-                                   'dont_cross_relations', 'cross_relations') )
-    assert not unknown, 'unknown mapping attribute(s): %s' % unknown
-    # relations that are necessarily not crossed
-    mapping['dont_cross_relations'] |= set(('owned_by', 'created_by'))
-    for rtype in ('is', 'is_instance_of', 'cw_source'):
-        assert rtype not in mapping['dont_cross_relations'], \
-               '%s relation should not be in dont_cross_relations' % rtype
-        assert rtype not in mapping['support_relations'], \
-               '%s relation should not be in support_relations' % rtype
-    return mapping
-
 
 class ReplaceByInOperator(Exception):
     def __init__(self, eids):
@@ -96,12 +68,6 @@
           'help': 'identifier of the repository in the pyro name server',
           'group': 'pyro-source', 'level': 0,
           }),
-        ('mapping-file',
-         {'type' : 'string',
-          'default': REQUIRED,
-          'help': 'path to a python file with the schema mapping definition',
-          'group': 'pyro-source', 'level': 1,
-          }),
         ('cubicweb-user',
          {'type' : 'string',
           'default': REQUIRED,
@@ -156,24 +122,7 @@
 
     def __init__(self, repo, source_config, *args, **kwargs):
         AbstractSource.__init__(self, repo, source_config, *args, **kwargs)
-        mappingfile = source_config['mapping-file']
-        if not mappingfile[0] == '/':
-            mappingfile = join(repo.config.apphome, mappingfile)
-        try:
-            mapping = load_mapping_file(mappingfile)
-        except IOError:
-            self.disabled = True
-            self.error('cant read mapping file %s, source disabled',
-                       mappingfile)
-            self.support_entities = {}
-            self.support_relations = {}
-            self.dont_cross_relations = set()
-            self.cross_relations = set()
-        else:
-            self.support_entities = mapping['support_entities']
-            self.support_relations = mapping['support_relations']
-            self.dont_cross_relations = mapping['dont_cross_relations']
-            self.cross_relations = mapping['cross_relations']
+        # XXX get it through pyro if unset
         baseurl = source_config.get('base-url')
         if baseurl and not baseurl.endswith('/'):
             source_config['base-url'] += '/'
@@ -212,12 +161,47 @@
         finally:
             session.close()
 
-    def init(self):
+    def init(self, activated, session=None):
         """method called by the repository once ready to handle request"""
-        interval = int(self.config.get('synchronization-interval', 5*60))
-        self.repo.looping_task(interval, self.synchronize)
-        self.repo.looping_task(self._query_cache.ttl.seconds/10,
-                               self._query_cache.clear_expired)
+        self.load_mapping(session)
+        if activated:
+            interval = int(self.config.get('synchronization-interval', 5*60))
+            self.repo.looping_task(interval, self.synchronize)
+            self.repo.looping_task(self._query_cache.ttl.seconds/10,
+                                   self._query_cache.clear_expired)
+
+    def load_mapping(self, session=None):
+        self.support_entities = {}
+        self.support_relations = {}
+        self.dont_cross_relations = set(('owned_by', 'created_by'))
+        self.cross_relations = set()
+        assert self.eid is not None
+        if session is None:
+            _session = self.repo.internal_session()
+        else:
+            _session = session
+        try:
+            for rql, struct in [('Any ETN WHERE S cw_support ET, ET name ETN, ET is CWEType, S eid %(s)s',
+                                 self.support_entities),
+                                ('Any RTN WHERE S cw_support RT, RT name RTN, RT is CWRType, S eid %(s)s',
+                                 self.support_relations)]:
+                for ertype, in _session.execute(rql, {'s': self.eid}):
+                    struct[ertype] = True # XXX write support
+            for rql, struct in [('Any RTN WHERE S cw_may_cross RT, RT name RTN, S eid %(s)s',
+                                 self.cross_relations),
+                                ('Any RTN WHERE S cw_dont_cross RT, RT name RTN, S eid %(s)s',
+                                 self.dont_cross_relations)]:
+                for rtype, in _session.execute(rql, {'s': self.eid}):
+                    struct.add(rtype)
+        finally:
+            if session is None:
+                _session.close()
+        # XXX move in hooks or schema constraints
+        for rtype in ('is', 'is_instance_of', 'cw_source'):
+            assert rtype not in self.dont_cross_relations, \
+                   '%s relation should not be in dont_cross_relations' % rtype
+            assert rtype not in self.support_relations, \
+                   '%s relation should not be in support_relations' % rtype
 
     def local_eid(self, cnx, extid, session):
         etype, dexturi, dextid = cnx.describe(extid)
@@ -246,8 +230,8 @@
         etypes = self.support_entities.keys()
         if mtime is None:
             mtime = self.last_update_time()
-        updatetime, modified, deleted = extrepo.entities_modified_since(etypes,
-                                                                        mtime)
+        updatetime, modified, deleted = extrepo.entities_modified_since(
+            etypes, mtime)
         self._query_cache.clear()
         repo = self.repo
         session = repo.internal_session()