devtools/repotest.py
author Aurelien Campeas <aurelien.campeas@logilab.fr>
Tue, 07 Dec 2010 12:18:20 +0100
brancholdstable
changeset 7078 bad26a22fe29
parent 6818 5fa425574548
child 6957 ffda12be2e9f
permissions -rw-r--r--
[test] New Handling of database for test. This patch adds a new TestDataBaseHandler class. TestDataBaseHandler are in charge of Setup, backup, restore, connection, repository caching and cleanup for database used during the test. TestDataBaseHandler reuse code and logic previously found in cubicweb.devtools functions and devtools.testlib.CubicwebTC. TestDataBaseHandler is an abstract class and must be subclassed to implement functionalities specific to each driver. TestDataBaseHandler can store and restore various database setups. devtools.testlib.CubicwebTC gains a test_db_id class attribute to specify that its TestCase uses a specific database that should be cached. The pre_setup_database class method is used to setup the database that will be cached. The setup_database method is kept uncached. The same TestDataBaseHandler are reused for every test using the same config object. TestDataBaseHandler try to reuse Repository objects as much as possible. All cubicweb test have been updated.

# copyright 2003-2010 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of CubicWeb.
#
# CubicWeb is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# CubicWeb is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
"""some utilities to ease repository testing

This module contains functions to initialize a new repository.
"""

__docformat__ = "restructuredtext en"

from copy import deepcopy
from pprint import pprint

from logilab.common.decorators import clear_cache
from logilab.common.testlib import SkipTest

def tuplify(list):
    for i in range(len(list)):
        if type(list[i]) is not type(()):
            list[i] = tuple(list[i])
    return list

def snippet_cmp(a, b):
    a = (a[0], [e.expression for e in a[1]])
    b = (b[0], [e.expression for e in b[1]])
    return cmp(a, b)

def test_plan(self, rql, expected, kwargs=None):
    plan = self._prepare_plan(rql, kwargs)
    self.planner.build_plan(plan)
    try:
        self.assertEqual(len(plan.steps), len(expected),
                          'expected %s steps, got %s' % (len(expected), len(plan.steps)))
        # step order is important
        for i, step in enumerate(plan.steps):
            compare_steps(self, step.test_repr(), expected[i])
    except AssertionError:
        pprint([step.test_repr() for step in plan.steps])
        raise

def compare_steps(self, step, expected):
    try:
        self.assertEqual(step[0], expected[0], 'expected step type %s, got %s' % (expected[0], step[0]))
        if len(step) > 2 and isinstance(step[1], list) and isinstance(expected[1], list):
            queries, equeries = step[1], expected[1]
            self.assertEqual(len(queries), len(equeries),
                              'expected %s queries, got %s' % (len(equeries), len(queries)))
            for i, (rql, sol) in enumerate(queries):
                self.assertEqual(rql, equeries[i][0])
                self.assertEqual(sorted(sol), sorted(equeries[i][1]))
            idx = 2
        else:
            idx = 1
        self.assertEqual(step[idx:-1], expected[idx:-1],
                          'expected step characteristic \n%s\n, got\n%s' % (expected[1:-1], step[1:-1]))
        self.assertEqual(len(step[-1]), len(expected[-1]),
                          'got %s child steps, expected %s' % (len(step[-1]), len(expected[-1])))
    except AssertionError:
        print 'error on step ',
        pprint(step[:-1])
        raise
    children = step[-1]
    if step[0] in ('UnionFetchStep', 'UnionStep'):
        # sort children
        children = sorted(children)
        expectedchildren = sorted(expected[-1])
    else:
        expectedchildren = expected[-1]
    for i, substep in enumerate(children):
        compare_steps(self, substep, expectedchildren[i])


class DumbOrderedDict(list):
    def __iter__(self):
        return self.iterkeys()
    def __contains__(self, key):
        return key in self.iterkeys()
    def __getitem__(self, key):
        for key_, value in list.__iter__(self):
            if key == key_:
                return value
        raise KeyError(key)
    def iterkeys(self):
        return (x for x, y in list.__iter__(self))
    def iteritems(self):
        return (x for x in list.__iter__(self))
    def items(self):
        return [x for x in list.__iter__(self)]

