devtools/testlib.py
brancholdstable
changeset 7074 e4580e5f0703
parent 7071 db7608cb32bc
child 7075 4751d77394b1
child 7078 bad26a22fe29
--- a/devtools/testlib.py	Fri Dec 10 12:17:18 2010 +0100
+++ b/devtools/testlib.py	Fri Mar 11 09:46:45 2011 +0100
@@ -1,4 +1,4 @@
-# copyright 2003-2010 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
 #
 # This file is part of CubicWeb.
@@ -25,7 +25,7 @@
 import sys
 import re
 import urlparse
-from os.path import dirname, join
+from os.path import dirname, join, abspath
 from urllib import unquote
 from math import log
 from contextlib import contextmanager
@@ -38,7 +38,7 @@
 from logilab.common.debugger import Debugger
 from logilab.common.umessage import message_from_string
 from logilab.common.decorators import cached, classproperty, clear_cache
-from logilab.common.deprecation import deprecated
+from logilab.common.deprecation import deprecated, class_deprecated
 from logilab.common.shellutils import getlogin
 
 from cubicweb import ValidationError, NoSelectableObject, AuthenticationError
@@ -185,6 +185,7 @@
     * `repo`, the repository object
     * `admlogin`, login of the admin user
     * `admpassword`, password of the admin user
+    * `shell`, create and use shell environment
     """
     appid = 'data'
     configcls = devtools.ApptestConfiguration
@@ -200,7 +201,7 @@
         try:
             return cls.__dict__['_config']
         except KeyError:
-            home = join(dirname(sys.modules[cls.__module__].__file__), cls.appid)
+            home = abspath(join(dirname(sys.modules[cls.__module__].__file__), cls.appid))
             config = cls._config = cls.configcls(cls.appid, apphome=home)
             config.mode = 'test'
             return config
@@ -286,18 +287,29 @@
         """return current server side session (using default manager account)"""
         return self.repo._sessions[self._orig_cnx[0].sessionid]
 
+    def shell(self):
+        """return a shell session object"""
+        from cubicweb.server.migractions import ServerMigrationHelper
+        return ServerMigrationHelper(None, repo=self.repo, cnx=self.cnx,
+                                     interactive=False,
+                                     # hack so it don't try to load fs schema
+                                     schema=1)
+
     def set_option(self, optname, value):
         self.config.global_set_option(optname, value)
 
     def set_debug(self, debugmode):
         server.set_debug(debugmode)
 
+    def debugged(self, debugmode):
+        return server.debugged(debugmode)
+
     # default test setup and teardown #########################################
 
     def setUp(self):
         # monkey patch send mail operation so emails are sent synchronously
-        self._old_mail_commit_event = SendMailOp.commit_event
-        SendMailOp.commit_event = SendMailOp.sendmails
+        self._old_mail_postcommit_event = SendMailOp.postcommit_event
+        SendMailOp.postcommit_event = SendMailOp.sendmails
         pause_tracing()
         previous_failure = self.__class__.__dict__.get('_repo_init_failed')
         if previous_failure is not None:
@@ -319,7 +331,7 @@
         for cnx in self._cnxs:
             if not cnx._closed:
                 cnx.close()
-        SendMailOp.commit_event = self._old_mail_commit_event
+        SendMailOp.postcommit_event = self._old_mail_postcommit_event
 
     def setup_database(self):
         """add your database setup code by overriding this method"""
@@ -344,7 +356,7 @@
         user = req.create_entity('CWUser', login=unicode(login),
                                  upassword=password, **kwargs)
         req.execute('SET X in_group G WHERE X eid %%(x)s, G name IN(%s)'
-                    % ','.join(repr(g) for g in groups),
+                    % ','.join(repr(str(g)) for g in groups),
                     {'x': user.eid})
         user.cw_clear_relation_cache('in_group', 'subject')
         if commit:
@@ -423,6 +435,21 @@
 
     # other utilities #########################################################
 
+    def grant_permission(self, entity, group, pname, plabel=None):
+        """insert a permission on an entity. Will have to commit the main
+        connection to be considered
+        """
+        pname = unicode(pname)
+        plabel = plabel and unicode(plabel) or unicode(group)
+        e = entity.eid
+        with security_enabled(self.session, False, False):
+            peid = self.execute(
+            'INSERT CWPermission X: X name %(pname)s, X label %(plabel)s,'
+            'X require_group G, E require_permission X '
+            'WHERE G name %(group)s, E eid %(e)s',
+            locals())[0][0]
+        return peid
+
     @contextmanager
     def temporary_appobjects(self, *appobjects):
         self.vreg._loadedmods.setdefault(self.__module__, {})
@@ -434,7 +461,20 @@
             for obj in appobjects:
                 self.vreg.unregister(obj)
 
-    # vregistry inspection utilities ###########################################
+    def assertModificationDateGreater(self, entity, olddate):
+        entity.cw_attr_cache.pop('modification_date', None)
+        self.failUnless(entity.modification_date > olddate)
+
+
+    # workflow utilities #######################################################
+
+    def assertPossibleTransitions(self, entity, expected):
+        transitions = entity.cw_adapt_to('IWorkflowable').possible_transitions()
+        self.assertListEqual(sorted(tr.name for tr in transitions),
+                             sorted(expected))
+
+
+    # views and actions registries inspection ##################################
 
     def pviews(self, req, rset):
         return sorted((a.__regid__, a.__class__)
@@ -468,9 +508,7 @@
             def items(self):
                 return self
         class fake_box(object):
-            def mk_action(self, label, url, **kwargs):
-                return (label, url)
-            def box_action(self, action, **kwargs):
+            def action_link(self, action, **kwargs):
                 return (action.title, action.url())
         submenu = fake_menu()
         action.fill_menu(fake_box(), submenu)
@@ -489,7 +527,8 @@
                 continue
             views = [view for view in views
                      if view.category != 'startupview'
-                     and not issubclass(view, notification.NotificationView)]
+                     and not issubclass(view, notification.NotificationView)
+                     and not isinstance(view, class_deprecated)]
             if views:
                 try:
                     view = viewsvreg._select_best(views, req, rset=rset)
@@ -511,7 +550,7 @@
     def list_boxes_for(self, rset):
         """returns the list of boxes that can be applied on `rset`"""
         req = rset.req
-        for box in self.vreg['boxes'].possible_objects(req, rset=rset):
+        for box in self.vreg['ctxcomponents'].possible_objects(req, rset=rset):
             yield box
 
     def list_startup_views(self):
@@ -620,6 +659,10 @@
     def init_authentication(self, authmode, anonuser=None):
         self.set_option('auth-mode', authmode)
         self.set_option('anonymous-user', anonuser)
+        if anonuser is None:
+            self.config.anonymous_credential = None
+        else:
+            self.config.anonymous_credential = (anonuser, anonuser)
         req = self.request()
         origsession = req.session
         req.session = req.cnx = None
@@ -721,10 +764,8 @@
         :returns: an instance of `cubicweb.devtools.htmlparser.PageInfo`
                   encapsulation the generated HTML
         """
