--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/server/msplanner.py Mon Dec 22 17:34:15 2008 +0100
@@ -0,0 +1,1212 @@
+"""plan execution of rql queries on multiple sources
+
+the best way to understand what are we trying to acheive here is to read
+the unit-tests in unittest_querier_planner.py
+
+
+
+Split and execution specifications
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+For a system source and a ldap user source (only EUser and its attributes
+is supported, no group or such):
+
+
+:EUser X:
+1. fetch EUser X from both sources and return concatenation of results
+
+
+:EUser X WHERE X in_group G, G name 'users':
+* catch 1
+ 1. fetch EUser X from both sources, store concatenation of results
+ into a temporary table
+ 2. return the result of TMP X WHERE X in_group G, G name 'users' from
+ the system source
+
+* catch 2
+ 1. return the result of EUser X WHERE X in_group G, G name 'users'
+ from system source, that's enough (optimization of the sql querier
+ will avoid join on EUser, so we will directly get local eids)
+
+
+:EUser X,L WHERE X in_group G, X login L, G name 'users':
+1. fetch Any X,L WHERE X is EUser, X login L from both sources, store
+ concatenation of results into a temporary table
+2. return the result of Any X, L WHERE X is TMP, X login LX in_group G,
+ G name 'users' from the system source
+
+
+:Any X WHERE X owned_by Y:
+* catch 1
+ 1. fetch EUser X from both sources, store concatenation of results
+ into a temporary table
+ 2. return the result of Any X WHERE X owned_by Y, Y is TMP from
+ the system source
+
+* catch 2
+ 1. return the result of Any X WHERE X owned_by Y
+ from system source, that's enough (optimization of the sql querier
+ will avoid join on EUser, so we will directly get local eids)
+
+
+:organization: Logilab
+:copyright: 2003-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
+"""
+__docformat__ = "restructuredtext en"
+
+from itertools import imap, ifilterfalse
+
+from logilab.common.compat import any
+from logilab.common.decorators import cached
+
+from rql.stmts import Union, Select
+from rql.nodes import VariableRef, Comparison, Relation, Constant, Exists, Variable
+
+from cubicweb import server
+from cubicweb.common.utils import make_uid
+from cubicweb.server.utils import cleanup_solutions
+from cubicweb.server.ssplanner import SSPlanner, OneFetchStep, add_types_restriction
+from cubicweb.server.mssteps import *
+from cubicweb.server.sources import AbstractSource
+
+Variable._ms_table_key = lambda x: x.name
+Relation._ms_table_key = lambda x: x.r_type
+# str() Constant.value to ensure generated table name won't be unicode
+Constant._ms_table_key = lambda x: str(x.value)
+
+AbstractSource.dont_cross_relations = ()
+
+def allequals(solutions):
+ """return true if all solutions are identical"""
+ sol = solutions.next()
+ for sol_ in solutions:
+ if sol_ != sol:
+ return False
+ return True
+
+def need_aggr_step(select, sources, stepdefs=None):
+ """return True if a temporary table is necessary to store some partial
+ results to execute the given query
+ """
+ if len(sources) == 1:
+ # can do everything at once with a single source
+ return False
+ if select.orderby or select.groupby or select.has_aggregat:
+ # if more than one source, we need a temp table to deal with sort /
+ # groups / aggregat if :
+ # * the rqlst won't be splitted (in the other case the last query
+ # using partial temporary table can do sort/groups/aggregat without
+ # the need for a later AggrStep)
+ # * the rqlst is splitted in multiple steps and there are more than one
+ # final step
+ if stepdefs is None:
+ return True
+ has_one_final = False
+ fstepsolindices = set()
+ for stepdef in stepdefs:
+ if stepdef[-1]:
+ if has_one_final or frozenset(stepdef[2]) != fstepsolindices:
+ return True
+ has_one_final = True
+ else:
+ fstepsolindices.update(stepdef[2])
+ return False
+
+def copy_node(newroot, node, subparts=()):
+ newnode = node.__class__(*node.initargs(newroot))
+ for part in subparts:
+ newnode.append(part)
+ return newnode
+
+def same_scope(var):
+ """return true if the variable is always used in the same scope"""
+ try:
+ return var.stinfo['samescope']
+ except KeyError:
+ for rel in var.stinfo['relations']:
+ if not rel.scope is var.scope:
+ var.stinfo['samescope'] = False
+ return False
+ var.stinfo['samescope'] = True
+ return True
+
+def select_group_sort(select): # XXX something similar done in rql2sql
+ # add variables used in groups and sort terms to the selection
+ # if necessary
+ if select.groupby:
+ for vref in select.groupby:
+ if not vref in select.selection:
+ select.append_selected(vref.copy(select))
+ for sortterm in select.orderby:
+ for vref in sortterm.iget_nodes(VariableRef):
+ if not vref in select.get_selected_variables():
+ # we can't directly insert sortterm.term because it references
+ # a variable of the select before the copy.
+ # XXX if constant term are used to define sort, their value
+ # may necessite a decay
+ select.append_selected(vref.copy(select))
+ if select.groupby and not vref in select.groupby:
+ select.add_group_var(vref.copy(select))
+
+
+class PartPlanInformation(object):
+ """regroups necessary information to execute some part of a "global" rql
+ query ("global" means as received by the querier, which may result in
+ several internal queries, e.g. parts, due to security insertions)
+
+ it exposes as well some methods helping in executing this part on a
+ multi-sources repository, modifying its internal structure during the
+ process
+
+ :attr solutions: a list of mappings (varname -> vartype)
+ :attr sourcesvars:
+ a dictionnary telling for each source which variable/solution are
+ supported, of the form {source : {varname: [solution index, ]}}
+ """
+ def __init__(self, plan, rqlst, rqlhelper=None):
+ self.needsplit = False
+ self.temptable = None
+ self.finaltable = None
+ self.plan = plan
+ self.rqlst = rqlst
+ self._session = plan.session
+ self._solutions = rqlst.solutions
+ self._solindices = range(len(self._solutions))
+ # source : {varname: [solution index, ]}
+ self._sourcesvars = {}
+ # dictionnary of variables which are linked to each other using a non
+ # final relation which is supported by multiple sources
+ self._linkedvars = {}
+ # processing
+ self._compute_sourcesvars()
+ self._remove_invalid_sources()
+ #if server.DEBUG:
+ # print 'planner sources vars', self._sourcesvars
+ self._compute_needsplit()
+ self._inputmaps = {}
+ if rqlhelper is not None: # else test
+ self._insert_identity_variable = rqlhelper._annotator.rewrite_shared_optional
+
+ def copy_solutions(self, solindices):
+ return [self._solutions[solidx].copy() for solidx in solindices]
+
+ @property
+ @cached
+ def part_sources(self):
+ if self._sourcesvars:
+ return tuple(sorted(self._sourcesvars))
+ return (self._session.repo.system_source,)
+
+ @property
+ @cached
+ def _sys_source_set(self):
+ return frozenset((self._session.repo.system_source, solindex)
+ for solindex in self._solindices)
+
+ @cached
+ def _norel_support_set(self, rtype):
+ """return a set of (source, solindex) where source doesn't support the
+ relation
+ """
+ return frozenset((source, solidx) for source in self._session.repo.sources
+ for solidx in self._solindices
+ if not (source.support_relation(rtype)
+ or rtype in source.dont_cross_relations))
+
+ def _compute_sourcesvars(self):
+ """compute for each variable/solution in the rqlst which sources support
+ them
+ """
+ repo = self._session.repo
+ eschema = repo.schema.eschema
+ sourcesvars = self._sourcesvars
+ # find for each source which variable/solution are supported
+ for varname, varobj in self.rqlst.defined_vars.items():
+ # if variable has an eid specified, we can get its source directly
+ # NOTE: use uidrels and not constnode to deal with "X eid IN(1,2,3,4)"
+ if varobj.stinfo['uidrels']:
+ vrels = varobj.stinfo['relations'] - varobj.stinfo['uidrels']
+ for rel in varobj.stinfo['uidrels']:
+ if rel.neged(strict=True) or rel.operator() != '=':
+ continue
+ for const in rel.children[1].get_nodes(Constant):
+ eid = const.eval(self.plan.args)
+ source = self._session.source_from_eid(eid)
+ if vrels and not any(source.support_relation(r.r_type)
+ for r in vrels):
+ self._set_source_for_var(repo.system_source, varobj)
+ else:
+ self._set_source_for_var(source, varobj)
+ continue
+ rels = varobj.stinfo['relations']
+ if not rels and not varobj.stinfo['typerels']:
+ # (rare) case where the variable has no type specified nor
+ # relation accessed ex. "Any MAX(X)"
+ self._set_source_for_var(repo.system_source, varobj)
+ continue
+ for i, sol in enumerate(self._solutions):
+ vartype = sol[varname]
+ # skip final variable
+ if eschema(vartype).is_final():
+ break
+ for source in repo.sources:
+ if source.support_entity(vartype):
+ # the source support the entity type, though we will
+ # actually have to fetch from it only if
+ # * the variable isn't invariant
+ # * at least one supported relation specified
+ if not varobj._q_invariant or \
+ any(imap(source.support_relation,
+ (r.r_type for r in rels if r.r_type != 'eid'))):
+ sourcesvars.setdefault(source, {}).setdefault(varobj, set()).add(i)
+ # if variable is not invariant and is used by a relation
+ # not supported by this source, we'll have to split the
+ # query
+ if not varobj._q_invariant and any(ifilterfalse(
+ source.support_relation, (r.r_type for r in rels))):
+ self.needsplit = True
+
+ def _remove_invalid_sources(self):
+ """removes invalid sources from `sourcesvars` member"""
+ repo = self._session.repo
+ rschema = repo.schema.rschema
+ vsources = {}
+ for rel in self.rqlst.iget_nodes(Relation):
+ # process non final relations only
+ # note: don't try to get schema for 'is' relation (not available
+ # during bootstrap)
+ if not rel.is_types_restriction() and not rschema(rel.r_type).is_final():
+ # nothing to do if relation is not supported by multiple sources
+ relsources = [source for source in repo.sources
+ if source.support_relation(rel.r_type)
+ or rel.r_type in source.dont_cross_relations]
+ if len(relsources) < 2:
+ if relsources:# and not relsources[0] in self._sourcesvars:
+ # this means the relation is using a variable inlined as
+ # a constant and another unsupported variable, in which
+ # case we put the relation in sourcesvars
+ self._sourcesvars.setdefault(relsources[0], {})[rel] = set(self._solindices)
+ continue
+ lhs, rhs = rel.get_variable_parts()
+ lhsv, rhsv = getattr(lhs, 'variable', lhs), getattr(rhs, 'variable', rhs)
+ # update dictionnary of sources supporting lhs and rhs vars
+ if not lhsv in vsources:
+ vsources[lhsv] = self._term_sources(lhs)
+ if not rhsv in vsources:
+ vsources[rhsv] = self._term_sources(rhs)
+ self._linkedvars.setdefault(lhsv, set()).add((rhsv, rel))
+ self._linkedvars.setdefault(rhsv, set()).add((lhsv, rel))
+ for term in self._linkedvars:
+ self._remove_sources_until_stable(term, vsources)
+ if len(self._sourcesvars) > 1 and hasattr(self.plan.rqlst, 'main_relations'):
+ # the querier doesn't annotate write queries, need to do it here
+ self.plan.annotate_rqlst()
+ # insert/update/delete queries, we may get extra information from
+ # the main relation (eg relations to the left of the WHERE
+ if self.plan.rqlst.TYPE == 'insert':
+ inserted = dict((vref.variable, etype)
+ for etype, vref in self.plan.rqlst.main_variables)
+ else:
+ inserted = {}
+ for rel in self.plan.rqlst.main_relations:
+ if not rschema(rel.r_type).is_final():
+ # nothing to do if relation is not supported by multiple sources
+ relsources = [source for source in repo.sources
+ if source.support_relation(rel.r_type)
+ or rel.r_type in source.dont_cross_relations]
+ if len(relsources) < 2:
+ continue
+ lhs, rhs = rel.get_variable_parts()
+ try:
+ lhsv = self._extern_term(lhs, vsources, inserted)
+ rhsv = self._extern_term(rhs, vsources, inserted)
+ except KeyError, ex:
+ continue
+ norelsup = self._norel_support_set(rel.r_type)
+ self._remove_var_sources(lhsv, norelsup, rhsv, vsources)
+ self._remove_var_sources(rhsv, norelsup, lhsv, vsources)
+ # cleanup linked var
+ for var, linkedrelsinfo in self._linkedvars.iteritems():
+ self._linkedvars[var] = frozenset(x[0] for x in linkedrelsinfo)
+ # if there are other sources than the system source, consider simplified
+ # variables'source
+ if self._sourcesvars and self._sourcesvars.keys() != [self._session.repo.system_source]:
+ # add source for rewritten constants to sourcesvars
+ for vconsts in self.rqlst.stinfo['rewritten'].itervalues():
+ const = vconsts[0]
+ eid = const.eval(self.plan.args)
+ source = self._session.source_from_eid(eid)
+ if source is self._session.repo.system_source:
+ for const in vconsts:
+ self._set_source_for_var(source, const)
+ elif source in self._sourcesvars:
+ source_scopes = frozenset(v.scope for v in self._sourcesvars[source])
+ for const in vconsts:
+ if const.scope in source_scopes:
+ self._set_source_for_var(source, const)
+
+ def _extern_term(self, term, vsources, inserted):
+ var = term.variable
+ if var.stinfo['constnode']:
+ termv = var.stinfo['constnode']
+ vsources[termv] = self._term_sources(termv)
+ elif var in inserted:
+ termv = var
+ source = self._session.repo.locate_etype_source(inserted[var])
+ vsources[termv] = set((source, solindex) for solindex in self._solindices)
+ else:
+ termv = self.rqlst.defined_vars[var.name]
+ if not termv in vsources:
+ vsources[termv] = self._term_sources(termv)
+ return termv
+
+ def _remove_sources_until_stable(self, var, vsources):
+ for ovar, rel in self._linkedvars.get(var, ()):
+ if not var.scope is ovar.scope and rel.scope.neged(strict=True):
+ # can't get information from relation inside a NOT exists
+ # where variables don't belong to the same scope
+ continue
+ if rel.neged(strict=True):
+ # neged relation doesn't allow to infer variable sources
+ continue
+ norelsup = self._norel_support_set(rel.r_type)
+ # compute invalid sources for variables and remove them
+ self._remove_var_sources(var, norelsup, ovar, vsources)
+ self._remove_var_sources(ovar, norelsup, var, vsources)
+
+ def _remove_var_sources(self, var, norelsup, ovar, vsources):
+ """remove invalid sources for var according to ovar's sources and the
+ relation between those two variables.
+ """
+ varsources = vsources[var]
+ invalid_sources = varsources - (vsources[ovar] | norelsup)
+ if invalid_sources:
+ self._remove_sources(var, invalid_sources)
+ varsources -= invalid_sources
+ self._remove_sources_until_stable(var, vsources)
+
+ def _compute_needsplit(self):
+ """tell according to sourcesvars if the rqlst has to be splitted for
+ execution among multiple sources
+
+ the execution has to be split if
+ * a source support an entity (non invariant) but doesn't support a
+ relation on it
+ * a source support an entity which is accessed by an optional relation
+ * there is more than one sources and either all sources'supported
+ variable/solutions are not equivalent or multiple variables have to
+ be fetched from some source
+ """
+ # NOTE: < 2 since may be 0 on queries such as Any X WHERE X eid 2
+ if len(self._sourcesvars) < 2:
+ self.needsplit = False
+ elif not self.needsplit:
+ if not allequals(self._sourcesvars.itervalues()):
+ self.needsplit = True
+ else:
+ sample = self._sourcesvars.itervalues().next()
+ if len(sample) > 1 and any(v for v in sample
+ if not v in self._linkedvars):
+ self.needsplit = True
+
+ def _set_source_for_var(self, source, var):
+ self._sourcesvars.setdefault(source, {})[var] = set(self._solindices)
+
+ def _term_sources(self, term):
+ """returns possible sources for terms `term`"""
+ if isinstance(term, Constant):
+ source = self._session.source_from_eid(term.eval(self.plan.args))
+ return set((source, solindex) for solindex in self._solindices)
+ else:
+ var = getattr(term, 'variable', term)
+ sources = [source for source, varobjs in self._sourcesvars.iteritems()
+ if var in varobjs]
+ return set((source, solindex) for source in sources
+ for solindex in self._sourcesvars[source][var])
+
+ def _remove_sources(self, var, sources):
+ """removes invalid sources (`sources`) from `sourcesvars`
+
+ :param sources: the list of sources to remove
+ :param var: the analyzed variable
+ """
+ sourcesvars = self._sourcesvars
+ for source, solindex in sources:
+ try:
+ sourcesvars[source][var].remove(solindex)
+ except KeyError:
+ return # may occur with subquery column alias
+ if not sourcesvars[source][var]:
+ del sourcesvars[source][var]
+ if not sourcesvars[source]:
+ del sourcesvars[source]
+
+ def part_steps(self):
+ """precompute necessary part steps before generating actual rql for
+ each step. This is necessary to know if an aggregate step will be
+ necessary or not.
+ """
+ steps = []
+ select = self.rqlst
+ rschema = self.plan.schema.rschema
+ for source in self.part_sources:
+ sourcevars = self._sourcesvars[source]
+ while sourcevars:
+ # take a variable randomly, and all variables supporting the
+ # same solutions
+ var, solindices = self._choose_var(sourcevars)
+ if source.uri == 'system':
+ # ensure all variables are available for the latest step
+ # (missing one will be available from temporary tables
+ # of previous steps)
+ scope = select
+ variables = scope.defined_vars.values() + scope.aliases.values()
+ sourcevars.clear()
+ else:
+ scope = var.scope
+ variables = self._expand_vars(var, sourcevars, scope, solindices)
+ if not sourcevars:
+ del self._sourcesvars[source]
+ # find which sources support the same variables/solutions
+ sources = self._expand_sources(source, variables, solindices)
+ # suppose this is a final step until the contrary is proven
+ final = scope is select
+ # set of variables which should be additionaly selected when
+ # possible
+ needsel = set()
+ # add attribute variables and mark variables which should be
+ # additionaly selected when possible
+ for var in select.defined_vars.itervalues():
+ if not var in variables:
+ stinfo = var.stinfo
+ for ovar, rtype in stinfo['attrvars']:
+ if ovar in variables:
+ needsel.add(var.name)
+ variables.append(var)
+ break
+ else:
+ needsel.add(var.name)
+ final = False
+ if final and source.uri != 'system':
+ # check rewritten constants
+ for vconsts in select.stinfo['rewritten'].itervalues():
+ const = vconsts[0]
+ eid = const.eval(self.plan.args)
+ _source = self._session.source_from_eid(eid)
+ if len(sources) > 1 or not _source in sources:
+ # if constant is only used by an identity relation,
+ # skip
+ for c in vconsts:
+ rel = c.relation()
+ if rel is None or not rel.neged(strict=True):
+ final = False
+ break
+ break
+ # check where all relations are supported by the sources
+ for rel in scope.iget_nodes(Relation):
+ if rel.is_types_restriction():
+ continue
+ # take care not overwriting the existing "source" identifier
+ for _source in sources:
+ if not _source.support_relation(rel.r_type):
+ for vref in rel.iget_nodes(VariableRef):
+ needsel.add(vref.name)
+ final = False
+ break
+ else:
+ if not scope is select:
+ self._exists_relation(rel, variables, needsel)
+ # if relation is supported by all sources and some of
+ # its lhs/rhs variable isn't in "variables", and the
+ # other end *is* in "variables", mark it have to be
+ # selected
+ if source.uri != 'system' and not rschema(rel.r_type).is_final():
+ lhs, rhs = rel.get_variable_parts()
+ try:
+ lhsvar = lhs.variable
+ except AttributeError:
+ lhsvar = lhs
+ try:
+ rhsvar = rhs.variable
+ except AttributeError:
+ rhsvar = rhs
+ if lhsvar in variables and not rhsvar in variables:
+ needsel.add(lhsvar.name)
+ elif rhsvar in variables and not lhsvar in variables:
+ needsel.add(rhsvar.name)
+ if final:
+ self._cleanup_sourcesvars(sources, solindices)
+ # XXX rename: variables may contain Relation and Constant nodes...
+ steps.append( (sources, variables, solindices, scope, needsel,
+ final) )
+ return steps
+
+ def _exists_relation(self, rel, variables, needsel):
+ rschema = self.plan.schema.rschema(rel.r_type)
+ lhs, rhs = rel.get_variable_parts()
+ try:
+ lhsvar, rhsvar = lhs.variable, rhs.variable
+ except AttributeError:
+ pass
+ else:
+ # supported relation with at least one end supported, check the
+ # other end is in as well. If not this usually means the
+ # variable is refed by an outer scope and should be substituted
+ # using an 'identity' relation (else we'll get a conflict of
+ # temporary tables)
+ if rhsvar in variables and not lhsvar in variables:
+ self._identity_substitute(rel, lhsvar, variables, needsel)
+ elif lhsvar in variables and not rhsvar in variables:
+ self._identity_substitute(rel, rhsvar, variables, needsel)
+
+ def _identity_substitute(self, relation, var, variables, needsel):
+ newvar = self._insert_identity_variable(relation.scope, var)
+ if newvar is not None:
+ # ensure relation is using '=' operator, else we rely on a
+ # sqlgenerator side effect (it won't insert an inequality operator
+ # in this case)
+ relation.children[1].operator = '='
+ variables.append(newvar)
+ needsel.add(newvar.name)
+ #self.insertedvars.append((var.name, self.schema['identity'],
+ # newvar.name))
+
+ def _choose_var(self, sourcevars):
+ secondchoice = None
+ if len(self._sourcesvars) > 1:
+ # priority to variable from subscopes
+ for var in sourcevars:
+ if not var.scope is self.rqlst:
+ if isinstance(var, Variable):
+ return var, sourcevars.pop(var)
+ secondchoice = var
+ else:
+ # priority to variable outer scope
+ for var in sourcevars:
+ if var.scope is self.rqlst:
+ if isinstance(var, Variable):
+ return var, sourcevars.pop(var)
+ secondchoice = var
+ if secondchoice is not None:
+ return secondchoice, sourcevars.pop(secondchoice)
+ # priority to variable
+ for var in sourcevars:
+ if isinstance(var, Variable):
+ return var, sourcevars.pop(var)
+ # whatever
+ var = iter(sourcevars).next()
+ return var, sourcevars.pop(var)
+
+ def _expand_vars(self, var, sourcevars, scope, solindices):
+ variables = [var]
+ nbunlinked = 1
+ linkedvars = self._linkedvars
+ # variable has to belong to the same scope if there is more
+ # than the system source remaining
+ if len(self._sourcesvars) > 1 and not scope is self.rqlst:
+ candidates = (v for v in sourcevars.keys() if scope is v.scope)
+ else:
+ candidates = sourcevars #.iterkeys()
+ candidates = [v for v in candidates
+ if isinstance(v, Constant) or
+ (solindices.issubset(sourcevars[v]) and v in linkedvars)]
+ # repeat until no variable can't be added, since addition of a new
+ # variable may permit to another one to be added
+ modified = True
+ while modified and candidates:
+ modified = False
+ for var in candidates[:]:
+ # we only want one unlinked variable in each generated query
+ if isinstance(var, Constant) or \
+ any(v for v in variables if v in linkedvars[var]):
+ variables.append(var)
+ # constant nodes should be systematically deleted
+ if isinstance(var, Constant):
+ del sourcevars[var]
+ # variable nodes should be deleted once all possible solution
+ # indices have been consumed
+ else:
+ sourcevars[var] -= solindices
+ if not sourcevars[var]:
+ del sourcevars[var]
+ candidates.remove(var)
+ modified = True
+ return variables
+
+ def _expand_sources(self, selected_source, vars, solindices):
+ sources = [selected_source]
+ sourcesvars = self._sourcesvars
+ for source in sourcesvars:
+ if source is selected_source:
+ continue
+ for var in vars:
+ if not (var in sourcesvars[source] and
+ solindices.issubset(sourcesvars[source][var])):
+ break
+ else:
+ sources.append(source)
+ if source.uri != 'system':
+ for var in vars:
+ varsolindices = sourcesvars[source][var]
+ varsolindices -= solindices
+ if not varsolindices:
+ del sourcesvars[source][var]
+
+ return sources
+
+ def _cleanup_sourcesvars(self, sources, solindices):
+ """on final parts, remove solutions so we know they are already processed"""
+ for source in sources:
+ try:
+ sourcevar = self._sourcesvars[source]
+ except KeyError:
+ continue
+ for var, varsolindices in sourcevar.items():
+ varsolindices -= solindices
+ if not varsolindices:
+ del sourcevar[var]
+
+ def merge_input_maps(self, allsolindices):
+ """inputmaps is a dictionary with tuple of solution indices as key with an
+ associateed input map as value. This function compute for each solution
+ its necessary input map and return them grouped
+
+ ex:
+ inputmaps = {(0, 1, 2): {'A': 't1.login1', 'U': 't1.C0', 'U.login': 't1.login1'},
+ (1,): {'X': 't2.C0', 'T': 't2.C1'}}
+ return : [([1], {'A': 't1.login1', 'U': 't1.C0', 'U.login': 't1.login1',
+ 'X': 't2.C0', 'T': 't2.C1'}),
+ ([0,2], {'A': 't1.login1', 'U': 't1.C0', 'U.login': 't1.login1'})]
+ """
+ if not self._inputmaps:
+ return [(allsolindices, None)]
+ mapbysol = {}
+ # compute a single map for each solution
+ for solindices, basemap in self._inputmaps.iteritems():
+ for solindex in solindices:
+ solmap = mapbysol.setdefault(solindex, {})
+ solmap.update(basemap)
+ try:
+ allsolindices.remove(solindex)
+ except KeyError:
+ continue # already removed
+ # group results by identical input map
+ result = []
+ for solindex, solmap in mapbysol.iteritems():
+ for solindices, commonmap in result:
+ if commonmap == solmap:
+ solindices.append(solindex)
+ break
+ else:
+ result.append( ([solindex], solmap) )
+ if allsolindices:
+ result.append( (list(allsolindices), None) )
+ return result
+
+ def build_final_part(self, select, solindices, inputmap, sources,
+ insertedvars):
+ plan = self.plan
+ rqlst = plan.finalize(select, [self._solutions[i] for i in solindices],
+ insertedvars)
+ if self.temptable is None and self.finaltable is None:
+ return OneFetchStep(plan, rqlst, sources, inputmap=inputmap)
+ table = self.temptable or self.finaltable
+ return FetchStep(plan, rqlst, sources, table, True, inputmap)
+
+ def build_non_final_part(self, select, solindices, sources, insertedvars,
+ table):
+ """non final step, will have to store results in a temporary table"""
+ plan = self.plan
+ rqlst = plan.finalize(select, [self._solutions[i] for i in solindices],
+ insertedvars)
+ step = FetchStep(plan, rqlst, sources, table, False)
+ # update input map for following steps, according to processed solutions
+ inputmapkey = tuple(sorted(solindices))
+ inputmap = self._inputmaps.setdefault(inputmapkey, {})
+ inputmap.update(step.outputmap)
+ plan.add_step(step)
+
+
+class MSPlanner(SSPlanner):
+ """MultiSourcesPlanner: build execution plan for rql queries
+
+ decompose the RQL query according to sources'schema
+ """
+
+ def build_select_plan(self, plan, rqlst):
+ """build execution plan for a SELECT RQL query
+
+ the rqlst should not be tagged at this point
+ """
+ if server.DEBUG:
+ print '-'*80
+ print 'PLANNING', rqlst
+ for select in rqlst.children:
+ if len(select.solutions) > 1:
+ hasmultiplesols = True
+ break
+ else:
+ hasmultiplesols = False
+ # preprocess deals with security insertion and returns a new syntax tree
+ # which have to be executed to fulfill the query: according
+ # to permissions for variable's type, different rql queries may have to
+ # be executed
+ plan.preprocess(rqlst)
+ ppis = [PartPlanInformation(plan, select, self.rqlhelper)
+ for select in rqlst.children]
+ steps = self._union_plan(plan, rqlst, ppis)
+ if server.DEBUG:
+ from pprint import pprint
+ for step in plan.steps:
+ pprint(step.test_repr())
+ pprint(steps[0].test_repr())
+ return steps
+
+ def _ppi_subqueries(self, ppi):
+ # part plan info for subqueries
+ plan = ppi.plan
+ inputmap = {}
+ for subquery in ppi.rqlst.with_[:]:
+ sppis = [PartPlanInformation(plan, select)
+ for select in subquery.query.children]
+ for sppi in sppis:
+ if sppi.needsplit or sppi.part_sources != ppi.part_sources:
+ temptable = 'T%s' % make_uid(id(subquery))
+ sstep = self._union_plan(plan, subquery.query, sppis, temptable)[0]
+ break
+ else:
+ sstep = None
+ if sstep is not None:
+ ppi.rqlst.with_.remove(subquery)
+ for i, colalias in enumerate(subquery.aliases):
+ inputmap[colalias.name] = '%s.C%s' % (temptable, i)
+ ppi.plan.add_step(sstep)
+ return inputmap
+
+ def _union_plan(self, plan, union, ppis, temptable=None):
+ tosplit, cango, allsources = [], {}, set()
+ for planinfo in ppis:
+ if planinfo.needsplit:
+ tosplit.append(planinfo)
+ else:
+ cango.setdefault(planinfo.part_sources, []).append(planinfo)
+ for source in planinfo.part_sources:
+ allsources.add(source)
+ # first add steps for query parts which doesn't need to splitted
+ steps = []
+ for sources, cppis in cango.iteritems():
+ byinputmap = {}
+ for ppi in cppis:
+ select = ppi.rqlst
+ if sources != (plan.session.repo.system_source,):
+ add_types_restriction(self.schema, select)
+ # part plan info for subqueries
+ inputmap = self._ppi_subqueries(ppi)
+ aggrstep = need_aggr_step(select, sources)
+ if aggrstep:
+ atemptable = 'T%s' % make_uid(id(select))
+ sunion = Union()
+ sunion.append(select)
+ selected = select.selection[:]
+ select_group_sort(select)
+ step = AggrStep(plan, selected, select, atemptable, temptable)
+ step.set_limit_offset(select.limit, select.offset)
+ select.limit = None
+ select.offset = 0
+ fstep = FetchStep(plan, sunion, sources, atemptable, True, inputmap)
+ step.children.append(fstep)
+ steps.append(step)
+ else:
+ byinputmap.setdefault(tuple(inputmap.iteritems()), []).append( (select) )
+ for inputmap, queries in byinputmap.iteritems():
+ inputmap = dict(inputmap)
+ sunion = Union()
+ for select in queries:
+ sunion.append(select)
+ if temptable:
+ steps.append(FetchStep(plan, sunion, sources, temptable, True, inputmap))
+ else:
+ steps.append(OneFetchStep(plan, sunion, sources, inputmap))
+ # then add steps for splitted query parts
+ for planinfo in tosplit:
+ steps.append(self.split_part(planinfo, temptable))
+ if len(steps) > 1:
+ if temptable:
+ step = UnionFetchStep(plan)
+ else:
+ step = UnionStep(plan)
+ step.children = steps
+ return (step,)
+ return steps
+
+ # internal methods for multisources decomposition #########################
+
+ def split_part(self, ppi, temptable):
+ ppi.finaltable = temptable
+ plan = ppi.plan
+ select = ppi.rqlst
+ subinputmap = self._ppi_subqueries(ppi)
+ stepdefs = ppi.part_steps()
+ if need_aggr_step(select, ppi.part_sources, stepdefs):
+ atemptable = 'T%s' % make_uid(id(select))
+ selection = select.selection[:]
+ select_group_sort(select)
+ else:
+ atemptable = None
+ selection = select.selection
+ ppi.temptable = atemptable
+ vfilter = VariablesFiltererVisitor(self.schema, ppi)
+ steps = []
+ for sources, variables, solindices, scope, needsel, final in stepdefs:
+ # extract an executable query using only the specified variables
+ if sources[0].uri == 'system':
+ # in this case we have to merge input maps before call to
+ # filter so already processed restriction are correctly
+ # removed
+ solsinputmaps = ppi.merge_input_maps(solindices)
+ for solindices, inputmap in solsinputmaps:
+ minrqlst, insertedvars = vfilter.filter(
+ sources, variables, scope, set(solindices), needsel, final)
+ if inputmap is None:
+ inputmap = subinputmap
+ else:
+ inputmap.update(subinputmap)
+ steps.append(ppi.build_final_part(minrqlst, solindices, inputmap,
+ sources, insertedvars))
+ else:
+ # this is a final part (i.e. retreiving results for the
+ # original query part) if all variable / sources have been
+ # treated or if this is the last shot for used solutions
+ minrqlst, insertedvars = vfilter.filter(
+ sources, variables, scope, solindices, needsel, final)
+ if final:
+ solsinputmaps = ppi.merge_input_maps(solindices)
+ for solindices, inputmap in solsinputmaps:
+ if inputmap is None:
+ inputmap = subinputmap
+ else:
+ inputmap.update(subinputmap)
+ steps.append(ppi.build_final_part(minrqlst, solindices, inputmap,
+ sources, insertedvars))
+ else:
+ table = '_T%s%s' % (''.join(sorted(v._ms_table_key() for v in variables)),
+ ''.join(sorted(str(i) for i in solindices)))
+ ppi.build_non_final_part(minrqlst, solindices, sources,
+ insertedvars, table)
+ # finally: join parts, deal with aggregat/group/sorts if necessary
+ if atemptable is not None:
+ step = AggrStep(plan, selection, select, atemptable, temptable)
+ step.children = steps
+ elif len(steps) > 1:
+ if temptable:
+ step = UnionFetchStep(plan)
+ else:
+ step = UnionStep(plan)
+ step.children = steps
+ else:
+ step = steps[0]
+ if select.limit is not None or select.offset:
+ step.set_limit_offset(select.limit, select.offset)
+ return step
+
+
+class UnsupportedBranch(Exception):
+ pass
+
+
+class VariablesFiltererVisitor(object):
+ def __init__(self, schema, ppi):
+ self.schema = schema
+ self.ppi = ppi
+ self.skip = {}
+ self.hasaggrstep = self.ppi.temptable
+ self.extneedsel = frozenset(vref.name for sortterm in ppi.rqlst.orderby
+ for vref in sortterm.iget_nodes(VariableRef))
+
+ def _rqlst_accept(self, rqlst, node, newroot, variables, setfunc=None):
+ try:
+ newrestr, node_ = node.accept(self, newroot, variables[:])
+ except UnsupportedBranch:
+ return rqlst
+ if setfunc is not None and newrestr is not None:
+ setfunc(newrestr)
+ if not node_ is node:
+ rqlst = node.parent
+ return rqlst
+
+ def filter(self, sources, variables, rqlst, solindices, needsel, final):
+ if server.DEBUG:
+ print 'filter', final and 'final' or '', sources, variables, rqlst, solindices, needsel
+ newroot = Select()
+ self.sources = sources
+ self.solindices = solindices
+ self.final = final
+ # variables which appear in unsupported branches
+ needsel |= self.extneedsel
+ self.needsel = needsel
+ # variables which appear in supported branches
+ self.mayneedsel = set()
+ # new inserted variables
+ self.insertedvars = []
+ # other structures (XXX document)
+ self.mayneedvar, self.hasvar = {}, {}
+ self.use_only_defined = False
+ self.scopes = {rqlst: newroot}
+ if rqlst.where:
+ rqlst = self._rqlst_accept(rqlst, rqlst.where, newroot, variables,
+ newroot.set_where)
+ if isinstance(rqlst, Select):
+ self.use_only_defined = True
+ if rqlst.groupby:
+ groupby = []
+ for node in rqlst.groupby:
+ rqlst = self._rqlst_accept(rqlst, node, newroot, variables,
+ groupby.append)
+ if groupby:
+ newroot.set_groupby(groupby)
+ if rqlst.having:
+ having = []
+ for node in rqlst.having:
+ rqlst = self._rqlst_accept(rqlst, node, newroot, variables,
+ having.append)
+ if having:
+ newroot.set_having(having)
+ if final and rqlst.orderby and not self.hasaggrstep:
+ orderby = []
+ for node in rqlst.orderby:
+ rqlst = self._rqlst_accept(rqlst, node, newroot, variables,
+ orderby.append)
+ if orderby:
+ newroot.set_orderby(orderby)
+ self.process_selection(newroot, variables, rqlst)
+ elif not newroot.where:
+ # no restrictions have been copied, just select variables and add
+ # type restriction (done later by add_types_restriction)
+ for v in variables:
+ if not isinstance(v, Variable):
+ continue
+ newroot.append_selected(VariableRef(newroot.get_variable(v.name)))
+ solutions = self.ppi.copy_solutions(solindices)
+ cleanup_solutions(newroot, solutions)
+ newroot.set_possible_types(solutions)
+ if final:
+ if self.hasaggrstep:
+ self.add_necessary_selection(newroot, self.mayneedsel & self.extneedsel)
+ newroot.distinct = rqlst.distinct
+ else:
+ self.add_necessary_selection(newroot, self.mayneedsel & self.needsel)
+ # insert vars to fetch constant values when needed
+ for (varname, rschema), reldefs in self.mayneedvar.iteritems():
+ for rel, ored in reldefs:
+ if not (varname, rschema) in self.hasvar:
+ self.hasvar[(varname, rschema)] = None # just to avoid further insertion
+ cvar = newroot.make_variable()
+ for sol in newroot.solutions:
+ sol[cvar.name] = rschema.objects(sol[varname])[0]
+ # if the current restriction is not used in a OR branch,
+ # we can keep it, else we have to drop the constant
+ # restriction (or we may miss some results)
+ if not ored:
+ rel = rel.copy(newroot)
+ newroot.add_restriction(rel)
+ # add a relation to link the variable
+ newroot.remove_node(rel.children[1])
+ cmp = Comparison('=')
+ rel.append(cmp)
+ cmp.append(VariableRef(cvar))
+ self.insertedvars.append((varname, rschema, cvar.name))
+ newroot.append_selected(VariableRef(newroot.get_variable(cvar.name)))
+ # NOTE: even if the restriction is done by this query, we have
+ # to let it in the original rqlst so that it appears anyway in
+ # the "final" query, else we may change the meaning of the query
+ # if there are NOT somewhere :
+ # 'NOT X relation Y, Y name "toto"' means X WHERE X isn't related
+ # to Y whose name is toto while
+ # 'NOT X relation Y' means X WHERE X has no 'relation' (whatever Y)
+ elif ored:
+ newroot.remove_node(rel)
+ add_types_restriction(self.schema, rqlst, newroot, solutions)
+ if server.DEBUG:
+ print '--->', newroot
+ return newroot, self.insertedvars
+
+ def visit_and(self, node, newroot, variables):
+ subparts = []
+ for i in xrange(len(node.children)):
+ child = node.children[i]
+ try:
+ newchild, child_ = child.accept(self, newroot, variables)
+ if not child_ is child:
+ node = child_.parent
+ if newchild is None:
+ continue
+ subparts.append(newchild)
+ except UnsupportedBranch:
+ continue
+ if not subparts:
+ return None, node
+ if len(subparts) == 1:
+ return subparts[0], node
+ return copy_node(newroot, node, subparts), node
+
+ visit_or = visit_and
+
+ def _relation_supported(self, rtype):
+ for source in self.sources:
+ if not source.support_relation(rtype):
+ return False
+ return True
+
+ def visit_relation(self, node, newroot, variables):
+ if not node.is_types_restriction():
+ if node in self.skip and self.solindices.issubset(self.skip[node]):
+ if not self.schema.rschema(node.r_type).is_final():
+ # can't really skip the relation if one variable is selected and only
+ # referenced by this relation
+ for vref in node.iget_nodes(VariableRef):
+ stinfo = vref.variable.stinfo
+ if stinfo['selected'] and len(stinfo['relations']) == 1:
+ break
+ else:
+ return None, node
+ else:
+ return None, node
+ if not self._relation_supported(node.r_type):
+ raise UnsupportedBranch()
+ # don't copy type restriction unless this is the only relation for the
+ # rhs variable, else they'll be reinserted later as needed (else we may
+ # copy a type restriction while the variable is not actually used)
+ elif not any(self._relation_supported(rel.r_type)
+ for rel in node.children[0].variable.stinfo['relations']):
+ rel, node = self.visit_default(node, newroot, variables)
+ return rel, node
+ else:
+ raise UnsupportedBranch()
+ rschema = self.schema.rschema(node.r_type)
+ res = self.visit_default(node, newroot, variables)[0]
+ ored = node.ored()
+ if rschema.is_final() or rschema.inlined:
+ vrefs = node.children[1].get_nodes(VariableRef)
+ if not vrefs:
+ if not ored:
+ self.skip.setdefault(node, set()).update(self.solindices)
+ else:
+ self.mayneedvar.setdefault((node.children[0].name, rschema), []).append( (res, ored) )
+
+ else:
+ assert len(vrefs) == 1
+ vref = vrefs[0]
+ # XXX check operator ?
+ self.hasvar[(node.children[0].name, rschema)] = vref
+ if self._may_skip_attr_rel(rschema, node, vref, ored, variables, res):
+ self.skip.setdefault(node, set()).update(self.solindices)
+ elif not ored:
+ self.skip.setdefault(node, set()).update(self.solindices)
+ return res, node
+
+ def _may_skip_attr_rel(self, rschema, rel, vref, ored, variables, res):
+ var = vref.variable
+ if ored:
+ return False
+ if var.name in self.extneedsel or var.stinfo['selected']:
+ return False
+ if not same_scope(var):
+ return False
+ if any(v for v,_ in var.stinfo['attrvars'] if not v.name in variables):
+ return False
+ return True
+
+ def visit_exists(self, node, newroot, variables):
+ newexists = node.__class__()
+ self.scopes = {node: newexists}
+ subparts, node = self._visit_children(node, newroot, variables)
+ if not subparts:
+ return None, node
+ newexists.set_where(subparts[0])
+ return newexists, node
+
+ def visit_not(self, node, newroot, variables):
+ subparts, node = self._visit_children(node, newroot, variables)
+ if not subparts:
+ return None, node
+ return copy_node(newroot, node, subparts), node
+
+ def visit_group(self, node, newroot, variables):
+ if not self.final:
+ return None, node
+ return self.visit_default(node, newroot, variables)
+
+ def visit_variableref(self, node, newroot, variables):
+ if self.use_only_defined:
+ if not node.variable.name in newroot.defined_vars:
+ raise UnsupportedBranch(node.name)
+ elif not node.variable in variables:
+ raise UnsupportedBranch(node.name)
+ self.mayneedsel.add(node.name)
+ # set scope so we can insert types restriction properly
+ newvar = newroot.get_variable(node.name)
+ newvar.stinfo['scope'] = self.scopes.get(node.variable.scope, newroot)
+ return VariableRef(newvar), node
+
+ def visit_constant(self, node, newroot, variables):
+ return copy_node(newroot, node), node
+
+ def visit_default(self, node, newroot, variables):
+ subparts, node = self._visit_children(node, newroot, variables)
+ return copy_node(newroot, node, subparts), node
+
+ visit_comparison = visit_mathexpression = visit_constant = visit_function = visit_default
+ visit_sort = visit_sortterm = visit_default
+
+ def _visit_children(self, node, newroot, variables):
+ subparts = []
+ for i in xrange(len(node.children)):
+ child = node.children[i]
+ newchild, child_ = child.accept(self, newroot, variables)
+ if not child is child_:
+ node = child_.parent
+ if newchild is not None:
+ subparts.append(newchild)
+ return subparts, node
+
+ def process_selection(self, newroot, variables, rqlst):
+ if self.final:
+ for term in rqlst.selection:
+ newroot.append_selected(term.copy(newroot))
+ for vref in term.get_nodes(VariableRef):
+ self.needsel.add(vref.name)
+ return
+ for term in rqlst.selection:
+ vrefs = term.get_nodes(VariableRef)
+ if vrefs:
+ supportedvars = []
+ for vref in vrefs:
+ var = vref.variable
+ if var in variables:
+ supportedvars.append(vref)
+ continue
+ else:
+ self.needsel.add(vref.name)
+ break
+ else:
+ for vref in vrefs:
+ newroot.append_selected(vref.copy(newroot))
+ supportedvars = []
+ for vref in supportedvars:
+ if not vref in newroot.get_selected_variables():
+ newroot.append_selected(VariableRef(newroot.get_variable(vref.name)))
+
+ def add_necessary_selection(self, newroot, variables):
+ selected = tuple(newroot.get_selected_variables())
+ for varname in variables:
+ var = newroot.defined_vars[varname]
+ for vref in var.references():
+ rel = vref.relation()
+ if rel is None and vref in selected:
+ # already selected
+ break
+ else:
+ selvref = VariableRef(var)
+ newroot.append_selected(selvref)
+ if newroot.groupby:
+ newroot.add_group_var(VariableRef(selvref.variable, noautoref=1))
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/server/mssteps.py Mon Dec 22 17:34:15 2008 +0100
@@ -0,0 +1,275 @@
+"""Defines the diferent querier steps usable in plans.
+
+FIXME : this code needs refactoring. Some problems :
+* get data from the parent plan, the latest step, temporary table...
+* each step has is own members (this is not necessarily bad, but a bit messy
+ for now)
+
+:organization: Logilab
+:copyright: 2003-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
+"""
+__docformat__ = "restructuredtext en"
+
+from rql.nodes import VariableRef, Variable, Function
+
+from cubicweb.server.ssplanner import (LimitOffsetMixIn, Step, OneFetchStep,
+ varmap_test_repr, offset_result)
+
+AGGR_TRANSFORMS = {'COUNT':'SUM', 'MIN':'MIN', 'MAX':'MAX', 'SUM': 'SUM'}
+
+def remove_clauses(union, keepgroup):
+ clauses = []
+ for select in union.children:
+ if keepgroup:
+ having, orderby = select.having, select.orderby
+ select.having, select.orderby = None, None
+ clauses.append( (having, orderby) )
+ else:
+ groupby, having, orderby = select.groupby, select.having, select.orderby
+ select.groupby, select.having, select.orderby = None, None, None
+ clauses.append( (groupby, having, orderby) )
+ return clauses
+
+def restore_clauses(union, keepgroup, clauses):
+ for i, select in enumerate(union.children):
+ if keepgroup:
+ select.having, select.orderby = clauses[i]
+ else:
+ select.groupby, select.having, select.orderby = clauses[i]
+
+
+class FetchStep(OneFetchStep):
+ """step consisting in fetching data from sources, and storing result in
+ a temporary table
+ """
+ def __init__(self, plan, union, sources, table, keepgroup, inputmap=None):
+ OneFetchStep.__init__(self, plan, union, sources)
+ # temporary table to store step result
+ self.table = table
+ # should groupby clause be kept or not
+ self.keepgroup = keepgroup
+ # variables mapping to use as input
+ self.inputmap = inputmap
+ # output variable mapping
+ srqlst = union.children[0] # sample select node
+ # add additional information to the output mapping
+ self.outputmap = plan.init_temp_table(table, srqlst.selection,
+ srqlst.solutions[0])
+ for vref in srqlst.selection:
+ if not isinstance(vref, VariableRef):
+ continue
+ var = vref.variable
+ if var.stinfo['attrvars']:
+ for lhsvar, rtype in var.stinfo['attrvars']:
+ if lhsvar.name in srqlst.defined_vars:
+ key = '%s.%s' % (lhsvar.name, rtype)
+ self.outputmap[key] = self.outputmap[var.name]
+ else:
+ rschema = self.plan.schema.rschema
+ for rel in var.stinfo['rhsrelations']:
+ if rschema(rel.r_type).inlined:
+ lhsvar = rel.children[0]
+ if lhsvar.name in srqlst.defined_vars:
+ key = '%s.%s' % (lhsvar.name, rel.r_type)
+ self.outputmap[key] = self.outputmap[var.name]
+
+ def execute(self):
+ """execute this step"""
+ self.execute_children()
+ plan = self.plan
+ plan.create_temp_table(self.table)
+ union = self.union
+ # XXX 2.5 use "with"
+ clauses = remove_clauses(union, self.keepgroup)
+ for source in self.sources:
+ source.flying_insert(self.table, plan.session, union, plan.args,
+ self.inputmap)
+ restore_clauses(union, self.keepgroup, clauses)
+
+ def mytest_repr(self):
+ """return a representation of this step suitable for test"""
+ clauses = remove_clauses(self.union, self.keepgroup)
+ try:
+ inputmap = varmap_test_repr(self.inputmap, self.plan.tablesinorder)
+ outputmap = varmap_test_repr(self.outputmap, self.plan.tablesinorder)
+ except AttributeError:
+ inputmap = self.inputmap
+ outputmap = self.outputmap
+ try:
+ return (self.__class__.__name__,
+ sorted((r.as_string(kwargs=self.plan.args), r.solutions)
+ for r in self.union.children),
+ sorted(self.sources), inputmap, outputmap)
+ finally:
+ restore_clauses(self.union, self.keepgroup, clauses)
+
+
+class AggrStep(LimitOffsetMixIn, Step):
+ """step consisting in making aggregat from temporary data in the system
+ source
+ """
+ def __init__(self, plan, selection, select, table, outputtable=None):
+ Step.__init__(self, plan)
+ # original selection
+ self.selection = selection
+ # original Select RQL tree
+ self.select = select
+ # table where are located temporary results
+ self.table = table
+ # optional table where to write results
+ self.outputtable = outputtable
+ if outputtable is not None:
+ plan.init_temp_table(outputtable, selection, select.solutions[0])
+
+ #self.inputmap = inputmap
+
+ def mytest_repr(self):
+ """return a representation of this step suitable for test"""
+ sel = self.select.selection
+ restr = self.select.where
+ self.select.selection = self.selection
+ self.select.where = None
+ rql = self.select.as_string(kwargs=self.plan.args)
+ self.select.selection = sel
+ self.select.where = restr
+ try:
+ # rely on a monkey patch (cf unittest_querier)
+ table = self.plan.tablesinorder[self.table]
+ outputtable = self.outputtable and self.plan.tablesinorder[self.outputtable]
+ except AttributeError:
+ # not monkey patched
+ table = self.table
+ outputtable = self.outputtable
+ return (self.__class__.__name__, rql, self.limit, self.offset, table,
+ outputtable)
+
+ def execute(self):
+ """execute this step"""
+ self.execute_children()
+ self.inputmap = inputmap = self.children[-1].outputmap
+ # get the select clause
+ clause = []
+ for i, term in enumerate(self.selection):
+ try:
+ var_name = inputmap[term.as_string()]
+ except KeyError:
+ var_name = 'C%s' % i
+ if isinstance(term, Function):
+ # we have to translate some aggregat function
+ # (for instance COUNT -> SUM)
+ orig_name = term.name
+ try:
+ term.name = AGGR_TRANSFORMS[term.name]
+ # backup and reduce children
+ orig_children = term.children
+ term.children = [VariableRef(Variable(var_name))]
+ clause.append(term.accept(self))
+ # restaure the tree XXX necessary?
+ term.name = orig_name
+ term.children = orig_children
+ except KeyError:
+ clause.append(var_name)
+ else:
+ clause.append(var_name)
+ for vref in term.iget_nodes(VariableRef):
+ inputmap[vref.name] = var_name
+ # XXX handle distinct with non selected sort term
+ if self.select.distinct:
+ sql = ['SELECT DISTINCT %s' % ', '.join(clause)]
+ else:
+ sql = ['SELECT %s' % ', '.join(clause)]
+ sql.append("FROM %s" % self.table)
+ # get the group/having clauses
+ if self.select.groupby:
+ clause = [inputmap[var.name] for var in self.select.groupby]
+ grouped = set(var.name for var in self.select.groupby)
+ sql.append('GROUP BY %s' % ', '.join(clause))
+ else:
+ grouped = None
+ if self.select.having:
+ clause = [term.accept(self) for term in self.select.having]
+ sql.append('HAVING %s' % ', '.join(clause))
+ # get the orderby clause
+ if self.select.orderby:
+ clause = []
+ for sortterm in self.select.orderby:
+ sqlterm = sortterm.term.accept(self)
+ if sortterm.asc:
+ clause.append(sqlterm)
+ else:
+ clause.append('%s DESC' % sqlterm)
+ if grouped is not None:
+ for vref in sortterm.iget_nodes(VariableRef):
+ if not vref.name in grouped:
+ sql[-1] += ', ' + self.inputmap[vref.name]
+ grouped.add(vref.name)
+ sql.append('ORDER BY %s' % ', '.join(clause))
+ if self.limit:
+ sql.append('LIMIT %s' % self.limit)
+ if self.offset:
+ sql.append('OFFSET %s' % self.offset)
+ #print 'DATA', plan.sqlexec('SELECT * FROM %s' % self.table, None)
+ sql = ' '.join(sql)
+ if self.outputtable:
+ self.plan.create_temp_table(self.outputtable)
+ sql = 'INSERT INTO %s %s' % (self.outputtable, sql)
+ return self.plan.sqlexec(sql, self.plan.args)
+
+ def visit_function(self, function):
+ """generate SQL name for a function"""
+ return '%s(%s)' % (function.name,
+ ','.join(c.accept(self) for c in function.children))
+
+ def visit_variableref(self, variableref):
+ """get the sql name for a variable reference"""
+ try:
+ return self.inputmap[variableref.name]
+ except KeyError: # XXX duh? explain
+ return variableref.variable.name
+
+ def visit_constant(self, constant):
+ """generate SQL name for a constant"""
+ assert constant.type == 'Int'
+ return str(constant.value)
+
+
+class UnionStep(LimitOffsetMixIn, Step):
+ """union results of child in-memory steps (e.g. OneFetchStep / AggrStep)"""
+
+ def execute(self):
+ """execute this step"""
+ result = []
+ limit = olimit = self.limit
+ offset = self.offset
+ assert offset != 0
+ if offset is not None:
+ limit = limit + offset
+ for step in self.children:
+ if limit is not None:
+ if offset is None:
+ limit = olimit - len(result)
+ step.set_limit_offset(limit, None)
+ result_ = step.execute()
+ if offset is not None:
+ offset, result_ = offset_result(offset, result_)
+ result += result_
+ if limit is not None:
+ if len(result) >= olimit:
+ return result[:olimit]
+ return result
+
+ def mytest_repr(self):
+ """return a representation of this step suitable for test"""
+ return (self.__class__.__name__, self.limit, self.offset)
+
+
+class UnionFetchStep(Step):
+ """union results of child steps using temporary tables (e.g. FetchStep)"""
+
+ def execute(self):
+ """execute this step"""
+ self.execute_children()
+
+
+__all__ = ('FetchStep', 'AggrStep', 'UnionStep', 'UnionFetchStep')
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/server/sources/extlite.py Mon Dec 22 17:34:15 2008 +0100
@@ -0,0 +1,247 @@
+"""provide an abstract class for external sources using a sqlite database helper
+
+:organization: Logilab
+:copyright: 2007-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
+"""
+__docformat__ = "restructuredtext en"
+
+
+import time
+import threading
+from os.path import join, exists
+
+from cubicweb import server
+from cubicweb.server.sqlutils import sqlexec, SQLAdapterMixIn
+from cubicweb.server.sources import AbstractSource, native
+from cubicweb.server.sources.rql2sql import SQLGenerator
+
+def timeout_acquire(lock, timeout):
+ while not lock.acquire(False):
+ time.sleep(0.2)
+ timeout -= 0.2
+ if timeout <= 0:
+ raise RuntimeError("svn source is busy, can't acquire connection lock")
+
+class ConnectionWrapper(object):
+ def __init__(self, source=None):
+ self.source = source
+ self._cnx = None
+
+ @property
+ def cnx(self):
+ if self._cnx is None:
+ timeout_acquire(self.source._cnxlock, 5)
+ self._cnx = self.source._sqlcnx
+ return self._cnx
+
+ def commit(self):
+ if self._cnx is not None:
+ self._cnx.commit()
+
+ def rollback(self):
+ if self._cnx is not None:
+ self._cnx.rollback()
+
+ def cursor(self):
+ return self.cnx.cursor()
+
+
+class SQLiteAbstractSource(AbstractSource):
+ """an abstract class for external sources using a sqlite database helper
+ """
+ sqlgen_class = SQLGenerator
+ @classmethod
+ def set_nonsystem_types(cls):
+ # those entities are only in this source, we don't want them in the
+ # system source
+ for etype in cls.support_entities:
+ native.NONSYSTEM_ETYPES.add(etype)
+ for rtype in cls.support_relations:
+ native.NONSYSTEM_RELATIONS.add(rtype)
+
+ options = (
+ ('helper-db-path',
+ {'type' : 'string',
+ 'default': None,
+ 'help': 'path to the sqlite database file used to do queries on the \
+repository.',
+ 'inputlevel': 2,
+ }),
+ )
+
+ def __init__(self, repo, appschema, source_config, *args, **kwargs):
+ # the helper db is used to easy querying and will store everything but
+ # actual file content
+ dbpath = source_config.get('helper-db-path')
+ if dbpath is None:
+ dbpath = join(repo.config.appdatahome,
+ '%(uri)s.sqlite' % source_config)
+ self.dbpath = dbpath
+ self.sqladapter = SQLAdapterMixIn({'db-driver': 'sqlite',
+ 'db-name': dbpath})
+ # those attributes have to be initialized before ancestor's __init__
+ # which will call set_schema
+ self._need_sql_create = not exists(dbpath)
+ self._need_full_import = self._need_sql_create
+ AbstractSource.__init__(self, repo, appschema, source_config,
+ *args, **kwargs)
+ # sql database can only be accessed by one connection at a time, and a
+ # connection can only be used by the thread which created it so:
+ # * create the connection when needed
+ # * use a lock to be sure only one connection is used
+ self._cnxlock = threading.Lock()
+
+ @property
+ def _sqlcnx(self):
+ # XXX: sqlite connections can only be used in the same thread, so
+ # create a new one each time necessary. If it appears to be time
+ # consuming, find another way
+ return self.sqladapter.get_connection()
+
+ def _is_schema_complete(self):
+ for etype in self.support_entities:
+ if not etype in self.schema:
+ self.warning('not ready to generate %s database, %s support missing from schema',
+ self.uri, etype)
+ return False
+ for rtype in self.support_relations:
+ if not rtype in self.schema:
+ self.warning('not ready to generate %s database, %s support missing from schema',
+ self.uri, rtype)
+ return False
+ return True
+
+ def _create_database(self):
+ from yams.schema2sql import eschema2sql, rschema2sql
+ from cubicweb.toolsutils import restrict_perms_to_user
+ self.warning('initializing sqlite database for %s source' % self.uri)
+ cnx = self._sqlcnx
+ cu = cnx.cursor()
+ schema = self.schema
+ for etype in self.support_entities:
+ eschema = schema.eschema(etype)
+ createsqls = eschema2sql(self.sqladapter.dbhelper, eschema,
+ skip_relations=('data',))
+ sqlexec(createsqls, cu, withpb=False)
+ for rtype in self.support_relations:
+ rschema = schema.rschema(rtype)
+ if not rschema.inlined:
+ sqlexec(rschema2sql(rschema), cu, withpb=False)
+ cnx.commit()
+ cnx.close()
+ self._need_sql_create = False
+ if self.repo.config['uid']:
+ from logilab.common.shellutils import chown
+ # database file must be owned by the uid of the server process
+ self.warning('set %s as owner of the database file',
+ self.repo.config['uid'])
+ chown(self.dbpath, self.repo.config['uid'])
+ restrict_perms_to_user(self.dbpath, self.info)
+
+ def set_schema(self, schema):
+ super(SQLiteAbstractSource, self).set_schema(schema)
+ if self._need_sql_create and self._is_schema_complete():
+ self._create_database()
+ self.rqlsqlgen = self.sqlgen_class(schema, self.sqladapter.dbhelper)
+
+ def get_connection(self):
+ return ConnectionWrapper(self)
+
+ def check_connection(self, cnx):
+ """check connection validity, return None if the connection is still valid
+ else a new connection (called when the pool using the given connection is
+ being attached to a session)
+
+ always return the connection to reset eventually cached cursor
+ """
+ return cnx
+
+ def pool_reset(self, cnx):
+ """the pool using the given connection is being reseted from its current
+ attached session: release the connection lock if the connection wrapper
+ has a connection set
+ """
+ if cnx._cnx is not None:
+ try:
+ cnx._cnx.close()
+ cnx._cnx = None
+ finally:
+ self._cnxlock.release()
+
+ def syntax_tree_search(self, session, union,
+ args=None, cachekey=None, varmap=None, debug=0):
+ """return result from this source for a rql query (actually from a rql
+ syntax tree and a solution dictionary mapping each used variable to a
+ possible type). If cachekey is given, the query necessary to fetch the
+ results (but not the results themselves) may be cached using this key.
+ """
+ if self._need_sql_create:
+ return []
+ sql, query_args = self.rqlsqlgen.generate(union, args)
+ if server.DEBUG:
+ print self.uri, 'SOURCE RQL', union.as_string()
+ print 'GENERATED SQL', sql
+ args = self.sqladapter.merge_args(args, query_args)
+ cursor = session.pool[self.uri]
+ cursor.execute(sql, args)
+ return self.sqladapter.process_result(cursor)
+
+ def local_add_entity(self, session, entity):
+ """insert the entity in the local database.
+
+ This is not provided as add_entity implementation since usually source
+ don't want to simply do this, so let raise NotImplementedError and the
+ source implementor may use this method if necessary
+ """
+ cu = session.pool[self.uri]
+ attrs = self.sqladapter.preprocess_entity(entity)
+ sql = self.sqladapter.sqlgen.insert(str(entity.e_schema), attrs)
+ cu.execute(sql, attrs)
+
+ def add_entity(self, session, entity):
+ """add a new entity to the source"""
+ raise NotImplementedError()
+
+ def local_update_entity(self, session, entity):
+ """update an entity in the source
+
+ This is not provided as update_entity implementation since usually
+ source don't want to simply do this, so let raise NotImplementedError
+ and the source implementor may use this method if necessary
+ """
+ cu = session.pool[self.uri]
+ attrs = self.sqladapter.preprocess_entity(entity)
+ sql = self.sqladapter.sqlgen.update(str(entity.e_schema), attrs, ['eid'])
+ cu.execute(sql, attrs)
+
+ def update_entity(self, session, entity):
+ """update an entity in the source"""
+ raise NotImplementedError()
+
+ def delete_entity(self, session, etype, eid):
+ """delete an entity from the source
+
+ this is not deleting a file in the svn but deleting entities from the
+ source. Main usage is to delete repository content when a Repository
+ entity is deleted.
+ """
+ sqlcursor = session.pool[self.uri]
+ attrs = {'eid': eid}
+ sql = self.sqladapter.sqlgen.delete(etype, attrs)
+ sqlcursor.execute(sql, attrs)
+
+ def delete_relation(self, session, subject, rtype, object):
+ """delete a relation from the source"""
+ rschema = self.schema.rschema(rtype)
+ if rschema.inlined:
+ if subject in session.query_data('pendingeids', ()):
+ return
+ etype = session.describe(subject)[0]
+ sql = 'UPDATE %s SET %s=NULL WHERE eid=%%(eid)s' % (etype, rtype)
+ attrs = {'eid' : subject}
+ else:
+ attrs = {'eid_from': subject, 'eid_to': object}
+ sql = self.sqladapter.sqlgen.delete('%s_relation' % rtype, attrs)
+ sqlcursor = session.pool[self.uri]
+ sqlcursor.execute(sql, attrs)
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/server/sources/ldapuser.py Mon Dec 22 17:34:15 2008 +0100
@@ -0,0 +1,685 @@
+"""cubicweb ldap user source
+
+this source is for now limited to a read-only EUser source
+
+:organization: Logilab
+:copyright: 2003-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
+
+
+Part of the code is coming form Zope's LDAPUserFolder
+
+Copyright (c) 2004 Jens Vagelpohl.
+All Rights Reserved.
+
+This software is subject to the provisions of the Zope Public License,
+Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
+THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+FOR A PARTICULAR PURPOSE.
+"""
+
+from mx.DateTime import now, DateTimeDelta
+
+from logilab.common.textutils import get_csv
+from rql.nodes import Relation, VariableRef, Constant, Function
+
+import ldap
+from ldap.ldapobject import ReconnectLDAPObject
+from ldap.filter import filter_format, escape_filter_chars
+from ldapurl import LDAPUrl
+
+from cubicweb.common import AuthenticationError, UnknownEid, RepositoryError
+from cubicweb.server.sources import AbstractSource, TrFunc, GlobTrFunc, ConnectionWrapper
+from cubicweb.server.utils import cartesian_product
+
+# search scopes
+BASE = ldap.SCOPE_BASE
+ONELEVEL = ldap.SCOPE_ONELEVEL
+SUBTREE = ldap.SCOPE_SUBTREE
+
+# XXX only for edition ??
+## password encryption possibilities
+#ENCRYPTIONS = ('SHA', 'CRYPT', 'MD5', 'CLEAR') # , 'SSHA'
+
+# mode identifier : (port, protocol)
+MODES = {
+ 0: (389, 'ldap'),
+ 1: (636, 'ldaps'),
+ 2: (0, 'ldapi'),
+ }
+
+class TimedCache(dict):
+ def __init__(self, ttlm, ttls=0):
+ # time to live in minutes
+ self.ttl = DateTimeDelta(0, 0, ttlm, ttls)
+
+ def __setitem__(self, key, value):
+ dict.__setitem__(self, key, (now(), value))
+
+ def __getitem__(self, key):
+ return dict.__getitem__(self, key)[1]
+
+ def clear_expired(self):
+ now_ = now()
+ ttl = self.ttl
+ for key, (timestamp, value) in self.items():
+ if now_ - timestamp > ttl:
+ del self[key]
+
+class LDAPUserSource(AbstractSource):
+ """LDAP read-only EUser source"""
+ support_entities = {'EUser': False}
+
+ port = None
+
+ cnx_mode = 0
+ cnx_dn = ''
+ cnx_pwd = ''
+
+ options = (
+ ('host',
+ {'type' : 'string',
+ 'default': 'ldap',
+ 'help': 'ldap host',
+ 'group': 'ldap-source', 'inputlevel': 1,
+ }),
+ ('user-base-dn',
+ {'type' : 'string',
+ 'default': 'ou=People,dc=logilab,dc=fr',
+ 'help': 'base DN to lookup for users',
+ 'group': 'ldap-source', 'inputlevel': 0,
+ }),
+ ('user-scope',
+ {'type' : 'choice',
+ 'default': 'ONELEVEL',
+ 'choices': ('BASE', 'ONELEVEL', 'SUBTREE'),
+ 'help': 'user search scope',
+ 'group': 'ldap-source', 'inputlevel': 1,
+ }),
+ ('user-classes',
+ {'type' : 'csv',
+ 'default': ('top', 'posixAccount'),
+ 'help': 'classes of user',
+ 'group': 'ldap-source', 'inputlevel': 1,
+ }),
+ ('user-login-attr',
+ {'type' : 'string',
+ 'default': 'uid',
+ 'help': 'attribute used as login on authentication',
+ 'group': 'ldap-source', 'inputlevel': 1,
+ }),
+ ('user-default-group',
+ {'type' : 'csv',
+ 'default': ('users',),
+ 'help': 'name of a group in which ldap users will be by default. \
+You can set multiple groups by separating them by a comma.',
+ 'group': 'ldap-source', 'inputlevel': 1,
+ }),
+ ('user-attrs-map',
+ {'type' : 'named',
+ 'default': {'uid': 'login', 'gecos': 'email'},
+ 'help': 'map from ldap user attributes to cubicweb attributes',
+ 'group': 'ldap-source', 'inputlevel': 1,
+ }),
+
+ ('synchronization-interval',
+ {'type' : 'int',
+ 'default': 24*60*60,
+ 'help': 'interval between synchronization with the ldap \
+directory (default to once a day).',
+ 'group': 'ldap-source', 'inputlevel': 2,
+ }),
+ ('cache-life-time',
+ {'type' : 'int',
+ 'default': 2*60,
+ 'help': 'life time of query cache in minutes (default to two hours).',
+ 'group': 'ldap-source', 'inputlevel': 2,
+ }),
+
+ )
+
+ def __init__(self, repo, appschema, source_config, *args, **kwargs):
+ AbstractSource.__init__(self, repo, appschema, source_config,
+ *args, **kwargs)
+ self.host = source_config['host']
+ self.user_base_dn = source_config['user-base-dn']
+ self.user_base_scope = globals()[source_config['user-scope']]
+ self.user_classes = get_csv(source_config['user-classes'])
+ self.user_login_attr = source_config['user-login-attr']
+ self.user_default_groups = get_csv(source_config['user-default-group'])
+ self.user_attrs = dict(v.split(':', 1) for v in get_csv(source_config['user-attrs-map']))
+ self.user_rev_attrs = {'eid': 'dn'}
+ for ldapattr, cwattr in self.user_attrs.items():
+ self.user_rev_attrs[cwattr] = ldapattr
+ self.base_filters = [filter_format('(%s=%s)', ('objectClass', o))
+ for o in self.user_classes]
+ self._conn = None
+ self._cache = {}
+ ttlm = int(source_config.get('cache-life-type', 2*60))
+ self._query_cache = TimedCache(ttlm)
+ self._interval = int(source_config.get('synchronization-interval',
+ 24*60*60))
+
+ def reset_caches(self):
+ """method called during test to reset potential source caches"""
+ self._query_cache = TimedCache(2*60)
+
+ def init(self):
+ """method called by the repository once ready to handle request"""
+ self.repo.looping_task(self._interval, self.synchronize)
+ self.repo.looping_task(self._query_cache.ttl.seconds/10, self._query_cache.clear_expired)
+
+ def synchronize(self):
+ """synchronize content known by this repository with content in the
+ external repository
+ """
+ self.info('synchronizing ldap source %s', self.uri)
+ session = self.repo.internal_session()
+ try:
+ cursor = session.system_sql("SELECT eid, extid FROM entities WHERE "
+ "source='%s'" % self.uri)
+ for eid, extid in cursor.fetchall():
+ # if no result found, _search automatically delete entity information
+ res = self._search(session, extid, BASE)
+ if res:
+ ldapemailaddr = res[0].get(self.user_rev_attrs['email'])
+ if ldapemailaddr:
+ rset = session.execute('EmailAddress X,A WHERE '
+ 'U use_email X, U eid %(u)s',
+ {'u': eid})
+ ldapemailaddr = unicode(ldapemailaddr)
+ for emaileid, emailaddr in rset:
+ if emailaddr == ldapemailaddr:
+ break
+ else:
+ self.info('updating email address of user %s to %s',
+ extid, ldapemailaddr)
+ if rset:
+ session.execute('SET X address %(addr)s WHERE '
+ 'U primary_email X, U eid %(u)s',
+ {'addr': ldapemailaddr, 'u': eid})
+ else:
+ # no email found, create it
+ _insert_email(session, ldapemailaddr, eid)
+ finally:
+ session.commit()
+ session.close()
+
+ def get_connection(self):
+ """open and return a connection to the source"""
+ if self._conn is None:
+ self._connect()
+ return ConnectionWrapper(self._conn)
+
+ def authenticate(self, session, login, password):
+ """return EUser eid for the given login/password if this account is
+ defined in this source, else raise `AuthenticationError`
+
+ two queries are needed since passwords are stored crypted, so we have
+ to fetch the salt first
+ """
+ assert login, 'no login!'
+ searchfilter = [filter_format('(%s=%s)', (self.user_login_attr, login))]
+ searchfilter.extend([filter_format('(%s=%s)', ('objectClass', o))
+ for o in self.user_classes])
+ searchstr = '(&%s)' % ''.join(searchfilter)
+ # first search the user
+ try:
+ user = self._search(session, self.user_base_dn,
+ self.user_base_scope, searchstr)[0]
+ except IndexError:
+ # no such user
+ raise AuthenticationError()
+ # check password by establishing a (unused) connection
+ try:
+ self._connect(user['dn'], password)
+ except:
+ # Something went wrong, most likely bad credentials
+ raise AuthenticationError()
+ return self.extid2eid(user['dn'], 'EUser', session)
+
+ def ldap_name(self, var):
+ if var.stinfo['relations']:
+ relname = iter(var.stinfo['relations']).next().r_type
+ return self.user_rev_attrs.get(relname)
+ return None
+
+ def prepare_columns(self, mainvars, rqlst):
+ """return two list describin how to build the final results
+ from the result of an ldap search (ie a list of dictionnary)
+ """
+ columns = []
+ global_transforms = []
+ for i, term in enumerate(rqlst.selection):
+ if isinstance(term, Constant):
+ columns.append(term)
+ continue
+ if isinstance(term, Function): # LOWER, UPPER, COUNT...
+ var = term.get_nodes(VariableRef)[0]
+ var = var.variable
+ try:
+ mainvar = var.stinfo['attrvar'].name
+ except AttributeError: # no attrvar set
+ mainvar = var.name
+ assert mainvar in mainvars
+ trname = term.name
+ ldapname = self.ldap_name(var)
+ if trname in ('COUNT', 'MIN', 'MAX', 'SUM'):
+ global_transforms.append(GlobTrFunc(trname, i, ldapname))
+ columns.append((mainvar, ldapname))
+ continue
+ if trname in ('LOWER', 'UPPER'):
+ columns.append((mainvar, TrFunc(trname, i, ldapname)))
+ continue
+ raise NotImplementedError('no support for %s function' % trname)
+ if term.name in mainvars:
+ columns.append((term.name, 'dn'))
+ continue
+ var = term.variable
+ mainvar = var.stinfo['attrvar'].name
+ columns.append((mainvar, self.ldap_name(var)))
+ #else:
+ # # probably a bug in rql splitting if we arrive here
+ # raise NotImplementedError
+ return columns, global_transforms
+
+ def syntax_tree_search(self, session, union,
+ args=None, cachekey=None, varmap=None, debug=0):
+ """return result from this source for a rql query (actually from a rql
+ syntax tree and a solution dictionary mapping each used variable to a
+ possible type). If cachekey is given, the query necessary to fetch the
+ results (but not the results themselves) may be cached using this key.
+ """
+ # XXX not handled : transform/aggregat function, join on multiple users...
+ assert len(union.children) == 1, 'union not supported'
+ rqlst = union.children[0]
+ assert not rqlst.with_, 'subquery not supported'
+ rqlkey = rqlst.as_string(kwargs=args)
+ try:
+ results = self._query_cache[rqlkey]
+ except KeyError:
+ results = self.rqlst_search(session, rqlst, args)
+ self._query_cache[rqlkey] = results
+ return results
+
+ def rqlst_search(self, session, rqlst, args):
+ mainvars = []
+ for varname in rqlst.defined_vars:
+ for sol in rqlst.solutions:
+ if sol[varname] == 'EUser':
+ mainvars.append(varname)
+ break
+ assert mainvars
+ columns, globtransforms = self.prepare_columns(mainvars, rqlst)
+ eidfilters = []
+ allresults = []
+ generator = RQL2LDAPFilter(self, session, args, mainvars)
+ for mainvar in mainvars:
+ # handle restriction
+ try:
+ eidfilters_, ldapfilter = generator.generate(rqlst, mainvar)
+ except GotDN, ex:
+ assert ex.dn, 'no dn!'
+ try:
+ res = [self._cache[ex.dn]]
+ except KeyError:
+ res = self._search(session, ex.dn, BASE)
+ except UnknownEid, ex:
+ # raised when we are looking for the dn of an eid which is not
+ # coming from this source
+ res = []
+ else:
+ eidfilters += eidfilters_
+ res = self._search(session, self.user_base_dn,
+ self.user_base_scope, ldapfilter)
+ allresults.append(res)
+ # 1. get eid for each dn and filter according to that eid if necessary
+ for i, res in enumerate(allresults):
+ filteredres = []
+ for resdict in res:
+ # get sure the entity exists in the system table
+ eid = self.extid2eid(resdict['dn'], 'EUser', session)
+ for eidfilter in eidfilters:
+ if not eidfilter(eid):
+ break
+ else:
+ resdict['eid'] = eid
+ filteredres.append(resdict)
+ allresults[i] = filteredres
+ # 2. merge result for each "mainvar": cartesian product
+ allresults = cartesian_product(allresults)
+ # 3. build final result according to column definition
+ result = []
+ for rawline in allresults:
+ rawline = dict(zip(mainvars, rawline))
+ line = []
+ for varname, ldapname in columns:
+ if ldapname is None:
+ value = None # no mapping available
+ elif ldapname == 'dn':
+ value = rawline[varname]['eid']
+ elif isinstance(ldapname, Constant):
+ if ldapname.type == 'Substitute':
+ value = args[ldapname.value]
+ else:
+ value = ldapname.value
+ elif isinstance(ldapname, TrFunc):
+ value = ldapname.apply(rawline[varname])
+ else:
+ value = rawline[varname].get(ldapname)
+ line.append(value)
+ result.append(line)
+ for trfunc in globtransforms:
+ result = trfunc.apply(result)
+ #print '--> ldap result', result
+ return result
+
+
+ def _connect(self, userdn=None, userpwd=None):
+ port, protocol = MODES[self.cnx_mode]
+ if protocol == 'ldapi':
+ hostport = self.host
+ else:
+ hostport = '%s:%s' % (self.host, self.port or port)
+ self.info('connecting %s://%s as %s', protocol, hostport,
+ userdn or 'anonymous')
+ url = LDAPUrl(urlscheme=protocol, hostport=hostport)
+ conn = ReconnectLDAPObject(url.initializeUrl())
+ # Set the protocol version - version 3 is preferred
+ try:
+ conn.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION3)
+ except ldap.LDAPError: # Invalid protocol version, fall back safely
+ conn.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION2)
+ # Deny auto-chasing of referrals to be safe, we handle them instead
+ #try:
+ # connection.set_option(ldap.OPT_REFERRALS, 0)
+ #except ldap.LDAPError: # Cannot set referrals, so do nothing
+ # pass
+ #conn.set_option(ldap.OPT_NETWORK_TIMEOUT, conn_timeout)
+ #conn.timeout = op_timeout
+ # Now bind with the credentials given. Let exceptions propagate out.
+ if userdn is None:
+ assert self._conn is None
+ self._conn = conn
+ userdn = self.cnx_dn
+ userpwd = self.cnx_pwd
+ conn.simple_bind_s(userdn, userpwd)
+ return conn
+
+ def _search(self, session, base, scope,
+ searchstr='(objectClass=*)', attrs=()):
+ """make an ldap query"""
+ cnx = session.pool.connection(self.uri).cnx
+ try:
+ res = cnx.search_s(base, scope, searchstr, attrs)
+ except ldap.PARTIAL_RESULTS:
+ res = cnx.result(all=0)[1]
+ except ldap.NO_SUCH_OBJECT:
+ eid = self.extid2eid(base, 'EUser', session, insert=False)
+ if eid:
+ self.warning('deleting ldap user with eid %s and dn %s',
+ eid, base)
+ self.repo.delete_info(session, eid)
+ self._cache.pop(base, None)
+ return []
+## except ldap.REFERRAL, e:
+## cnx = self.handle_referral(e)
+## try:
+## res = cnx.search_s(base, scope, searchstr, attrs)
+## except ldap.PARTIAL_RESULTS:
+## res_type, res = cnx.result(all=0)
+ result = []
+ for rec_dn, rec_dict in res:
+ # When used against Active Directory, "rec_dict" may not be
+ # be a dictionary in some cases (instead, it can be a list)
+ # An example of a useless "res" entry that can be ignored
+ # from AD is
+ # (None, ['ldap://ForestDnsZones.PORTAL.LOCAL/DC=ForestDnsZones,DC=PORTAL,DC=LOCAL'])
+ # This appears to be some sort of internal referral, but
+ # we can't handle it, so we need to skip over it.
+ try:
+ items = rec_dict.items()
+ except AttributeError:
+ # 'items' not found on rec_dict, skip
+ continue
+ for key, value in items: # XXX syt: huuum ?
+ if not isinstance(value, str):
+ try:
+ for i in range(len(value)):
+ value[i] = unicode(value[i], 'utf8')
+ except:
+ pass
+ if isinstance(value, list) and len(value) == 1:
+ rec_dict[key] = value = value[0]
+ rec_dict['dn'] = rec_dn
+ self._cache[rec_dn] = rec_dict
+ result.append(rec_dict)
+ #print '--->', result
+ return result
+
+ def before_entity_insertion(self, session, lid, etype, eid):
+ """called by the repository when an eid has been attributed for an
+ entity stored here but the entity has not been inserted in the system
+ table yet.
+
+ This method must return the an Entity instance representation of this
+ entity.
+ """
+ entity = super(LDAPUserSource, self).before_entity_insertion(session, lid, etype, eid)
+ res = self._search(session, lid, BASE)[0]
+ for attr in entity.e_schema.indexable_attributes():
+ entity[attr] = res[self.user_rev_attrs[attr]]
+ return entity
+
+ def after_entity_insertion(self, session, dn, entity):
+ """called by the repository after an entity stored here has been
+ inserted in the system table.
+ """
+ super(LDAPUserSource, self).after_entity_insertion(session, dn, entity)
+ for group in self.user_default_groups:
+ session.execute('SET X in_group G WHERE X eid %(x)s, G name %(group)s',
+ {'x': entity.eid, 'group': group}, 'x')
+ # search for existant email first
+ try:
+ emailaddr = self._cache[dn][self.user_rev_attrs['email']]
+ except KeyError:
+ return
+ rset = session.execute('EmailAddress X WHERE X address %(addr)s',
+ {'addr': emailaddr})
+ if rset:
+ session.execute('SET U primary_email X WHERE U eid %(u)s, X eid %(x)s',
+ {'x': rset[0][0], 'u': entity.eid}, 'u')
+ else:
+ # not found, create it
+ _insert_email(session, emailaddr, entity.eid)
+
+ def update_entity(self, session, entity):
+ """replace an entity in the source"""
+ raise RepositoryError('this source is read only')
+
+ def delete_entity(self, session, etype, eid):
+ """delete an entity from the source"""
+ raise RepositoryError('this source is read only')
+
+def _insert_email(session, emailaddr, ueid):
+ session.execute('INSERT EmailAddress X: X address %(addr)s, U primary_email X '
+ 'WHERE U eid %(x)s', {'addr': emailaddr, 'x': ueid}, 'x')
+
+class GotDN(Exception):
+ """exception used when a dn localizing the searched user has been found"""
+ def __init__(self, dn):
+ self.dn = dn
+
+
+class RQL2LDAPFilter(object):
+ """generate an LDAP filter for a rql query"""
+ def __init__(self, source, session, args=None, mainvars=()):
+ self.source = source
+ self._ldap_attrs = source.user_rev_attrs
+ self._base_filters = source.base_filters
+ self._session = session
+ if args is None:
+ args = {}
+ self._args = args
+ self.mainvars = mainvars
+
+ def generate(self, selection, mainvarname):
+ self._filters = res = self._base_filters[:]
+ self._mainvarname = mainvarname
+ self._eidfilters = []
+ self._done_not = set()
+ restriction = selection.where
+ if isinstance(restriction, Relation):
+ # only a single relation, need to append result here (no AND/OR)
+ filter = restriction.accept(self)
+ if filter is not None:
+ res.append(filter)
+ elif restriction:
+ restriction.accept(self)
+ if len(res) > 1:
+ return self._eidfilters, '(&%s)' % ''.join(res)
+ return self._eidfilters, res[0]
+
+ def visit_and(self, et):
+ """generate filter for a AND subtree"""
+ for c in et.children:
+ part = c.accept(self)
+ if part:
+ self._filters.append(part)
+
+ def visit_or(self, ou):
+ """generate filter for a OR subtree"""
+ res = []
+ for c in ou.children:
+ part = c.accept(self)
+ if part:
+ res.append(part)
+ if res:
+ if len(res) > 1:
+ part = '(|%s)' % ''.join(res)
+ else:
+ part = res[0]
+ self._filters.append(part)
+
+ def visit_not(self, node):
+ """generate filter for a OR subtree"""
+ part = node.children[0].accept(self)
+ if part:
+ self._filters.append('(!(%s))'% part)
+
+ def visit_relation(self, relation):
+ """generate filter for a relation"""
+ rtype = relation.r_type
+ # don't care of type constraint statement (i.e. relation_type = 'is')
+ if rtype == 'is':
+ return ''
+ lhs, rhs = relation.get_parts()
+ # attribute relation
+ if self.source.schema.rschema(rtype).is_final():
+ # dunno what to do here, don't pretend anything else
+ if lhs.name != self._mainvarname:
+ if lhs.name in self.mainvars:
+ # XXX check we don't have variable as rhs
+ return
+ raise NotImplementedError
+ rhs_vars = rhs.get_nodes(VariableRef)
+ if rhs_vars:
+ if len(rhs_vars) > 1:
+ raise NotImplementedError
+ # selected variable, nothing to do here
+ return
+ # no variables in the RHS
+ if isinstance(rhs.children[0], Function):
+ res = rhs.children[0].accept(self)
+ elif rtype != 'has_text':
+ res = self._visit_attribute_relation(relation)
+ else:
+ raise NotImplementedError(relation)
+ # regular relation XXX todo: in_group
+ else:
+ raise NotImplementedError(relation)
+ return res
+
+ def _visit_attribute_relation(self, relation):
+ """generate filter for an attribute relation"""
+ lhs, rhs = relation.get_parts()
+ lhsvar = lhs.variable
+ if relation.r_type == 'eid':
+ # XXX hack
+ # skip comparison sign
+ eid = int(rhs.children[0].accept(self))
+ if relation.neged(strict=True):
+ self._done_not.add(relation.parent)
+ self._eidfilters.append(lambda x: not x == eid)
+ return
+ if rhs.operator != '=':
+ filter = {'>': lambda x: x > eid,
+ '>=': lambda x: x >= eid,
+ '<': lambda x: x < eid,
+ '<=': lambda x: x <= eid,
+ }[rhs.operator]
+ self._eidfilters.append(filter)
+ return
+ dn = self.source.eid2extid(eid, self._session)
+ raise GotDN(dn)
+ try:
+ filter = '(%s%s)' % (self._ldap_attrs[relation.r_type],
+ rhs.accept(self))
+ except KeyError:
+ assert relation.r_type == 'password' # 2.38 migration
+ raise UnknownEid # trick to return no result
+ return filter
+
+ def visit_comparison(self, cmp):
+ """generate filter for a comparaison"""
+ return '%s%s'% (cmp.operator, cmp.children[0].accept(self))
+
+ def visit_mathexpression(self, mexpr):
+ """generate filter for a mathematic expression"""
+ raise NotImplementedError
+
+ def visit_function(self, function):
+ """generate filter name for a function"""
+ if function.name == 'IN':
+ return self.visit_in(function)
+ raise NotImplementedError
+
+ def visit_in(self, function):
+ grandpapa = function.parent.parent
+ ldapattr = self._ldap_attrs[grandpapa.r_type]
+ res = []
+ for c in function.children:
+ part = c.accept(self)
+ if part:
+ res.append(part)
+ if res:
+ if len(res) > 1:
+ part = '(|%s)' % ''.join('(%s=%s)' % (ldapattr, v) for v in res)
+ else:
+ part = '(%s=%s)' % (ldapattr, res[0])
+ return part
+
+ def visit_constant(self, constant):
+ """generate filter name for a constant"""
+ value = constant.value
+ if constant.type is None:
+ raise NotImplementedError
+ if constant.type == 'Date':
+ raise NotImplementedError
+ #value = self.keyword_map[value]()
+ elif constant.type == 'Substitute':
+ value = self._args[constant.value]
+ else:
+ value = constant.value
+ if isinstance(value, unicode):
+ value = value.encode('utf8')
+ else:
+ value = str(value)
+ return escape_filter_chars(value)
+
+ def visit_variableref(self, variableref):
+ """get the sql name for a variable reference"""
+ pass
+
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/server/sources/pyrorql.py Mon Dec 22 17:34:15 2008 +0100
@@ -0,0 +1,553 @@
+"""Source to query another RQL repository using pyro
+
+:organization: Logilab
+:copyright: 2007-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
+"""
+__docformat__ = "restructuredtext en"
+
+import threading
+from os.path import join
+
+from mx.DateTime import DateTimeFromTicks
+
+from Pyro.errors import PyroError, ConnectionClosedError
+
+from logilab.common.configuration import REQUIRED
+
+from rql.nodes import Constant
+from rql.utils import rqlvar_maker
+
+from cubicweb import dbapi, server
+from cubicweb import BadConnectionId, UnknownEid, ConnectionError
+from cubicweb.cwconfig import register_persistent_options
+from cubicweb.server.sources import AbstractSource, ConnectionWrapper
+
+class ReplaceByInOperator:
+ def __init__(self, eids):
+ self.eids = eids
+
+class PyroRQLSource(AbstractSource):
+ """External repository source, using Pyro connection"""
+
+ # boolean telling if modification hooks should be called when something is
+ # modified in this source
+ should_call_hooks = False
+ # boolean telling if the repository should connect to this source during
+ # migration
+ connect_for_migration = False
+
+ support_entities = None
+
+ options = (
+ # XXX pyro-ns host/port
+ ('pyro-ns-id',
+ {'type' : 'string',
+ 'default': REQUIRED,
+ 'help': 'identifier of the repository in the pyro name server',
+ 'group': 'pyro-source', 'inputlevel': 0,
+ }),
+ ('mapping-file',
+ {'type' : 'string',
+ 'default': REQUIRED,
+ 'help': 'path to a python file with the schema mapping definition',
+ 'group': 'pyro-source', 'inputlevel': 1,
+ }),
+ ('cubicweb-user',
+ {'type' : 'string',
+ 'default': REQUIRED,
+ 'help': 'user to use for connection on the distant repository',
+ 'group': 'pyro-source', 'inputlevel': 0,
+ }),
+ ('cubicweb-password',
+ {'type' : 'password',
+ 'default': '',
+ 'help': 'user to use for connection on the distant repository',
+ 'group': 'pyro-source', 'inputlevel': 0,
+ }),
+ ('base-url',
+ {'type' : 'string',
+ 'default': '',
+ 'help': 'url of the web site for the distant repository, if you want '
+ 'to generate external link to entities from this repository',
+ 'group': 'pyro-source', 'inputlevel': 1,
+ }),
+ ('pyro-ns-host',
+ {'type' : 'string',
+ 'default': None,
+ 'help': 'Pyro name server\'s host. If not set, default to the value \
+from all_in_one.conf.',
+ 'group': 'pyro-source', 'inputlevel': 1,
+ }),
+ ('pyro-ns-port',
+ {'type' : 'int',
+ 'default': None,
+ 'help': 'Pyro name server\'s listening port. If not set, default to \
+the value from all_in_one.conf.',
+ 'group': 'pyro-source', 'inputlevel': 1,
+ }),
+ ('pyro-ns-group',
+ {'type' : 'string',
+ 'default': None,
+ 'help': 'Pyro name server\'s group where the repository will be \
+registered. If not set, default to the value from all_in_one.conf.',
+ 'group': 'pyro-source', 'inputlevel': 1,
+ }),
+ ('synchronization-interval',
+ {'type' : 'int',
+ 'default': 5*60,
+ 'help': 'interval between synchronization with the external \
+repository (default to 5 minutes).',
+ 'group': 'pyro-source', 'inputlevel': 2,
+ }),
+
+ )
+
+ PUBLIC_KEYS = AbstractSource.PUBLIC_KEYS + ('base-url',)
+ _conn = None
+
+ def __init__(self, repo, appschema, source_config, *args, **kwargs):
+ AbstractSource.__init__(self, repo, appschema, source_config,
+ *args, **kwargs)
+ mappingfile = source_config['mapping-file']
+ if not mappingfile[0] == '/':
+ mappingfile = join(repo.config.apphome, mappingfile)
+ mapping = {}
+ execfile(mappingfile, mapping)
+ self.support_entities = mapping['support_entities']
+ self.support_relations = mapping.get('support_relations', {})
+ self.dont_cross_relations = mapping.get('dont_cross_relations', ())
+ baseurl = source_config.get('base-url')
+ if baseurl and not baseurl.endswith('/'):
+ source_config['base-url'] += '/'
+ self.config = source_config
+ myoptions = (('%s.latest-update-time' % self.uri,
+ {'type' : 'int', 'sitewide': True,
+ 'default': 0,
+ 'help': _('timestamp of the latest source synchronization.'),
+ 'group': 'sources',
+ }),)
+ register_persistent_options(myoptions)
+
+ def last_update_time(self):
+ pkey = u'sources.%s.latest-update-time' % self.uri
+ rql = 'Any V WHERE X is EProperty, X value V, X pkey %(k)s'
+ session = self.repo.internal_session()
+ try:
+ rset = session.execute(rql, {'k': pkey})
+ if not rset:
+ # insert it
+ session.execute('INSERT EProperty X: X pkey %(k)s, X value %(v)s',
+ {'k': pkey, 'v': u'0'})
+ session.commit()
+ timestamp = 0
+ else:
+ assert len(rset) == 1
+ timestamp = int(rset[0][0])
+ return DateTimeFromTicks(timestamp)
+ finally:
+ session.close()
+
+ def init(self):
+ """method called by the repository once ready to handle request"""
+ interval = int(self.config.get('synchronization-interval', 5*60))
+ self.repo.looping_task(interval, self.synchronize)
+
+ def synchronize(self, mtime=None):
+ """synchronize content known by this repository with content in the
+ external repository
+ """
+ self.info('synchronizing pyro source %s', self.uri)
+ extrepo = self.get_connection()._repo
+ etypes = self.support_entities.keys()
+ if mtime is None:
+ mtime = self.last_update_time()
+ updatetime, modified, deleted = extrepo.entities_modified_since(etypes,
+ mtime)
+ repo = self.repo
+ session = repo.internal_session()
+ try:
+ for etype, extid in modified:
+ try:
+ eid = self.extid2eid(extid, etype, session)
+ rset = session.eid_rset(eid, etype)
+ entity = rset.get_entity(0, 0)
+ entity.complete(entity.e_schema.indexable_attributes())
+ repo.index_entity(session, entity)
+ except:
+ self.exception('while updating %s with external id %s of source %s',
+ etype, extid, self.uri)
+ continue
+ for etype, extid in deleted:
+ try:
+ eid = self.extid2eid(extid, etype, session, insert=False)
+ # entity has been deleted from external repository but is not known here
+ if eid is not None:
+ repo.delete_info(session, eid)
+ except:
+ self.exception('while updating %s with external id %s of source %s',
+ etype, extid, self.uri)
+ continue
+ session.execute('SET X value %(v)s WHERE X pkey %(k)s',
+ {'k': u'sources.%s.latest-update-time' % self.uri,
+ 'v': unicode(int(updatetime.ticks()))})
+ session.commit()
+ finally:
+ session.close()
+
+ def _get_connection(self):
+ """open and return a connection to the source"""
+ nshost = self.config.get('pyro-ns-host') or self.repo.config['pyro-ns-host']
+ nsport = self.config.get('pyro-ns-port') or self.repo.config['pyro-ns-port']
+ nsgroup = self.config.get('pyro-ns-group') or self.repo.config['pyro-ns-group']
+ #cnxprops = ConnectionProperties(cnxtype=self.config['cnx-type'])
+ return dbapi.connect(database=self.config['pyro-ns-id'],
+ user=self.config['cubicweb-user'],
+ password=self.config['cubicweb-password'],
+ host=nshost, port=nsport, group=nsgroup,
+ setvreg=False) #cnxprops=cnxprops)
+
+ def get_connection(self):
+ try:
+ return self._get_connection()
+ except (ConnectionError, PyroError):
+ self.critical("can't get connection to source %s", self.uri,
+ exc_info=1)
+ return ConnectionWrapper()
+
+ def check_connection(self, cnx):
+ """check connection validity, return None if the connection is still valid
+ else a new connection
+ """
+ # we have to transfer manually thread ownership. This can be done safely
+ # since the pool to which belong the connection is affected to one
+ # session/thread and can't be called simultaneously
+ try:
+ cnx._repo._transferThread(threading.currentThread())
+ except AttributeError:
+ # inmemory connection
+ pass
+ if not isinstance(cnx, ConnectionWrapper):
+ try:
+ cnx.check()
+ return # ok
+ except (BadConnectionId, ConnectionClosedError):
+ pass
+ # try to reconnect
+ return self.get_connection()
+
+
+ def syntax_tree_search(self, session, union, args=None, cachekey=None,
+ varmap=None):
+ """return result from this source for a rql query (actually from a rql
+ syntax tree and a solution dictionary mapping each used variable to a
+ possible type). If cachekey is given, the query necessary to fetch the
+ results (but not the results themselves) may be cached using this key.
+ """
+ if not args is None:
+ args = args.copy()
+ if server.DEBUG:
+ print 'RQL FOR PYRO SOURCE', self.uri
+ print union.as_string()
+ if args: print 'ARGS', args
+ print 'SOLUTIONS', ','.join(str(s.solutions) for s in union.children)
+ # get cached cursor anyway
+ cu = session.pool[self.uri]
+ if cu is None:
+ # this is a ConnectionWrapper instance
+ msg = session._("can't connect to source %s, some data may be missing")
+ session.set_shared_data('sources_error', msg % self.uri)
+ return []
+ try:
+ rql, cachekey = RQL2RQL(self).generate(session, union, args)
+ except UnknownEid, ex:
+ if server.DEBUG:
+ print 'unknown eid', ex, 'no results'
+ return []
+ if server.DEBUG:
+ print 'TRANSLATED RQL', rql
+ try:
+ rset = cu.execute(rql, args, cachekey)
+ except Exception, ex:
+ self.exception(str(ex))
+ msg = session._("error while querying source %s, some data may be missing")
+ session.set_shared_data('sources_error', msg % self.uri)
+ return []
+ descr = rset.description
+ if rset:
+ needtranslation = []
+ for i, etype in enumerate(descr[0]):
+ if (etype is None or not self.schema.eschema(etype).is_final() or
+ getattr(union.locate_subquery(i, etype, args).selection[i], 'uidtype', None)):
+ needtranslation.append(i)
+ if needtranslation:
+ for rowindex, row in enumerate(rset):
+ for colindex in needtranslation:
+ if row[colindex] is not None: # optional variable
+ etype = descr[rowindex][colindex]
+ eid = self.extid2eid(row[colindex], etype, session)
+ row[colindex] = eid
+ results = rset.rows
+ else:
+ results = []
+ if server.DEBUG:
+ if len(results)>10:
+ print '--------------->', results[:10], '...', len(results)
+ else:
+ print '--------------->', results
+ return results
+
+ def _entity_relations_and_kwargs(self, session, entity):
+ relations = []
+ kwargs = {'x': self.eid2extid(entity.eid, session)}
+ for key, val in entity.iteritems():
+ relations.append('X %s %%(%s)s' % (key, key))
+ kwargs[key] = val
+ return relations, kwargs
+
+ def add_entity(self, session, entity):
+ """add a new entity to the source"""
+ raise NotImplementedError()
+
+ def update_entity(self, session, entity):
+ """update an entity in the source"""
+ relations, kwargs = self._entity_relations_and_kwargs(session, entity)
+ cu = session.pool[self.uri]
+ cu.execute('SET %s WHERE X eid %%(x)s' % ','.join(relations),
+ kwargs, 'x')
+
+ def delete_entity(self, session, etype, eid):
+ """delete an entity from the source"""
+ cu = session.pool[self.uri]
+ cu.execute('DELETE %s X WHERE X eid %%(x)s' % etype,
+ {'x': self.eid2extid(eid, session)}, 'x')
+
+ def add_relation(self, session, subject, rtype, object):
+ """add a relation to the source"""
+ cu = session.pool[self.uri]
+ cu.execute('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
+ {'x': self.eid2extid(subject, session),
+ 'y': self.eid2extid(object, session)}, ('x', 'y'))
+
+ def delete_relation(self, session, subject, rtype, object):
+ """delete a relation from the source"""
+ cu = session.pool[self.uri]
+ cu.execute('DELETE X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
+ {'x': self.eid2extid(subject, session),
+ 'y': self.eid2extid(object, session)}, ('x', 'y'))
+
+
+class RQL2RQL(object):
+ """translate a local rql query to be executed on a distant repository"""
+ def __init__(self, source):
+ self.source = source
+
+ def _accept_children(self, node):
+ res = []
+ for child in node.children:
+ rql = child.accept(self)
+ if rql is not None:
+ res.append(rql)
+ return res
+
+ def generate(self, session, rqlst, args):
+ self._session = session
+ self.kwargs = args
+ self.cachekey = []
+ self.need_translation = False
+ return self.visit_union(rqlst), self.cachekey
+
+ def visit_union(self, node):
+ s = self._accept_children(node)
+ if len(s) > 1:
+ return ' UNION '.join('(%s)' % q for q in s)
+ return s[0]
+
+ def visit_select(self, node):
+ """return the tree as an encoded rql string"""
+ self._varmaker = rqlvar_maker(defined=node.defined_vars.copy())
+ self._const_var = {}
+ if node.distinct:
+ base = 'DISTINCT Any'
+ else:
+ base = 'Any'
+ s = ['%s %s' % (base, ','.join(v.accept(self) for v in node.selection))]
+ if node.groupby:
+ s.append('GROUPBY %s' % ', '.join(group.accept(self)
+ for group in node.groupby))
+ if node.orderby:
+ s.append('ORDERBY %s' % ', '.join(self.visit_sortterm(term)
+ for term in node.orderby))
+ if node.limit is not None:
+ s.append('LIMIT %s' % node.limit)
+ if node.offset:
+ s.append('OFFSET %s' % node.offset)
+ restrictions = []
+ if node.where is not None:
+ nr = node.where.accept(self)
+ if nr is not None:
+ restrictions.append(nr)
+ if restrictions:
+ s.append('WHERE %s' % ','.join(restrictions))
+
+ if node.having:
+ s.append('HAVING %s' % ', '.join(term.accept(self)
+ for term in node.having))
+ subqueries = []
+ for subquery in node.with_:
+ subqueries.append('%s BEING (%s)' % (','.join(ca.name for ca in subquery.aliases),
+ self.visit_union(subquery.query)))
+ if subqueries:
+ s.append('WITH %s' % (','.join(subqueries)))
+ return ' '.join(s)
+
+ def visit_and(self, node):
+ res = self._accept_children(node)
+ if res:
+ return ', '.join(res)
+ return
+
+ def visit_or(self, node):
+ res = self._accept_children(node)
+ if len(res) > 1:
+ return ' OR '.join('(%s)' % rql for rql in res)
+ elif res:
+ return res[0]
+ return
+
+ def visit_not(self, node):
+ rql = node.children[0].accept(self)
+ if rql:
+ return 'NOT (%s)' % rql
+ return
+
+ def visit_exists(self, node):
+ return 'EXISTS(%s)' % node.children[0].accept(self)
+
+ def visit_relation(self, node):
+ try:
+ if isinstance(node.children[0], Constant):
+ # simplified rqlst, reintroduce eid relation
+ restr, lhs = self.process_eid_const(node.children[0])
+ else:
+ lhs = node.children[0].accept(self)
+ restr = None
+ except UnknownEid:
+ # can safely skip not relation with an unsupported eid
+ if node.neged(strict=True):
+ return
+ # XXX what about optional relation or outer NOT EXISTS()
+ raise
+ if node.optional in ('left', 'both'):
+ lhs += '?'
+ if node.r_type == 'eid' or not self.source.schema.rschema(node.r_type).is_final():
+ self.need_translation = True
+ self.current_operator = node.operator()
+ if isinstance(node.children[0], Constant):
+ self.current_etypes = (node.children[0].uidtype,)
+ else:
+ self.current_etypes = node.children[0].variable.stinfo['possibletypes']
+ try:
+ rhs = node.children[1].accept(self)
+ except UnknownEid:
+ # can safely skip not relation with an unsupported eid
+ if node.neged(strict=True):
+ return
+ # XXX what about optional relation or outer NOT EXISTS()
+ raise
+ except ReplaceByInOperator, ex:
+ rhs = 'IN (%s)' % ','.join(str(eid) for eid in ex.eids)
+ self.need_translation = False
+ self.current_operator = None
+ if node.optional in ('right', 'both'):
+ rhs += '?'
+ if restr is not None:
+ return '%s %s %s, %s' % (lhs, node.r_type, rhs, restr)
+ return '%s %s %s' % (lhs, node.r_type, rhs)
+
+ def visit_comparison(self, node):
+ if node.operator in ('=', 'IS'):
+ return node.children[0].accept(self)
+ return '%s %s' % (node.operator.encode(),
+ node.children[0].accept(self))
+
+ def visit_mathexpression(self, node):
+ return '(%s %s %s)' % (node.children[0].accept(self),
+ node.operator.encode(),
+ node.children[1].accept(self))
+
+ def visit_function(self, node):
+ #if node.name == 'IN':
+ res = []
+ for child in node.children:
+ try:
+ rql = child.accept(self)
+ except UnknownEid, ex:
+ continue
+ res.append(rql)
+ if not res:
+ raise ex
+ return '%s(%s)' % (node.name, ', '.join(res))
+
+ def visit_constant(self, node):
+ if self.need_translation or node.uidtype:
+ if node.type == 'Int':
+ return str(self.eid2extid(node.value))
+ if node.type == 'Substitute':
+ key = node.value
+ # ensure we have not yet translated the value...
+ if not key in self._const_var:
+ self.kwargs[key] = self.eid2extid(self.kwargs[key])
+ self.cachekey.append(key)
+ self._const_var[key] = None
+ return node.as_string()
+
+ def visit_variableref(self, node):
+ """get the sql name for a variable reference"""
+ return node.name
+
+ def visit_sortterm(self, node):
+ if node.asc:
+ return node.term.accept(self)
+ return '%s DESC' % node.term.accept(self)
+
+ def process_eid_const(self, const):
+ value = const.eval(self.kwargs)
+ try:
+ return None, self._const_var[value]
+ except:
+ var = self._varmaker.next()
+ self.need_translation = True
+ restr = '%s eid %s' % (var, self.visit_constant(const))
+ self.need_translation = False
+ self._const_var[value] = var
+ return restr, var
+
+ def eid2extid(self, eid):
+ try:
+ return self.source.eid2extid(eid, self._session)
+ except UnknownEid:
+ operator = self.current_operator
+ if operator is not None and operator != '=':
+ # deal with query like X eid > 12
+ #
+ # The problem is
+ # that eid order in the external source may differ from the
+ # local source
+ #
+ # So search for all eids from this
+ # source matching the condition locally and then to replace the
+ # > 12 branch by IN (eids) (XXX we may have to insert a huge
+ # number of eids...)
+ # planner so that
+ sql = "SELECT extid FROM entities WHERE source='%s' AND type IN (%s) AND eid%s%s"
+ etypes = ','.join("'%s'" % etype for etype in self.current_etypes)
+ cu = self._session.system_sql(sql % (self.source.uri, etypes,
+ operator, eid))
+ # XXX buggy cu.rowcount which may be zero while there are some
+ # results
+ rows = cu.fetchall()
+ if rows:
+ raise ReplaceByInOperator((r[0] for r in rows))
+ raise
+