class DumbOrderedDict2(object):
    def __init__(self, origdict, sortkey):
        self.origdict = origdict
        self.sortkey = sortkey
    def __getattr__(self, attr):
        return getattr(self.origdict, attr)
    def __iter__(self):
        return iter(sorted(self.origdict, key=self.sortkey))

def schema_eids_idx(schema):
    """return a dictionary mapping schema types to their eids so we can reread
    it from the fs instead of the db (too costly) between tests
    """
    schema_eids = {}
    for x in schema.entities():
        schema_eids[x] = x.eid
    for x in schema.relations():
        schema_eids[x] = x.eid
        for rdef in x.rdefs.itervalues():
            schema_eids[(rdef.subject, rdef.rtype, rdef.object)] = rdef.eid
    return schema_eids

def restore_schema_eids_idx(schema, schema_eids):
    """rebuild schema eid index"""
    for x in schema.entities():
        x.eid = schema_eids[x]
        schema._eid_index[x.eid] = x
    for x in schema.relations():
        x.eid = schema_eids[x]
        schema._eid_index[x.eid] = x
        for rdef in x.rdefs.itervalues():
            rdef.eid = schema_eids[(rdef.subject, rdef.rtype, rdef.object)]
            schema._eid_index[rdef.eid] = rdef


from logilab.common.testlib import TestCase, mock_object
from logilab.database import get_db_helper

from rql import RQLHelper

from cubicweb.devtools.fake import FakeRepo, FakeSession
from cubicweb.server import set_debug, debugged
from cubicweb.server.querier import QuerierHelper
from cubicweb.server.session import Session
from cubicweb.server.sources.rql2sql import SQLGenerator, remove_unused_solutions

class RQLGeneratorTC(TestCase):
    schema = backend = None # set this in concret test


    @classmethod
    def setUpClass(cls):
        if cls.backend is not None:
            try:
                cls.dbhelper = get_db_helper(cls.backend)
            except ImportError, ex:
                raise SkipTest(str(ex))

    def setUp(self):
        self.repo = FakeRepo(self.schema)
        self.repo.system_source = mock_object(dbdriver=self.backend)
        self.rqlhelper = RQLHelper(self.schema, special_relations={'eid': 'uid',
                                                                   'has_text': 'fti'},
                                   backend=self.backend)
        self.qhelper = QuerierHelper(self.repo, self.schema)
        ExecutionPlan._check_permissions = _dummy_check_permissions
        rqlannotation._select_principal = _select_principal
        if self.backend is not None:
            self.o = SQLGenerator(self.schema, self.dbhelper)

    def tearDown(self):
        ExecutionPlan._check_permissions = _orig_check_permissions
        rqlannotation._select_principal = _orig_select_principal

    def set_debug(self, debug):
        set_debug(debug)
    def debugged(self, debug):
        return debugged(debug)

    def _prepare(self, rql):
        #print '******************** prepare', rql
        union = self.rqlhelper.parse(rql)
        #print '********* parsed', union.as_string()
        self.rqlhelper.compute_solutions(union)
        #print '********* solutions', solutions
        self.rqlhelper.simplify(union)
        #print '********* simplified', union.as_string()
        plan = self.qhelper.plan_factory(union, {}, FakeSession(self.repo))
        plan.preprocess(union)
        for select in union.children:
            select.solutions.sort()
        #print '********* ppsolutions', solutions
        return union


