--- a/devtools/testlib.py Fri Dec 10 12:17:18 2010 +0100
+++ b/devtools/testlib.py Fri Mar 11 09:46:45 2011 +0100
@@ -1,4 +1,4 @@
-# copyright 2003-2010 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of CubicWeb.
@@ -25,7 +25,7 @@
import sys
import re
import urlparse
-from os.path import dirname, join
+from os.path import dirname, join, abspath
from urllib import unquote
from math import log
from contextlib import contextmanager
@@ -38,7 +38,7 @@
from logilab.common.debugger import Debugger
from logilab.common.umessage import message_from_string
from logilab.common.decorators import cached, classproperty, clear_cache
-from logilab.common.deprecation import deprecated
+from logilab.common.deprecation import deprecated, class_deprecated
from logilab.common.shellutils import getlogin
from cubicweb import ValidationError, NoSelectableObject, AuthenticationError
@@ -185,6 +185,7 @@
* `repo`, the repository object
* `admlogin`, login of the admin user
* `admpassword`, password of the admin user
+ * `shell`, create and use shell environment
"""
appid = 'data'
configcls = devtools.ApptestConfiguration
@@ -200,7 +201,7 @@
try:
return cls.__dict__['_config']
except KeyError:
- home = join(dirname(sys.modules[cls.__module__].__file__), cls.appid)
+ home = abspath(join(dirname(sys.modules[cls.__module__].__file__), cls.appid))
config = cls._config = cls.configcls(cls.appid, apphome=home)
config.mode = 'test'
return config
@@ -286,18 +287,29 @@
"""return current server side session (using default manager account)"""
return self.repo._sessions[self._orig_cnx[0].sessionid]
+ def shell(self):
+ """return a shell session object"""
+ from cubicweb.server.migractions import ServerMigrationHelper
+ return ServerMigrationHelper(None, repo=self.repo, cnx=self.cnx,
+ interactive=False,
+ # hack so it don't try to load fs schema
+ schema=1)
+
def set_option(self, optname, value):
self.config.global_set_option(optname, value)
def set_debug(self, debugmode):
server.set_debug(debugmode)
+ def debugged(self, debugmode):
+ return server.debugged(debugmode)
+
# default test setup and teardown #########################################
def setUp(self):
# monkey patch send mail operation so emails are sent synchronously
- self._old_mail_commit_event = SendMailOp.commit_event
- SendMailOp.commit_event = SendMailOp.sendmails
+ self._old_mail_postcommit_event = SendMailOp.postcommit_event
+ SendMailOp.postcommit_event = SendMailOp.sendmails
pause_tracing()
previous_failure = self.__class__.__dict__.get('_repo_init_failed')
if previous_failure is not None:
@@ -319,7 +331,7 @@
for cnx in self._cnxs:
if not cnx._closed:
cnx.close()
- SendMailOp.commit_event = self._old_mail_commit_event
+ SendMailOp.postcommit_event = self._old_mail_postcommit_event
def setup_database(self):
"""add your database setup code by overriding this method"""
@@ -344,7 +356,7 @@
user = req.create_entity('CWUser', login=unicode(login),
upassword=password, **kwargs)
req.execute('SET X in_group G WHERE X eid %%(x)s, G name IN(%s)'
- % ','.join(repr(g) for g in groups),
+ % ','.join(repr(str(g)) for g in groups),
{'x': user.eid})
user.cw_clear_relation_cache('in_group', 'subject')
if commit:
@@ -423,6 +435,21 @@
# other utilities #########################################################
+ def grant_permission(self, entity, group, pname, plabel=None):
+ """insert a permission on an entity. Will have to commit the main
+ connection to be considered
+ """
+ pname = unicode(pname)
+ plabel = plabel and unicode(plabel) or unicode(group)
+ e = entity.eid
+ with security_enabled(self.session, False, False):
+ peid = self.execute(
+ 'INSERT CWPermission X: X name %(pname)s, X label %(plabel)s,'
+ 'X require_group G, E require_permission X '
+ 'WHERE G name %(group)s, E eid %(e)s',
+ locals())[0][0]
+ return peid
+
@contextmanager
def temporary_appobjects(self, *appobjects):
self.vreg._loadedmods.setdefault(self.__module__, {})
@@ -434,7 +461,20 @@
for obj in appobjects:
self.vreg.unregister(obj)
- # vregistry inspection utilities ###########################################
+ def assertModificationDateGreater(self, entity, olddate):
+ entity.cw_attr_cache.pop('modification_date', None)
+ self.failUnless(entity.modification_date > olddate)
+
+
+ # workflow utilities #######################################################
+
+ def assertPossibleTransitions(self, entity, expected):
+ transitions = entity.cw_adapt_to('IWorkflowable').possible_transitions()
+ self.assertListEqual(sorted(tr.name for tr in transitions),
+ sorted(expected))
+
+
+ # views and actions registries inspection ##################################
def pviews(self, req, rset):
return sorted((a.__regid__, a.__class__)
@@ -468,9 +508,7 @@
def items(self):
return self
class fake_box(object):
- def mk_action(self, label, url, **kwargs):
- return (label, url)
- def box_action(self, action, **kwargs):
+ def action_link(self, action, **kwargs):
return (action.title, action.url())
submenu = fake_menu()
action.fill_menu(fake_box(), submenu)
@@ -489,7 +527,8 @@
continue
views = [view for view in views
if view.category != 'startupview'
- and not issubclass(view, notification.NotificationView)]
+ and not issubclass(view, notification.NotificationView)
+ and not isinstance(view, class_deprecated)]
if views:
try:
view = viewsvreg._select_best(views, req, rset=rset)
@@ -511,7 +550,7 @@
def list_boxes_for(self, rset):
"""returns the list of boxes that can be applied on `rset`"""
req = rset.req
- for box in self.vreg['boxes'].possible_objects(req, rset=rset):
+ for box in self.vreg['ctxcomponents'].possible_objects(req, rset=rset):
yield box
def list_startup_views(self):
@@ -620,6 +659,10 @@
def init_authentication(self, authmode, anonuser=None):
self.set_option('auth-mode', authmode)
self.set_option('anonymous-user', anonuser)
+ if anonuser is None:
+ self.config.anonymous_credential = None
+ else:
+ self.config.anonymous_credential = (anonuser, anonuser)
req = self.request()
origsession = req.session
req.session = req.cnx = None
@@ -721,10 +764,8 @@
:returns: an instance of `cubicweb.devtools.htmlparser.PageInfo`
encapsulation the generated HTML
"""
- output = None
try:
output = viewfunc(**kwargs)
- return self._check_html(output, view, template)
except (SystemExit, KeyboardInterrupt):
raise
except:
@@ -735,44 +776,107 @@
msg = '[%s in %s] %s' % (klass, view.__regid__, exc)
except:
msg = '[%s in %s] undisplayable exception' % (klass, view.__regid__)
- if output is not None:
- position = getattr(exc, "position", (0,))[0]
- if position:
- # define filter
- output = output.splitlines()
- width = int(log(len(output), 10)) + 1
- line_template = " %" + ("%i" % width) + "i: %s"
- # XXX no need to iterate the whole file except to get
- # the line number
- output = '\n'.join(line_template % (idx + 1, line)
- for idx, line in enumerate(output)
- if line_context_filter(idx+1, position))
- msg += '\nfor output:\n%s' % output
raise AssertionError, msg, tcbk
+ return self._check_html(output, view, template)
+ def get_validator(self, view=None, content_type=None, output=None):
+ if view is not None:
+ try:
+ return self.vid_validators[view.__regid__]()
+ except KeyError:
+ if content_type is None:
+ content_type = view.content_type
+ if content_type is None:
+ content_type = 'text/html'
+ if content_type in ('text/html', 'application/xhtml+xml'):
+ if output and output.startswith('<?xml'):
+ default_validator = htmlparser.DTDValidator
+ else:
+ default_validator = htmlparser.HTMLValidator
+ else:
+ default_validator = None
+ validatorclass = self.content_type_validators.get(content_type,
+ default_validator)
+ if validatorclass is None:
+ return
+ return validatorclass()
@nocoverage
def _check_html(self, output, view, template='main-template'):
"""raises an exception if the HTML is invalid"""
- try:
- validatorclass = self.vid_validators[view.__regid__]
- except KeyError:
- if view.content_type in ('text/html', 'application/xhtml+xml'):
- if template is None:
- default_validator = htmlparser.HTMLValidator
- else:
- default_validator = htmlparser.DTDValidator
- else:
- default_validator = None
- validatorclass = self.content_type_validators.get(view.content_type,
- default_validator)
- if validatorclass is None:
- return output.strip()
- validator = validatorclass()
+ output = output.strip()
+ validator = self.get_validator(view, output=output)
+ if validator is None:
+ return
if isinstance(validator, htmlparser.DTDValidator):
# XXX remove <canvas> used in progress widget, unknown in html dtd
output = re.sub('<canvas.*?></canvas>', '', output)
- return validator.parse_string(output.strip())
+ return self.assertWellFormed(validator, output.strip(), context= view.__regid__)
+
+ def assertWellFormed(self, validator, content, context=None):
+ try:
+ return validator.parse_string(content)
+ except (SystemExit, KeyboardInterrupt):
+ raise
+ except:
+ # hijack exception: generative tests stop when the exception
+ # is not an AssertionError
+ klass, exc, tcbk = sys.exc_info()
+ if context is None:
+ msg = u'[%s]' % (klass,)
+ else:
+ msg = u'[%s in %s]' % (klass, context)
+ msg = msg.encode(sys.getdefaultencoding(), 'replace')
+
+ try:
+ str_exc = str(exc)
+ except:
+ str_exc = 'undisplayable exception'
+ msg += str_exc
+ if content is not None:
+ position = getattr(exc, "position", (0,))[0]
+ if position:
+ # define filter
+ if isinstance(content, str):
+ content = unicode(content, sys.getdefaultencoding(), 'replace')
+ content = content.splitlines()
+ width = int(log(len(content), 10)) + 1
+ line_template = " %" + ("%i" % width) + "i: %s"
+ # XXX no need to iterate the whole file except to get
+ # the line number
+ content = u'\n'.join(line_template % (idx + 1, line)
+ for idx, line in enumerate(content)
+ if line_context_filter(idx+1, position))
+ msg += u'\nfor content:\n%s' % content
+ raise AssertionError, msg, tcbk
+
+ def assertDocTestFile(self, testfile):
+ # doctest returns tuple (failure_count, test_count)
+ result = self.shell().process_script(testfile)
+ if result[0] and result[1]:
+ raise self.failureException("doctest file '%s' failed"
+ % testfile)
+
+ # notifications ############################################################
+
+ def assertSentEmail(self, subject, recipients=None, nb_msgs=None):
+ """test recipients in system mailbox for given email subject
+
+ :param subject: email subject to find in mailbox
+ :param recipients: list of email recipients
+ :param nb_msgs: expected number of entries
+ :returns: list of matched emails
+ """
+ messages = [email for email in MAILBOX
+ if email.message.get('Subject') == subject]
+ if recipients is not None:
+ sent_to = set()
+ for msg in messages:
+ sent_to.update(msg.recipients)
+ self.assertSetEqual(set(recipients), sent_to)
+ if nb_msgs is not None:
+ self.assertEqual(len(MAILBOX), nb_msgs)
+ return messages
# deprecated ###############################################################
@@ -966,7 +1070,8 @@
for action in self.list_actions_for(rset):
yield InnerTest(self._testname(rset, action.__regid__, 'action'), self._test_action, action)
for box in self.list_boxes_for(rset):
- yield InnerTest(self._testname(rset, box.__regid__, 'box'), box.render)
+ w = [].append
+ yield InnerTest(self._testname(rset, box.__regid__, 'box'), box.render, w)
@staticmethod
def _testname(rset, objid, objtype):