-        output = None
         try:
             output = viewfunc(**kwargs)
-            return self._check_html(output, view, template)
         except (SystemExit, KeyboardInterrupt):
             raise
         except:
@@ -735,44 +776,107 @@
                 msg = '[%s in %s] %s' % (klass, view.__regid__, exc)
             except:
                 msg = '[%s in %s] undisplayable exception' % (klass, view.__regid__)
-            if output is not None:
-                position = getattr(exc, "position", (0,))[0]
-                if position:
-                    # define filter
-                    output = output.splitlines()
-                    width = int(log(len(output), 10)) + 1
-                    line_template = " %" + ("%i" % width) + "i: %s"
-                    # XXX no need to iterate the whole file except to get
-                    # the line number
-                    output = '\n'.join(line_template % (idx + 1, line)
-                                for idx, line in enumerate(output)
-                                if line_context_filter(idx+1, position))
-                    msg += '\nfor output:\n%s' % output
             raise AssertionError, msg, tcbk
+        return self._check_html(output, view, template)
 
+    def get_validator(self, view=None, content_type=None, output=None):
+        if view is not None:
+            try:
+                return self.vid_validators[view.__regid__]()
+            except KeyError:
+                if content_type is None:
+                    content_type = view.content_type
+        if content_type is None:
+            content_type = 'text/html'
+        if content_type in ('text/html', 'application/xhtml+xml'):
+            if output and output.startswith('<?xml'):
+                default_validator = htmlparser.DTDValidator
+            else:
+                default_validator = htmlparser.HTMLValidator
+        else:
+            default_validator = None
+        validatorclass = self.content_type_validators.get(content_type,
+                                                          default_validator)
+        if validatorclass is None:
+            return
+        return validatorclass()
 
     @nocoverage
     def _check_html(self, output, view, template='main-template'):
         """raises an exception if the HTML is invalid"""
