# HG changeset patch # User Sylvain Thenault # Date 1229963655 -3600 # Node ID 4c7d3af7e94d7d401b71f0959d97a143e15c38f7 # Parent 3dbee583526cd3f2a1e589bfcb12c36f9c789512 restore multi-sources capabilities diff -r 3dbee583526c -r 4c7d3af7e94d server/msplanner.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/server/msplanner.py Mon Dec 22 17:34:15 2008 +0100 @@ -0,0 +1,1212 @@ +"""plan execution of rql queries on multiple sources + +the best way to understand what are we trying to acheive here is to read +the unit-tests in unittest_querier_planner.py + + + +Split and execution specifications +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +For a system source and a ldap user source (only EUser and its attributes +is supported, no group or such): + + +:EUser X: +1. fetch EUser X from both sources and return concatenation of results + + +:EUser X WHERE X in_group G, G name 'users': +* catch 1 + 1. fetch EUser X from both sources, store concatenation of results + into a temporary table + 2. return the result of TMP X WHERE X in_group G, G name 'users' from + the system source + +* catch 2 + 1. return the result of EUser X WHERE X in_group G, G name 'users' + from system source, that's enough (optimization of the sql querier + will avoid join on EUser, so we will directly get local eids) + + +:EUser X,L WHERE X in_group G, X login L, G name 'users': +1. fetch Any X,L WHERE X is EUser, X login L from both sources, store + concatenation of results into a temporary table +2. return the result of Any X, L WHERE X is TMP, X login LX in_group G, + G name 'users' from the system source + + +:Any X WHERE X owned_by Y: +* catch 1 + 1. fetch EUser X from both sources, store concatenation of results + into a temporary table + 2. return the result of Any X WHERE X owned_by Y, Y is TMP from + the system source + +* catch 2 + 1. return the result of Any X WHERE X owned_by Y + from system source, that's enough (optimization of the sql querier + will avoid join on EUser, so we will directly get local eids) + + +:organization: Logilab +:copyright: 2003-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr +""" +__docformat__ = "restructuredtext en" + +from itertools import imap, ifilterfalse + +from logilab.common.compat import any +from logilab.common.decorators import cached + +from rql.stmts import Union, Select +from rql.nodes import VariableRef, Comparison, Relation, Constant, Exists, Variable + +from cubicweb import server +from cubicweb.common.utils import make_uid +from cubicweb.server.utils import cleanup_solutions +from cubicweb.server.ssplanner import SSPlanner, OneFetchStep, add_types_restriction +from cubicweb.server.mssteps import * +from cubicweb.server.sources import AbstractSource + +Variable._ms_table_key = lambda x: x.name +Relation._ms_table_key = lambda x: x.r_type +# str() Constant.value to ensure generated table name won't be unicode +Constant._ms_table_key = lambda x: str(x.value) + +AbstractSource.dont_cross_relations = () + +def allequals(solutions): + """return true if all solutions are identical""" + sol = solutions.next() + for sol_ in solutions: + if sol_ != sol: + return False + return True + +def need_aggr_step(select, sources, stepdefs=None): + """return True if a temporary table is necessary to store some partial + results to execute the given query + """ + if len(sources) == 1: + # can do everything at once with a single source + return False + if select.orderby or select.groupby or select.has_aggregat: + # if more than one source, we need a temp table to deal with sort / + # groups / aggregat if : + # * the rqlst won't be splitted (in the other case the last query + # using partial temporary table can do sort/groups/aggregat without + # the need for a later AggrStep) + # * the rqlst is splitted in multiple steps and there are more than one + # final step + if stepdefs is None: + return True + has_one_final = False + fstepsolindices = set() + for stepdef in stepdefs: + if stepdef[-1]: + if has_one_final or frozenset(stepdef[2]) != fstepsolindices: + return True + has_one_final = True + else: + fstepsolindices.update(stepdef[2]) + return False + +def copy_node(newroot, node, subparts=()): + newnode = node.__class__(*node.initargs(newroot)) + for part in subparts: + newnode.append(part) + return newnode + +def same_scope(var): + """return true if the variable is always used in the same scope""" + try: + return var.stinfo['samescope'] + except KeyError: + for rel in var.stinfo['relations']: + if not rel.scope is var.scope: + var.stinfo['samescope'] = False + return False + var.stinfo['samescope'] = True + return True + +def select_group_sort(select): # XXX something similar done in rql2sql + # add variables used in groups and sort terms to the selection + # if necessary + if select.groupby: + for vref in select.groupby: + if not vref in select.selection: + select.append_selected(vref.copy(select)) + for sortterm in select.orderby: + for vref in sortterm.iget_nodes(VariableRef): + if not vref in select.get_selected_variables(): + # we can't directly insert sortterm.term because it references + # a variable of the select before the copy. + # XXX if constant term are used to define sort, their value + # may necessite a decay + select.append_selected(vref.copy(select)) + if select.groupby and not vref in select.groupby: + select.add_group_var(vref.copy(select)) + + +class PartPlanInformation(object): + """regroups necessary information to execute some part of a "global" rql + query ("global" means as received by the querier, which may result in + several internal queries, e.g. parts, due to security insertions) + + it exposes as well some methods helping in executing this part on a + multi-sources repository, modifying its internal structure during the + process + + :attr solutions: a list of mappings (varname -> vartype) + :attr sourcesvars: + a dictionnary telling for each source which variable/solution are + supported, of the form {source : {varname: [solution index, ]}} + """ + def __init__(self, plan, rqlst, rqlhelper=None): + self.needsplit = False + self.temptable = None + self.finaltable = None + self.plan = plan + self.rqlst = rqlst + self._session = plan.session + self._solutions = rqlst.solutions + self._solindices = range(len(self._solutions)) + # source : {varname: [solution index, ]} + self._sourcesvars = {} + # dictionnary of variables which are linked to each other using a non + # final relation which is supported by multiple sources + self._linkedvars = {} + # processing + self._compute_sourcesvars() + self._remove_invalid_sources() + #if server.DEBUG: + # print 'planner sources vars', self._sourcesvars + self._compute_needsplit() + self._inputmaps = {} + if rqlhelper is not None: # else test + self._insert_identity_variable = rqlhelper._annotator.rewrite_shared_optional + + def copy_solutions(self, solindices): + return [self._solutions[solidx].copy() for solidx in solindices] + + @property + @cached + def part_sources(self): + if self._sourcesvars: + return tuple(sorted(self._sourcesvars)) + return (self._session.repo.system_source,) + + @property + @cached + def _sys_source_set(self): + return frozenset((self._session.repo.system_source, solindex) + for solindex in self._solindices) + + @cached + def _norel_support_set(self, rtype): + """return a set of (source, solindex) where source doesn't support the + relation + """ + return frozenset((source, solidx) for source in self._session.repo.sources + for solidx in self._solindices + if not (source.support_relation(rtype) + or rtype in source.dont_cross_relations)) + + def _compute_sourcesvars(self): + """compute for each variable/solution in the rqlst which sources support + them + """ + repo = self._session.repo + eschema = repo.schema.eschema + sourcesvars = self._sourcesvars + # find for each source which variable/solution are supported + for varname, varobj in self.rqlst.defined_vars.items(): + # if variable has an eid specified, we can get its source directly + # NOTE: use uidrels and not constnode to deal with "X eid IN(1,2,3,4)" + if varobj.stinfo['uidrels']: + vrels = varobj.stinfo['relations'] - varobj.stinfo['uidrels'] + for rel in varobj.stinfo['uidrels']: + if rel.neged(strict=True) or rel.operator() != '=': + continue + for const in rel.children[1].get_nodes(Constant): + eid = const.eval(self.plan.args) + source = self._session.source_from_eid(eid) + if vrels and not any(source.support_relation(r.r_type) + for r in vrels): + self._set_source_for_var(repo.system_source, varobj) + else: + self._set_source_for_var(source, varobj) + continue + rels = varobj.stinfo['relations'] + if not rels and not varobj.stinfo['typerels']: + # (rare) case where the variable has no type specified nor + # relation accessed ex. "Any MAX(X)" + self._set_source_for_var(repo.system_source, varobj) + continue + for i, sol in enumerate(self._solutions): + vartype = sol[varname] + # skip final variable + if eschema(vartype).is_final(): + break + for source in repo.sources: + if source.support_entity(vartype): + # the source support the entity type, though we will + # actually have to fetch from it only if + # * the variable isn't invariant + # * at least one supported relation specified + if not varobj._q_invariant or \ + any(imap(source.support_relation, + (r.r_type for r in rels if r.r_type != 'eid'))): + sourcesvars.setdefault(source, {}).setdefault(varobj, set()).add(i) + # if variable is not invariant and is used by a relation + # not supported by this source, we'll have to split the + # query + if not varobj._q_invariant and any(ifilterfalse( + source.support_relation, (r.r_type for r in rels))): + self.needsplit = True + + def _remove_invalid_sources(self): + """removes invalid sources from `sourcesvars` member""" + repo = self._session.repo + rschema = repo.schema.rschema + vsources = {} + for rel in self.rqlst.iget_nodes(Relation): + # process non final relations only + # note: don't try to get schema for 'is' relation (not available + # during bootstrap) + if not rel.is_types_restriction() and not rschema(rel.r_type).is_final(): + # nothing to do if relation is not supported by multiple sources + relsources = [source for source in repo.sources + if source.support_relation(rel.r_type) + or rel.r_type in source.dont_cross_relations] + if len(relsources) < 2: + if relsources:# and not relsources[0] in self._sourcesvars: + # this means the relation is using a variable inlined as + # a constant and another unsupported variable, in which + # case we put the relation in sourcesvars + self._sourcesvars.setdefault(relsources[0], {})[rel] = set(self._solindices) + continue + lhs, rhs = rel.get_variable_parts() + lhsv, rhsv = getattr(lhs, 'variable', lhs), getattr(rhs, 'variable', rhs) + # update dictionnary of sources supporting lhs and rhs vars + if not lhsv in vsources: + vsources[lhsv] = self._term_sources(lhs) + if not rhsv in vsources: + vsources[rhsv] = self._term_sources(rhs) + self._linkedvars.setdefault(lhsv, set()).add((rhsv, rel)) + self._linkedvars.setdefault(rhsv, set()).add((lhsv, rel)) + for term in self._linkedvars: + self._remove_sources_until_stable(term, vsources) + if len(self._sourcesvars) > 1 and hasattr(self.plan.rqlst, 'main_relations'): + # the querier doesn't annotate write queries, need to do it here + self.plan.annotate_rqlst() + # insert/update/delete queries, we may get extra information from + # the main relation (eg relations to the left of the WHERE + if self.plan.rqlst.TYPE == 'insert': + inserted = dict((vref.variable, etype) + for etype, vref in self.plan.rqlst.main_variables) + else: + inserted = {} + for rel in self.plan.rqlst.main_relations: + if not rschema(rel.r_type).is_final(): + # nothing to do if relation is not supported by multiple sources + relsources = [source for source in repo.sources + if source.support_relation(rel.r_type) + or rel.r_type in source.dont_cross_relations] + if len(relsources) < 2: + continue + lhs, rhs = rel.get_variable_parts() + try: + lhsv = self._extern_term(lhs, vsources, inserted) + rhsv = self._extern_term(rhs, vsources, inserted) + except KeyError, ex: + continue + norelsup = self._norel_support_set(rel.r_type) + self._remove_var_sources(lhsv, norelsup, rhsv, vsources) + self._remove_var_sources(rhsv, norelsup, lhsv, vsources) + # cleanup linked var + for var, linkedrelsinfo in self._linkedvars.iteritems(): + self._linkedvars[var] = frozenset(x[0] for x in linkedrelsinfo) + # if there are other sources than the system source, consider simplified + # variables'source + if self._sourcesvars and self._sourcesvars.keys() != [self._session.repo.system_source]: + # add source for rewritten constants to sourcesvars + for vconsts in self.rqlst.stinfo['rewritten'].itervalues(): + const = vconsts[0] + eid = const.eval(self.plan.args) + source = self._session.source_from_eid(eid) + if source is self._session.repo.system_source: + for const in vconsts: + self._set_source_for_var(source, const) + elif source in self._sourcesvars: + source_scopes = frozenset(v.scope for v in self._sourcesvars[source]) + for const in vconsts: + if const.scope in source_scopes: + self._set_source_for_var(source, const) + + def _extern_term(self, term, vsources, inserted): + var = term.variable + if var.stinfo['constnode']: + termv = var.stinfo['constnode'] + vsources[termv] = self._term_sources(termv) + elif var in inserted: + termv = var + source = self._session.repo.locate_etype_source(inserted[var]) + vsources[termv] = set((source, solindex) for solindex in self._solindices) + else: + termv = self.rqlst.defined_vars[var.name] + if not termv in vsources: + vsources[termv] = self._term_sources(termv) + return termv + + def _remove_sources_until_stable(self, var, vsources): + for ovar, rel in self._linkedvars.get(var, ()): + if not var.scope is ovar.scope and rel.scope.neged(strict=True): + # can't get information from relation inside a NOT exists + # where variables don't belong to the same scope + continue + if rel.neged(strict=True): + # neged relation doesn't allow to infer variable sources + continue + norelsup = self._norel_support_set(rel.r_type) + # compute invalid sources for variables and remove them + self._remove_var_sources(var, norelsup, ovar, vsources) + self._remove_var_sources(ovar, norelsup, var, vsources) + + def _remove_var_sources(self, var, norelsup, ovar, vsources): + """remove invalid sources for var according to ovar's sources and the + relation between those two variables. + """ + varsources = vsources[var] + invalid_sources = varsources - (vsources[ovar] | norelsup) + if invalid_sources: + self._remove_sources(var, invalid_sources) + varsources -= invalid_sources + self._remove_sources_until_stable(var, vsources) + + def _compute_needsplit(self): + """tell according to sourcesvars if the rqlst has to be splitted for + execution among multiple sources + + the execution has to be split if + * a source support an entity (non invariant) but doesn't support a + relation on it + * a source support an entity which is accessed by an optional relation + * there is more than one sources and either all sources'supported + variable/solutions are not equivalent or multiple variables have to + be fetched from some source + """ + # NOTE: < 2 since may be 0 on queries such as Any X WHERE X eid 2 + if len(self._sourcesvars) < 2: + self.needsplit = False + elif not self.needsplit: + if not allequals(self._sourcesvars.itervalues()): + self.needsplit = True + else: + sample = self._sourcesvars.itervalues().next() + if len(sample) > 1 and any(v for v in sample + if not v in self._linkedvars): + self.needsplit = True + + def _set_source_for_var(self, source, var): + self._sourcesvars.setdefault(source, {})[var] = set(self._solindices) + + def _term_sources(self, term): + """returns possible sources for terms `term`""" + if isinstance(term, Constant): + source = self._session.source_from_eid(term.eval(self.plan.args)) + return set((source, solindex) for solindex in self._solindices) + else: + var = getattr(term, 'variable', term) + sources = [source for source, varobjs in self._sourcesvars.iteritems() + if var in varobjs] + return set((source, solindex) for source in sources + for solindex in self._sourcesvars[source][var]) + + def _remove_sources(self, var, sources): + """removes invalid sources (`sources`) from `sourcesvars` + + :param sources: the list of sources to remove + :param var: the analyzed variable + """ + sourcesvars = self._sourcesvars + for source, solindex in sources: + try: + sourcesvars[source][var].remove(solindex) + except KeyError: + return # may occur with subquery column alias + if not sourcesvars[source][var]: + del sourcesvars[source][var] + if not sourcesvars[source]: + del sourcesvars[source] + + def part_steps(self): + """precompute necessary part steps before generating actual rql for + each step. This is necessary to know if an aggregate step will be + necessary or not. + """ + steps = [] + select = self.rqlst + rschema = self.plan.schema.rschema + for source in self.part_sources: + sourcevars = self._sourcesvars[source] + while sourcevars: + # take a variable randomly, and all variables supporting the + # same solutions + var, solindices = self._choose_var(sourcevars) + if source.uri == 'system': + # ensure all variables are available for the latest step + # (missing one will be available from temporary tables + # of previous steps) + scope = select + variables = scope.defined_vars.values() + scope.aliases.values() + sourcevars.clear() + else: + scope = var.scope + variables = self._expand_vars(var, sourcevars, scope, solindices) + if not sourcevars: + del self._sourcesvars[source] + # find which sources support the same variables/solutions + sources = self._expand_sources(source, variables, solindices) + # suppose this is a final step until the contrary is proven + final = scope is select + # set of variables which should be additionaly selected when + # possible + needsel = set() + # add attribute variables and mark variables which should be + # additionaly selected when possible + for var in select.defined_vars.itervalues(): + if not var in variables: + stinfo = var.stinfo + for ovar, rtype in stinfo['attrvars']: + if ovar in variables: + needsel.add(var.name) + variables.append(var) + break + else: + needsel.add(var.name) + final = False + if final and source.uri != 'system': + # check rewritten constants + for vconsts in select.stinfo['rewritten'].itervalues(): + const = vconsts[0] + eid = const.eval(self.plan.args) + _source = self._session.source_from_eid(eid) + if len(sources) > 1 or not _source in sources: + # if constant is only used by an identity relation, + # skip + for c in vconsts: + rel = c.relation() + if rel is None or not rel.neged(strict=True): + final = False + break + break + # check where all relations are supported by the sources + for rel in scope.iget_nodes(Relation): + if rel.is_types_restriction(): + continue + # take care not overwriting the existing "source" identifier + for _source in sources: + if not _source.support_relation(rel.r_type): + for vref in rel.iget_nodes(VariableRef): + needsel.add(vref.name) + final = False + break + else: + if not scope is select: + self._exists_relation(rel, variables, needsel) + # if relation is supported by all sources and some of + # its lhs/rhs variable isn't in "variables", and the + # other end *is* in "variables", mark it have to be + # selected + if source.uri != 'system' and not rschema(rel.r_type).is_final(): + lhs, rhs = rel.get_variable_parts() + try: + lhsvar = lhs.variable + except AttributeError: + lhsvar = lhs + try: + rhsvar = rhs.variable + except AttributeError: + rhsvar = rhs + if lhsvar in variables and not rhsvar in variables: + needsel.add(lhsvar.name) + elif rhsvar in variables and not lhsvar in variables: + needsel.add(rhsvar.name) + if final: + self._cleanup_sourcesvars(sources, solindices) + # XXX rename: variables may contain Relation and Constant nodes... + steps.append( (sources, variables, solindices, scope, needsel, + final) ) + return steps + + def _exists_relation(self, rel, variables, needsel): + rschema = self.plan.schema.rschema(rel.r_type) + lhs, rhs = rel.get_variable_parts() + try: + lhsvar, rhsvar = lhs.variable, rhs.variable + except AttributeError: + pass + else: + # supported relation with at least one end supported, check the + # other end is in as well. If not this usually means the + # variable is refed by an outer scope and should be substituted + # using an 'identity' relation (else we'll get a conflict of + # temporary tables) + if rhsvar in variables and not lhsvar in variables: + self._identity_substitute(rel, lhsvar, variables, needsel) + elif lhsvar in variables and not rhsvar in variables: + self._identity_substitute(rel, rhsvar, variables, needsel) + + def _identity_substitute(self, relation, var, variables, needsel): + newvar = self._insert_identity_variable(relation.scope, var) + if newvar is not None: + # ensure relation is using '=' operator, else we rely on a + # sqlgenerator side effect (it won't insert an inequality operator + # in this case) + relation.children[1].operator = '=' + variables.append(newvar) + needsel.add(newvar.name) + #self.insertedvars.append((var.name, self.schema['identity'], + # newvar.name)) + + def _choose_var(self, sourcevars): + secondchoice = None + if len(self._sourcesvars) > 1: + # priority to variable from subscopes + for var in sourcevars: + if not var.scope is self.rqlst: + if isinstance(var, Variable): + return var, sourcevars.pop(var) + secondchoice = var + else: + # priority to variable outer scope + for var in sourcevars: + if var.scope is self.rqlst: + if isinstance(var, Variable): + return var, sourcevars.pop(var) + secondchoice = var + if secondchoice is not None: + return secondchoice, sourcevars.pop(secondchoice) + # priority to variable + for var in sourcevars: + if isinstance(var, Variable): + return var, sourcevars.pop(var) + # whatever + var = iter(sourcevars).next() + return var, sourcevars.pop(var) + + def _expand_vars(self, var, sourcevars, scope, solindices): + variables = [var] + nbunlinked = 1 + linkedvars = self._linkedvars + # variable has to belong to the same scope if there is more + # than the system source remaining + if len(self._sourcesvars) > 1 and not scope is self.rqlst: + candidates = (v for v in sourcevars.keys() if scope is v.scope) + else: + candidates = sourcevars #.iterkeys() + candidates = [v for v in candidates + if isinstance(v, Constant) or + (solindices.issubset(sourcevars[v]) and v in linkedvars)] + # repeat until no variable can't be added, since addition of a new + # variable may permit to another one to be added + modified = True + while modified and candidates: + modified = False + for var in candidates[:]: + # we only want one unlinked variable in each generated query + if isinstance(var, Constant) or \ + any(v for v in variables if v in linkedvars[var]): + variables.append(var) + # constant nodes should be systematically deleted + if isinstance(var, Constant): + del sourcevars[var] + # variable nodes should be deleted once all possible solution + # indices have been consumed + else: + sourcevars[var] -= solindices + if not sourcevars[var]: + del sourcevars[var] + candidates.remove(var) + modified = True + return variables + + def _expand_sources(self, selected_source, vars, solindices): + sources = [selected_source] + sourcesvars = self._sourcesvars + for source in sourcesvars: + if source is selected_source: + continue + for var in vars: + if not (var in sourcesvars[source] and + solindices.issubset(sourcesvars[source][var])): + break + else: + sources.append(source) + if source.uri != 'system': + for var in vars: + varsolindices = sourcesvars[source][var] + varsolindices -= solindices + if not varsolindices: + del sourcesvars[source][var] + + return sources + + def _cleanup_sourcesvars(self, sources, solindices): + """on final parts, remove solutions so we know they are already processed""" + for source in sources: + try: + sourcevar = self._sourcesvars[source] + except KeyError: + continue + for var, varsolindices in sourcevar.items(): + varsolindices -= solindices + if not varsolindices: + del sourcevar[var] + + def merge_input_maps(self, allsolindices): + """inputmaps is a dictionary with tuple of solution indices as key with an + associateed input map as value. This function compute for each solution + its necessary input map and return them grouped + + ex: + inputmaps = {(0, 1, 2): {'A': 't1.login1', 'U': 't1.C0', 'U.login': 't1.login1'}, + (1,): {'X': 't2.C0', 'T': 't2.C1'}} + return : [([1], {'A': 't1.login1', 'U': 't1.C0', 'U.login': 't1.login1', + 'X': 't2.C0', 'T': 't2.C1'}), + ([0,2], {'A': 't1.login1', 'U': 't1.C0', 'U.login': 't1.login1'})] + """ + if not self._inputmaps: + return [(allsolindices, None)] + mapbysol = {} + # compute a single map for each solution + for solindices, basemap in self._inputmaps.iteritems(): + for solindex in solindices: + solmap = mapbysol.setdefault(solindex, {}) + solmap.update(basemap) + try: + allsolindices.remove(solindex) + except KeyError: + continue # already removed + # group results by identical input map + result = [] + for solindex, solmap in mapbysol.iteritems(): + for solindices, commonmap in result: + if commonmap == solmap: + solindices.append(solindex) + break + else: + result.append( ([solindex], solmap) ) + if allsolindices: + result.append( (list(allsolindices), None) ) + return result + + def build_final_part(self, select, solindices, inputmap, sources, + insertedvars): + plan = self.plan + rqlst = plan.finalize(select, [self._solutions[i] for i in solindices], + insertedvars) + if self.temptable is None and self.finaltable is None: + return OneFetchStep(plan, rqlst, sources, inputmap=inputmap) + table = self.temptable or self.finaltable + return FetchStep(plan, rqlst, sources, table, True, inputmap) + + def build_non_final_part(self, select, solindices, sources, insertedvars, + table): + """non final step, will have to store results in a temporary table""" + plan = self.plan + rqlst = plan.finalize(select, [self._solutions[i] for i in solindices], + insertedvars) + step = FetchStep(plan, rqlst, sources, table, False) + # update input map for following steps, according to processed solutions + inputmapkey = tuple(sorted(solindices)) + inputmap = self._inputmaps.setdefault(inputmapkey, {}) + inputmap.update(step.outputmap) + plan.add_step(step) + + +class MSPlanner(SSPlanner): + """MultiSourcesPlanner: build execution plan for rql queries + + decompose the RQL query according to sources'schema + """ + + def build_select_plan(self, plan, rqlst): + """build execution plan for a SELECT RQL query + + the rqlst should not be tagged at this point + """ + if server.DEBUG: + print '-'*80 + print 'PLANNING', rqlst + for select in rqlst.children: + if len(select.solutions) > 1: + hasmultiplesols = True + break + else: + hasmultiplesols = False + # preprocess deals with security insertion and returns a new syntax tree + # which have to be executed to fulfill the query: according + # to permissions for variable's type, different rql queries may have to + # be executed + plan.preprocess(rqlst) + ppis = [PartPlanInformation(plan, select, self.rqlhelper) + for select in rqlst.children] + steps = self._union_plan(plan, rqlst, ppis) + if server.DEBUG: + from pprint import pprint + for step in plan.steps: + pprint(step.test_repr()) + pprint(steps[0].test_repr()) + return steps + + def _ppi_subqueries(self, ppi): + # part plan info for subqueries + plan = ppi.plan + inputmap = {} + for subquery in ppi.rqlst.with_[:]: + sppis = [PartPlanInformation(plan, select) + for select in subquery.query.children] + for sppi in sppis: + if sppi.needsplit or sppi.part_sources != ppi.part_sources: + temptable = 'T%s' % make_uid(id(subquery)) + sstep = self._union_plan(plan, subquery.query, sppis, temptable)[0] + break + else: + sstep = None + if sstep is not None: + ppi.rqlst.with_.remove(subquery) + for i, colalias in enumerate(subquery.aliases): + inputmap[colalias.name] = '%s.C%s' % (temptable, i) + ppi.plan.add_step(sstep) + return inputmap + + def _union_plan(self, plan, union, ppis, temptable=None): + tosplit, cango, allsources = [], {}, set() + for planinfo in ppis: + if planinfo.needsplit: + tosplit.append(planinfo) + else: + cango.setdefault(planinfo.part_sources, []).append(planinfo) + for source in planinfo.part_sources: + allsources.add(source) + # first add steps for query parts which doesn't need to splitted + steps = [] + for sources, cppis in cango.iteritems(): + byinputmap = {} + for ppi in cppis: + select = ppi.rqlst + if sources != (plan.session.repo.system_source,): + add_types_restriction(self.schema, select) + # part plan info for subqueries + inputmap = self._ppi_subqueries(ppi) + aggrstep = need_aggr_step(select, sources) + if aggrstep: + atemptable = 'T%s' % make_uid(id(select)) + sunion = Union() + sunion.append(select) + selected = select.selection[:] + select_group_sort(select) + step = AggrStep(plan, selected, select, atemptable, temptable) + step.set_limit_offset(select.limit, select.offset) + select.limit = None + select.offset = 0 + fstep = FetchStep(plan, sunion, sources, atemptable, True, inputmap) + step.children.append(fstep) + steps.append(step) + else: + byinputmap.setdefault(tuple(inputmap.iteritems()), []).append( (select) ) + for inputmap, queries in byinputmap.iteritems(): + inputmap = dict(inputmap) + sunion = Union() + for select in queries: + sunion.append(select) + if temptable: + steps.append(FetchStep(plan, sunion, sources, temptable, True, inputmap)) + else: + steps.append(OneFetchStep(plan, sunion, sources, inputmap)) + # then add steps for splitted query parts + for planinfo in tosplit: + steps.append(self.split_part(planinfo, temptable)) + if len(steps) > 1: + if temptable: + step = UnionFetchStep(plan) + else: + step = UnionStep(plan) + step.children = steps + return (step,) + return steps + + # internal methods for multisources decomposition ######################### + + def split_part(self, ppi, temptable): + ppi.finaltable = temptable + plan = ppi.plan + select = ppi.rqlst + subinputmap = self._ppi_subqueries(ppi) + stepdefs = ppi.part_steps() + if need_aggr_step(select, ppi.part_sources, stepdefs): + atemptable = 'T%s' % make_uid(id(select)) + selection = select.selection[:] + select_group_sort(select) + else: + atemptable = None + selection = select.selection + ppi.temptable = atemptable + vfilter = VariablesFiltererVisitor(self.schema, ppi) + steps = [] + for sources, variables, solindices, scope, needsel, final in stepdefs: + # extract an executable query using only the specified variables + if sources[0].uri == 'system': + # in this case we have to merge input maps before call to + # filter so already processed restriction are correctly + # removed + solsinputmaps = ppi.merge_input_maps(solindices) + for solindices, inputmap in solsinputmaps: + minrqlst, insertedvars = vfilter.filter( + sources, variables, scope, set(solindices), needsel, final) + if inputmap is None: + inputmap = subinputmap + else: + inputmap.update(subinputmap) + steps.append(ppi.build_final_part(minrqlst, solindices, inputmap, + sources, insertedvars)) + else: + # this is a final part (i.e. retreiving results for the + # original query part) if all variable / sources have been + # treated or if this is the last shot for used solutions + minrqlst, insertedvars = vfilter.filter( + sources, variables, scope, solindices, needsel, final) + if final: + solsinputmaps = ppi.merge_input_maps(solindices) + for solindices, inputmap in solsinputmaps: + if inputmap is None: + inputmap = subinputmap + else: + inputmap.update(subinputmap) + steps.append(ppi.build_final_part(minrqlst, solindices, inputmap, + sources, insertedvars)) + else: + table = '_T%s%s' % (''.join(sorted(v._ms_table_key() for v in variables)), + ''.join(sorted(str(i) for i in solindices))) + ppi.build_non_final_part(minrqlst, solindices, sources, + insertedvars, table) + # finally: join parts, deal with aggregat/group/sorts if necessary + if atemptable is not None: + step = AggrStep(plan, selection, select, atemptable, temptable) + step.children = steps + elif len(steps) > 1: + if temptable: + step = UnionFetchStep(plan) + else: + step = UnionStep(plan) + step.children = steps + else: + step = steps[0] + if select.limit is not None or select.offset: + step.set_limit_offset(select.limit, select.offset) + return step + + +class UnsupportedBranch(Exception): + pass + + +class VariablesFiltererVisitor(object): + def __init__(self, schema, ppi): + self.schema = schema + self.ppi = ppi + self.skip = {} + self.hasaggrstep = self.ppi.temptable + self.extneedsel = frozenset(vref.name for sortterm in ppi.rqlst.orderby + for vref in sortterm.iget_nodes(VariableRef)) + + def _rqlst_accept(self, rqlst, node, newroot, variables, setfunc=None): + try: + newrestr, node_ = node.accept(self, newroot, variables[:]) + except UnsupportedBranch: + return rqlst + if setfunc is not None and newrestr is not None: + setfunc(newrestr) + if not node_ is node: + rqlst = node.parent + return rqlst + + def filter(self, sources, variables, rqlst, solindices, needsel, final): + if server.DEBUG: + print 'filter', final and 'final' or '', sources, variables, rqlst, solindices, needsel + newroot = Select() + self.sources = sources + self.solindices = solindices + self.final = final + # variables which appear in unsupported branches + needsel |= self.extneedsel + self.needsel = needsel + # variables which appear in supported branches + self.mayneedsel = set() + # new inserted variables + self.insertedvars = [] + # other structures (XXX document) + self.mayneedvar, self.hasvar = {}, {} + self.use_only_defined = False + self.scopes = {rqlst: newroot} + if rqlst.where: + rqlst = self._rqlst_accept(rqlst, rqlst.where, newroot, variables, + newroot.set_where) + if isinstance(rqlst, Select): + self.use_only_defined = True + if rqlst.groupby: + groupby = [] + for node in rqlst.groupby: + rqlst = self._rqlst_accept(rqlst, node, newroot, variables, + groupby.append) + if groupby: + newroot.set_groupby(groupby) + if rqlst.having: + having = [] + for node in rqlst.having: + rqlst = self._rqlst_accept(rqlst, node, newroot, variables, + having.append) + if having: + newroot.set_having(having) + if final and rqlst.orderby and not self.hasaggrstep: + orderby = [] + for node in rqlst.orderby: + rqlst = self._rqlst_accept(rqlst, node, newroot, variables, + orderby.append) + if orderby: + newroot.set_orderby(orderby) + self.process_selection(newroot, variables, rqlst) + elif not newroot.where: + # no restrictions have been copied, just select variables and add + # type restriction (done later by add_types_restriction) + for v in variables: + if not isinstance(v, Variable): + continue + newroot.append_selected(VariableRef(newroot.get_variable(v.name))) + solutions = self.ppi.copy_solutions(solindices) + cleanup_solutions(newroot, solutions) + newroot.set_possible_types(solutions) + if final: + if self.hasaggrstep: + self.add_necessary_selection(newroot, self.mayneedsel & self.extneedsel) + newroot.distinct = rqlst.distinct + else: + self.add_necessary_selection(newroot, self.mayneedsel & self.needsel) + # insert vars to fetch constant values when needed + for (varname, rschema), reldefs in self.mayneedvar.iteritems(): + for rel, ored in reldefs: + if not (varname, rschema) in self.hasvar: + self.hasvar[(varname, rschema)] = None # just to avoid further insertion + cvar = newroot.make_variable() + for sol in newroot.solutions: + sol[cvar.name] = rschema.objects(sol[varname])[0] + # if the current restriction is not used in a OR branch, + # we can keep it, else we have to drop the constant + # restriction (or we may miss some results) + if not ored: + rel = rel.copy(newroot) + newroot.add_restriction(rel) + # add a relation to link the variable + newroot.remove_node(rel.children[1]) + cmp = Comparison('=') + rel.append(cmp) + cmp.append(VariableRef(cvar)) + self.insertedvars.append((varname, rschema, cvar.name)) + newroot.append_selected(VariableRef(newroot.get_variable(cvar.name))) + # NOTE: even if the restriction is done by this query, we have + # to let it in the original rqlst so that it appears anyway in + # the "final" query, else we may change the meaning of the query + # if there are NOT somewhere : + # 'NOT X relation Y, Y name "toto"' means X WHERE X isn't related + # to Y whose name is toto while + # 'NOT X relation Y' means X WHERE X has no 'relation' (whatever Y) + elif ored: + newroot.remove_node(rel) + add_types_restriction(self.schema, rqlst, newroot, solutions) + if server.DEBUG: + print '--->', newroot + return newroot, self.insertedvars + + def visit_and(self, node, newroot, variables): + subparts = [] + for i in xrange(len(node.children)): + child = node.children[i] + try: + newchild, child_ = child.accept(self, newroot, variables) + if not child_ is child: + node = child_.parent + if newchild is None: + continue + subparts.append(newchild) + except UnsupportedBranch: + continue + if not subparts: + return None, node + if len(subparts) == 1: + return subparts[0], node + return copy_node(newroot, node, subparts), node + + visit_or = visit_and + + def _relation_supported(self, rtype): + for source in self.sources: + if not source.support_relation(rtype): + return False + return True + + def visit_relation(self, node, newroot, variables): + if not node.is_types_restriction(): + if node in self.skip and self.solindices.issubset(self.skip[node]): + if not self.schema.rschema(node.r_type).is_final(): + # can't really skip the relation if one variable is selected and only + # referenced by this relation + for vref in node.iget_nodes(VariableRef): + stinfo = vref.variable.stinfo + if stinfo['selected'] and len(stinfo['relations']) == 1: + break + else: + return None, node + else: + return None, node + if not self._relation_supported(node.r_type): + raise UnsupportedBranch() + # don't copy type restriction unless this is the only relation for the + # rhs variable, else they'll be reinserted later as needed (else we may + # copy a type restriction while the variable is not actually used) + elif not any(self._relation_supported(rel.r_type) + for rel in node.children[0].variable.stinfo['relations']): + rel, node = self.visit_default(node, newroot, variables) + return rel, node + else: + raise UnsupportedBranch() + rschema = self.schema.rschema(node.r_type) + res = self.visit_default(node, newroot, variables)[0] + ored = node.ored() + if rschema.is_final() or rschema.inlined: + vrefs = node.children[1].get_nodes(VariableRef) + if not vrefs: + if not ored: + self.skip.setdefault(node, set()).update(self.solindices) + else: + self.mayneedvar.setdefault((node.children[0].name, rschema), []).append( (res, ored) ) + + else: + assert len(vrefs) == 1 + vref = vrefs[0] + # XXX check operator ? + self.hasvar[(node.children[0].name, rschema)] = vref + if self._may_skip_attr_rel(rschema, node, vref, ored, variables, res): + self.skip.setdefault(node, set()).update(self.solindices) + elif not ored: + self.skip.setdefault(node, set()).update(self.solindices) + return res, node + + def _may_skip_attr_rel(self, rschema, rel, vref, ored, variables, res): + var = vref.variable + if ored: + return False + if var.name in self.extneedsel or var.stinfo['selected']: + return False + if not same_scope(var): + return False + if any(v for v,_ in var.stinfo['attrvars'] if not v.name in variables): + return False + return True + + def visit_exists(self, node, newroot, variables): + newexists = node.__class__() + self.scopes = {node: newexists} + subparts, node = self._visit_children(node, newroot, variables) + if not subparts: + return None, node + newexists.set_where(subparts[0]) + return newexists, node + + def visit_not(self, node, newroot, variables): + subparts, node = self._visit_children(node, newroot, variables) + if not subparts: + return None, node + return copy_node(newroot, node, subparts), node + + def visit_group(self, node, newroot, variables): + if not self.final: + return None, node + return self.visit_default(node, newroot, variables) + + def visit_variableref(self, node, newroot, variables): + if self.use_only_defined: + if not node.variable.name in newroot.defined_vars: + raise UnsupportedBranch(node.name) + elif not node.variable in variables: + raise UnsupportedBranch(node.name) + self.mayneedsel.add(node.name) + # set scope so we can insert types restriction properly + newvar = newroot.get_variable(node.name) + newvar.stinfo['scope'] = self.scopes.get(node.variable.scope, newroot) + return VariableRef(newvar), node + + def visit_constant(self, node, newroot, variables): + return copy_node(newroot, node), node + + def visit_default(self, node, newroot, variables): + subparts, node = self._visit_children(node, newroot, variables) + return copy_node(newroot, node, subparts), node + + visit_comparison = visit_mathexpression = visit_constant = visit_function = visit_default + visit_sort = visit_sortterm = visit_default + + def _visit_children(self, node, newroot, variables): + subparts = [] + for i in xrange(len(node.children)): + child = node.children[i] + newchild, child_ = child.accept(self, newroot, variables) + if not child is child_: + node = child_.parent + if newchild is not None: + subparts.append(newchild) + return subparts, node + + def process_selection(self, newroot, variables, rqlst): + if self.final: + for term in rqlst.selection: + newroot.append_selected(term.copy(newroot)) + for vref in term.get_nodes(VariableRef): + self.needsel.add(vref.name) + return + for term in rqlst.selection: + vrefs = term.get_nodes(VariableRef) + if vrefs: + supportedvars = [] + for vref in vrefs: + var = vref.variable + if var in variables: + supportedvars.append(vref) + continue + else: + self.needsel.add(vref.name) + break + else: + for vref in vrefs: + newroot.append_selected(vref.copy(newroot)) + supportedvars = [] + for vref in supportedvars: + if not vref in newroot.get_selected_variables(): + newroot.append_selected(VariableRef(newroot.get_variable(vref.name))) + + def add_necessary_selection(self, newroot, variables): + selected = tuple(newroot.get_selected_variables()) + for varname in variables: + var = newroot.defined_vars[varname] + for vref in var.references(): + rel = vref.relation() + if rel is None and vref in selected: + # already selected + break + else: + selvref = VariableRef(var) + newroot.append_selected(selvref) + if newroot.groupby: + newroot.add_group_var(VariableRef(selvref.variable, noautoref=1)) diff -r 3dbee583526c -r 4c7d3af7e94d server/mssteps.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/server/mssteps.py Mon Dec 22 17:34:15 2008 +0100 @@ -0,0 +1,275 @@ +"""Defines the diferent querier steps usable in plans. + +FIXME : this code needs refactoring. Some problems : +* get data from the parent plan, the latest step, temporary table... +* each step has is own members (this is not necessarily bad, but a bit messy + for now) + +:organization: Logilab +:copyright: 2003-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr +""" +__docformat__ = "restructuredtext en" + +from rql.nodes import VariableRef, Variable, Function + +from cubicweb.server.ssplanner import (LimitOffsetMixIn, Step, OneFetchStep, + varmap_test_repr, offset_result) + +AGGR_TRANSFORMS = {'COUNT':'SUM', 'MIN':'MIN', 'MAX':'MAX', 'SUM': 'SUM'} + +def remove_clauses(union, keepgroup): + clauses = [] + for select in union.children: + if keepgroup: + having, orderby = select.having, select.orderby + select.having, select.orderby = None, None + clauses.append( (having, orderby) ) + else: + groupby, having, orderby = select.groupby, select.having, select.orderby + select.groupby, select.having, select.orderby = None, None, None + clauses.append( (groupby, having, orderby) ) + return clauses + +def restore_clauses(union, keepgroup, clauses): + for i, select in enumerate(union.children): + if keepgroup: + select.having, select.orderby = clauses[i] + else: + select.groupby, select.having, select.orderby = clauses[i] + + +class FetchStep(OneFetchStep): + """step consisting in fetching data from sources, and storing result in + a temporary table + """ + def __init__(self, plan, union, sources, table, keepgroup, inputmap=None): + OneFetchStep.__init__(self, plan, union, sources) + # temporary table to store step result + self.table = table + # should groupby clause be kept or not + self.keepgroup = keepgroup + # variables mapping to use as input + self.inputmap = inputmap + # output variable mapping + srqlst = union.children[0] # sample select node + # add additional information to the output mapping + self.outputmap = plan.init_temp_table(table, srqlst.selection, + srqlst.solutions[0]) + for vref in srqlst.selection: + if not isinstance(vref, VariableRef): + continue + var = vref.variable + if var.stinfo['attrvars']: + for lhsvar, rtype in var.stinfo['attrvars']: + if lhsvar.name in srqlst.defined_vars: + key = '%s.%s' % (lhsvar.name, rtype) + self.outputmap[key] = self.outputmap[var.name] + else: + rschema = self.plan.schema.rschema + for rel in var.stinfo['rhsrelations']: + if rschema(rel.r_type).inlined: + lhsvar = rel.children[0] + if lhsvar.name in srqlst.defined_vars: + key = '%s.%s' % (lhsvar.name, rel.r_type) + self.outputmap[key] = self.outputmap[var.name] + + def execute(self): + """execute this step""" + self.execute_children() + plan = self.plan + plan.create_temp_table(self.table) + union = self.union + # XXX 2.5 use "with" + clauses = remove_clauses(union, self.keepgroup) + for source in self.sources: + source.flying_insert(self.table, plan.session, union, plan.args, + self.inputmap) + restore_clauses(union, self.keepgroup, clauses) + + def mytest_repr(self): + """return a representation of this step suitable for test""" + clauses = remove_clauses(self.union, self.keepgroup) + try: + inputmap = varmap_test_repr(self.inputmap, self.plan.tablesinorder) + outputmap = varmap_test_repr(self.outputmap, self.plan.tablesinorder) + except AttributeError: + inputmap = self.inputmap + outputmap = self.outputmap + try: + return (self.__class__.__name__, + sorted((r.as_string(kwargs=self.plan.args), r.solutions) + for r in self.union.children), + sorted(self.sources), inputmap, outputmap) + finally: + restore_clauses(self.union, self.keepgroup, clauses) + + +class AggrStep(LimitOffsetMixIn, Step): + """step consisting in making aggregat from temporary data in the system + source + """ + def __init__(self, plan, selection, select, table, outputtable=None): + Step.__init__(self, plan) + # original selection + self.selection = selection + # original Select RQL tree + self.select = select + # table where are located temporary results + self.table = table + # optional table where to write results + self.outputtable = outputtable + if outputtable is not None: + plan.init_temp_table(outputtable, selection, select.solutions[0]) + + #self.inputmap = inputmap + + def mytest_repr(self): + """return a representation of this step suitable for test""" + sel = self.select.selection + restr = self.select.where + self.select.selection = self.selection + self.select.where = None + rql = self.select.as_string(kwargs=self.plan.args) + self.select.selection = sel + self.select.where = restr + try: + # rely on a monkey patch (cf unittest_querier) + table = self.plan.tablesinorder[self.table] + outputtable = self.outputtable and self.plan.tablesinorder[self.outputtable] + except AttributeError: + # not monkey patched + table = self.table + outputtable = self.outputtable + return (self.__class__.__name__, rql, self.limit, self.offset, table, + outputtable) + + def execute(self): + """execute this step""" + self.execute_children() + self.inputmap = inputmap = self.children[-1].outputmap + # get the select clause + clause = [] + for i, term in enumerate(self.selection): + try: + var_name = inputmap[term.as_string()] + except KeyError: + var_name = 'C%s' % i + if isinstance(term, Function): + # we have to translate some aggregat function + # (for instance COUNT -> SUM) + orig_name = term.name + try: + term.name = AGGR_TRANSFORMS[term.name] + # backup and reduce children + orig_children = term.children + term.children = [VariableRef(Variable(var_name))] + clause.append(term.accept(self)) + # restaure the tree XXX necessary? + term.name = orig_name + term.children = orig_children + except KeyError: + clause.append(var_name) + else: + clause.append(var_name) + for vref in term.iget_nodes(VariableRef): + inputmap[vref.name] = var_name + # XXX handle distinct with non selected sort term + if self.select.distinct: + sql = ['SELECT DISTINCT %s' % ', '.join(clause)] + else: + sql = ['SELECT %s' % ', '.join(clause)] + sql.append("FROM %s" % self.table) + # get the group/having clauses + if self.select.groupby: + clause = [inputmap[var.name] for var in self.select.groupby] + grouped = set(var.name for var in self.select.groupby) + sql.append('GROUP BY %s' % ', '.join(clause)) + else: + grouped = None + if self.select.having: + clause = [term.accept(self) for term in self.select.having] + sql.append('HAVING %s' % ', '.join(clause)) + # get the orderby clause + if self.select.orderby: + clause = [] + for sortterm in self.select.orderby: + sqlterm = sortterm.term.accept(self) + if sortterm.asc: + clause.append(sqlterm) + else: + clause.append('%s DESC' % sqlterm) + if grouped is not None: + for vref in sortterm.iget_nodes(VariableRef): + if not vref.name in grouped: + sql[-1] += ', ' + self.inputmap[vref.name] + grouped.add(vref.name) + sql.append('ORDER BY %s' % ', '.join(clause)) + if self.limit: + sql.append('LIMIT %s' % self.limit) + if self.offset: + sql.append('OFFSET %s' % self.offset) + #print 'DATA', plan.sqlexec('SELECT * FROM %s' % self.table, None) + sql = ' '.join(sql) + if self.outputtable: + self.plan.create_temp_table(self.outputtable) + sql = 'INSERT INTO %s %s' % (self.outputtable, sql) + return self.plan.sqlexec(sql, self.plan.args) + + def visit_function(self, function): + """generate SQL name for a function""" + return '%s(%s)' % (function.name, + ','.join(c.accept(self) for c in function.children)) + + def visit_variableref(self, variableref): + """get the sql name for a variable reference""" + try: + return self.inputmap[variableref.name] + except KeyError: # XXX duh? explain + return variableref.variable.name + + def visit_constant(self, constant): + """generate SQL name for a constant""" + assert constant.type == 'Int' + return str(constant.value) + + +class UnionStep(LimitOffsetMixIn, Step): + """union results of child in-memory steps (e.g. OneFetchStep / AggrStep)""" + + def execute(self): + """execute this step""" + result = [] + limit = olimit = self.limit + offset = self.offset + assert offset != 0 + if offset is not None: + limit = limit + offset + for step in self.children: + if limit is not None: + if offset is None: + limit = olimit - len(result) + step.set_limit_offset(limit, None) + result_ = step.execute() + if offset is not None: + offset, result_ = offset_result(offset, result_) + result += result_ + if limit is not None: + if len(result) >= olimit: + return result[:olimit] + return result + + def mytest_repr(self): + """return a representation of this step suitable for test""" + return (self.__class__.__name__, self.limit, self.offset) + + +class UnionFetchStep(Step): + """union results of child steps using temporary tables (e.g. FetchStep)""" + + def execute(self): + """execute this step""" + self.execute_children() + + +__all__ = ('FetchStep', 'AggrStep', 'UnionStep', 'UnionFetchStep') diff -r 3dbee583526c -r 4c7d3af7e94d server/sources/extlite.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/server/sources/extlite.py Mon Dec 22 17:34:15 2008 +0100 @@ -0,0 +1,247 @@ +"""provide an abstract class for external sources using a sqlite database helper + +:organization: Logilab +:copyright: 2007-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr +""" +__docformat__ = "restructuredtext en" + + +import time +import threading +from os.path import join, exists + +from cubicweb import server +from cubicweb.server.sqlutils import sqlexec, SQLAdapterMixIn +from cubicweb.server.sources import AbstractSource, native +from cubicweb.server.sources.rql2sql import SQLGenerator + +def timeout_acquire(lock, timeout): + while not lock.acquire(False): + time.sleep(0.2) + timeout -= 0.2 + if timeout <= 0: + raise RuntimeError("svn source is busy, can't acquire connection lock") + +class ConnectionWrapper(object): + def __init__(self, source=None): + self.source = source + self._cnx = None + + @property + def cnx(self): + if self._cnx is None: + timeout_acquire(self.source._cnxlock, 5) + self._cnx = self.source._sqlcnx + return self._cnx + + def commit(self): + if self._cnx is not None: + self._cnx.commit() + + def rollback(self): + if self._cnx is not None: + self._cnx.rollback() + + def cursor(self): + return self.cnx.cursor() + + +class SQLiteAbstractSource(AbstractSource): + """an abstract class for external sources using a sqlite database helper + """ + sqlgen_class = SQLGenerator + @classmethod + def set_nonsystem_types(cls): + # those entities are only in this source, we don't want them in the + # system source + for etype in cls.support_entities: + native.NONSYSTEM_ETYPES.add(etype) + for rtype in cls.support_relations: + native.NONSYSTEM_RELATIONS.add(rtype) + + options = ( + ('helper-db-path', + {'type' : 'string', + 'default': None, + 'help': 'path to the sqlite database file used to do queries on the \ +repository.', + 'inputlevel': 2, + }), + ) + + def __init__(self, repo, appschema, source_config, *args, **kwargs): + # the helper db is used to easy querying and will store everything but + # actual file content + dbpath = source_config.get('helper-db-path') + if dbpath is None: + dbpath = join(repo.config.appdatahome, + '%(uri)s.sqlite' % source_config) + self.dbpath = dbpath + self.sqladapter = SQLAdapterMixIn({'db-driver': 'sqlite', + 'db-name': dbpath}) + # those attributes have to be initialized before ancestor's __init__ + # which will call set_schema + self._need_sql_create = not exists(dbpath) + self._need_full_import = self._need_sql_create + AbstractSource.__init__(self, repo, appschema, source_config, + *args, **kwargs) + # sql database can only be accessed by one connection at a time, and a + # connection can only be used by the thread which created it so: + # * create the connection when needed + # * use a lock to be sure only one connection is used + self._cnxlock = threading.Lock() + + @property + def _sqlcnx(self): + # XXX: sqlite connections can only be used in the same thread, so + # create a new one each time necessary. If it appears to be time + # consuming, find another way + return self.sqladapter.get_connection() + + def _is_schema_complete(self): + for etype in self.support_entities: + if not etype in self.schema: + self.warning('not ready to generate %s database, %s support missing from schema', + self.uri, etype) + return False + for rtype in self.support_relations: + if not rtype in self.schema: + self.warning('not ready to generate %s database, %s support missing from schema', + self.uri, rtype) + return False + return True + + def _create_database(self): + from yams.schema2sql import eschema2sql, rschema2sql + from cubicweb.toolsutils import restrict_perms_to_user + self.warning('initializing sqlite database for %s source' % self.uri) + cnx = self._sqlcnx + cu = cnx.cursor() + schema = self.schema + for etype in self.support_entities: + eschema = schema.eschema(etype) + createsqls = eschema2sql(self.sqladapter.dbhelper, eschema, + skip_relations=('data',)) + sqlexec(createsqls, cu, withpb=False) + for rtype in self.support_relations: + rschema = schema.rschema(rtype) + if not rschema.inlined: + sqlexec(rschema2sql(rschema), cu, withpb=False) + cnx.commit() + cnx.close() + self._need_sql_create = False + if self.repo.config['uid']: + from logilab.common.shellutils import chown + # database file must be owned by the uid of the server process + self.warning('set %s as owner of the database file', + self.repo.config['uid']) + chown(self.dbpath, self.repo.config['uid']) + restrict_perms_to_user(self.dbpath, self.info) + + def set_schema(self, schema): + super(SQLiteAbstractSource, self).set_schema(schema) + if self._need_sql_create and self._is_schema_complete(): + self._create_database() + self.rqlsqlgen = self.sqlgen_class(schema, self.sqladapter.dbhelper) + + def get_connection(self): + return ConnectionWrapper(self) + + def check_connection(self, cnx): + """check connection validity, return None if the connection is still valid + else a new connection (called when the pool using the given connection is + being attached to a session) + + always return the connection to reset eventually cached cursor + """ + return cnx + + def pool_reset(self, cnx): + """the pool using the given connection is being reseted from its current + attached session: release the connection lock if the connection wrapper + has a connection set + """ + if cnx._cnx is not None: + try: + cnx._cnx.close() + cnx._cnx = None + finally: + self._cnxlock.release() + + def syntax_tree_search(self, session, union, + args=None, cachekey=None, varmap=None, debug=0): + """return result from this source for a rql query (actually from a rql + syntax tree and a solution dictionary mapping each used variable to a + possible type). If cachekey is given, the query necessary to fetch the + results (but not the results themselves) may be cached using this key. + """ + if self._need_sql_create: + return [] + sql, query_args = self.rqlsqlgen.generate(union, args) + if server.DEBUG: + print self.uri, 'SOURCE RQL', union.as_string() + print 'GENERATED SQL', sql + args = self.sqladapter.merge_args(args, query_args) + cursor = session.pool[self.uri] + cursor.execute(sql, args) + return self.sqladapter.process_result(cursor) + + def local_add_entity(self, session, entity): + """insert the entity in the local database. + + This is not provided as add_entity implementation since usually source + don't want to simply do this, so let raise NotImplementedError and the + source implementor may use this method if necessary + """ + cu = session.pool[self.uri] + attrs = self.sqladapter.preprocess_entity(entity) + sql = self.sqladapter.sqlgen.insert(str(entity.e_schema), attrs) + cu.execute(sql, attrs) + + def add_entity(self, session, entity): + """add a new entity to the source""" + raise NotImplementedError() + + def local_update_entity(self, session, entity): + """update an entity in the source + + This is not provided as update_entity implementation since usually + source don't want to simply do this, so let raise NotImplementedError + and the source implementor may use this method if necessary + """ + cu = session.pool[self.uri] + attrs = self.sqladapter.preprocess_entity(entity) + sql = self.sqladapter.sqlgen.update(str(entity.e_schema), attrs, ['eid']) + cu.execute(sql, attrs) + + def update_entity(self, session, entity): + """update an entity in the source""" + raise NotImplementedError() + + def delete_entity(self, session, etype, eid): + """delete an entity from the source + + this is not deleting a file in the svn but deleting entities from the + source. Main usage is to delete repository content when a Repository + entity is deleted. + """ + sqlcursor = session.pool[self.uri] + attrs = {'eid': eid} + sql = self.sqladapter.sqlgen.delete(etype, attrs) + sqlcursor.execute(sql, attrs) + + def delete_relation(self, session, subject, rtype, object): + """delete a relation from the source""" + rschema = self.schema.rschema(rtype) + if rschema.inlined: + if subject in session.query_data('pendingeids', ()): + return + etype = session.describe(subject)[0] + sql = 'UPDATE %s SET %s=NULL WHERE eid=%%(eid)s' % (etype, rtype) + attrs = {'eid' : subject} + else: + attrs = {'eid_from': subject, 'eid_to': object} + sql = self.sqladapter.sqlgen.delete('%s_relation' % rtype, attrs) + sqlcursor = session.pool[self.uri] + sqlcursor.execute(sql, attrs) diff -r 3dbee583526c -r 4c7d3af7e94d server/sources/ldapuser.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/server/sources/ldapuser.py Mon Dec 22 17:34:15 2008 +0100 @@ -0,0 +1,685 @@ +"""cubicweb ldap user source + +this source is for now limited to a read-only EUser source + +:organization: Logilab +:copyright: 2003-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr + + +Part of the code is coming form Zope's LDAPUserFolder + +Copyright (c) 2004 Jens Vagelpohl. +All Rights Reserved. + +This software is subject to the provisions of the Zope Public License, +Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +FOR A PARTICULAR PURPOSE. +""" + +from mx.DateTime import now, DateTimeDelta + +from logilab.common.textutils import get_csv +from rql.nodes import Relation, VariableRef, Constant, Function + +import ldap +from ldap.ldapobject import ReconnectLDAPObject +from ldap.filter import filter_format, escape_filter_chars +from ldapurl import LDAPUrl + +from cubicweb.common import AuthenticationError, UnknownEid, RepositoryError +from cubicweb.server.sources import AbstractSource, TrFunc, GlobTrFunc, ConnectionWrapper +from cubicweb.server.utils import cartesian_product + +# search scopes +BASE = ldap.SCOPE_BASE +ONELEVEL = ldap.SCOPE_ONELEVEL +SUBTREE = ldap.SCOPE_SUBTREE + +# XXX only for edition ?? +## password encryption possibilities +#ENCRYPTIONS = ('SHA', 'CRYPT', 'MD5', 'CLEAR') # , 'SSHA' + +# mode identifier : (port, protocol) +MODES = { + 0: (389, 'ldap'), + 1: (636, 'ldaps'), + 2: (0, 'ldapi'), + } + +class TimedCache(dict): + def __init__(self, ttlm, ttls=0): + # time to live in minutes + self.ttl = DateTimeDelta(0, 0, ttlm, ttls) + + def __setitem__(self, key, value): + dict.__setitem__(self, key, (now(), value)) + + def __getitem__(self, key): + return dict.__getitem__(self, key)[1] + + def clear_expired(self): + now_ = now() + ttl = self.ttl + for key, (timestamp, value) in self.items(): + if now_ - timestamp > ttl: + del self[key] + +class LDAPUserSource(AbstractSource): + """LDAP read-only EUser source""" + support_entities = {'EUser': False} + + port = None + + cnx_mode = 0 + cnx_dn = '' + cnx_pwd = '' + + options = ( + ('host', + {'type' : 'string', + 'default': 'ldap', + 'help': 'ldap host', + 'group': 'ldap-source', 'inputlevel': 1, + }), + ('user-base-dn', + {'type' : 'string', + 'default': 'ou=People,dc=logilab,dc=fr', + 'help': 'base DN to lookup for users', + 'group': 'ldap-source', 'inputlevel': 0, + }), + ('user-scope', + {'type' : 'choice', + 'default': 'ONELEVEL', + 'choices': ('BASE', 'ONELEVEL', 'SUBTREE'), + 'help': 'user search scope', + 'group': 'ldap-source', 'inputlevel': 1, + }), + ('user-classes', + {'type' : 'csv', + 'default': ('top', 'posixAccount'), + 'help': 'classes of user', + 'group': 'ldap-source', 'inputlevel': 1, + }), + ('user-login-attr', + {'type' : 'string', + 'default': 'uid', + 'help': 'attribute used as login on authentication', + 'group': 'ldap-source', 'inputlevel': 1, + }), + ('user-default-group', + {'type' : 'csv', + 'default': ('users',), + 'help': 'name of a group in which ldap users will be by default. \ +You can set multiple groups by separating them by a comma.', + 'group': 'ldap-source', 'inputlevel': 1, + }), + ('user-attrs-map', + {'type' : 'named', + 'default': {'uid': 'login', 'gecos': 'email'}, + 'help': 'map from ldap user attributes to cubicweb attributes', + 'group': 'ldap-source', 'inputlevel': 1, + }), + + ('synchronization-interval', + {'type' : 'int', + 'default': 24*60*60, + 'help': 'interval between synchronization with the ldap \ +directory (default to once a day).', + 'group': 'ldap-source', 'inputlevel': 2, + }), + ('cache-life-time', + {'type' : 'int', + 'default': 2*60, + 'help': 'life time of query cache in minutes (default to two hours).', + 'group': 'ldap-source', 'inputlevel': 2, + }), + + ) + + def __init__(self, repo, appschema, source_config, *args, **kwargs): + AbstractSource.__init__(self, repo, appschema, source_config, + *args, **kwargs) + self.host = source_config['host'] + self.user_base_dn = source_config['user-base-dn'] + self.user_base_scope = globals()[source_config['user-scope']] + self.user_classes = get_csv(source_config['user-classes']) + self.user_login_attr = source_config['user-login-attr'] + self.user_default_groups = get_csv(source_config['user-default-group']) + self.user_attrs = dict(v.split(':', 1) for v in get_csv(source_config['user-attrs-map'])) + self.user_rev_attrs = {'eid': 'dn'} + for ldapattr, cwattr in self.user_attrs.items(): + self.user_rev_attrs[cwattr] = ldapattr + self.base_filters = [filter_format('(%s=%s)', ('objectClass', o)) + for o in self.user_classes] + self._conn = None + self._cache = {} + ttlm = int(source_config.get('cache-life-type', 2*60)) + self._query_cache = TimedCache(ttlm) + self._interval = int(source_config.get('synchronization-interval', + 24*60*60)) + + def reset_caches(self): + """method called during test to reset potential source caches""" + self._query_cache = TimedCache(2*60) + + def init(self): + """method called by the repository once ready to handle request""" + self.repo.looping_task(self._interval, self.synchronize) + self.repo.looping_task(self._query_cache.ttl.seconds/10, self._query_cache.clear_expired) + + def synchronize(self): + """synchronize content known by this repository with content in the + external repository + """ + self.info('synchronizing ldap source %s', self.uri) + session = self.repo.internal_session() + try: + cursor = session.system_sql("SELECT eid, extid FROM entities WHERE " + "source='%s'" % self.uri) + for eid, extid in cursor.fetchall(): + # if no result found, _search automatically delete entity information + res = self._search(session, extid, BASE) + if res: + ldapemailaddr = res[0].get(self.user_rev_attrs['email']) + if ldapemailaddr: + rset = session.execute('EmailAddress X,A WHERE ' + 'U use_email X, U eid %(u)s', + {'u': eid}) + ldapemailaddr = unicode(ldapemailaddr) + for emaileid, emailaddr in rset: + if emailaddr == ldapemailaddr: + break + else: + self.info('updating email address of user %s to %s', + extid, ldapemailaddr) + if rset: + session.execute('SET X address %(addr)s WHERE ' + 'U primary_email X, U eid %(u)s', + {'addr': ldapemailaddr, 'u': eid}) + else: + # no email found, create it + _insert_email(session, ldapemailaddr, eid) + finally: + session.commit() + session.close() + + def get_connection(self): + """open and return a connection to the source""" + if self._conn is None: + self._connect() + return ConnectionWrapper(self._conn) + + def authenticate(self, session, login, password): + """return EUser eid for the given login/password if this account is + defined in this source, else raise `AuthenticationError` + + two queries are needed since passwords are stored crypted, so we have + to fetch the salt first + """ + assert login, 'no login!' + searchfilter = [filter_format('(%s=%s)', (self.user_login_attr, login))] + searchfilter.extend([filter_format('(%s=%s)', ('objectClass', o)) + for o in self.user_classes]) + searchstr = '(&%s)' % ''.join(searchfilter) + # first search the user + try: + user = self._search(session, self.user_base_dn, + self.user_base_scope, searchstr)[0] + except IndexError: + # no such user + raise AuthenticationError() + # check password by establishing a (unused) connection + try: + self._connect(user['dn'], password) + except: + # Something went wrong, most likely bad credentials + raise AuthenticationError() + return self.extid2eid(user['dn'], 'EUser', session) + + def ldap_name(self, var): + if var.stinfo['relations']: + relname = iter(var.stinfo['relations']).next().r_type + return self.user_rev_attrs.get(relname) + return None + + def prepare_columns(self, mainvars, rqlst): + """return two list describin how to build the final results + from the result of an ldap search (ie a list of dictionnary) + """ + columns = [] + global_transforms = [] + for i, term in enumerate(rqlst.selection): + if isinstance(term, Constant): + columns.append(term) + continue + if isinstance(term, Function): # LOWER, UPPER, COUNT... + var = term.get_nodes(VariableRef)[0] + var = var.variable + try: + mainvar = var.stinfo['attrvar'].name + except AttributeError: # no attrvar set + mainvar = var.name + assert mainvar in mainvars + trname = term.name + ldapname = self.ldap_name(var) + if trname in ('COUNT', 'MIN', 'MAX', 'SUM'): + global_transforms.append(GlobTrFunc(trname, i, ldapname)) + columns.append((mainvar, ldapname)) + continue + if trname in ('LOWER', 'UPPER'): + columns.append((mainvar, TrFunc(trname, i, ldapname))) + continue + raise NotImplementedError('no support for %s function' % trname) + if term.name in mainvars: + columns.append((term.name, 'dn')) + continue + var = term.variable + mainvar = var.stinfo['attrvar'].name + columns.append((mainvar, self.ldap_name(var))) + #else: + # # probably a bug in rql splitting if we arrive here + # raise NotImplementedError + return columns, global_transforms + + def syntax_tree_search(self, session, union, + args=None, cachekey=None, varmap=None, debug=0): + """return result from this source for a rql query (actually from a rql + syntax tree and a solution dictionary mapping each used variable to a + possible type). If cachekey is given, the query necessary to fetch the + results (but not the results themselves) may be cached using this key. + """ + # XXX not handled : transform/aggregat function, join on multiple users... + assert len(union.children) == 1, 'union not supported' + rqlst = union.children[0] + assert not rqlst.with_, 'subquery not supported' + rqlkey = rqlst.as_string(kwargs=args) + try: + results = self._query_cache[rqlkey] + except KeyError: + results = self.rqlst_search(session, rqlst, args) + self._query_cache[rqlkey] = results + return results + + def rqlst_search(self, session, rqlst, args): + mainvars = [] + for varname in rqlst.defined_vars: + for sol in rqlst.solutions: + if sol[varname] == 'EUser': + mainvars.append(varname) + break + assert mainvars + columns, globtransforms = self.prepare_columns(mainvars, rqlst) + eidfilters = [] + allresults = [] + generator = RQL2LDAPFilter(self, session, args, mainvars) + for mainvar in mainvars: + # handle restriction + try: + eidfilters_, ldapfilter = generator.generate(rqlst, mainvar) + except GotDN, ex: + assert ex.dn, 'no dn!' + try: + res = [self._cache[ex.dn]] + except KeyError: + res = self._search(session, ex.dn, BASE) + except UnknownEid, ex: + # raised when we are looking for the dn of an eid which is not + # coming from this source + res = [] + else: + eidfilters += eidfilters_ + res = self._search(session, self.user_base_dn, + self.user_base_scope, ldapfilter) + allresults.append(res) + # 1. get eid for each dn and filter according to that eid if necessary + for i, res in enumerate(allresults): + filteredres = [] + for resdict in res: + # get sure the entity exists in the system table + eid = self.extid2eid(resdict['dn'], 'EUser', session) + for eidfilter in eidfilters: + if not eidfilter(eid): + break + else: + resdict['eid'] = eid + filteredres.append(resdict) + allresults[i] = filteredres + # 2. merge result for each "mainvar": cartesian product + allresults = cartesian_product(allresults) + # 3. build final result according to column definition + result = [] + for rawline in allresults: + rawline = dict(zip(mainvars, rawline)) + line = [] + for varname, ldapname in columns: + if ldapname is None: + value = None # no mapping available + elif ldapname == 'dn': + value = rawline[varname]['eid'] + elif isinstance(ldapname, Constant): + if ldapname.type == 'Substitute': + value = args[ldapname.value] + else: + value = ldapname.value + elif isinstance(ldapname, TrFunc): + value = ldapname.apply(rawline[varname]) + else: + value = rawline[varname].get(ldapname) + line.append(value) + result.append(line) + for trfunc in globtransforms: + result = trfunc.apply(result) + #print '--> ldap result', result + return result + + + def _connect(self, userdn=None, userpwd=None): + port, protocol = MODES[self.cnx_mode] + if protocol == 'ldapi': + hostport = self.host + else: + hostport = '%s:%s' % (self.host, self.port or port) + self.info('connecting %s://%s as %s', protocol, hostport, + userdn or 'anonymous') + url = LDAPUrl(urlscheme=protocol, hostport=hostport) + conn = ReconnectLDAPObject(url.initializeUrl()) + # Set the protocol version - version 3 is preferred + try: + conn.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION3) + except ldap.LDAPError: # Invalid protocol version, fall back safely + conn.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION2) + # Deny auto-chasing of referrals to be safe, we handle them instead + #try: + # connection.set_option(ldap.OPT_REFERRALS, 0) + #except ldap.LDAPError: # Cannot set referrals, so do nothing + # pass + #conn.set_option(ldap.OPT_NETWORK_TIMEOUT, conn_timeout) + #conn.timeout = op_timeout + # Now bind with the credentials given. Let exceptions propagate out. + if userdn is None: + assert self._conn is None + self._conn = conn + userdn = self.cnx_dn + userpwd = self.cnx_pwd + conn.simple_bind_s(userdn, userpwd) + return conn + + def _search(self, session, base, scope, + searchstr='(objectClass=*)', attrs=()): + """make an ldap query""" + cnx = session.pool.connection(self.uri).cnx + try: + res = cnx.search_s(base, scope, searchstr, attrs) + except ldap.PARTIAL_RESULTS: + res = cnx.result(all=0)[1] + except ldap.NO_SUCH_OBJECT: + eid = self.extid2eid(base, 'EUser', session, insert=False) + if eid: + self.warning('deleting ldap user with eid %s and dn %s', + eid, base) + self.repo.delete_info(session, eid) + self._cache.pop(base, None) + return [] +## except ldap.REFERRAL, e: +## cnx = self.handle_referral(e) +## try: +## res = cnx.search_s(base, scope, searchstr, attrs) +## except ldap.PARTIAL_RESULTS: +## res_type, res = cnx.result(all=0) + result = [] + for rec_dn, rec_dict in res: + # When used against Active Directory, "rec_dict" may not be + # be a dictionary in some cases (instead, it can be a list) + # An example of a useless "res" entry that can be ignored + # from AD is + # (None, ['ldap://ForestDnsZones.PORTAL.LOCAL/DC=ForestDnsZones,DC=PORTAL,DC=LOCAL']) + # This appears to be some sort of internal referral, but + # we can't handle it, so we need to skip over it. + try: + items = rec_dict.items() + except AttributeError: + # 'items' not found on rec_dict, skip + continue + for key, value in items: # XXX syt: huuum ? + if not isinstance(value, str): + try: + for i in range(len(value)): + value[i] = unicode(value[i], 'utf8') + except: + pass + if isinstance(value, list) and len(value) == 1: + rec_dict[key] = value = value[0] + rec_dict['dn'] = rec_dn + self._cache[rec_dn] = rec_dict + result.append(rec_dict) + #print '--->', result + return result + + def before_entity_insertion(self, session, lid, etype, eid): + """called by the repository when an eid has been attributed for an + entity stored here but the entity has not been inserted in the system + table yet. + + This method must return the an Entity instance representation of this + entity. + """ + entity = super(LDAPUserSource, self).before_entity_insertion(session, lid, etype, eid) + res = self._search(session, lid, BASE)[0] + for attr in entity.e_schema.indexable_attributes(): + entity[attr] = res[self.user_rev_attrs[attr]] + return entity + + def after_entity_insertion(self, session, dn, entity): + """called by the repository after an entity stored here has been + inserted in the system table. + """ + super(LDAPUserSource, self).after_entity_insertion(session, dn, entity) + for group in self.user_default_groups: + session.execute('SET X in_group G WHERE X eid %(x)s, G name %(group)s', + {'x': entity.eid, 'group': group}, 'x') + # search for existant email first + try: + emailaddr = self._cache[dn][self.user_rev_attrs['email']] + except KeyError: + return + rset = session.execute('EmailAddress X WHERE X address %(addr)s', + {'addr': emailaddr}) + if rset: + session.execute('SET U primary_email X WHERE U eid %(u)s, X eid %(x)s', + {'x': rset[0][0], 'u': entity.eid}, 'u') + else: + # not found, create it + _insert_email(session, emailaddr, entity.eid) + + def update_entity(self, session, entity): + """replace an entity in the source""" + raise RepositoryError('this source is read only') + + def delete_entity(self, session, etype, eid): + """delete an entity from the source""" + raise RepositoryError('this source is read only') + +def _insert_email(session, emailaddr, ueid): + session.execute('INSERT EmailAddress X: X address %(addr)s, U primary_email X ' + 'WHERE U eid %(x)s', {'addr': emailaddr, 'x': ueid}, 'x') + +class GotDN(Exception): + """exception used when a dn localizing the searched user has been found""" + def __init__(self, dn): + self.dn = dn + + +class RQL2LDAPFilter(object): + """generate an LDAP filter for a rql query""" + def __init__(self, source, session, args=None, mainvars=()): + self.source = source + self._ldap_attrs = source.user_rev_attrs + self._base_filters = source.base_filters + self._session = session + if args is None: + args = {} + self._args = args + self.mainvars = mainvars + + def generate(self, selection, mainvarname): + self._filters = res = self._base_filters[:] + self._mainvarname = mainvarname + self._eidfilters = [] + self._done_not = set() + restriction = selection.where + if isinstance(restriction, Relation): + # only a single relation, need to append result here (no AND/OR) + filter = restriction.accept(self) + if filter is not None: + res.append(filter) + elif restriction: + restriction.accept(self) + if len(res) > 1: + return self._eidfilters, '(&%s)' % ''.join(res) + return self._eidfilters, res[0] + + def visit_and(self, et): + """generate filter for a AND subtree""" + for c in et.children: + part = c.accept(self) + if part: + self._filters.append(part) + + def visit_or(self, ou): + """generate filter for a OR subtree""" + res = [] + for c in ou.children: + part = c.accept(self) + if part: + res.append(part) + if res: + if len(res) > 1: + part = '(|%s)' % ''.join(res) + else: + part = res[0] + self._filters.append(part) + + def visit_not(self, node): + """generate filter for a OR subtree""" + part = node.children[0].accept(self) + if part: + self._filters.append('(!(%s))'% part) + + def visit_relation(self, relation): + """generate filter for a relation""" + rtype = relation.r_type + # don't care of type constraint statement (i.e. relation_type = 'is') + if rtype == 'is': + return '' + lhs, rhs = relation.get_parts() + # attribute relation + if self.source.schema.rschema(rtype).is_final(): + # dunno what to do here, don't pretend anything else + if lhs.name != self._mainvarname: + if lhs.name in self.mainvars: + # XXX check we don't have variable as rhs + return + raise NotImplementedError + rhs_vars = rhs.get_nodes(VariableRef) + if rhs_vars: + if len(rhs_vars) > 1: + raise NotImplementedError + # selected variable, nothing to do here + return + # no variables in the RHS + if isinstance(rhs.children[0], Function): + res = rhs.children[0].accept(self) + elif rtype != 'has_text': + res = self._visit_attribute_relation(relation) + else: + raise NotImplementedError(relation) + # regular relation XXX todo: in_group + else: + raise NotImplementedError(relation) + return res + + def _visit_attribute_relation(self, relation): + """generate filter for an attribute relation""" + lhs, rhs = relation.get_parts() + lhsvar = lhs.variable + if relation.r_type == 'eid': + # XXX hack + # skip comparison sign + eid = int(rhs.children[0].accept(self)) + if relation.neged(strict=True): + self._done_not.add(relation.parent) + self._eidfilters.append(lambda x: not x == eid) + return + if rhs.operator != '=': + filter = {'>': lambda x: x > eid, + '>=': lambda x: x >= eid, + '<': lambda x: x < eid, + '<=': lambda x: x <= eid, + }[rhs.operator] + self._eidfilters.append(filter) + return + dn = self.source.eid2extid(eid, self._session) + raise GotDN(dn) + try: + filter = '(%s%s)' % (self._ldap_attrs[relation.r_type], + rhs.accept(self)) + except KeyError: + assert relation.r_type == 'password' # 2.38 migration + raise UnknownEid # trick to return no result + return filter + + def visit_comparison(self, cmp): + """generate filter for a comparaison""" + return '%s%s'% (cmp.operator, cmp.children[0].accept(self)) + + def visit_mathexpression(self, mexpr): + """generate filter for a mathematic expression""" + raise NotImplementedError + + def visit_function(self, function): + """generate filter name for a function""" + if function.name == 'IN': + return self.visit_in(function) + raise NotImplementedError + + def visit_in(self, function): + grandpapa = function.parent.parent + ldapattr = self._ldap_attrs[grandpapa.r_type] + res = [] + for c in function.children: + part = c.accept(self) + if part: + res.append(part) + if res: + if len(res) > 1: + part = '(|%s)' % ''.join('(%s=%s)' % (ldapattr, v) for v in res) + else: + part = '(%s=%s)' % (ldapattr, res[0]) + return part + + def visit_constant(self, constant): + """generate filter name for a constant""" + value = constant.value + if constant.type is None: + raise NotImplementedError + if constant.type == 'Date': + raise NotImplementedError + #value = self.keyword_map[value]() + elif constant.type == 'Substitute': + value = self._args[constant.value] + else: + value = constant.value + if isinstance(value, unicode): + value = value.encode('utf8') + else: + value = str(value) + return escape_filter_chars(value) + + def visit_variableref(self, variableref): + """get the sql name for a variable reference""" + pass + diff -r 3dbee583526c -r 4c7d3af7e94d server/sources/pyrorql.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/server/sources/pyrorql.py Mon Dec 22 17:34:15 2008 +0100 @@ -0,0 +1,553 @@ +"""Source to query another RQL repository using pyro + +:organization: Logilab +:copyright: 2007-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr +""" +__docformat__ = "restructuredtext en" + +import threading +from os.path import join + +from mx.DateTime import DateTimeFromTicks + +from Pyro.errors import PyroError, ConnectionClosedError + +from logilab.common.configuration import REQUIRED + +from rql.nodes import Constant +from rql.utils import rqlvar_maker + +from cubicweb import dbapi, server +from cubicweb import BadConnectionId, UnknownEid, ConnectionError +from cubicweb.cwconfig import register_persistent_options +from cubicweb.server.sources import AbstractSource, ConnectionWrapper + +class ReplaceByInOperator: + def __init__(self, eids): + self.eids = eids + +class PyroRQLSource(AbstractSource): + """External repository source, using Pyro connection""" + + # boolean telling if modification hooks should be called when something is + # modified in this source + should_call_hooks = False + # boolean telling if the repository should connect to this source during + # migration + connect_for_migration = False + + support_entities = None + + options = ( + # XXX pyro-ns host/port + ('pyro-ns-id', + {'type' : 'string', + 'default': REQUIRED, + 'help': 'identifier of the repository in the pyro name server', + 'group': 'pyro-source', 'inputlevel': 0, + }), + ('mapping-file', + {'type' : 'string', + 'default': REQUIRED, + 'help': 'path to a python file with the schema mapping definition', + 'group': 'pyro-source', 'inputlevel': 1, + }), + ('cubicweb-user', + {'type' : 'string', + 'default': REQUIRED, + 'help': 'user to use for connection on the distant repository', + 'group': 'pyro-source', 'inputlevel': 0, + }), + ('cubicweb-password', + {'type' : 'password', + 'default': '', + 'help': 'user to use for connection on the distant repository', + 'group': 'pyro-source', 'inputlevel': 0, + }), + ('base-url', + {'type' : 'string', + 'default': '', + 'help': 'url of the web site for the distant repository, if you want ' + 'to generate external link to entities from this repository', + 'group': 'pyro-source', 'inputlevel': 1, + }), + ('pyro-ns-host', + {'type' : 'string', + 'default': None, + 'help': 'Pyro name server\'s host. If not set, default to the value \ +from all_in_one.conf.', + 'group': 'pyro-source', 'inputlevel': 1, + }), + ('pyro-ns-port', + {'type' : 'int', + 'default': None, + 'help': 'Pyro name server\'s listening port. If not set, default to \ +the value from all_in_one.conf.', + 'group': 'pyro-source', 'inputlevel': 1, + }), + ('pyro-ns-group', + {'type' : 'string', + 'default': None, + 'help': 'Pyro name server\'s group where the repository will be \ +registered. If not set, default to the value from all_in_one.conf.', + 'group': 'pyro-source', 'inputlevel': 1, + }), + ('synchronization-interval', + {'type' : 'int', + 'default': 5*60, + 'help': 'interval between synchronization with the external \ +repository (default to 5 minutes).', + 'group': 'pyro-source', 'inputlevel': 2, + }), + + ) + + PUBLIC_KEYS = AbstractSource.PUBLIC_KEYS + ('base-url',) + _conn = None + + def __init__(self, repo, appschema, source_config, *args, **kwargs): + AbstractSource.__init__(self, repo, appschema, source_config, + *args, **kwargs) + mappingfile = source_config['mapping-file'] + if not mappingfile[0] == '/': + mappingfile = join(repo.config.apphome, mappingfile) + mapping = {} + execfile(mappingfile, mapping) + self.support_entities = mapping['support_entities'] + self.support_relations = mapping.get('support_relations', {}) + self.dont_cross_relations = mapping.get('dont_cross_relations', ()) + baseurl = source_config.get('base-url') + if baseurl and not baseurl.endswith('/'): + source_config['base-url'] += '/' + self.config = source_config + myoptions = (('%s.latest-update-time' % self.uri, + {'type' : 'int', 'sitewide': True, + 'default': 0, + 'help': _('timestamp of the latest source synchronization.'), + 'group': 'sources', + }),) + register_persistent_options(myoptions) + + def last_update_time(self): + pkey = u'sources.%s.latest-update-time' % self.uri + rql = 'Any V WHERE X is EProperty, X value V, X pkey %(k)s' + session = self.repo.internal_session() + try: + rset = session.execute(rql, {'k': pkey}) + if not rset: + # insert it + session.execute('INSERT EProperty X: X pkey %(k)s, X value %(v)s', + {'k': pkey, 'v': u'0'}) + session.commit() + timestamp = 0 + else: + assert len(rset) == 1 + timestamp = int(rset[0][0]) + return DateTimeFromTicks(timestamp) + finally: + session.close() + + def init(self): + """method called by the repository once ready to handle request""" + interval = int(self.config.get('synchronization-interval', 5*60)) + self.repo.looping_task(interval, self.synchronize) + + def synchronize(self, mtime=None): + """synchronize content known by this repository with content in the + external repository + """ + self.info('synchronizing pyro source %s', self.uri) + extrepo = self.get_connection()._repo + etypes = self.support_entities.keys() + if mtime is None: + mtime = self.last_update_time() + updatetime, modified, deleted = extrepo.entities_modified_since(etypes, + mtime) + repo = self.repo + session = repo.internal_session() + try: + for etype, extid in modified: + try: + eid = self.extid2eid(extid, etype, session) + rset = session.eid_rset(eid, etype) + entity = rset.get_entity(0, 0) + entity.complete(entity.e_schema.indexable_attributes()) + repo.index_entity(session, entity) + except: + self.exception('while updating %s with external id %s of source %s', + etype, extid, self.uri) + continue + for etype, extid in deleted: + try: + eid = self.extid2eid(extid, etype, session, insert=False) + # entity has been deleted from external repository but is not known here + if eid is not None: + repo.delete_info(session, eid) + except: + self.exception('while updating %s with external id %s of source %s', + etype, extid, self.uri) + continue + session.execute('SET X value %(v)s WHERE X pkey %(k)s', + {'k': u'sources.%s.latest-update-time' % self.uri, + 'v': unicode(int(updatetime.ticks()))}) + session.commit() + finally: + session.close() + + def _get_connection(self): + """open and return a connection to the source""" + nshost = self.config.get('pyro-ns-host') or self.repo.config['pyro-ns-host'] + nsport = self.config.get('pyro-ns-port') or self.repo.config['pyro-ns-port'] + nsgroup = self.config.get('pyro-ns-group') or self.repo.config['pyro-ns-group'] + #cnxprops = ConnectionProperties(cnxtype=self.config['cnx-type']) + return dbapi.connect(database=self.config['pyro-ns-id'], + user=self.config['cubicweb-user'], + password=self.config['cubicweb-password'], + host=nshost, port=nsport, group=nsgroup, + setvreg=False) #cnxprops=cnxprops) + + def get_connection(self): + try: + return self._get_connection() + except (ConnectionError, PyroError): + self.critical("can't get connection to source %s", self.uri, + exc_info=1) + return ConnectionWrapper() + + def check_connection(self, cnx): + """check connection validity, return None if the connection is still valid + else a new connection + """ + # we have to transfer manually thread ownership. This can be done safely + # since the pool to which belong the connection is affected to one + # session/thread and can't be called simultaneously + try: + cnx._repo._transferThread(threading.currentThread()) + except AttributeError: + # inmemory connection + pass + if not isinstance(cnx, ConnectionWrapper): + try: + cnx.check() + return # ok + except (BadConnectionId, ConnectionClosedError): + pass + # try to reconnect + return self.get_connection() + + + def syntax_tree_search(self, session, union, args=None, cachekey=None, + varmap=None): + """return result from this source for a rql query (actually from a rql + syntax tree and a solution dictionary mapping each used variable to a + possible type). If cachekey is given, the query necessary to fetch the + results (but not the results themselves) may be cached using this key. + """ + if not args is None: + args = args.copy() + if server.DEBUG: + print 'RQL FOR PYRO SOURCE', self.uri + print union.as_string() + if args: print 'ARGS', args + print 'SOLUTIONS', ','.join(str(s.solutions) for s in union.children) + # get cached cursor anyway + cu = session.pool[self.uri] + if cu is None: + # this is a ConnectionWrapper instance + msg = session._("can't connect to source %s, some data may be missing") + session.set_shared_data('sources_error', msg % self.uri) + return [] + try: + rql, cachekey = RQL2RQL(self).generate(session, union, args) + except UnknownEid, ex: + if server.DEBUG: + print 'unknown eid', ex, 'no results' + return [] + if server.DEBUG: + print 'TRANSLATED RQL', rql + try: + rset = cu.execute(rql, args, cachekey) + except Exception, ex: + self.exception(str(ex)) + msg = session._("error while querying source %s, some data may be missing") + session.set_shared_data('sources_error', msg % self.uri) + return [] + descr = rset.description + if rset: + needtranslation = [] + for i, etype in enumerate(descr[0]): + if (etype is None or not self.schema.eschema(etype).is_final() or + getattr(union.locate_subquery(i, etype, args).selection[i], 'uidtype', None)): + needtranslation.append(i) + if needtranslation: + for rowindex, row in enumerate(rset): + for colindex in needtranslation: + if row[colindex] is not None: # optional variable + etype = descr[rowindex][colindex] + eid = self.extid2eid(row[colindex], etype, session) + row[colindex] = eid + results = rset.rows + else: + results = [] + if server.DEBUG: + if len(results)>10: + print '--------------->', results[:10], '...', len(results) + else: + print '--------------->', results + return results + + def _entity_relations_and_kwargs(self, session, entity): + relations = [] + kwargs = {'x': self.eid2extid(entity.eid, session)} + for key, val in entity.iteritems(): + relations.append('X %s %%(%s)s' % (key, key)) + kwargs[key] = val + return relations, kwargs + + def add_entity(self, session, entity): + """add a new entity to the source""" + raise NotImplementedError() + + def update_entity(self, session, entity): + """update an entity in the source""" + relations, kwargs = self._entity_relations_and_kwargs(session, entity) + cu = session.pool[self.uri] + cu.execute('SET %s WHERE X eid %%(x)s' % ','.join(relations), + kwargs, 'x') + + def delete_entity(self, session, etype, eid): + """delete an entity from the source""" + cu = session.pool[self.uri] + cu.execute('DELETE %s X WHERE X eid %%(x)s' % etype, + {'x': self.eid2extid(eid, session)}, 'x') + + def add_relation(self, session, subject, rtype, object): + """add a relation to the source""" + cu = session.pool[self.uri] + cu.execute('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype, + {'x': self.eid2extid(subject, session), + 'y': self.eid2extid(object, session)}, ('x', 'y')) + + def delete_relation(self, session, subject, rtype, object): + """delete a relation from the source""" + cu = session.pool[self.uri] + cu.execute('DELETE X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype, + {'x': self.eid2extid(subject, session), + 'y': self.eid2extid(object, session)}, ('x', 'y')) + + +class RQL2RQL(object): + """translate a local rql query to be executed on a distant repository""" + def __init__(self, source): + self.source = source + + def _accept_children(self, node): + res = [] + for child in node.children: + rql = child.accept(self) + if rql is not None: + res.append(rql) + return res + + def generate(self, session, rqlst, args): + self._session = session + self.kwargs = args + self.cachekey = [] + self.need_translation = False + return self.visit_union(rqlst), self.cachekey + + def visit_union(self, node): + s = self._accept_children(node) + if len(s) > 1: + return ' UNION '.join('(%s)' % q for q in s) + return s[0] + + def visit_select(self, node): + """return the tree as an encoded rql string""" + self._varmaker = rqlvar_maker(defined=node.defined_vars.copy()) + self._const_var = {} + if node.distinct: + base = 'DISTINCT Any' + else: + base = 'Any' + s = ['%s %s' % (base, ','.join(v.accept(self) for v in node.selection))] + if node.groupby: + s.append('GROUPBY %s' % ', '.join(group.accept(self) + for group in node.groupby)) + if node.orderby: + s.append('ORDERBY %s' % ', '.join(self.visit_sortterm(term) + for term in node.orderby)) + if node.limit is not None: + s.append('LIMIT %s' % node.limit) + if node.offset: + s.append('OFFSET %s' % node.offset) + restrictions = [] + if node.where is not None: + nr = node.where.accept(self) + if nr is not None: + restrictions.append(nr) + if restrictions: + s.append('WHERE %s' % ','.join(restrictions)) + + if node.having: + s.append('HAVING %s' % ', '.join(term.accept(self) + for term in node.having)) + subqueries = [] + for subquery in node.with_: + subqueries.append('%s BEING (%s)' % (','.join(ca.name for ca in subquery.aliases), + self.visit_union(subquery.query))) + if subqueries: + s.append('WITH %s' % (','.join(subqueries))) + return ' '.join(s) + + def visit_and(self, node): + res = self._accept_children(node) + if res: + return ', '.join(res) + return + + def visit_or(self, node): + res = self._accept_children(node) + if len(res) > 1: + return ' OR '.join('(%s)' % rql for rql in res) + elif res: + return res[0] + return + + def visit_not(self, node): + rql = node.children[0].accept(self) + if rql: + return 'NOT (%s)' % rql + return + + def visit_exists(self, node): + return 'EXISTS(%s)' % node.children[0].accept(self) + + def visit_relation(self, node): + try: + if isinstance(node.children[0], Constant): + # simplified rqlst, reintroduce eid relation + restr, lhs = self.process_eid_const(node.children[0]) + else: + lhs = node.children[0].accept(self) + restr = None + except UnknownEid: + # can safely skip not relation with an unsupported eid + if node.neged(strict=True): + return + # XXX what about optional relation or outer NOT EXISTS() + raise + if node.optional in ('left', 'both'): + lhs += '?' + if node.r_type == 'eid' or not self.source.schema.rschema(node.r_type).is_final(): + self.need_translation = True + self.current_operator = node.operator() + if isinstance(node.children[0], Constant): + self.current_etypes = (node.children[0].uidtype,) + else: + self.current_etypes = node.children[0].variable.stinfo['possibletypes'] + try: + rhs = node.children[1].accept(self) + except UnknownEid: + # can safely skip not relation with an unsupported eid + if node.neged(strict=True): + return + # XXX what about optional relation or outer NOT EXISTS() + raise + except ReplaceByInOperator, ex: + rhs = 'IN (%s)' % ','.join(str(eid) for eid in ex.eids) + self.need_translation = False + self.current_operator = None + if node.optional in ('right', 'both'): + rhs += '?' + if restr is not None: + return '%s %s %s, %s' % (lhs, node.r_type, rhs, restr) + return '%s %s %s' % (lhs, node.r_type, rhs) + + def visit_comparison(self, node): + if node.operator in ('=', 'IS'): + return node.children[0].accept(self) + return '%s %s' % (node.operator.encode(), + node.children[0].accept(self)) + + def visit_mathexpression(self, node): + return '(%s %s %s)' % (node.children[0].accept(self), + node.operator.encode(), + node.children[1].accept(self)) + + def visit_function(self, node): + #if node.name == 'IN': + res = [] + for child in node.children: + try: + rql = child.accept(self) + except UnknownEid, ex: + continue + res.append(rql) + if not res: + raise ex + return '%s(%s)' % (node.name, ', '.join(res)) + + def visit_constant(self, node): + if self.need_translation or node.uidtype: + if node.type == 'Int': + return str(self.eid2extid(node.value)) + if node.type == 'Substitute': + key = node.value + # ensure we have not yet translated the value... + if not key in self._const_var: + self.kwargs[key] = self.eid2extid(self.kwargs[key]) + self.cachekey.append(key) + self._const_var[key] = None + return node.as_string() + + def visit_variableref(self, node): + """get the sql name for a variable reference""" + return node.name + + def visit_sortterm(self, node): + if node.asc: + return node.term.accept(self) + return '%s DESC' % node.term.accept(self) + + def process_eid_const(self, const): + value = const.eval(self.kwargs) + try: + return None, self._const_var[value] + except: + var = self._varmaker.next() + self.need_translation = True + restr = '%s eid %s' % (var, self.visit_constant(const)) + self.need_translation = False + self._const_var[value] = var + return restr, var + + def eid2extid(self, eid): + try: + return self.source.eid2extid(eid, self._session) + except UnknownEid: + operator = self.current_operator + if operator is not None and operator != '=': + # deal with query like X eid > 12 + # + # The problem is + # that eid order in the external source may differ from the + # local source + # + # So search for all eids from this + # source matching the condition locally and then to replace the + # > 12 branch by IN (eids) (XXX we may have to insert a huge + # number of eids...) + # planner so that + sql = "SELECT extid FROM entities WHERE source='%s' AND type IN (%s) AND eid%s%s" + etypes = ','.join("'%s'" % etype for etype in self.current_etypes) + cu = self._session.system_sql(sql % (self.source.uri, etypes, + operator, eid)) + # XXX buggy cu.rowcount which may be zero while there are some + # results + rows = cu.fetchall() + if rows: + raise ReplaceByInOperator((r[0] for r in rows)) + raise +