devtools/repotest.py
changeset 11057 0b59724cb3f2
parent 11052 058bb3dc685f
child 11058 23eb30449fe5
equal deleted inserted replaced
11052:058bb3dc685f 11057:0b59724cb3f2
     1 # copyright 2003-2014 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
       
     2 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
       
     3 #
       
     4 # This file is part of CubicWeb.
       
     5 #
       
     6 # CubicWeb is free software: you can redistribute it and/or modify it under the
       
     7 # terms of the GNU Lesser General Public License as published by the Free
       
     8 # Software Foundation, either version 2.1 of the License, or (at your option)
       
     9 # any later version.
       
    10 #
       
    11 # CubicWeb is distributed in the hope that it will be useful, but WITHOUT
       
    12 # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
       
    13 # FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
       
    14 # details.
       
    15 #
       
    16 # You should have received a copy of the GNU Lesser General Public License along
       
    17 # with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
       
    18 """some utilities to ease repository testing
       
    19 
       
    20 This module contains functions to initialize a new repository.
       
    21 """
       
    22 from __future__ import print_function
       
    23 
       
    24 __docformat__ = "restructuredtext en"
       
    25 
       
    26 from pprint import pprint
       
    27 
       
    28 from logilab.common.testlib import SkipTest
       
    29 
       
    30 def tuplify(mylist):
       
    31     return [tuple(item) for item in mylist]
       
    32 
       
    33 def snippet_key(a):
       
    34     # a[0] may be a dict or a key/value tuple
       
    35     return (sorted(dict(a[0]).items()), [e.expression for e in a[1]])
       
    36 
       
    37 def test_plan(self, rql, expected, kwargs=None):
       
    38     with self.session.new_cnx() as cnx:
       
    39         plan = self._prepare_plan(cnx, rql, kwargs)
       
    40         self.planner.build_plan(plan)
       
    41         try:
       
    42             self.assertEqual(len(plan.steps), len(expected),
       
    43                               'expected %s steps, got %s' % (len(expected), len(plan.steps)))
       
    44             # step order is important
       
    45             for i, step in enumerate(plan.steps):
       
    46                 compare_steps(self, step.test_repr(), expected[i])
       
    47         except AssertionError:
       
    48             pprint([step.test_repr() for step in plan.steps])
       
    49             raise
       
    50 
       
    51 def compare_steps(self, step, expected):
       
    52     try:
       
    53         self.assertEqual(step[0], expected[0], 'expected step type %s, got %s' % (expected[0], step[0]))
       
    54         if len(step) > 2 and isinstance(step[1], list) and isinstance(expected[1], list):
       
    55             queries, equeries = step[1], expected[1]
       
    56             self.assertEqual(len(queries), len(equeries),
       
    57                               'expected %s queries, got %s' % (len(equeries), len(queries)))
       
    58             for i, (rql, sol) in enumerate(queries):
       
    59                 self.assertEqual(rql, equeries[i][0])
       
    60                 self.assertEqual(sorted(sorted(x.items()) for x in sol), sorted(sorted(x.items()) for x in equeries[i][1]))
       
    61             idx = 2
       
    62         else:
       
    63             idx = 1
       
    64         self.assertEqual(step[idx:-1], expected[idx:-1],
       
    65                           'expected step characteristic \n%s\n, got\n%s' % (expected[1:-1], step[1:-1]))
       
    66         self.assertEqual(len(step[-1]), len(expected[-1]),
       
    67                           'got %s child steps, expected %s' % (len(step[-1]), len(expected[-1])))
       
    68     except AssertionError:
       
    69         print('error on step ', end=' ')
       
    70         pprint(step[:-1])
       
    71         raise
       
    72     children = step[-1]
       
    73     if step[0] in ('UnionFetchStep', 'UnionStep'):
       
    74         # sort children
       
    75         children = sorted(children)
       
    76         expectedchildren = sorted(expected[-1])
       
    77     else:
       
    78         expectedchildren = expected[-1]
       
    79     for i, substep in enumerate(children):
       
    80         compare_steps(self, substep, expectedchildren[i])
       
    81 
       
    82 
       
    83 class DumbOrderedDict(list):
       
    84     def __iter__(self):
       
    85         return self.iterkeys()
       
    86     def __contains__(self, key):
       
    87         return key in self.iterkeys()
       
    88     def __getitem__(self, key):
       
    89         for key_, value in list.__iter__(self):
       
    90             if key == key_:
       
    91                 return value
       
    92         raise KeyError(key)
       
    93     def iterkeys(self):
       
    94         return (x for x, y in list.__iter__(self))
       
    95     def iteritems(self):
       
    96         return (x for x in list.__iter__(self))
       
    97     def items(self):
       
    98         return [x for x in list.__iter__(self)]
       
    99 
       
   100 class DumbOrderedDict2(object):
       
   101     def __init__(self, origdict, sortkey):
       
   102         self.origdict = origdict
       
   103         self.sortkey = sortkey
       
   104     def __getattr__(self, attr):
       
   105         return getattr(self.origdict, attr)
       
   106     def __iter__(self):
       
   107         return iter(sorted(self.origdict, key=self.sortkey))
       
   108 
       
   109 def schema_eids_idx(schema):
       
   110     """return a dictionary mapping schema types to their eids so we can reread
       
   111     it from the fs instead of the db (too costly) between tests
       
   112     """
       
   113     schema_eids = {}
       
   114     for x in schema.entities():
       
   115         schema_eids[x] = x.eid
       
   116     for x in schema.relations():
       
   117         schema_eids[x] = x.eid
       
   118         for rdef in x.rdefs.values():
       
   119             schema_eids[(rdef.subject, rdef.rtype, rdef.object)] = rdef.eid
       
   120     return schema_eids
       
   121 
       
   122 def restore_schema_eids_idx(schema, schema_eids):
       
   123     """rebuild schema eid index"""
       
   124     for x in schema.entities():
       
   125         x.eid = schema_eids[x]
       
   126         schema._eid_index[x.eid] = x
       
   127     for x in schema.relations():
       
   128         x.eid = schema_eids[x]
       
   129         schema._eid_index[x.eid] = x
       
   130         for rdef in x.rdefs.values():
       
   131             rdef.eid = schema_eids[(rdef.subject, rdef.rtype, rdef.object)]
       
   132             schema._eid_index[rdef.eid] = rdef
       
   133 
       
   134 
       
   135 from logilab.common.testlib import TestCase, mock_object
       
   136 from logilab.database import get_db_helper
       
   137 
       
   138 from rql import RQLHelper
       
   139 
       
   140 from cubicweb.devtools.fake import FakeRepo, FakeConfig, FakeSession
       
   141 from cubicweb.server import set_debug, debugged
       
   142 from cubicweb.server.querier import QuerierHelper
       
   143 from cubicweb.server.session import Session
       
   144 from cubicweb.server.sources.rql2sql import SQLGenerator, remove_unused_solutions
       
   145 
       
   146 class RQLGeneratorTC(TestCase):
       
   147     schema = backend = None # set this in concrete class
       
   148 
       
   149     @classmethod
       
   150     def setUpClass(cls):
       
   151         if cls.backend is not None:
       
   152             try:
       
   153                 cls.dbhelper = get_db_helper(cls.backend)
       
   154             except ImportError as ex:
       
   155                 raise SkipTest(str(ex))
       
   156 
       
   157     def setUp(self):
       
   158         self.repo = FakeRepo(self.schema, config=FakeConfig(apphome=self.datadir))
       
   159         self.repo.system_source = mock_object(dbdriver=self.backend)
       
   160         self.rqlhelper = RQLHelper(self.schema,
       
   161                                    special_relations={'eid': 'uid',
       
   162                                                       'has_text': 'fti'},
       
   163                                    backend=self.backend)
       
   164         self.qhelper = QuerierHelper(self.repo, self.schema)
       
   165         ExecutionPlan._check_permissions = _dummy_check_permissions
       
   166         rqlannotation._select_principal = _select_principal
       
   167         if self.backend is not None:
       
   168             self.o = SQLGenerator(self.schema, self.dbhelper)
       
   169 
       
   170     def tearDown(self):
       
   171         ExecutionPlan._check_permissions = _orig_check_permissions
       
   172         rqlannotation._select_principal = _orig_select_principal
       
   173 
       
   174     def set_debug(self, debug):
       
   175         set_debug(debug)
       
   176     def debugged(self, debug):
       
   177         return debugged(debug)
       
   178 
       
   179     def _prepare(self, rql):
       
   180         #print '******************** prepare', rql
       
   181         union = self.rqlhelper.parse(rql)
       
   182         #print '********* parsed', union.as_string()
       
   183         self.rqlhelper.compute_solutions(union)
       
   184         #print '********* solutions', solutions
       
   185         self.rqlhelper.simplify(union)
       
   186         #print '********* simplified', union.as_string()
       
   187         plan = self.qhelper.plan_factory(union, {}, FakeSession(self.repo))
       
   188         plan.preprocess(union)
       
   189         for select in union.children:
       
   190             select.solutions.sort(key=lambda x: list(x.items()))
       
   191         #print '********* ppsolutions', solutions
       
   192         return union
       
   193 
       
   194 
       
   195 class BaseQuerierTC(TestCase):
       
   196     repo = None # set this in concrete class
       
   197 
       
   198     def setUp(self):
       
   199         self.o = self.repo.querier
       
   200         self.session = next(iter(self.repo._sessions.values()))
       
   201         self.ueid = self.session.user.eid
       
   202         assert self.ueid != -1
       
   203         self.repo._type_source_cache = {} # clear cache
       
   204         self.maxeid = self.get_max_eid()
       
   205         do_monkey_patch()
       
   206         self._dumb_sessions = []
       
   207 
       
   208     def get_max_eid(self):
       
   209         with self.session.new_cnx() as cnx:
       
   210             return cnx.execute('Any MAX(X)')[0][0]
       
   211 
       
   212     def cleanup(self):
       
   213         with self.session.new_cnx() as cnx:
       
   214             cnx.execute('DELETE Any X WHERE X eid > %s' % self.maxeid)
       
   215             cnx.commit()
       
   216 
       
   217     def tearDown(self):
       
   218         undo_monkey_patch()
       
   219         self.cleanup()
       
   220         assert self.session.user.eid != -1
       
   221 
       
   222     def set_debug(self, debug):
       
   223         set_debug(debug)
       
   224     def debugged(self, debug):
       
   225         return debugged(debug)
       
   226 
       
   227     def _rqlhelper(self):
       
   228         rqlhelper = self.repo.vreg.rqlhelper
       
   229         # reset uid_func so it don't try to get type from eids
       
   230         rqlhelper._analyser.uid_func = None
       
   231         rqlhelper._analyser.uid_func_mapping = {}
       
   232         return rqlhelper
       
   233 
       
   234     def _prepare_plan(self, cnx, rql, kwargs=None, simplify=True):
       
   235         rqlhelper = self._rqlhelper()
       
   236         rqlst = rqlhelper.parse(rql)
       
   237         rqlhelper.compute_solutions(rqlst, kwargs=kwargs)
       
   238         if simplify:
       
   239             rqlhelper.simplify(rqlst)
       
   240         for select in rqlst.children:
       
   241             select.solutions.sort(key=lambda x: list(x.items()))
       
   242         return self.o.plan_factory(rqlst, kwargs, cnx)
       
   243 
       
   244     def _prepare(self, cnx, rql, kwargs=None):
       
   245         plan = self._prepare_plan(cnx, rql, kwargs, simplify=False)
       
   246         plan.preprocess(plan.rqlst)
       
   247         rqlst = plan.rqlst.children[0]
       
   248         rqlst.solutions = remove_unused_solutions(rqlst, rqlst.solutions, {}, self.repo.schema)[0]
       
   249         return rqlst
       
   250 
       
   251     def user_groups_session(self, *groups):
       
   252         """lightweight session using the current user with hi-jacked groups"""
       
   253         # use self.session.user.eid to get correct owned_by relation, unless explicit eid
       
   254         with self.session.new_cnx() as cnx:
       
   255             u = self.repo._build_user(cnx, self.session.user.eid)
       
   256             u._groups = set(groups)
       
   257             s = Session(u, self.repo)
       
   258             return s
       
   259 
       
   260     def qexecute(self, rql, args=None, build_descr=True):
       
   261         with self.session.new_cnx() as cnx:
       
   262             try:
       
   263                 return self.o.execute(cnx, rql, args, build_descr)
       
   264             finally:
       
   265                 if rql.startswith(('INSERT', 'DELETE', 'SET')):
       
   266                     cnx.commit()
       
   267 
       
   268 
       
   269 class BasePlannerTC(BaseQuerierTC):
       
   270 
       
   271     def setup(self):
       
   272         # XXX source_defs
       
   273         self.o = self.repo.querier
       
   274         self.session = self.repo._sessions.values()[0]
       
   275         self.schema = self.o.schema
       
   276         self.system = self.repo.system_source
       
   277         do_monkey_patch()
       
   278         self.repo.vreg.rqlhelper.backend = 'postgres' # so FTIRANK is considered
       
   279 
       
   280     def tearDown(self):
       
   281         undo_monkey_patch()
       
   282 
       
   283     def _prepare_plan(self, cnx, rql, kwargs=None):
       
   284         rqlst = self.o.parse(rql, annotate=True)
       
   285         self.o.solutions(cnx, rqlst, kwargs)
       
   286         if rqlst.TYPE == 'select':
       
   287             self.repo.vreg.rqlhelper.annotate(rqlst)
       
   288             for select in rqlst.children:
       
   289                 select.solutions.sort(key=lambda x: list(x.items()))
       
   290         else:
       
   291             rqlst.solutions.sort(key=lambda x: list(x.items()))
       
   292         return self.o.plan_factory(rqlst, kwargs, cnx)
       
   293 
       
   294 
       
   295 # monkey patch some methods to get predictable results #######################
       
   296 
       
   297 from cubicweb import rqlrewrite
       
   298 _orig_iter_relations = rqlrewrite.iter_relations
       
   299 _orig_insert_snippets = rqlrewrite.RQLRewriter.insert_snippets
       
   300 _orig_build_variantes = rqlrewrite.RQLRewriter.build_variantes
       
   301 
       
   302 def _insert_snippets(self, snippets, varexistsmap=None):
       
   303     _orig_insert_snippets(self, sorted(snippets, key=snippet_key), varexistsmap)
       
   304 
       
   305 def _build_variantes(self, newsolutions):
       
   306     variantes = _orig_build_variantes(self, newsolutions)
       
   307     sortedvariantes = []
       
   308     for variante in variantes:
       
   309         orderedkeys = sorted((k[1], k[2], v) for k, v in variante.items())
       
   310         variante = DumbOrderedDict(sorted(variante.items(),
       
   311                                           key=lambda a: (a[0][1], a[0][2], a[1])))
       
   312         sortedvariantes.append( (orderedkeys, variante) )
       
   313     return [v for ok, v in sorted(sortedvariantes)]
       
   314 
       
   315 from cubicweb.server.querier import ExecutionPlan
       
   316 _orig_check_permissions = ExecutionPlan._check_permissions
       
   317 
       
   318 def _check_permissions(*args, **kwargs):
       
   319     res, restricted = _orig_check_permissions(*args, **kwargs)
       
   320     res = DumbOrderedDict(sorted(res.items(), key=lambda x: [y.items() for y in x[1]]))
       
   321     return res, restricted
       
   322 
       
   323 def _dummy_check_permissions(self, rqlst):
       
   324     return {(): rqlst.solutions}, set()
       
   325 
       
   326 from cubicweb.server import rqlannotation
       
   327 _orig_select_principal = rqlannotation._select_principal
       
   328 
       
   329 def _select_principal(scope, relations):
       
   330     def sort_key(something):
       
   331         try:
       
   332             return something.r_type
       
   333         except AttributeError:
       
   334             return (something[0].r_type, something[1])
       
   335     return _orig_select_principal(scope, relations,
       
   336                                   _sort=lambda rels: sorted(rels, key=sort_key))
       
   337 
       
   338 
       
   339 def _ordered_iter_relations(stinfo):
       
   340     return sorted(_orig_iter_relations(stinfo), key=lambda x:x.r_type)
       
   341 
       
   342 def do_monkey_patch():
       
   343     rqlrewrite.iter_relations = _ordered_iter_relations
       
   344     rqlrewrite.RQLRewriter.insert_snippets = _insert_snippets
       
   345     rqlrewrite.RQLRewriter.build_variantes = _build_variantes
       
   346     ExecutionPlan._check_permissions = _check_permissions
       
   347     ExecutionPlan.tablesinorder = None
       
   348 
       
   349 def undo_monkey_patch():
       
   350     rqlrewrite.iter_relations = _orig_iter_relations
       
   351     rqlrewrite.RQLRewriter.insert_snippets = _orig_insert_snippets
       
   352     rqlrewrite.RQLRewriter.build_variantes = _orig_build_variantes
       
   353     ExecutionPlan._check_permissions = _orig_check_permissions