devtools/repotest.py
author sylvain.thenault@logilab.fr
Thu, 29 Jan 2009 16:26:33 +0100
changeset 519 06390418cd9a
parent 0 b97547f5f1fa
child 599 9ef680acd92a
permissions -rw-r--r--
pyrorql source now ignore external eids which are themselves coming from another external source already in use by the repository (should have the same uri)

"""some utilities to ease repository testing

This module contains functions to initialize a new repository.

:organization: Logilab
:copyright: 2003-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
"""
__docformat__ = "restructuredtext en"

from pprint import pprint

def tuplify(list):
    for i in range(len(list)):
        if type(list[i]) is not type(()):
            list[i] = tuple(list[i])
    return list

def snippet_cmp(a, b):
    a = (a[0], [e.expression for e in a[1]])
    b = (b[0], [e.expression for e in b[1]])
    return cmp(a, b)

def test_plan(self, rql, expected, kwargs=None):
    plan = self._prepare_plan(rql, kwargs)
    self.planner.build_plan(plan)
    try:
        self.assertEquals(len(plan.steps), len(expected),
                          'expected %s steps, got %s' % (len(expected), len(plan.steps)))
        # step order is important
        for i, step in enumerate(plan.steps):
            compare_steps(self, step.test_repr(), expected[i])
    except AssertionError:
        pprint([step.test_repr() for step in plan.steps])
        raise

def compare_steps(self, step, expected):
    try:
        self.assertEquals(step[0], expected[0], 'expected step type %s, got %s' % (expected[0], step[0]))
        if len(step) > 2 and isinstance(step[1], list) and isinstance(expected[1], list):
            queries, equeries = step[1], expected[1]
            self.assertEquals(len(queries), len(equeries),
                              'expected %s queries, got %s' % (len(equeries), len(queries)))
            for i, (rql, sol) in enumerate(queries):
                self.assertEquals(rql, equeries[i][0])
                self.assertEquals(sol, equeries[i][1])
            idx = 2
        else:
            idx = 1
        self.assertEquals(step[idx:-1], expected[idx:-1],
                          'expected step characteristic \n%s\n, got\n%s' % (expected[1:-1], step[1:-1]))
        self.assertEquals(len(step[-1]), len(expected[-1]),
                          'got %s child steps, expected %s' % (len(step[-1]), len(expected[-1])))
    except AssertionError:
        print 'error on step ',
        pprint(step[:-1])
        raise
    children = step[-1]
    if step[0] in ('UnionFetchStep', 'UnionStep'):
        # sort children
        children = sorted(children)
        expectedchildren = sorted(expected[-1])
    else:
        expectedchildren = expected[-1]
    for i, substep in enumerate(children):
        compare_steps(self, substep, expectedchildren[i])


class DumbOrderedDict(list):
    def __iter__(self):
        return self.iterkeys()
    def __contains__(self, key):
        return key in self.iterkeys()
    def __getitem__(self, key):
        for key_, value in self.iteritems():
            if key == key_:
                return value
        raise KeyError(key)
    def iterkeys(self):
        return (x for x, y in list.__iter__(self))
    def iteritems(self):
        return (x for x in list.__iter__(self))


from logilab.common.testlib import TestCase
from rql import RQLHelper
from cubicweb.devtools.fake import FakeRepo, FakeSession
from cubicweb.server import set_debug
from cubicweb.server.querier import QuerierHelper
from cubicweb.server.session import Session
from cubicweb.server.sources.rql2sql import remove_unused_solutions