class BaseQuerierTC(TestCase):
    repo = None # set this in concret test

    def setUp(self):
        self.o = self.repo.querier
        self.session = self.repo._sessions.values()[0]
        self.ueid = self.session.user.eid
        assert self.ueid != -1
        self.repo._type_source_cache = {} # clear cache
        self.pool = self.session.set_pool()
        self.maxeid = self.get_max_eid()
        do_monkey_patch()
        self._dumb_sessions = []

    def get_max_eid(self):
        return self.session.execute('Any MAX(X)')[0][0]
    def cleanup(self):
        self.session.set_pool()
        self.session.execute('DELETE Any X WHERE X eid > %s' % self.maxeid)

    def tearDown(self):
        undo_monkey_patch()
        self.session.rollback()
        self.cleanup()
        self.commit()
        # properly close dumb sessions
        for session in self._dumb_sessions:
            session.rollback()
            session.close()
        self.repo._free_pool(self.pool)
        assert self.session.user.eid != -1

    def set_debug(self, debug):
        set_debug(debug)
    def debugged(self, debug):
        return debugged(debug)

    def _rqlhelper(self):
        rqlhelper = self.repo.vreg.rqlhelper
        # reset uid_func so it don't try to get type from eids
        rqlhelper._analyser.uid_func = None
        rqlhelper._analyser.uid_func_mapping = {}
        return rqlhelper

    def _prepare_plan(self, rql, kwargs=None, simplify=True):
        rqlhelper = self._rqlhelper()
        rqlst = rqlhelper.parse(rql)
        rqlhelper.compute_solutions(rqlst, kwargs=kwargs)
        if simplify:
            rqlhelper.simplify(rqlst)
        for select in rqlst.children:
            select.solutions.sort()
        return self.o.plan_factory(rqlst, kwargs, self.session)

    def _prepare(self, rql, kwargs=None):
        plan = self._prepare_plan(rql, kwargs, simplify=False)
        plan.preprocess(plan.rqlst)
        rqlst = plan.rqlst.children[0]
        rqlst.solutions = remove_unused_solutions(rqlst, rqlst.solutions, {}, self.repo.schema)[0]
        return rqlst

    def user_groups_session(self, *groups):
        """lightweight session using the current user with hi-jacked groups"""
        # use self.session.user.eid to get correct owned_by relation, unless explicit eid
        u = self.repo._build_user(self.session, self.session.user.eid)
        u._groups = set(groups)
        s = Session(u, self.repo)
        s._threaddata.pool = self.pool
        # register session to ensure it gets closed
        self._dumb_sessions.append(s)
        return s

    def execute(self, rql, args=None, build_descr=True):
        return self.o.execute(self.session, rql, args, build_descr)

    def commit(self):
        self.session.commit()
        self.session.set_pool()


class BasePlannerTC(BaseQuerierTC):
    newsources = 0
    def setup(self):
        clear_cache(self.repo, 'rel_type_sources')
        clear_cache(self.repo, 'rel_type_sources')
        clear_cache(self.repo, 'can_cross_relation')
        clear_cache(self.repo, 'is_multi_sources_relation')
        # XXX source_defs
        self.o = self.repo.querier
        self.session = self.repo._sessions.values()[0]
        self.pool = self.session.set_pool()
        self.schema = self.o.schema
        self.sources = self.o._repo.sources
        self.system = self.sources[-1]
        do_monkey_patch()
        self._dumb_sessions = [] # by hi-jacked parent setup
        self.repo.vreg.rqlhelper.backend = 'postgres' # so FTIRANK is considered

    def add_source(self, sourcecls, uri):
        self.sources.append(sourcecls(self.repo, {'uri': uri}))
        self.repo.sources_by_uri[uri] = self.sources[-1]
        setattr(self, uri, self.sources[-1])
        self.newsources += 1

    def tearDown(self):
        while self.newsources:
            source = self.sources.pop(-1)
            del self.repo.sources_by_uri[source.uri]
            self.newsources -= 1
        undo_monkey_patch()
        for session in self._dumb_sessions:
            session._threaddata.pool = None
            session.close()

    def _prepare_plan(self, rql, kwargs=None):
        rqlst = self.o.parse(rql, annotate=True)
        self.o.solutions(self.session, rqlst, kwargs)
        if rqlst.TYPE == 'select':
            self.repo.vreg.rqlhelper.annotate(rqlst)
            for select in rqlst.children:
                select.solutions.sort()
        else:
            rqlst.solutions.sort()
        return self.o.plan_factory(rqlst, kwargs, self.session)


# monkey patch some methods to get predicatable results #######################

from cubicweb.rqlrewrite import RQLRewriter
_orig_insert_snippets = RQLRewriter.insert_snippets
_orig_build_variantes = RQLRewriter.build_variantes

