devtools/repotest.py
changeset 0 b97547f5f1fa
child 599 9ef680acd92a
equal deleted inserted replaced
-1:000000000000 0:b97547f5f1fa
       
     1 """some utilities to ease repository testing
       
     2 
       
     3 This module contains functions to initialize a new repository.
       
     4 
       
     5 :organization: Logilab
       
     6 :copyright: 2003-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
       
     7 :contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
       
     8 """
       
     9 __docformat__ = "restructuredtext en"
       
    10 
       
    11 from pprint import pprint
       
    12 
       
    13 def tuplify(list):
       
    14     for i in range(len(list)):
       
    15         if type(list[i]) is not type(()):
       
    16             list[i] = tuple(list[i])
       
    17     return list
       
    18 
       
    19 def snippet_cmp(a, b):
       
    20     a = (a[0], [e.expression for e in a[1]])
       
    21     b = (b[0], [e.expression for e in b[1]])
       
    22     return cmp(a, b)
       
    23 
       
    24 def test_plan(self, rql, expected, kwargs=None):
       
    25     plan = self._prepare_plan(rql, kwargs)
       
    26     self.planner.build_plan(plan)
       
    27     try:
       
    28         self.assertEquals(len(plan.steps), len(expected),
       
    29                           'expected %s steps, got %s' % (len(expected), len(plan.steps)))
       
    30         # step order is important
       
    31         for i, step in enumerate(plan.steps):
       
    32             compare_steps(self, step.test_repr(), expected[i])
       
    33     except AssertionError:
       
    34         pprint([step.test_repr() for step in plan.steps])
       
    35         raise
       
    36 
       
    37 def compare_steps(self, step, expected):
       
    38     try:
       
    39         self.assertEquals(step[0], expected[0], 'expected step type %s, got %s' % (expected[0], step[0]))
       
    40         if len(step) > 2 and isinstance(step[1], list) and isinstance(expected[1], list):
       
    41             queries, equeries = step[1], expected[1]
       
    42             self.assertEquals(len(queries), len(equeries),
       
    43                               'expected %s queries, got %s' % (len(equeries), len(queries)))
       
    44             for i, (rql, sol) in enumerate(queries):
       
    45                 self.assertEquals(rql, equeries[i][0])
       
    46                 self.assertEquals(sol, equeries[i][1])
       
    47             idx = 2
       
    48         else:
       
    49             idx = 1
       
    50         self.assertEquals(step[idx:-1], expected[idx:-1],
       
    51                           'expected step characteristic \n%s\n, got\n%s' % (expected[1:-1], step[1:-1]))
       
    52         self.assertEquals(len(step[-1]), len(expected[-1]),
       
    53                           'got %s child steps, expected %s' % (len(step[-1]), len(expected[-1])))
       
    54     except AssertionError:
       
    55         print 'error on step ',
       
    56         pprint(step[:-1])
       
    57         raise
       
    58     children = step[-1]
       
    59     if step[0] in ('UnionFetchStep', 'UnionStep'):
       
    60         # sort children
       
    61         children = sorted(children)
       
    62         expectedchildren = sorted(expected[-1])
       
    63     else:
       
    64         expectedchildren = expected[-1]
       
    65     for i, substep in enumerate(children):
       
    66         compare_steps(self, substep, expectedchildren[i])
       
    67 
       
    68 
       
    69 class DumbOrderedDict(list):
       
    70     def __iter__(self):
       
    71         return self.iterkeys()
       
    72     def __contains__(self, key):
       
    73         return key in self.iterkeys()
       
    74     def __getitem__(self, key):
       
    75         for key_, value in self.iteritems():
       
    76             if key == key_:
       
    77                 return value
       
    78         raise KeyError(key)
       
    79     def iterkeys(self):
       
    80         return (x for x, y in list.__iter__(self))
       
    81     def iteritems(self):
       
    82         return (x for x in list.__iter__(self))
       
    83 
       
    84 
       
    85 from logilab.common.testlib import TestCase
       
    86 from rql import RQLHelper
       
    87 from cubicweb.devtools.fake import FakeRepo, FakeSession
       
    88 from cubicweb.server import set_debug
       
    89 from cubicweb.server.querier import QuerierHelper
       
    90 from cubicweb.server.session import Session
       
    91 from cubicweb.server.sources.rql2sql import remove_unused_solutions
       
    92 
       
    93 class RQLGeneratorTC(TestCase):
       
    94     schema = None # set this in concret test
       
    95     
       
    96     def setUp(self):
       
    97         self.rqlhelper = RQLHelper(self.schema, special_relations={'eid': 'uid',
       
    98                                                                    'has_text': 'fti'})
       
    99         self.qhelper = QuerierHelper(FakeRepo(self.schema), self.schema)
       
   100         ExecutionPlan._check_permissions = _dummy_check_permissions
       
   101         rqlannotation._select_principal = _select_principal
       
   102 
       
   103     def tearDown(self):
       
   104         ExecutionPlan._check_permissions = _orig_check_permissions
       
   105         rqlannotation._select_principal = _orig_select_principal
       
   106         
       
   107     def _prepare(self, rql):
       
   108         #print '******************** prepare', rql
       
   109         union = self.rqlhelper.parse(rql)
       
   110         #print '********* parsed', union.as_string()
       
   111         self.rqlhelper.compute_solutions(union)
       
   112         #print '********* solutions', solutions
       
   113         self.rqlhelper.simplify(union)
       
   114         #print '********* simplified', union.as_string()
       
   115         plan = self.qhelper.plan_factory(union, {}, FakeSession())
       
   116         plan.preprocess(union)
       
   117         for select in union.children:
       
   118             select.solutions.sort()
       
   119         #print '********* ppsolutions', solutions
       
   120         return union
       
   121 
       
   122 
       
   123 class BaseQuerierTC(TestCase):
       
   124     repo = None # set this in concret test
       
   125     
       
   126     def setUp(self):
       
   127         self.o = self.repo.querier
       
   128         self.session = self.repo._sessions.values()[0]
       
   129         self.ueid = self.session.user.eid
       
   130         assert self.ueid != -1
       
   131         self.repo._type_source_cache = {} # clear cache
       
   132         self.pool = self.session.set_pool()
       
   133         self.maxeid = self.get_max_eid()
       
   134         do_monkey_patch()
       
   135 
       
   136     def get_max_eid(self):
       
   137         return self.session.unsafe_execute('Any MAX(X)')[0][0]
       
   138     def cleanup(self):
       
   139         self.session.unsafe_execute('DELETE Any X WHERE X eid > %s' % self.maxeid)
       
   140         
       
   141     def tearDown(self):
       
   142         undo_monkey_patch()
       
   143         self.session.rollback()
       
   144         self.cleanup()
       
   145         self.commit()
       
   146         self.repo._free_pool(self.pool)
       
   147         assert self.session.user.eid != -1
       
   148 
       
   149     def set_debug(self, debug):
       
   150         set_debug(debug)
       
   151         
       
   152     def _rqlhelper(self):
       
   153         rqlhelper = self.o._rqlhelper
       
   154         # reset uid_func so it don't try to get type from eids
       
   155         rqlhelper._analyser.uid_func = None
       
   156         rqlhelper._analyser.uid_func_mapping = {}
       
   157         return rqlhelper
       
   158 
       
   159     def _prepare_plan(self, rql, kwargs=None):
       
   160         rqlhelper = self._rqlhelper()
       
   161         rqlst = rqlhelper.parse(rql)
       
   162         rqlhelper.compute_solutions(rqlst, kwargs=kwargs)
       
   163         rqlhelper.simplify(rqlst)
       
   164         for select in rqlst.children:
       
   165             select.solutions.sort()
       
   166         return self.o.plan_factory(rqlst, kwargs, self.session)
       
   167         
       
   168     def _prepare(self, rql, kwargs=None):    
       
   169         plan = self._prepare_plan(rql, kwargs)
       
   170         plan.preprocess(plan.rqlst)
       
   171         rqlst = plan.rqlst.children[0]
       
   172         rqlst.solutions = remove_unused_solutions(rqlst, rqlst.solutions, {}, self.repo.schema)[0]
       
   173         return rqlst
       
   174 
       
   175     def _user_session(self, groups=('guests',), ueid=None):
       
   176         # use self.session.user.eid to get correct owned_by relation, unless explicit eid
       
   177         if ueid is None:
       
   178             ueid = self.session.user.eid
       
   179         u = self.repo._build_user(self.session, ueid)
       
   180         u._groups = set(groups)
       
   181         s = Session(u, self.repo)
       
   182         s._threaddata.pool = self.pool
       
   183         return u, s
       
   184 
       
   185     def execute(self, rql, args=None, eid_key=None, build_descr=True):
       
   186         return self.o.execute(self.session, rql, args, eid_key, build_descr)
       
   187     
       
   188     def commit(self):
       
   189         self.session.commit()
       
   190         self.session.set_pool()        
       
   191 
       
   192 
       
   193 class BasePlannerTC(BaseQuerierTC):
       
   194 
       
   195     def _prepare_plan(self, rql, kwargs=None):
       
   196         rqlst = self.o.parse(rql, annotate=True)
       
   197         self.o.solutions(self.session, rqlst, kwargs)
       
   198         if rqlst.TYPE == 'select':
       
   199             self.o._rqlhelper.annotate(rqlst)
       
   200             for select in rqlst.children:
       
   201                 select.solutions.sort()
       
   202         else:
       
   203             rqlst.solutions.sort()
       
   204         return self.o.plan_factory(rqlst, kwargs, self.session)
       
   205 
       
   206 
       
   207 # monkey patch some methods to get predicatable results #######################
       
   208 
       
   209 from cubicweb.server.rqlrewrite import RQLRewriter
       
   210 _orig_insert_snippets = RQLRewriter.insert_snippets
       
   211 _orig_build_variantes = RQLRewriter.build_variantes
       
   212 
       
   213 def _insert_snippets(self, snippets, varexistsmap=None):
       
   214     _orig_insert_snippets(self, sorted(snippets, snippet_cmp), varexistsmap)
       
   215 
       
   216 def _build_variantes(self, newsolutions):
       
   217     variantes = _orig_build_variantes(self, newsolutions)
       
   218     sortedvariantes = []
       
   219     for variante in variantes:
       
   220         orderedkeys = sorted((k[1], k[2], v) for k,v in variante.iteritems())
       
   221         variante = DumbOrderedDict(sorted(variante.iteritems(),
       
   222                                           lambda a,b: cmp((a[0][1],a[0][2],a[1]),
       
   223                                                           (b[0][1],b[0][2],b[1]))))
       
   224         sortedvariantes.append( (orderedkeys, variante) )
       
   225     return [v for ok, v in sorted(sortedvariantes)]
       
   226 
       
   227 from cubicweb.server.querier import ExecutionPlan
       
   228 _orig_check_permissions = ExecutionPlan._check_permissions
       
   229 _orig_init_temp_table = ExecutionPlan.init_temp_table
       
   230 
       
   231 def _check_permissions(*args, **kwargs):
       
   232     res, restricted = _orig_check_permissions(*args, **kwargs)
       
   233     res = DumbOrderedDict(sorted(res.iteritems(), lambda a,b: cmp(a[1], b[1])))
       
   234     return res, restricted
       
   235 
       
   236 def _dummy_check_permissions(self, rqlst):
       
   237     return {(): rqlst.solutions}, set()
       
   238 
       
   239 def _init_temp_table(self, table, selection, solution):
       
   240     if self.tablesinorder is None:
       
   241         tablesinorder = self.tablesinorder = {}
       
   242     else:
       
   243         tablesinorder = self.tablesinorder
       
   244     if not table in tablesinorder:
       
   245         tablesinorder[table] = 'table%s' % len(tablesinorder)
       
   246     return _orig_init_temp_table(self, table, selection, solution)
       
   247 
       
   248 from cubicweb.server import rqlannotation
       
   249 _orig_select_principal = rqlannotation._select_principal
       
   250 
       
   251 def _select_principal(scope, relations):
       
   252     return _orig_select_principal(scope, sorted(relations, key=lambda x: x.r_type))
       
   253 
       
   254 try:
       
   255     from cubicweb.server.msplanner import PartPlanInformation
       
   256 except ImportError:
       
   257     class PartPlanInformation(object):
       
   258         def merge_input_maps(*args):
       
   259             pass
       
   260         def _choose_var(self, sourcevars):
       
   261             pass    
       
   262 _orig_merge_input_maps = PartPlanInformation.merge_input_maps
       
   263 _orig_choose_var = PartPlanInformation._choose_var
       
   264 
       
   265 def _merge_input_maps(*args):
       
   266     return sorted(_orig_merge_input_maps(*args))
       
   267 
       
   268 def _choose_var(self, sourcevars):
       
   269     # predictable order for test purpose
       
   270     def get_key(x):
       
   271         try:
       
   272             # variable
       
   273             return x.name
       
   274         except AttributeError:
       
   275             try:
       
   276                 # relation
       
   277                 return x.r_type
       
   278             except AttributeError:
       
   279                 # const
       
   280                 return x.value
       
   281     varsinorder = sorted(sourcevars, key=get_key)
       
   282     if len(self._sourcesvars) > 1:
       
   283         for var in varsinorder:
       
   284             if not var.scope is self.rqlst:
       
   285                 return var, sourcevars.pop(var)
       
   286     else:
       
   287         for var in varsinorder:
       
   288             if var.scope is self.rqlst:
       
   289                 return var, sourcevars.pop(var)
       
   290     var = varsinorder[0]
       
   291     return var, sourcevars.pop(var)
       
   292 
       
   293 
       
   294 def do_monkey_patch():
       
   295     RQLRewriter.insert_snippets = _insert_snippets
       
   296     RQLRewriter.build_variantes = _build_variantes
       
   297     ExecutionPlan._check_permissions = _check_permissions
       
   298     ExecutionPlan.tablesinorder = None
       
   299     ExecutionPlan.init_temp_table = _init_temp_table
       
   300     PartPlanInformation.merge_input_maps = _merge_input_maps
       
   301     PartPlanInformation._choose_var = _choose_var
       
   302 
       
   303 def undo_monkey_patch():
       
   304     RQLRewriter.insert_snippets = _orig_insert_snippets
       
   305     RQLRewriter.build_variantes = _orig_build_variantes
       
   306     ExecutionPlan._check_permissions = _orig_check_permissions
       
   307     ExecutionPlan.init_temp_table = _orig_init_temp_table
       
   308     PartPlanInformation.merge_input_maps = _orig_merge_input_maps
       
   309     PartPlanInformation._choose_var = _orig_choose_var
       
   310