oops, missing part of 6125:628cf5213154 (mapping file checking) stable
authorSylvain Thénault <sylvain.thenault@logilab.fr>
Fri, 20 Aug 2010 08:36:58 +0200
branchstable
changeset 6130 15fa8425b6e7
parent 6129 fea746b60093
child 6131 087c5a168010
oops, missing part of 6125:628cf5213154 (mapping file checking)
server/sources/pyrorql.py
--- a/server/sources/pyrorql.py	Fri Aug 20 08:35:10 2010 +0200
+++ b/server/sources/pyrorql.py	Fri Aug 20 08:36:58 2010 +0200
@@ -43,6 +43,34 @@
     select, col = union.locate_subquery(col, etype, args)
     return getattr(select.selection[col], 'uidtype', None)
 
+def load_mapping_file(mappingfile):
+    mapping = {}
+    execfile(mappingfile, mapping)
+    for junk in ('__builtins__', '__doc__'):
+        mapping.pop(junk, None)
+    mapping.setdefault('support_relations', {})
+    mapping.setdefault('dont_cross_relations', set())
+    mapping.setdefault('cross_relations', set())
+
+    # do some basic checks of the mapping content
+    assert 'support_entities' in mapping, \
+           'mapping file should at least define support_entities'
+    assert isinstance(mapping['support_entities'], dict)
+    assert isinstance(mapping['support_relations'], dict)
+    assert isinstance(mapping['dont_cross_relations'], set)
+    assert isinstance(mapping['cross_relations'], set)
+    unknown = set(mapping) - set( ('support_entities', 'support_relations',
+                                   'dont_cross_relations', 'cross_relations') )
+    assert not unknown, 'unknown mapping attribute(s): %s' % unknown
+    # relations that are necessarily not crossed
+    mapping['dont_cross_relations'] |= set(('owned_by', 'created_by'))
+    for rtype in ('is', 'is_instance_of'):
+        assert rtype not in mapping['dont_cross_relations'], \
+               '%s relation should not be in dont_cross_relations' % rtype
+        assert rtype not in mapping['support_relations'], \
+               '%s relation should not be in support_relations' % rtype
+    return mapping
+
 
 class ReplaceByInOperator(Exception):
     def __init__(self, eids):
@@ -124,14 +152,11 @@
         mappingfile = source_config['mapping-file']
         if not mappingfile[0] == '/':
             mappingfile = join(repo.config.apphome, mappingfile)
-        mapping = {}
-        execfile(mappingfile, mapping)
+        mapping = load_mapping_file(mappingfile)
         self.support_entities = mapping['support_entities']
-        self.support_relations = mapping.get('support_relations', {})
-        self.dont_cross_relations = set(mapping.get('dont_cross_relations', ()))
-        self.cross_relations = set(mapping.get('cross_relations', ()))
-        self.dont_cross_relations.add('owned_by')
-        self.dont_cross_relations.add('created_by')
+        self.support_relations = mapping['support_relations']
+        self.dont_cross_relations = mapping['dont_cross_relations']
+        self.cross_relations = mapping['cross_relations']
         baseurl = source_config.get('base-url')
         if baseurl and not baseurl.endswith('/'):
             source_config['base-url'] += '/'