--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/devtools/apptest.py Wed Nov 05 15:52:50 2008 +0100
@@ -0,0 +1,504 @@
+"""This module provides misc utilities to test applications
+
+:organization: Logilab
+:copyright: 2001-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
+"""
+__docformat__ = "restructuredtext en"
+
+from copy import deepcopy
+
+import simplejson
+
+from logilab.common.testlib import TestCase
+from logilab.common.pytest import nocoverage
+from logilab.common.umessage import message_from_string
+
+from cubicweb.devtools import init_test_database, TestServerConfiguration, ApptestConfiguration
+from cubicweb.devtools._apptest import TestEnvironment
+from cubicweb.devtools.fake import FakeRequest
+
+from cubicweb.dbapi import repo_connect, ConnectionProperties, ProgrammingError
+
+
+MAILBOX = []
+class Email:
+ def __init__(self, recipients, msg):
+ self.recipients = recipients
+ self.msg = msg
+
+ @property
+ def message(self):
+ return message_from_string(self.msg)
+
+ def __repr__(self):
+ return '<Email to %s with subject %s>' % (','.join(self.recipients),
+ self.message.get('Subject'))
+
+class MockSMTP:
+ def __init__(self, server, port):
+ pass
+ def close(self):
+ pass
+ def sendmail(self, helo_addr, recipients, msg):
+ MAILBOX.append(Email(recipients, msg))
+
+from cubicweb.server import hookhelper
+hookhelper.SMTP = MockSMTP
+
+
+def get_versions(self, checkversions=False):
+ """return the a dictionary containing cubes used by this application
+ as key with their version as value, including cubicweb version. This is a
+ public method, not requiring a session id.
+
+ replace Repository.get_versions by this method if you don't want versions
+ checking
+ """
+ vcconf = {'cubicweb': self.config.cubicweb_version()}
+ self.config.bootstrap_cubes()
+ for pk in self.config.cubes():
+ version = self.config.template_version(pk)
+ vcconf[pk] = version
+ self.config._cubes = None
+ return vcconf
+
+
+@property
+def late_binding_env(self):
+ """builds TestEnvironment as late as possible"""
+ if not hasattr(self, '_env'):
+ self.__class__._env = TestEnvironment('data', configcls=self.configcls,
+ requestcls=self.requestcls)
+ return self._env
+
+
+class autoenv(type):
+ """automatically set environment on EnvBasedTC subclasses if necessary
+ """
+ def __new__(mcs, name, bases, classdict):
+ env = classdict.get('env')
+ # try to find env in one of the base classes
+ if env is None:
+ for base in bases:
+ env = getattr(base, 'env', None)
+ if env is not None:
+ classdict['env'] = env
+ break
+ if not classdict.get('__abstract__') and not classdict.get('env'):
+ classdict['env'] = late_binding_env
+ return super(autoenv, mcs).__new__(mcs, name, bases, classdict)
+
+
+class EnvBasedTC(TestCase):
+ """abstract class for test using an apptest environment
+ """
+ __metaclass__ = autoenv
+ __abstract__ = True
+ env = None
+ configcls = ApptestConfiguration
+ requestcls = FakeRequest
+
+ # user / session management ###############################################
+
+ def user(self, req=None):
+ if req is None:
+ req = self.env.create_request()
+ return self.env.cnx.user(req)
+ else:
+ return req.user
+
+ def create_user(self, *args, **kwargs):
+ return self.env.create_user(*args, **kwargs)
+
+ def login(self, login):
+ return self.env.login(login)
+
+ def restore_connection(self):
+ self.env.restore_connection()
+
+ # db api ##################################################################
+
+ @nocoverage
+ def cursor(self, req=None):
+ return self.env.cnx.cursor(req or self.request())
+
+ @nocoverage
+ def execute(self, *args, **kwargs):
+ return self.env.execute(*args, **kwargs)
+
+ @nocoverage
+ def commit(self):
+ self.env.cnx.commit()
+
+ @nocoverage
+ def rollback(self):
+ try:
+ self.env.cnx.rollback()
+ except ProgrammingError:
+ pass
+
+ # other utilities #########################################################
+ def set_debug(self, debugmode):
+ from cubicweb.server import set_debug
+ set_debug(debugmode)
+
+ @property
+ def config(self):
+ return self.vreg.config
+
+ def session(self):
+ """return current server side session (using default manager account)"""
+ return self.env.repo._sessions[self.env.cnx.sessionid]
+
+ def request(self, *args, **kwargs):
+ """return a web interface request"""
+ return self.env.create_request(*args, **kwargs)
+
+ @nocoverage
+ def rset_and_req(self, *args, **kwargs):
+ return self.env.get_rset_and_req(*args, **kwargs)
+
+ def entity(self, rql, args=None, eidkey=None, req=None):
+ return self.execute(rql, args, eidkey, req=req).get_entity(0, 0)
+
+ def etype_instance(self, etype, req=None):
+ req = req or self.request()
+ e = self.env.vreg.etype_class(etype)(req, None, None)
+ e.eid = None
+ return e
+
+ def add_entity(self, etype, **kwargs):
+ rql = ['INSERT %s X' % etype]
+
+ # dict for replacement in RQL Request
+ rql_args = {}
+
+ if kwargs: #
+ rql.append(':')
+ # dict to define new entities variables
+ entities = {}
+
+ # assignement part of the request
+ sub_rql = []
+ for key, value in kwargs.iteritems():
+ # entities
+ if hasattr(value, 'eid'):
+ new_value = "%s__" % key.upper()
+
+ entities[new_value] = value.eid
+ rql_args[new_value] = value.eid
+
+ sub_rql.append("X %s %s" % (key, new_value))
+ # final attributes
+ else:
+ sub_rql.append('X %s %%(%s)s' % (key, key))
+ rql_args[key] = value
+ rql.append(', '.join(sub_rql))
+
+
+ if entities:
+ rql.append('WHERE')
+ # WHERE part of the request (to link entity to they eid)
+ sub_rql = []
+ for key, value in entities.iteritems():
+ sub_rql.append("%s eid %%(%s)s" % (key, key))
+ rql.append(', '.join(sub_rql))
+
+ rql = ' '.join(rql)
+ rset = self.execute(rql, rql_args)
+ return rset.get_entity(0, 0)
+
+ def set_option(self, optname, value):
+ self.vreg.config.global_set_option(optname, value)
+
+ def pviews(self, req, rset):
+ return sorted((a.id, a.__class__) for a in self.vreg.possible_views(req, rset))
+
+ def pactions(self, req, rset, skipcategories=('addrelated', 'siteactions', 'useractions')):
+ return [(a.id, a.__class__) for a in self.vreg.possible_vobjects('actions', req, rset)
+ if a.category not in skipcategories]
+ def pactionsdict(self, req, rset, skipcategories=('addrelated', 'siteactions', 'useractions')):
+ res = {}
+ for a in self.vreg.possible_vobjects('actions', req, rset):
+ if a.category not in skipcategories:
+ res.setdefault(a.category, []).append(a.__class__)
+ return res
+
+ def paddrelactions(self, req, rset):
+ return [(a.id, a.__class__) for a in self.vreg.possible_vobjects('actions', req, rset)
+ if a.category == 'addrelated']
+
+ def remote_call(self, fname, *args):
+ """remote call simulation"""
+ dump = simplejson.dumps
+ args = [dump(arg) for arg in args]
+ req = self.request(mode='remote', fname=fname, pageid='123', arg=args)
+ ctrl = self.env.app.select_controller('json', req)
+ return ctrl.publish(), req
+
+ # default test setup and teardown #########################################
+
+ def setup_database(self):
+ pass
+
+ def setUp(self):
+ self.restore_connection()
+ session = self.session()
+ #self.maxeid = self.execute('Any MAX(X)')
+ session.set_pool()
+ self.maxeid = session.system_sql('SELECT MAX(eid) FROM entities').fetchone()[0]
+ self.app = self.env.app
+ self.vreg = self.env.app.vreg
+ self.schema = self.vreg.schema
+ self.vreg.config.mode = 'test'
+ # set default-dest-addrs to a dumb email address to avoid mailbox or
+ # mail queue pollution
+ self.set_option('default-dest-addrs', ['whatever'])
+ self.setup_database()
+ self.commit()
+ MAILBOX[:] = [] # reset mailbox
+
+ @nocoverage
+ def tearDown(self):
+ self.rollback()
+ # self.env.restore_database()
+ self.env.restore_connection()
+ self.session().unsafe_execute('DELETE Any X WHERE X eid > %s' % self.maxeid)
+ self.commit()
+
+
+# XXX
+try:
+ from cubicweb.web import Redirect
+ from urllib import unquote
+except ImportError:
+ pass # cubicweb-web not installed
+else:
+ class ControllerTC(EnvBasedTC):
+ def setUp(self):
+ super(ControllerTC, self).setUp()
+ self.req = self.request()
+ self.ctrl = self.env.app.select_controller('edit', self.req)
+
+ def publish(self, req):
+ assert req is self.ctrl.req
+ try:
+ result = self.ctrl.publish()
+ req.cnx.commit()
+ except Redirect:
+ req.cnx.commit()
+ raise
+ return result
+
+ def expect_redirect_publish(self, req=None):
+ if req is not None:
+ self.ctrl = self.env.app.select_controller('edit', req)
+ else:
+ req = self.req
+ try:
+ res = self.publish(req)
+ except Redirect, ex:
+ try:
+ path, params = ex.location.split('?', 1)
+ except:
+ path, params = ex.location, ""
+ req._url = path
+ cleanup = lambda p: (p[0], unquote(p[1]))
+ params = dict(cleanup(p.split('=', 1)) for p in params.split('&') if p)
+ return req.relative_path(False), params # path.rsplit('/', 1)[-1], params
+ else:
+ self.fail('expected a Redirect exception')
+
+
+def make_late_binding_repo_property(attrname):
+ @property
+ def late_binding(self):
+ """builds cnx as late as possible"""
+ if not hasattr(self, attrname):
+ # sets explicit test mode here to avoid autoreload
+ from cubicweb.cwconfig import CubicWebConfiguration
+ CubicWebConfiguration.mode = 'test'
+ cls = self.__class__
+ config = self.repo_config or TestServerConfiguration('data')
+ cls._repo, cls._cnx = init_test_database('sqlite', config=config)
+ return getattr(self, attrname)
+ return late_binding
+
+
+class autorepo(type):
+ """automatically set repository on RepositoryBasedTC subclasses if necessary
+ """
+ def __new__(mcs, name, bases, classdict):
+ repo = classdict.get('repo')
+ # try to find repo in one of the base classes
+ if repo is None:
+ for base in bases:
+ repo = getattr(base, 'repo', None)
+ if repo is not None:
+ classdict['repo'] = repo
+ break
+ if name != 'RepositoryBasedTC' and not classdict.get('repo'):
+ classdict['repo'] = make_late_binding_repo_property('_repo')
+ classdict['cnx'] = make_late_binding_repo_property('_cnx')
+ return super(autorepo, mcs).__new__(mcs, name, bases, classdict)
+
+
+class RepositoryBasedTC(TestCase):
+ """abstract class for test using direct repository connections
+ """
+ __metaclass__ = autorepo
+ repo_config = None # set a particular config instance if necessary
+
+ # user / session management ###############################################
+
+ def create_user(self, user, groups=('users',), password=None, commit=True):
+ if password is None:
+ password = user
+ eid = self.execute('INSERT EUser X: X login %(x)s, X upassword %(p)s,'
+ 'X in_state S WHERE S name "activated"',
+ {'x': unicode(user), 'p': password})[0][0]
+ groups = ','.join(repr(group) for group in groups)
+ self.execute('SET X in_group Y WHERE X eid %%(x)s, Y name IN (%s)' % groups,
+ {'x': eid})
+ if commit:
+ self.commit()
+ self.session.reset_pool()
+ return eid
+
+ def login(self, login, password=None):
+ cnx = repo_connect(self.repo, unicode(login), password or login,
+ ConnectionProperties('inmemory'))
+ self.cnxs.append(cnx)
+ return cnx
+
+ def current_session(self):
+ return self.repo._sessions[self.cnxs[-1].sessionid]
+
+ def restore_connection(self):
+ assert len(self.cnxs) == 1, self.cnxs
+ cnx = self.cnxs.pop()
+ try:
+ cnx.close()
+ except Exception, ex:
+ print "exception occured while closing connection", ex
+
+ # db api ##################################################################
+
+ def execute(self, rql, args=None, eid_key=None):
+ assert self.session.id == self.cnxid
+ rset = self.__execute(self.cnxid, rql, args, eid_key)
+ rset.vreg = self.vreg
+ rset.req = self.session
+ # call to set_pool is necessary to avoid pb when using
+ # application entities for convenience
+ self.session.set_pool()
+ return rset
+
+ def commit(self):
+ self.__commit(self.cnxid)
+ self.session.set_pool()
+
+ def rollback(self):
+ self.__rollback(self.cnxid)
+ self.session.set_pool()
+
+ def close(self):
+ self.__close(self.cnxid)
+
+ # other utilities #########################################################
+ def set_debug(self, debugmode):
+ from cubicweb.server import set_debug
+ set_debug(debugmode)
+
+ def set_option(self, optname, value):
+ self.vreg.config.global_set_option(optname, value)
+
+ def add_entity(self, etype, **kwargs):
+ restrictions = ', '.join('X %s %%(%s)s' % (key, key) for key in kwargs)
+ rql = 'INSERT %s X' % etype
+ if kwargs:
+ rql += ': %s' % ', '.join('X %s %%(%s)s' % (key, key) for key in kwargs)
+ rset = self.execute(rql, kwargs)
+ return rset.get_entity(0, 0)
+
+ def default_user_password(self):
+ config = self.repo.config #TestConfiguration('data')
+ user = unicode(config.sources()['system']['db-user'])
+ passwd = config.sources()['system']['db-password']
+ return user, passwd
+
+ def close_connections(self):
+ for cnx in self.cnxs:
+ try:
+ cnx.rollback()
+ cnx.close()
+ except:
+ continue
+ self.cnxs = []
+
+ pactions = EnvBasedTC.pactions.im_func
+ pactionsdict = EnvBasedTC.pactionsdict.im_func
+
+ # default test setup and teardown #########################################
+ copy_schema = False
+
+ def _prepare(self):
+ MAILBOX[:] = [] # reset mailbox
+ if hasattr(self, 'cnxid'):
+ return
+ repo = self.repo
+ self.__execute = repo.execute
+ self.__commit = repo.commit
+ self.__rollback = repo.rollback
+ self.__close = repo.close
+ self.cnxid = repo.connect(*self.default_user_password())
+ self.session = repo._sessions[self.cnxid]
+ # XXX copy schema since hooks may alter it and it may be not fully
+ # cleaned (missing some schema synchronization support)
+ try:
+ origschema = repo.__schema
+ except AttributeError:
+ origschema = repo.schema
+ repo.__schema = origschema
+ if self.copy_schema:
+ repo.schema = deepcopy(origschema)
+ repo.set_schema(repo.schema) # reset hooks
+ repo.vreg.update_schema(repo.schema)
+ self.cnxs = []
+ # reset caches, they may introduce bugs among tests
+ repo._type_source_cache = {}
+ repo._extid_cache = {}
+ repo.querier._rql_cache = {}
+ for source in repo.sources:
+ source.reset_caches()
+ for s in repo.sources:
+ if hasattr(s, '_cache'):
+ s._cache = {}
+
+ @property
+ def config(self):
+ return self.repo.config
+
+ @property
+ def vreg(self):
+ return self.repo.vreg
+
+ @property
+ def schema(self):
+ return self.repo.schema
+
+ def setUp(self):
+ self._prepare()
+ self.session.set_pool()
+ self.maxeid = self.session.system_sql('SELECT MAX(eid) FROM entities').fetchone()[0]
+ #self.maxeid = self.execute('Any MAX(X)')
+
+ def tearDown(self, close=True):
+ self.close_connections()
+ self.rollback()
+ self.session.unsafe_execute('DELETE Any X WHERE X eid > %(x)s', {'x': self.maxeid})
+ self.commit()
+ if close:
+ self.close()
+