class RQLGeneratorTC(TestCase):
    schema = None # set this in concret test
    
    def setUp(self):
        self.rqlhelper = RQLHelper(self.schema, special_relations={'eid': 'uid',
                                                                   'has_text': 'fti'})
        self.qhelper = QuerierHelper(FakeRepo(self.schema), self.schema)
        ExecutionPlan._check_permissions = _dummy_check_permissions
        rqlannotation._select_principal = _select_principal

    def tearDown(self):
        ExecutionPlan._check_permissions = _orig_check_permissions
        rqlannotation._select_principal = _orig_select_principal
        
    def _prepare(self, rql):
        #print '******************** prepare', rql
        union = self.rqlhelper.parse(rql)
        #print '********* parsed', union.as_string()
        self.rqlhelper.compute_solutions(union)
        #print '********* solutions', solutions
        self.rqlhelper.simplify(union)
        #print '********* simplified', union.as_string()
        plan = self.qhelper.plan_factory(union, {}, FakeSession())
        plan.preprocess(union)
        for select in union.children:
            select.solutions.sort()
        #print '********* ppsolutions', solutions
        return union


class BaseQuerierTC(TestCase):
    repo = None # set this in concret test
    
    def setUp(self):
        self.o = self.repo.querier
        self.session = self.repo._sessions.values()[0]
        self.ueid = self.session.user.eid
        assert self.ueid != -1
        self.repo._type_source_cache = {} # clear cache
        self.pool = self.session.set_pool()
        self.maxeid = self.get_max_eid()
        do_monkey_patch()

    def get_max_eid(self):
        return self.session.unsafe_execute('Any MAX(X)')[0][0]
    def cleanup(self):
        self.session.unsafe_execute('DELETE Any X WHERE X eid > %s' % self.maxeid)
        
    def tearDown(self):
        undo_monkey_patch()
        self.session.rollback()
        self.cleanup()
        self.commit()
        self.repo._free_pool(self.pool)
        assert self.session.user.eid != -1

    def set_debug(self, debug):
        set_debug(debug)
        
    def _rqlhelper(self):
        rqlhelper = self.o._rqlhelper
        # reset uid_func so it don't try to get type from eids
        rqlhelper._analyser.uid_func = None
        rqlhelper._analyser.uid_func_mapping = {}
        return rqlhelper

    def _prepare_plan(self, rql, kwargs=None):
        rqlhelper = self._rqlhelper()
        rqlst = rqlhelper.parse(rql)
        rqlhelper.compute_solutions(rqlst, kwargs=kwargs)
        rqlhelper.simplify(rqlst)
        for select in rqlst.children:
            select.solutions.sort()
        return self.o.plan_factory(rqlst, kwargs, self.session)
        
    def _prepare(self, rql, kwargs=None):    
        plan = self._prepare_plan(rql, kwargs)
        plan.preprocess(plan.rqlst)
        rqlst = plan.rqlst.children[0]
        rqlst.solutions = remove_unused_solutions(rqlst, rqlst.solutions, {}, self.repo.schema)[0]
        return rqlst

    def _user_session(self, groups=('guests',), ueid=None):
        # use self.session.user.eid to get correct owned_by relation, unless explicit eid
        if ueid is None:
            ueid = self.session.user.eid
        u = self.repo._build_user(self.session, ueid)
        u._groups = set(groups)
        s = Session(u, self.repo)
        s._threaddata.pool = self.pool
        return u, s

    def execute(self, rql, args=None, eid_key=None, build_descr=True):
        return self.o.execute(self.session, rql, args, eid_key, build_descr)
    
    def commit(self):
        self.session.commit()
        self.session.set_pool()        


class BasePlannerTC(BaseQuerierTC):

    def _prepare_plan(self, rql, kwargs=None):
        rqlst = self.o.parse(rql, annotate=True)
        self.o.solutions(self.session, rqlst, kwargs)
        if rqlst.TYPE == 'select':
            self.o._rqlhelper.annotate(rqlst)
            for select in rqlst.children:
                select.solutions.sort()
        else:
            rqlst.solutions.sort()
        return self.o.plan_factory(rqlst, kwargs, self.session)


# monkey patch some methods to get predicatable results #######################