-        try:
-            validatorclass = self.vid_validators[view.__regid__]
-        except KeyError:
-            if view.content_type in ('text/html', 'application/xhtml+xml'):
-                if template is None:
-                    default_validator = htmlparser.HTMLValidator
-                else:
-                    default_validator = htmlparser.DTDValidator
-            else:
-                default_validator = None
-            validatorclass = self.content_type_validators.get(view.content_type,
-                                                              default_validator)
-        if validatorclass is None:
-            return output.strip()
-        validator = validatorclass()
+        output = output.strip()
+        validator = self.get_validator(view, output=output)
+        if validator is None:
+            return
         if isinstance(validator, htmlparser.DTDValidator):
             # XXX remove <canvas> used in progress widget, unknown in html dtd
             output = re.sub('<canvas.*?></canvas>', '', output)
-        return validator.parse_string(output.strip())
+        return self.assertWellFormed(validator, output.strip(), context= view.__regid__)
+
+    def assertWellFormed(self, validator, content, context=None):
+        try:
+            return validator.parse_string(content)
+        except (SystemExit, KeyboardInterrupt):
+            raise
+        except:
+            # hijack exception: generative tests stop when the exception
+            # is not an AssertionError
+            klass, exc, tcbk = sys.exc_info()
+            if context is None:
+                msg = u'[%s]' % (klass,)
+            else:
+                msg = u'[%s in %s]' % (klass, context)
+            msg = msg.encode(sys.getdefaultencoding(), 'replace')
+
+            try:
+                str_exc = str(exc)
+            except:
+                str_exc = 'undisplayable exception'
+            msg += str_exc
+            if content is not None:
+                position = getattr(exc, "position", (0,))[0]
+                if position:
+                    # define filter
+                    if isinstance(content, str):
+                        content = unicode(content, sys.getdefaultencoding(), 'replace')
+                    content = content.splitlines()
+                    width = int(log(len(content), 10)) + 1
+                    line_template = " %" + ("%i" % width) + "i: %s"
+                    # XXX no need to iterate the whole file except to get
+                    # the line number
+                    content = u'\n'.join(line_template % (idx + 1, line)
+                                         for idx, line in enumerate(content)
+                                         if line_context_filter(idx+1, position))
+                    msg += u'\nfor content:\n%s' % content
+            raise AssertionError, msg, tcbk
+
+    def assertDocTestFile(self, testfile):
+        # doctest returns tuple (failure_count, test_count)
+        result = self.shell().process_script(testfile)
+        if result[0] and result[1]:
+            raise self.failureException("doctest file '%s' failed"
+                                        % testfile)
+
+    # notifications ############################################################
+
+    def assertSentEmail(self, subject, recipients=None, nb_msgs=None):
+        """test recipients in system mailbox for given email subject
+
+        :param subject: email subject to find in mailbox
+        :param recipients: list of email recipients
+        :param nb_msgs: expected number of entries
+        :returns: list of matched emails
+        """
+        messages = [email for email in MAILBOX
+                    if email.message.get('Subject') == subject]
+        if recipients is not None:
+            sent_to = set()
+            for msg in messages:
+                sent_to.update(msg.recipients)
+            self.assertSetEqual(set(recipients), sent_to)
+        if nb_msgs is not None:
+            self.assertEqual(len(MAILBOX), nb_msgs)
+        return messages
 
     # deprecated ###############################################################
 
@@ -966,7 +1070,8 @@
         for action in self.list_actions_for(rset):
             yield InnerTest(self._testname(rset, action.__regid__, 'action'), self._test_action, action)
         for box in self.list_boxes_for(rset):
-            yield InnerTest(self._testname(rset, box.__regid__, 'box'), box.render)
+            w = [].append
+            yield InnerTest(self._testname(rset, box.__regid__, 'box'), box.render, w)
 
     @staticmethod
     def _testname(rset, objid, objtype):