fix vobjects registration to deal with objects inter-dependancy tls-sprint
authorsylvain.thenault@logilab.fr
Thu, 09 Apr 2009 11:30:13 +0200
branchtls-sprint
changeset 1310 99dfced5673e
parent 1309 a4eb20f86cb0
child 1311 4cc6e2723dc7
fix vobjects registration to deal with objects inter-dependancy
vregistry.py
--- a/vregistry.py	Wed Apr 08 20:38:34 2009 +0200
+++ b/vregistry.py	Thu Apr 09 11:30:13 2009 +0200
@@ -27,7 +27,7 @@
 
 import sys
 from os import listdir, stat
-from os.path import dirname, join, realpath, split, isdir
+from os.path import dirname, join, realpath, split, isdir, exists
 from logging import getLogger
 import types
 
@@ -35,6 +35,23 @@
 from cubicweb import RegistryNotFound, ObjectNotFound, NoSelectableObject
 
 
+def _toload_info(path, _toload=None):
+    """return a dictionary of <modname>: <modpath> and an ordered list of
+    (file, module name) to load
+    """
+    from logilab.common.modutils import modpath_from_file
+    if _toload is None:
+        _toload = {}, []
+    for fileordir in path:
+        if isdir(fileordir) and exists(join(fileordir, '__init__.py')):
+            subfiles = [join(fileordir, fname) for fname in listdir(fileordir)]
+            _toload_info(subfiles, _toload)
+        elif fileordir[-3:] == '.py':
+            modname = '.'.join(modpath_from_file(fileordir))
+            _toload[0][modname] = fileordir
+            _toload[1].append((fileordir, modname))
+    return _toload
+
 
 class registerer(object):
     """do whatever is needed at registration time for the wrapped
@@ -239,7 +256,7 @@
         # registered() is technically a classmethod but is not declared
         # as such because we need to compose registered in some cases
         vobject = obj.registered.im_func(obj, self)
-        assert not vobject in vobjects
+        assert not vobject in vobjects, vobject
         assert callable(vobject.__select__), vobject
         vobjects.append(vobject)
         try:
@@ -249,7 +266,7 @@
         self.debug('registered vobject %s in registry %s with id %s',
                    vname, registryname, oid)
         # automatic reloading management
-        self._registered['%s.%s' % (obj.__module__, oid)] = obj
+        self._loadedmods[obj.__module__]['%s.%s' % (obj.__module__, oid)] = obj
 
     def unregister(self, obj, registryname=None):
         registryname = registryname or obj.__registry__
@@ -352,87 +369,70 @@
             if webdir in sys.path:
                 sys.path.remove(webdir)
         if CW_SOFTWARE_ROOT in sys.path:
-            sys.path.remove(CW_SOFTWARE_ROOT)        
+            sys.path.remove(CW_SOFTWARE_ROOT)
+        # compute list of all modules that have to be loaded
+        self._toloadmods, filemods = _toload_info(path)
+        self._loadedmods = {}
         # load views from each directory in the application's path
         change = False
-        for fileordirectory in path:
-            if isdir(fileordirectory):
-                if self.read_directory(fileordirectory, force_reload):
-                    change = True
-            else:
-                directory, filename = split(fileordirectory)
-                if self.load_file(directory, filename, force_reload):
-                    change = True
+        for filepath, modname in filemods:
+            if self.load_file(filepath, modname, force_reload):
+                change = True
         return change
-    
-    def read_directory(self, directory, force_reload=False):
-        """read a directory and register available views"""
-        modified_on = stat(realpath(directory))[-2]
-        # only read directory if it was modified
-        _lastmodifs = self._lastmodifs
-        if directory in _lastmodifs and modified_on <= _lastmodifs[directory]:
+
+    def load_file(self, filepath, modname, force_reload=False):
+        """load visual objects from a python file"""
+        from logilab.common.modutils import load_module_from_name
+        if modname in self._loadedmods:
+            return
+        self._loadedmods[modname] = {}
+        try:
+            modified_on = stat(filepath)[-2]
+        except OSError:
+            # this typically happens on emacs backup files (.#foo.py)
+            self.warning('Unable to load %s. It is likely to be a backup file',
+                         filepath)
             return False
-        self.info('loading directory %s', directory)
-        for filename in listdir(directory):
-            if filename[-3:] == '.py':
-                try:
-                    self.load_file(directory, filename, force_reload)
-                except OSError:
-                    # this typically happens on emacs backup files (.#foo.py)
-                    self.warning('Unable to load file %s. It is likely to be a backup file',
-                                 filename)
-                except Exception, ex:
-                    if self.config.mode in ('dev', 'test'):
-                        raise
-                    self.exception('%r while loading file %s', ex, filename)
-        _lastmodifs[directory] = modified_on
-        return True
-
-    def load_file(self, directory, filename, force_reload=False):
-        """load visual objects from a python file"""
-        from logilab.common.modutils import load_module_from_modpath, modpath_from_file
-        filepath = join(directory, filename)
-        modified_on = stat(filepath)[-2]
-        modpath = modpath_from_file(join(directory, filename))
-        modname = '.'.join(modpath)
-        unregistered = {}
-        _lastmodifs = self._lastmodifs
-        if filepath in _lastmodifs:
+        if filepath in self._lastmodifs:
             # only load file if it was modified
-            if modified_on <= _lastmodifs[filepath]:
+            if modified_on <= self._lastmodifs[filepath]:
                 return
-            else:
-                # if it was modified, unregister all exisiting objects
-                # from this module, and keep track of what was unregistered
-                unregistered = self.unregister_module_vobjects(modname)
+            # if it was modified, unregister all exisiting objects
+            # from this module, and keep track of what was unregistered
+            unregistered = self.unregister_module_vobjects(modname)
+        else:
+            unregistered = None
         # load the module
-        module = load_module_from_modpath(modpath, use_sys=not force_reload)
-        registered = self.load_module(module)
+        module = load_module_from_name(modname, use_sys=not force_reload)
+        self.load_module(module)
         # if something was unregistered, we need to update places where it was
         # referenced 
         if unregistered:
             # oldnew_mapping = {}
+            registered = self._loadedmods[modname]
             oldnew_mapping = dict((unregistered[name], registered[name])
                                   for name in unregistered if name in registered)
             self.update_registered_subclasses(oldnew_mapping)
-        _lastmodifs[filepath] = modified_on
+        self._lastmodifs[filepath] = modified_on
         return True
 
     def load_module(self, module):
-        self._registered = {}
+        self.info('loading %s', module)
         if hasattr(module, 'registration_callback'):
             module.registration_callback(self)
         else:
-            self.info('loading %s', module)
             for objname, obj in vars(module).items():
                 if objname.startswith('_'):
                     continue
-                self.load_ancestors_then_object(module.__name__, obj)
-        return self._registered
+                self._load_ancestors_then_object(module.__name__, obj)
+        self.debug('loaded %s', module)
     
-    def load_ancestors_then_object(self, modname, obj):
-        # skip imported classes
-        if getattr(obj, '__module__', None) != modname:
+    def _load_ancestors_then_object(self, modname, obj):
+        # imported classes
+        objmodname = getattr(obj, '__module__', None)
+        if objmodname != modname:
+            if objmodname in self._toloadmods:
+                self.load_file(self._toloadmods[objmodname], objmodname)
             return
         # skip non registerable object
         try:
@@ -441,11 +441,11 @@
         except TypeError:
             return
         objname = '%s.%s' % (modname, obj.__name__)
-        if objname in self._registered:
+        if objname in self._loadedmods[modname]:
             return
-        self._registered[objname] = obj
+        self._loadedmods[modname][objname] = obj
         for parent in obj.__bases__:
-            self.load_ancestors_then_object(modname, parent)
+            self._load_ancestors_then_object(modname, parent)
         self.load_object(obj)
             
     def load_object(self, obj):