def _insert_snippets(self, snippets, varexistsmap=None):
    _orig_insert_snippets(self, sorted(snippets, snippet_cmp), varexistsmap)

def _build_variantes(self, newsolutions):
    variantes = _orig_build_variantes(self, newsolutions)
    sortedvariantes = []
    for variante in variantes:
        orderedkeys = sorted((k[1], k[2], v) for k, v in variante.iteritems())
        variante = DumbOrderedDict(sorted(variante.iteritems(),
                                          lambda a, b: cmp((a[0][1],a[0][2],a[1]),
                                                           (b[0][1],b[0][2],b[1]))))
        sortedvariantes.append( (orderedkeys, variante) )
    return [v for ok, v in sorted(sortedvariantes)]

from cubicweb.server.querier import ExecutionPlan
_orig_check_permissions = ExecutionPlan._check_permissions
_orig_init_temp_table = ExecutionPlan.init_temp_table

def _check_permissions(*args, **kwargs):
    res, restricted = _orig_check_permissions(*args, **kwargs)
    res = DumbOrderedDict(sorted(res.iteritems(), lambda a, b: cmp(a[1], b[1])))
    return res, restricted

def _dummy_check_permissions(self, rqlst):
    return {(): rqlst.solutions}, set()

def _init_temp_table(self, table, selection, solution):
    if self.tablesinorder is None:
        tablesinorder = self.tablesinorder = {}
    else:
        tablesinorder = self.tablesinorder
    if not table in tablesinorder:
        tablesinorder[table] = 'table%s' % len(tablesinorder)
    return _orig_init_temp_table(self, table, selection, solution)

from cubicweb.server import rqlannotation
_orig_select_principal = rqlannotation._select_principal

def _select_principal(scope, relations):
    return _orig_select_principal(scope, relations,
                                  _sort=lambda rels: sorted(rels, key=lambda x: x.r_type))

try:
    from cubicweb.server.msplanner import PartPlanInformation
except ImportError:
    class PartPlanInformation(object):
        def merge_input_maps(self, *args, **kwargs):
            pass
        def _choose_term(self, sourceterms):
            pass
_orig_merge_input_maps = PartPlanInformation.merge_input_maps
_orig_choose_term = PartPlanInformation._choose_term

def _merge_input_maps(*args, **kwargs):
    return sorted(_orig_merge_input_maps(*args, **kwargs))

def _choose_term(self, source, sourceterms):
    # predictable order for test purpose
    def get_key(x):
        try:
            # variable
            return x.name
        except AttributeError:
            try:
                # relation
                return x.r_type
            except AttributeError:
                # const
                return x.value
    return _orig_choose_term(self, source, DumbOrderedDict2(sourceterms, get_key))

from cubicweb.server.sources.pyrorql import PyroRQLSource
_orig_syntax_tree_search = PyroRQLSource.syntax_tree_search

def _syntax_tree_search(*args, **kwargs):
    return deepcopy(_orig_syntax_tree_search(*args, **kwargs))

def do_monkey_patch():
    RQLRewriter.insert_snippets = _insert_snippets
    RQLRewriter.build_variantes = _build_variantes
    ExecutionPlan._check_permissions = _check_permissions
    ExecutionPlan.tablesinorder = None
    ExecutionPlan.init_temp_table = _init_temp_table
    PartPlanInformation.merge_input_maps = _merge_input_maps
    PartPlanInformation._choose_term = _choose_term
    PyroRQLSource.syntax_tree_search = _syntax_tree_search

def undo_monkey_patch():
    RQLRewriter.insert_snippets = _orig_insert_snippets
    RQLRewriter.build_variantes = _orig_build_variantes
    ExecutionPlan._check_permissions = _orig_check_permissions
    ExecutionPlan.init_temp_table = _orig_init_temp_table
    PartPlanInformation.merge_input_maps = _orig_merge_input_maps
    PartPlanInformation._choose_term = _orig_choose_term
    PyroRQLSource.syntax_tree_search = _orig_syntax_tree_search