from cubicweb.server.rqlrewrite import RQLRewriter
_orig_insert_snippets = RQLRewriter.insert_snippets
_orig_build_variantes = RQLRewriter.build_variantes

def _insert_snippets(self, snippets, varexistsmap=None):
    _orig_insert_snippets(self, sorted(snippets, snippet_cmp), varexistsmap)

def _build_variantes(self, newsolutions):
    variantes = _orig_build_variantes(self, newsolutions)
    sortedvariantes = []
    for variante in variantes:
        orderedkeys = sorted((k[1], k[2], v) for k,v in variante.iteritems())
        variante = DumbOrderedDict(sorted(variante.iteritems(),
                                          lambda a,b: cmp((a[0][1],a[0][2],a[1]),
                                                          (b[0][1],b[0][2],b[1]))))
        sortedvariantes.append( (orderedkeys, variante) )
    return [v for ok, v in sorted(sortedvariantes)]

from cubicweb.server.querier import ExecutionPlan
_orig_check_permissions = ExecutionPlan._check_permissions
_orig_init_temp_table = ExecutionPlan.init_temp_table

def _check_permissions(*args, **kwargs):
    res, restricted = _orig_check_permissions(*args, **kwargs)
    res = DumbOrderedDict(sorted(res.iteritems(), lambda a,b: cmp(a[1], b[1])))
    return res, restricted

def _dummy_check_permissions(self, rqlst):
    return {(): rqlst.solutions}, set()

def _init_temp_table(self, table, selection, solution):
    if self.tablesinorder is None:
        tablesinorder = self.tablesinorder = {}
    else:
        tablesinorder = self.tablesinorder
    if not table in tablesinorder:
        tablesinorder[table] = 'table%s' % len(tablesinorder)
    return _orig_init_temp_table(self, table, selection, solution)

from cubicweb.server import rqlannotation
_orig_select_principal = rqlannotation._select_principal

def _select_principal(scope, relations):
    return _orig_select_principal(scope, sorted(relations, key=lambda x: x.r_type))

try:
    from cubicweb.server.msplanner import PartPlanInformation
except ImportError:
    class PartPlanInformation(object):
        def merge_input_maps(*args):
            pass
        def _choose_var(self, sourcevars):
            pass    
_orig_merge_input_maps = PartPlanInformation.merge_input_maps
_orig_choose_var = PartPlanInformation._choose_var

def _merge_input_maps(*args):
    return sorted(_orig_merge_input_maps(*args))

def _choose_var(self, sourcevars):
    # predictable order for test purpose
    def get_key(x):
        try:
            # variable
            return x.name
        except AttributeError:
            try:
                # relation
                return x.r_type
            except AttributeError:
                # const
                return x.value
    varsinorder = sorted(sourcevars, key=get_key)
    if len(self._sourcesvars) > 1:
        for var in varsinorder:
            if not var.scope is self.rqlst:
                return var, sourcevars.pop(var)
    else:
        for var in varsinorder:
            if var.scope is self.rqlst:
                return var, sourcevars.pop(var)
    var = varsinorder[0]
    return var, sourcevars.pop(var)


def do_monkey_patch():
    RQLRewriter.insert_snippets = _insert_snippets
    RQLRewriter.build_variantes = _build_variantes
    ExecutionPlan._check_permissions = _check_permissions
    ExecutionPlan.tablesinorder = None
    ExecutionPlan.init_temp_table = _init_temp_table
    PartPlanInformation.merge_input_maps = _merge_input_maps
    PartPlanInformation._choose_var = _choose_var

def undo_monkey_patch():
    RQLRewriter.insert_snippets = _orig_insert_snippets
    RQLRewriter.build_variantes = _orig_build_variantes
    ExecutionPlan._check_permissions = _orig_check_permissions
    ExecutionPlan.init_temp_table = _orig_init_temp_table
    PartPlanInformation.merge_input_maps = _orig_merge_input_maps
    PartPlanInformation._choose_var = _orig_choose_var