# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of CubicWeb.
#
# CubicWeb is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# CubicWeb is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with CubicWeb. If not, see <http://www.gnu.org/licenses/>.
"""some utilities to ease repository testing
This module contains functions to initialize a new repository.
"""
__docformat__ = "restructuredtext en"
from copy import deepcopy
from pprint import pprint
from logilab.common.decorators import clear_cache
from logilab.common.testlib import SkipTest
def tuplify(list):
for i in range(len(list)):
if type(list[i]) is not type(()):
list[i] = tuple(list[i])
return list
def snippet_cmp(a, b):
a = (a[0], [e.expression for e in a[1]])
b = (b[0], [e.expression for e in b[1]])
return cmp(a, b)
def test_plan(self, rql, expected, kwargs=None):
plan = self._prepare_plan(rql, kwargs)
self.planner.build_plan(plan)
try:
self.assertEqual(len(plan.steps), len(expected),
'expected %s steps, got %s' % (len(expected), len(plan.steps)))
# step order is important
for i, step in enumerate(plan.steps):
compare_steps(self, step.test_repr(), expected[i])
except AssertionError:
pprint([step.test_repr() for step in plan.steps])
raise
def compare_steps(self, step, expected):
try:
self.assertEqual(step[0], expected[0], 'expected step type %s, got %s' % (expected[0], step[0]))
if len(step) > 2 and isinstance(step[1], list) and isinstance(expected[1], list):
queries, equeries = step[1], expected[1]
self.assertEqual(len(queries), len(equeries),
'expected %s queries, got %s' % (len(equeries), len(queries)))
for i, (rql, sol) in enumerate(queries):
self.assertEqual(rql, equeries[i][0])
self.assertEqual(sorted(sol), sorted(equeries[i][1]))
idx = 2
else:
idx = 1
self.assertEqual(step[idx:-1], expected[idx:-1],
'expected step characteristic \n%s\n, got\n%s' % (expected[1:-1], step[1:-1]))
self.assertEqual(len(step[-1]), len(expected[-1]),
'got %s child steps, expected %s' % (len(step[-1]), len(expected[-1])))
except AssertionError:
print 'error on step ',
pprint(step[:-1])
raise
children = step[-1]
if step[0] in ('UnionFetchStep', 'UnionStep'):
# sort children
children = sorted(children)
expectedchildren = sorted(expected[-1])
else:
expectedchildren = expected[-1]
for i, substep in enumerate(children):
compare_steps(self, substep, expectedchildren[i])
class DumbOrderedDict(list):
def __iter__(self):
return self.iterkeys()
def __contains__(self, key):
return key in self.iterkeys()
def __getitem__(self, key):
for key_, value in list.__iter__(self):
if key == key_:
return value
raise KeyError(key)
def iterkeys(self):
return (x for x, y in list.__iter__(self))
def iteritems(self):
return (x for x in list.__iter__(self))
def items(self):
return [x for x in list.__iter__(self)]
class DumbOrderedDict2(object):
def __init__(self, origdict, sortkey):
self.origdict = origdict
self.sortkey = sortkey
def __getattr__(self, attr):
return getattr(self.origdict, attr)
def __iter__(self):
return iter(sorted(self.origdict, key=self.sortkey))
def schema_eids_idx(schema):
"""return a dictionary mapping schema types to their eids so we can reread
it from the fs instead of the db (too costly) between tests
"""
schema_eids = {}
for x in schema.entities():
schema_eids[x] = x.eid
for x in schema.relations():
schema_eids[x] = x.eid
for rdef in x.rdefs.itervalues():
schema_eids[(rdef.subject, rdef.rtype, rdef.object)] = rdef.eid
return schema_eids
def restore_schema_eids_idx(schema, schema_eids):
"""rebuild schema eid index"""
for x in schema.entities():
x.eid = schema_eids[x]
schema._eid_index[x.eid] = x
for x in schema.relations():
x.eid = schema_eids[x]
schema._eid_index[x.eid] = x
for rdef in x.rdefs.itervalues():
rdef.eid = schema_eids[(rdef.subject, rdef.rtype, rdef.object)]
schema._eid_index[rdef.eid] = rdef
from logilab.common.testlib import TestCase, mock_object
from logilab.database import get_db_helper
from rql import RQLHelper
from cubicweb.devtools.fake import FakeRepo, FakeSession
from cubicweb.server import set_debug, debugged
from cubicweb.server.querier import QuerierHelper
from cubicweb.server.session import Session
from cubicweb.server.sources.rql2sql import SQLGenerator, remove_unused_solutions
class RQLGeneratorTC(TestCase):
schema = backend = None # set this in concrete class
@classmethod
def setUpClass(cls):
if cls.backend is not None:
try:
cls.dbhelper = get_db_helper(cls.backend)
except ImportError as ex:
raise SkipTest(str(ex))
def setUp(self):
self.repo = FakeRepo(self.schema)
self.repo.system_source = mock_object(dbdriver=self.backend)
self.rqlhelper = RQLHelper(self.schema, special_relations={'eid': 'uid',
'has_text': 'fti'},
backend=self.backend)
self.qhelper = QuerierHelper(self.repo, self.schema)
ExecutionPlan._check_permissions = _dummy_check_permissions
rqlannotation._select_principal = _select_principal
if self.backend is not None:
self.o = SQLGenerator(self.schema, self.dbhelper)
def tearDown(self):
ExecutionPlan._check_permissions = _orig_check_permissions
rqlannotation._select_principal = _orig_select_principal
def set_debug(self, debug):
set_debug(debug)
def debugged(self, debug):
return debugged(debug)
def _prepare(self, rql):
#print '******************** prepare', rql
union = self.rqlhelper.parse(rql)
#print '********* parsed', union.as_string()
self.rqlhelper.compute_solutions(union)
#print '********* solutions', solutions
self.rqlhelper.simplify(union)
#print '********* simplified', union.as_string()
plan = self.qhelper.plan_factory(union, {}, FakeSession(self.repo))
plan.preprocess(union)
for select in union.children:
select.solutions.sort()
#print '********* ppsolutions', solutions
return union
class BaseQuerierTC(TestCase):
repo = None # set this in concrete class
def setUp(self):
self.o = self.repo.querier
self.session = self.repo._sessions.values()[0]
self.ueid = self.session.user.eid
assert self.ueid != -1
self.repo._type_source_cache = {} # clear cache
self.cnxset = self.session.set_cnxset()
self.maxeid = self.get_max_eid()
do_monkey_patch()
self._dumb_sessions = []
def get_max_eid(self):
return self.session.execute('Any MAX(X)')[0][0]
def cleanup(self):
self.session.set_cnxset()
self.session.execute('DELETE Any X WHERE X eid > %s' % self.maxeid)
def tearDown(self):
undo_monkey_patch()
self.session.rollback()
self.cleanup()
self.commit()
# properly close dumb sessions
for session in self._dumb_sessions:
session.rollback()
session.close()
self.repo._free_cnxset(self.cnxset)
assert self.session.user.eid != -1
def set_debug(self, debug):
set_debug(debug)
def debugged(self, debug):
return debugged(debug)
def _rqlhelper(self):
rqlhelper = self.repo.vreg.rqlhelper
# reset uid_func so it don't try to get type from eids
rqlhelper._analyser.uid_func = None
rqlhelper._analyser.uid_func_mapping = {}
return rqlhelper
def _prepare_plan(self, rql, kwargs=None, simplify=True):
rqlhelper = self._rqlhelper()
rqlst = rqlhelper.parse(rql)
rqlhelper.compute_solutions(rqlst, kwargs=kwargs)
if simplify:
rqlhelper.simplify(rqlst)
for select in rqlst.children:
select.solutions.sort()
return self.o.plan_factory(rqlst, kwargs, self.session)
def _prepare(self, rql, kwargs=None):
plan = self._prepare_plan(rql, kwargs, simplify=False)
plan.preprocess(plan.rqlst)
rqlst = plan.rqlst.children[0]
rqlst.solutions = remove_unused_solutions(rqlst, rqlst.solutions, {}, self.repo.schema)[0]
return rqlst
def user_groups_session(self, *groups):
"""lightweight session using the current user with hi-jacked groups"""
# use self.session.user.eid to get correct owned_by relation, unless explicit eid
u = self.repo._build_user(self.session, self.session.user.eid)
u._groups = set(groups)
s = Session(u, self.repo)
s._tx.cnxset = self.cnxset
s._tx.ctx_count = 1
# register session to ensure it gets closed
self._dumb_sessions.append(s)
return s
def execute(self, rql, args=None, build_descr=True):
return self.o.execute(self.session, rql, args, build_descr)
def commit(self):
self.session.commit()
self.session.set_cnxset()
class BasePlannerTC(BaseQuerierTC):
newsources = ()
def setup(self):
clear_cache(self.repo, 'rel_type_sources')
clear_cache(self.repo, 'rel_type_sources')
clear_cache(self.repo, 'can_cross_relation')
clear_cache(self.repo, 'is_multi_sources_relation')
# XXX source_defs
self.o = self.repo.querier
self.session = self.repo._sessions.values()[0]
self.cnxset = self.session.set_cnxset()
self.schema = self.o.schema
self.sources = self.o._repo.sources
self.system = self.sources[-1]
do_monkey_patch()
self._dumb_sessions = [] # by hi-jacked parent setup
self.repo.vreg.rqlhelper.backend = 'postgres' # so FTIRANK is considered
self.newsources = []
def add_source(self, sourcecls, uri):
source = sourcecls(self.repo, {'uri': uri, 'type': 'whatever'})
if not source.copy_based_source:
self.sources.append(source)
self.newsources.append(source)
self.repo.sources_by_uri[uri] = source
setattr(self, uri, source)
def tearDown(self):
for source in self.newsources:
if not source.copy_based_source:
self.sources.remove(source)
del self.repo.sources_by_uri[source.uri]
undo_monkey_patch()
for session in self._dumb_sessions:
session._tx.cnxset = None
session.close()
def _prepare_plan(self, rql, kwargs=None):
rqlst = self.o.parse(rql, annotate=True)
self.o.solutions(self.session, rqlst, kwargs)
if rqlst.TYPE == 'select':
self.repo.vreg.rqlhelper.annotate(rqlst)
for select in rqlst.children:
select.solutions.sort()
else:
rqlst.solutions.sort()
return self.o.plan_factory(rqlst, kwargs, self.session)
# monkey patch some methods to get predicatable results #######################
from cubicweb import rqlrewrite
_orig_iter_relations = rqlrewrite.iter_relations
_orig_insert_snippets = rqlrewrite.RQLRewriter.insert_snippets
_orig_build_variantes = rqlrewrite.RQLRewriter.build_variantes
def _insert_snippets(self, snippets, varexistsmap=None):
_orig_insert_snippets(self, sorted(snippets, snippet_cmp), varexistsmap)
def _build_variantes(self, newsolutions):
variantes = _orig_build_variantes(self, newsolutions)
sortedvariantes = []
for variante in variantes:
orderedkeys = sorted((k[1], k[2], v) for k, v in variante.iteritems())
variante = DumbOrderedDict(sorted(variante.iteritems(),
lambda a, b: cmp((a[0][1],a[0][2],a[1]),
(b[0][1],b[0][2],b[1]))))
sortedvariantes.append( (orderedkeys, variante) )
return [v for ok, v in sorted(sortedvariantes)]
from cubicweb.server.querier import ExecutionPlan
_orig_check_permissions = ExecutionPlan._check_permissions
_orig_init_temp_table = ExecutionPlan.init_temp_table
def _check_permissions(*args, **kwargs):
res, restricted = _orig_check_permissions(*args, **kwargs)
res = DumbOrderedDict(sorted(res.iteritems(), lambda a, b: cmp(a[1], b[1])))
return res, restricted
def _dummy_check_permissions(self, rqlst):
return {(): rqlst.solutions}, set()
def _init_temp_table(self, table, selection, solution):
if self.tablesinorder is None:
tablesinorder = self.tablesinorder = {}
else:
tablesinorder = self.tablesinorder
if not table in tablesinorder:
tablesinorder[table] = 'table%s' % len(tablesinorder)
return _orig_init_temp_table(self, table, selection, solution)
from cubicweb.server import rqlannotation
_orig_select_principal = rqlannotation._select_principal
def _select_principal(scope, relations):
def sort_key(something):
try:
return something.r_type
except AttributeError:
return (something[0].r_type, something[1])
return _orig_select_principal(scope, relations,
_sort=lambda rels: sorted(rels, key=sort_key))
try:
from cubicweb.server.msplanner import PartPlanInformation
except ImportError:
class PartPlanInformation(object):
def merge_input_maps(self, *args, **kwargs):
pass
def _choose_term(self, sourceterms):
pass
_orig_merge_input_maps = PartPlanInformation.merge_input_maps
_orig_choose_term = PartPlanInformation._choose_term
def _merge_input_maps(*args, **kwargs):
return sorted(_orig_merge_input_maps(*args, **kwargs))
def _choose_term(self, source, sourceterms):
# predictable order for test purpose
def get_key(x):
try:
# variable
return x.name
except AttributeError:
try:
# relation
return x.r_type
except AttributeError:
# const
return x.value
return _orig_choose_term(self, source, DumbOrderedDict2(sourceterms, get_key))
from cubicweb.server.sources.pyrorql import PyroRQLSource
_orig_syntax_tree_search = PyroRQLSource.syntax_tree_search
def _syntax_tree_search(*args, **kwargs):
return deepcopy(_orig_syntax_tree_search(*args, **kwargs))
def _ordered_iter_relations(stinfo):
return sorted(_orig_iter_relations(stinfo), key=lambda x:x.r_type)
def do_monkey_patch():
rqlrewrite.iter_relations = _ordered_iter_relations
rqlrewrite.RQLRewriter.insert_snippets = _insert_snippets
rqlrewrite.RQLRewriter.build_variantes = _build_variantes
ExecutionPlan._check_permissions = _check_permissions
ExecutionPlan.tablesinorder = None
ExecutionPlan.init_temp_table = _init_temp_table
PartPlanInformation.merge_input_maps = _merge_input_maps
PartPlanInformation._choose_term = _choose_term
PyroRQLSource.syntax_tree_search = _syntax_tree_search
def undo_monkey_patch():
rqlrewrite.iter_relations = _orig_iter_relations
rqlrewrite.RQLRewriter.insert_snippets = _orig_insert_snippets
rqlrewrite.RQLRewriter.build_variantes = _orig_build_variantes
ExecutionPlan._check_permissions = _orig_check_permissions
ExecutionPlan.init_temp_table = _orig_init_temp_table
PartPlanInformation.merge_input_maps = _orig_merge_input_maps
PartPlanInformation._choose_term = _orig_choose_term
PyroRQLSource.syntax_tree_search = _orig_syntax_tree_search