rqlrewrite.py
changeset 11057 0b59724cb3f2
parent 11052 058bb3dc685f
child 11058 23eb30449fe5
equal deleted inserted replaced
11052:058bb3dc685f 11057:0b59724cb3f2
     1 # copyright 2003-2014 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
       
     2 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
       
     3 #
       
     4 # This file is part of CubicWeb.
       
     5 #
       
     6 # CubicWeb is free software: you can redistribute it and/or modify it under the
       
     7 # terms of the GNU Lesser General Public License as published by the Free
       
     8 # Software Foundation, either version 2.1 of the License, or (at your option)
       
     9 # any later version.
       
    10 #
       
    11 # CubicWeb is distributed in the hope that it will be useful, but WITHOUT
       
    12 # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
       
    13 # FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
       
    14 # details.
       
    15 #
       
    16 # You should have received a copy of the GNU Lesser General Public License along
       
    17 # with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
       
    18 """RQL rewriting utilities : insert rql expression snippets into rql syntax
       
    19 tree.
       
    20 
       
    21 This is used for instance for read security checking in the repository.
       
    22 """
       
    23 __docformat__ = "restructuredtext en"
       
    24 
       
    25 from six import text_type, string_types
       
    26 
       
    27 from rql import nodes as n, stmts, TypeResolverException
       
    28 from rql.utils import common_parent
       
    29 
       
    30 from yams import BadSchemaDefinition
       
    31 
       
    32 from logilab.common import tempattr
       
    33 from logilab.common.graph import has_path
       
    34 
       
    35 from cubicweb import Unauthorized
       
    36 from cubicweb.schema import RRQLExpression
       
    37 
       
    38 def cleanup_solutions(rqlst, solutions):
       
    39     for sol in solutions:
       
    40         for vname in list(sol):
       
    41             if not (vname in rqlst.defined_vars or vname in rqlst.aliases):
       
    42                 del sol[vname]
       
    43 
       
    44 
       
    45 def add_types_restriction(schema, rqlst, newroot=None, solutions=None):
       
    46     if newroot is None:
       
    47         assert solutions is None
       
    48         if hasattr(rqlst, '_types_restr_added'):
       
    49             return
       
    50         solutions = rqlst.solutions
       
    51         newroot = rqlst
       
    52         rqlst._types_restr_added = True
       
    53     else:
       
    54         assert solutions is not None
       
    55         rqlst = rqlst.stmt
       
    56     eschema = schema.eschema
       
    57     allpossibletypes = {}
       
    58     for solution in solutions:
       
    59         for varname, etype in solution.items():
       
    60             # XXX not considering aliases by design, right ?
       
    61             if varname not in newroot.defined_vars or eschema(etype).final:
       
    62                 continue
       
    63             allpossibletypes.setdefault(varname, set()).add(etype)
       
    64     # XXX could be factorized with add_etypes_restriction from rql 0.31
       
    65     for varname in sorted(allpossibletypes):
       
    66         var = newroot.defined_vars[varname]
       
    67         stinfo = var.stinfo
       
    68         if stinfo.get('uidrel') is not None:
       
    69             continue # eid specified, no need for additional type specification
       
    70         try:
       
    71             typerel = rqlst.defined_vars[varname].stinfo.get('typerel')
       
    72         except KeyError:
       
    73             assert varname in rqlst.aliases
       
    74             continue
       
    75         if newroot is rqlst and typerel is not None:
       
    76             mytyperel = typerel
       
    77         else:
       
    78             for vref in var.references():
       
    79                 rel = vref.relation()
       
    80                 if rel and rel.is_types_restriction():
       
    81                     mytyperel = rel
       
    82                     break
       
    83             else:
       
    84                 mytyperel = None
       
    85         possibletypes = allpossibletypes[varname]
       
    86         if mytyperel is not None:
       
    87             if mytyperel.r_type == 'is_instance_of':
       
    88                 # turn is_instance_of relation into a is relation since we've
       
    89                 # all possible solutions and don't want to bother with
       
    90                 # potential is_instance_of incompatibility
       
    91                 mytyperel.r_type = 'is'
       
    92                 if len(possibletypes) > 1:
       
    93                     node = n.Function('IN')
       
    94                     for etype in sorted(possibletypes):
       
    95                         node.append(n.Constant(etype, 'etype'))
       
    96                 else:
       
    97                     etype = next(iter(possibletypes))
       
    98                     node = n.Constant(etype, 'etype')
       
    99                 comp = mytyperel.children[1]
       
   100                 comp.replace(comp.children[0], node)
       
   101             else:
       
   102                 # variable has already some strict types restriction. new
       
   103                 # possible types can only be a subset of existing ones, so only
       
   104                 # remove no more possible types
       
   105                 for cst in mytyperel.get_nodes(n.Constant):
       
   106                     if not cst.value in possibletypes:
       
   107                         cst.parent.remove(cst)
       
   108         else:
       
   109             # we have to add types restriction
       
   110             if stinfo.get('scope') is not None:
       
   111                 rel = var.scope.add_type_restriction(var, possibletypes)
       
   112             else:
       
   113                 # tree is not annotated yet, no scope set so add the restriction
       
   114                 # to the root
       
   115                 rel = newroot.add_type_restriction(var, possibletypes)
       
   116             stinfo['typerel'] = rel
       
   117         stinfo['possibletypes'] = possibletypes
       
   118 
       
   119 
       
   120 def remove_solutions(origsolutions, solutions, defined):
       
   121     """when a rqlst has been generated from another by introducing security
       
   122     assertions, this method returns solutions which are contained in orig
       
   123     solutions
       
   124     """
       
   125     newsolutions = []
       
   126     for origsol in origsolutions:
       
   127         for newsol in solutions[:]:
       
   128             for var, etype in origsol.items():
       
   129                 try:
       
   130                     if newsol[var] != etype:
       
   131                         try:
       
   132                             defined[var].stinfo['possibletypes'].remove(newsol[var])
       
   133                         except KeyError:
       
   134                             pass
       
   135                         break
       
   136                 except KeyError:
       
   137                     # variable has been rewritten
       
   138                     continue
       
   139             else:
       
   140                 newsolutions.append(newsol)
       
   141                 solutions.remove(newsol)
       
   142     return newsolutions
       
   143 
       
   144 
       
   145 def _add_noinvariant(noinvariant, restricted, select, nbtrees):
       
   146     # a variable can actually be invariant if it has not been restricted for
       
   147     # security reason or if security assertion hasn't modified the possible
       
   148     # solutions for the query
       
   149     for vname in restricted:
       
   150         try:
       
   151             var = select.defined_vars[vname]
       
   152         except KeyError:
       
   153             # this is an alias
       
   154             continue
       
   155         if nbtrees != 1 or len(var.stinfo['possibletypes']) != 1:
       
   156             noinvariant.add(var)
       
   157 
       
   158 
       
   159 def _expand_selection(terms, selected, aliases, select, newselect):
       
   160     for term in terms:
       
   161         for vref in term.iget_nodes(n.VariableRef):
       
   162             if not vref.name in selected:
       
   163                 select.append_selected(vref)
       
   164                 colalias = newselect.get_variable(vref.name, len(aliases))
       
   165                 aliases.append(n.VariableRef(colalias))
       
   166                 selected.add(vref.name)
       
   167 
       
   168 def _has_multiple_cardinality(etypes, rdef, ttypes_func, cardindex):
       
   169     """return True if relation definitions from entity types (`etypes`) to
       
   170     target types returned by the `ttypes_func` function all have single (1 or ?)
       
   171     cardinality.
       
   172     """
       
   173     for etype in etypes:
       
   174         for ttype in ttypes_func(etype):
       
   175             if rdef(etype, ttype).cardinality[cardindex] in '+*':
       
   176                 return True
       
   177     return False
       
   178 
       
   179 def _compatible_relation(relations, stmt, sniprel):
       
   180     """Search among given rql relation nodes if there is one 'compatible' with the
       
   181     snippet relation, and return it if any, else None.
       
   182 
       
   183     A relation is compatible if it:
       
   184     * belongs to the currently processed statement,
       
   185     * isn't negged (i.e. direct parent is a NOT node)
       
   186     * isn't optional (outer join) or similarly as the snippet relation
       
   187     """
       
   188     for rel in relations:
       
   189         # don't share if relation's scope is not the current statement
       
   190         if rel.scope is not stmt:
       
   191             continue
       
   192         # don't share neged relation
       
   193         if rel.neged(strict=True):
       
   194             continue
       
   195         # don't share optional relation, unless the snippet relation is
       
   196         # similarly optional
       
   197         if rel.optional and rel.optional != sniprel.optional:
       
   198             continue
       
   199         return rel
       
   200     return None
       
   201 
       
   202 
       
   203 def iter_relations(stinfo):
       
   204     # this is a function so that test may return relation in a predictable order
       
   205     return stinfo['relations'] - stinfo['rhsrelations']
       
   206 
       
   207 
       
   208 class Unsupported(Exception):
       
   209     """raised when an rql expression can't be inserted in some rql query
       
   210     because it create an unresolvable query (eg no solutions found)
       
   211     """
       
   212 
       
   213 class VariableFromSubQuery(Exception):
       
   214     """flow control exception to indicate that a variable is coming from a
       
   215     subquery, and let parent act accordingly
       
   216     """
       
   217     def __init__(self, variable):
       
   218         self.variable = variable
       
   219 
       
   220 
       
   221 class RQLRewriter(object):
       
   222     """Insert some rql snippets into another rql syntax tree, for security /
       
   223     relation vocabulary. This implies that it should only restrict results of
       
   224     the original query, not generate new ones. Hence, inserted snippets are
       
   225     inserted under an EXISTS node.
       
   226 
       
   227     This class *isn't thread safe*.
       
   228     """
       
   229 
       
   230     def __init__(self, session):
       
   231         self.session = session
       
   232         vreg = session.vreg
       
   233         self.schema = vreg.schema
       
   234         self.annotate = vreg.rqlhelper.annotate
       
   235         self._compute_solutions = vreg.solutions
       
   236 
       
   237     def compute_solutions(self):
       
   238         self.annotate(self.select)
       
   239         try:
       
   240             self._compute_solutions(self.session, self.select, self.kwargs)
       
   241         except TypeResolverException:
       
   242             raise Unsupported(str(self.select))
       
   243         if len(self.select.solutions) < len(self.solutions):
       
   244             raise Unsupported()
       
   245 
       
   246     def insert_local_checks(self, select, kwargs,
       
   247                             localchecks, restricted, noinvariant):
       
   248         """
       
   249         select: the rql syntax tree Select node
       
   250         kwargs: query arguments
       
   251 
       
   252         localchecks: {(('Var name', (rqlexpr1, rqlexpr2)),
       
   253                        ('Var name1', (rqlexpr1, rqlexpr23))): [solution]}
       
   254 
       
   255               (see querier._check_permissions docstring for more information)
       
   256 
       
   257         restricted: set of variable names to which an rql expression has to be
       
   258               applied
       
   259 
       
   260         noinvariant: set of variable names that can't be considered has
       
   261               invariant due to security reason (will be filed by this method)
       
   262         """
       
   263         nbtrees = len(localchecks)
       
   264         myunion = union = select.parent
       
   265         # transform in subquery when len(localchecks)>1 and groups
       
   266         if nbtrees > 1 and (select.orderby or select.groupby or
       
   267                             select.having or select.has_aggregat or
       
   268                             select.distinct or
       
   269                             select.limit or select.offset):
       
   270             newselect = stmts.Select()
       
   271             # only select variables in subqueries
       
   272             origselection = select.selection
       
   273             select.select_only_variables()
       
   274             select.has_aggregat = False
       
   275             # create subquery first so correct node are used on copy
       
   276             # (eg ColumnAlias instead of Variable)
       
   277             aliases = [n.VariableRef(newselect.get_variable(vref.name, i))
       
   278                        for i, vref in enumerate(select.selection)]
       
   279             selected = set(vref.name for vref in aliases)
       
   280             # now copy original selection and groups
       
   281             for term in origselection:
       
   282                 newselect.append_selected(term.copy(newselect))
       
   283             if select.orderby:
       
   284                 sortterms = []
       
   285                 for sortterm in select.orderby:
       
   286                     sortterms.append(sortterm.copy(newselect))
       
   287                     for fnode in sortterm.get_nodes(n.Function):
       
   288                         if fnode.name == 'FTIRANK':
       
   289                             # we've to fetch the has_text relation as well
       
   290                             var = fnode.children[0].variable
       
   291                             rel = next(iter(var.stinfo['ftirels']))
       
   292                             assert not rel.ored(), 'unsupported'
       
   293                             newselect.add_restriction(rel.copy(newselect))
       
   294                             # remove relation from the orig select and
       
   295                             # cleanup variable stinfo
       
   296                             rel.parent.remove(rel)
       
   297                             var.stinfo['ftirels'].remove(rel)
       
   298                             var.stinfo['relations'].remove(rel)
       
   299                             # XXX not properly re-annotated after security insertion?
       
   300                             newvar = newselect.get_variable(var.name)
       
   301                             newvar.stinfo.setdefault('ftirels', set()).add(rel)
       
   302                             newvar.stinfo.setdefault('relations', set()).add(rel)
       
   303                 newselect.set_orderby(sortterms)
       
   304                 _expand_selection(select.orderby, selected, aliases, select, newselect)
       
   305                 select.orderby = () # XXX dereference?
       
   306             if select.groupby:
       
   307                 newselect.set_groupby([g.copy(newselect) for g in select.groupby])
       
   308                 _expand_selection(select.groupby, selected, aliases, select, newselect)
       
   309                 select.groupby = () # XXX dereference?
       
   310             if select.having:
       
   311                 newselect.set_having([g.copy(newselect) for g in select.having])
       
   312                 _expand_selection(select.having, selected, aliases, select, newselect)
       
   313                 select.having = () # XXX dereference?
       
   314             if select.limit:
       
   315                 newselect.limit = select.limit
       
   316                 select.limit = None
       
   317             if select.offset:
       
   318                 newselect.offset = select.offset
       
   319                 select.offset = 0
       
   320             myunion = stmts.Union()
       
   321             newselect.set_with([n.SubQuery(aliases, myunion)], check=False)
       
   322             newselect.distinct = select.distinct
       
   323             solutions = [sol.copy() for sol in select.solutions]
       
   324             cleanup_solutions(newselect, solutions)
       
   325             newselect.set_possible_types(solutions)
       
   326             # if some solutions doesn't need rewriting, insert original
       
   327             # select as first union subquery
       
   328             if () in localchecks:
       
   329                 myunion.append(select)
       
   330             # we're done, replace original select by the new select with
       
   331             # subqueries (more added in the loop below)
       
   332             union.replace(select, newselect)
       
   333         elif not () in localchecks:
       
   334             union.remove(select)
       
   335         for lcheckdef, lchecksolutions in localchecks.items():
       
   336             if not lcheckdef:
       
   337                 continue
       
   338             myrqlst = select.copy(solutions=lchecksolutions)
       
   339             myunion.append(myrqlst)
       
   340             # in-place rewrite + annotation / simplification
       
   341             lcheckdef = [({var: 'X'}, rqlexprs) for var, rqlexprs in lcheckdef]
       
   342             self.rewrite(myrqlst, lcheckdef, kwargs)
       
   343             _add_noinvariant(noinvariant, restricted, myrqlst, nbtrees)
       
   344         if () in localchecks:
       
   345             select.set_possible_types(localchecks[()])
       
   346             add_types_restriction(self.schema, select)
       
   347             _add_noinvariant(noinvariant, restricted, select, nbtrees)
       
   348         self.annotate(union)
       
   349 
       
   350     def rewrite(self, select, snippets, kwargs, existingvars=None):
       
   351         """
       
   352         snippets: (varmap, list of rql expression)
       
   353                   with varmap a *dict* {select var: snippet var}
       
   354         """
       
   355         self.select = select
       
   356         # remove_solutions used below require a copy
       
   357         self.solutions = solutions = select.solutions[:]
       
   358         self.kwargs = kwargs
       
   359         self.u_varname = None
       
   360         self.removing_ambiguity = False
       
   361         self.exists_snippet = {}
       
   362         self.pending_keys = []
       
   363         self.existingvars = existingvars
       
   364         # we have to annotate the rqlst before inserting snippets, even though
       
   365         # we'll have to redo it later
       
   366         self.annotate(select)
       
   367         self.insert_snippets(snippets)
       
   368         if not self.exists_snippet and self.u_varname:
       
   369             # U has been inserted than cancelled, cleanup
       
   370             select.undefine_variable(select.defined_vars[self.u_varname])
       
   371         # clean solutions according to initial solutions
       
   372         newsolutions = remove_solutions(solutions, select.solutions,
       
   373                                         select.defined_vars)
       
   374         assert len(newsolutions) >= len(solutions), (
       
   375             'rewritten rql %s has lost some solutions, there is probably '
       
   376             'something wrong in your schema permission (for instance using a '
       
   377             'RQLExpression which inserts a relation which doesn\'t exist in '
       
   378             'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
       
   379             select, solutions, newsolutions))
       
   380         if len(newsolutions) > len(solutions):
       
   381             newsolutions = self.remove_ambiguities(snippets, newsolutions)
       
   382             assert newsolutions
       
   383         select.solutions = newsolutions
       
   384         add_types_restriction(self.schema, select)
       
   385 
       
   386     def insert_snippets(self, snippets, varexistsmap=None):
       
   387         self.rewritten = {}
       
   388         for varmap, rqlexprs in snippets:
       
   389             if isinstance(varmap, dict):
       
   390                 varmap = tuple(sorted(varmap.items()))
       
   391             else:
       
   392                 assert isinstance(varmap, tuple), varmap
       
   393             if varexistsmap is not None and not varmap in varexistsmap:
       
   394                 continue
       
   395             self.insert_varmap_snippets(varmap, rqlexprs, varexistsmap)
       
   396 
       
   397     def init_from_varmap(self, varmap, varexistsmap=None):
       
   398         self.varmap = varmap
       
   399         self.revvarmap = {}
       
   400         self.varinfos = []
       
   401         for i, (selectvar, snippetvar) in enumerate(varmap):
       
   402             assert snippetvar in 'SOX'
       
   403             self.revvarmap[snippetvar] = (selectvar, i)
       
   404             vi = {}
       
   405             self.varinfos.append(vi)
       
   406             try:
       
   407                 vi['const'] = int(selectvar)
       
   408                 vi['rhs_rels'] = vi['lhs_rels'] = {}
       
   409             except ValueError:
       
   410                 try:
       
   411                     vi['stinfo'] = sti = self.select.defined_vars[selectvar].stinfo
       
   412                 except KeyError:
       
   413                     vi['stinfo'] = sti = self._subquery_variable(selectvar)
       
   414                 if varexistsmap is None:
       
   415                     # build an index for quick access to relations
       
   416                     vi['rhs_rels'] = {}
       
   417                     for rel in sti.get('rhsrelations', []):
       
   418                         vi['rhs_rels'].setdefault(rel.r_type, []).append(rel)
       
   419                     vi['lhs_rels'] = {}
       
   420                     for rel in sti.get('relations', []):
       
   421                         if not rel in sti.get('rhsrelations', []):
       
   422                             vi['lhs_rels'].setdefault(rel.r_type, []).append(rel)
       
   423                 else:
       
   424                     vi['rhs_rels'] = vi['lhs_rels'] = {}
       
   425 
       
   426     def _subquery_variable(self, selectvar):
       
   427         raise VariableFromSubQuery(selectvar)
       
   428 
       
   429     def insert_varmap_snippets(self, varmap, rqlexprs, varexistsmap):
       
   430         try:
       
   431             self.init_from_varmap(varmap, varexistsmap)
       
   432         except VariableFromSubQuery as ex:
       
   433             # variable may have been moved to a newly inserted subquery
       
   434             # we should insert snippet in that subquery
       
   435             subquery = self.select.aliases[ex.variable].query
       
   436             assert len(subquery.children) == 1, subquery
       
   437             subselect = subquery.children[0]
       
   438             RQLRewriter(self.session).rewrite(subselect, [(varmap, rqlexprs)],
       
   439                                               self.kwargs)
       
   440             return
       
   441         self._insert_scope = None
       
   442         previous = None
       
   443         inserted = False
       
   444         for rqlexpr in rqlexprs:
       
   445             self.current_expr = rqlexpr
       
   446             if varexistsmap is None:
       
   447                 try:
       
   448                     new = self.insert_snippet(varmap, rqlexpr.snippet_rqlst, previous)
       
   449                 except Unsupported:
       
   450                     continue
       
   451                 inserted = True
       
   452                 if new is not None and self._insert_scope is None:
       
   453                     self.exists_snippet[rqlexpr] = new
       
   454                 previous = previous or new
       
   455             else:
       
   456                 # called to reintroduce snippet due to ambiguity creation,
       
   457                 # so skip snippets which are not introducing this ambiguity
       
   458                 exists = varexistsmap[varmap]
       
   459                 if self.exists_snippet.get(rqlexpr) is exists:
       
   460                     self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists)
       
   461         if varexistsmap is None and not inserted:
       
   462             # no rql expression found matching rql solutions. User has no access right
       
   463             raise Unauthorized() # XXX may also be because of bad constraints in schema definition
       
   464 
       
   465     def insert_snippet(self, varmap, snippetrqlst, previous=None):
       
   466         new = snippetrqlst.where.accept(self)
       
   467         existing = self.existingvars
       
   468         self.existingvars = None
       
   469         try:
       
   470             return self._insert_snippet(varmap, previous, new)
       
   471         finally:
       
   472             self.existingvars = existing
       
   473 
       
   474     def _inserted_root(self, new):
       
   475         if not isinstance(new, (n.Exists, n.Not)):
       
   476             new = n.Exists(new)
       
   477         return new
       
   478 
       
   479     def _insert_snippet(self, varmap, previous, new):
       
   480         """insert `new` snippet into the syntax tree, which have been rewritten
       
   481         using `varmap`. In cases where an action is protected by several rql
       
   482         expresssion, `previous` will be the first rql expression which has been
       
   483         inserted, and so should be ORed with the following expressions.
       
   484         """
       
   485         if new is not None:
       
   486             if self._insert_scope is None:
       
   487                 insert_scope = None
       
   488                 for vi in self.varinfos:
       
   489                     scope = vi.get('stinfo', {}).get('scope', self.select)
       
   490                     if insert_scope is None:
       
   491                         insert_scope = scope
       
   492                     else:
       
   493                         insert_scope = common_parent(scope, insert_scope)
       
   494             else:
       
   495                 insert_scope = self._insert_scope
       
   496             if self._insert_scope is None and any(vi.get('stinfo', {}).get('optrelations')
       
   497                                                   for vi in self.varinfos):
       
   498                 assert previous is None
       
   499                 self._insert_scope, new = self.snippet_subquery(varmap, new)
       
   500                 self.insert_pending()
       
   501                 #self._insert_scope = None
       
   502                 return new
       
   503             new = self._inserted_root(new)
       
   504             if previous is None:
       
   505                 insert_scope.add_restriction(new)
       
   506             else:
       
   507                 grandpa = previous.parent
       
   508                 or_ = n.Or(previous, new)
       
   509                 grandpa.replace(previous, or_)
       
   510             if not self.removing_ambiguity:
       
   511                 try:
       
   512                     self.compute_solutions()
       
   513                 except Unsupported:
       
   514                     # some solutions have been lost, can't apply this rql expr
       
   515                     if previous is None:
       
   516                         self.current_statement().remove_node(new, undefine=True)
       
   517                     else:
       
   518                         grandpa.replace(or_, previous)
       
   519                         self._cleanup_inserted(new)
       
   520                     raise
       
   521                 else:
       
   522                     with tempattr(self, '_insert_scope', new):
       
   523                         self.insert_pending()
       
   524             return new
       
   525         self.insert_pending()
       
   526 
       
   527     def insert_pending(self):
       
   528         """pending_keys hold variable referenced by U has_<action>_permission X
       
   529         relation.
       
   530 
       
   531         Once the snippet introducing this has been inserted and solutions
       
   532         recomputed, we have to insert snippet defined for <action> of entity
       
   533         types taken by X
       
   534         """
       
   535         stmt = self.current_statement()
       
   536         while self.pending_keys:
       
   537             key, action = self.pending_keys.pop()
       
   538             try:
       
   539                 varname = self.rewritten[key]
       
   540             except KeyError:
       
   541                 try:
       
   542                     varname = self.revvarmap[key[-1]][0]
       
   543                 except KeyError:
       
   544                     # variable isn't used anywhere else, we can't insert security
       
   545                     raise Unauthorized()
       
   546             ptypes = stmt.defined_vars[varname].stinfo['possibletypes']
       
   547             if len(ptypes) > 1:
       
   548                 # XXX dunno how to handle this
       
   549                 self.session.error(
       
   550                     'cant check security of %s, ambigous type for %s in %s',
       
   551                     stmt, varname, key[0]) # key[0] == the rql expression
       
   552                 raise Unauthorized()
       
   553             etype = next(iter(ptypes))
       
   554             eschema = self.schema.eschema(etype)
       
   555             if not eschema.has_perm(self.session, action):
       
   556                 rqlexprs = eschema.get_rqlexprs(action)
       
   557                 if not rqlexprs:
       
   558                     raise Unauthorized()
       
   559                 self.insert_snippets([({varname: 'X'}, rqlexprs)])
       
   560 
       
   561     def snippet_subquery(self, varmap, transformedsnippet):
       
   562         """introduce the given snippet in a subquery"""
       
   563         subselect = stmts.Select()
       
   564         snippetrqlst = n.Exists(transformedsnippet.copy(subselect))
       
   565         get_rschema = self.schema.rschema
       
   566         aliases = []
       
   567         done = set()
       
   568         for i, (selectvar, _) in enumerate(varmap):
       
   569             need_null_test = False
       
   570             subselectvar = subselect.get_variable(selectvar)
       
   571             subselect.append_selected(n.VariableRef(subselectvar))
       
   572             aliases.append(selectvar)
       
   573             todo = [(selectvar, self.varinfos[i]['stinfo'])]
       
   574             while todo:
       
   575                 varname, stinfo = todo.pop()
       
   576                 done.add(varname)
       
   577                 for rel in iter_relations(stinfo):
       
   578                     if rel in done:
       
   579                         continue
       
   580                     done.add(rel)
       
   581                     rschema = get_rschema(rel.r_type)
       
   582                     if rschema.final or rschema.inlined:
       
   583                         rel.children[0].name = varname # XXX explain why
       
   584                         subselect.add_restriction(rel.copy(subselect))
       
   585                         for vref in rel.children[1].iget_nodes(n.VariableRef):
       
   586                             if isinstance(vref.variable, n.ColumnAlias):
       
   587                                 # XXX could probably be handled by generating the
       
   588                                 # subquery into the detected subquery
       
   589                                 raise BadSchemaDefinition(
       
   590                                     "cant insert security because of usage two inlined "
       
   591                                     "relations in this query. You should probably at "
       
   592                                     "least uninline %s" % rel.r_type)
       
   593                             subselect.append_selected(vref.copy(subselect))
       
   594                             aliases.append(vref.name)
       
   595                         self.select.remove_node(rel)
       
   596                         # when some inlined relation has to be copied in the
       
   597                         # subquery and that relation is optional, we need to
       
   598                         # test that either value is NULL or that the snippet
       
   599                         # condition is satisfied
       
   600                         if varname == selectvar and rel.optional and rschema.inlined:
       
   601                             need_null_test = True
       
   602                         # also, if some attributes or inlined relation of the
       
   603                         # object variable are accessed, we need to get all those
       
   604                         # from the subquery as well
       
   605                         if vref.name not in done and rschema.inlined:
       
   606                             # we can use vref here define in above for loop
       
   607                             ostinfo = vref.variable.stinfo
       
   608                             for orel in iter_relations(ostinfo):
       
   609                                 orschema = get_rschema(orel.r_type)
       
   610                                 if orschema.final or orschema.inlined:
       
   611                                     todo.append( (vref.name, ostinfo) )
       
   612                                     break
       
   613             if need_null_test:
       
   614                 snippetrqlst = n.Or(
       
   615                     n.make_relation(subselect.get_variable(selectvar), 'is',
       
   616                                     (None, None), n.Constant,
       
   617                                     operator='='),
       
   618                     snippetrqlst)
       
   619         subselect.add_restriction(snippetrqlst)
       
   620         if self.u_varname:
       
   621             # generate an identifier for the substitution
       
   622             argname = subselect.allocate_varname()
       
   623             while argname in self.kwargs:
       
   624                 argname = subselect.allocate_varname()
       
   625             subselect.add_constant_restriction(subselect.get_variable(self.u_varname),
       
   626                                                'eid', text_type(argname), 'Substitute')
       
   627             self.kwargs[argname] = self.session.user.eid
       
   628         add_types_restriction(self.schema, subselect, subselect,
       
   629                               solutions=self.solutions)
       
   630         myunion = stmts.Union()
       
   631         myunion.append(subselect)
       
   632         aliases = [n.VariableRef(self.select.get_variable(name, i))
       
   633                    for i, name in enumerate(aliases)]
       
   634         self.select.add_subquery(n.SubQuery(aliases, myunion), check=False)
       
   635         self._cleanup_inserted(transformedsnippet)
       
   636         try:
       
   637             self.compute_solutions()
       
   638         except Unsupported:
       
   639             # some solutions have been lost, can't apply this rql expr
       
   640             self.select.remove_subquery(self.select.with_[-1])
       
   641             raise
       
   642         return subselect, snippetrqlst
       
   643 
       
   644     def remove_ambiguities(self, snippets, newsolutions):
       
   645         # the snippet has introduced some ambiguities, we have to resolve them
       
   646         # "manually"
       
   647         variantes = self.build_variantes(newsolutions)
       
   648         # insert "is" where necessary
       
   649         varexistsmap = {}
       
   650         self.removing_ambiguity = True
       
   651         for (erqlexpr, varmap, oldvarname), etype in variantes[0].items():
       
   652             varname = self.rewritten[(erqlexpr, varmap, oldvarname)]
       
   653             var = self.select.defined_vars[varname]
       
   654             exists = var.references()[0].scope
       
   655             exists.add_constant_restriction(var, 'is', etype, 'etype')
       
   656             varexistsmap[varmap] = exists
       
   657         # insert ORED exists where necessary
       
   658         for variante in variantes[1:]:
       
   659             self.insert_snippets(snippets, varexistsmap)
       
   660             for key, etype in variante.items():
       
   661                 varname = self.rewritten[key]
       
   662                 try:
       
   663                     var = self.select.defined_vars[varname]
       
   664                 except KeyError:
       
   665                     # not a newly inserted variable
       
   666                     continue
       
   667                 exists = var.references()[0].scope
       
   668                 exists.add_constant_restriction(var, 'is', etype, 'etype')
       
   669         # recompute solutions
       
   670         self.compute_solutions()
       
   671         # clean solutions according to initial solutions
       
   672         return remove_solutions(self.solutions, self.select.solutions,
       
   673                                 self.select.defined_vars)
       
   674 
       
   675     def build_variantes(self, newsolutions):
       
   676         variantes = set()
       
   677         for sol in newsolutions:
       
   678             variante = []
       
   679             for key, newvar in self.rewritten.items():
       
   680                 variante.append( (key, sol[newvar]) )
       
   681             variantes.add(tuple(variante))
       
   682         # rebuild variantes as dict
       
   683         variantes = [dict(variante) for variante in variantes]
       
   684         # remove variable which have always the same type
       
   685         for key in self.rewritten:
       
   686             it = iter(variantes)
       
   687             etype = next(it)[key]
       
   688             for variante in it:
       
   689                 if variante[key] != etype:
       
   690                     break
       
   691             else:
       
   692                 for variante in variantes:
       
   693                     del variante[key]
       
   694         return variantes
       
   695 
       
   696     def _cleanup_inserted(self, node):
       
   697         # cleanup inserted variable references
       
   698         removed = set()
       
   699         for vref in node.iget_nodes(n.VariableRef):
       
   700             vref.unregister_reference()
       
   701             if not vref.variable.stinfo['references']:
       
   702                 # no more references, undefine the variable
       
   703                 del self.select.defined_vars[vref.name]
       
   704                 removed.add(vref.name)
       
   705         for key, newvar in list(self.rewritten.items()):
       
   706             if newvar in removed:
       
   707                 del self.rewritten[key]
       
   708 
       
   709 
       
   710     def _may_be_shared_with(self, sniprel, target):
       
   711         """if the snippet relation can be skipped to use a relation from the
       
   712         original query, return that relation node
       
   713         """
       
   714         if sniprel.neged(strict=True):
       
   715             return None # no way
       
   716         rschema = self.schema.rschema(sniprel.r_type)
       
   717         stmt = self.current_statement()
       
   718         for vi in self.varinfos:
       
   719             try:
       
   720                 if target == 'object':
       
   721                     orels = vi['lhs_rels'][sniprel.r_type]
       
   722                     cardindex = 0
       
   723                     ttypes_func = rschema.objects
       
   724                     rdef = rschema.rdef
       
   725                 else: # target == 'subject':
       
   726                     orels = vi['rhs_rels'][sniprel.r_type]
       
   727                     cardindex = 1
       
   728                     ttypes_func = rschema.subjects
       
   729                     rdef = lambda x, y: rschema.rdef(y, x)
       
   730             except KeyError:
       
   731                 # may be raised by vi['xhs_rels'][sniprel.r_type]
       
   732                 continue
       
   733             # if cardinality isn't in '?1', we can't ignore the snippet relation
       
   734             # and use variable from the original query
       
   735             if _has_multiple_cardinality(vi['stinfo']['possibletypes'], rdef,
       
   736                                          ttypes_func, cardindex):
       
   737                 continue
       
   738             orel = _compatible_relation(orels, stmt, sniprel)
       
   739             if orel is not None:
       
   740                 return orel
       
   741         return None
       
   742 
       
   743     def _use_orig_term(self, snippet_varname, term):
       
   744         key = (self.current_expr, self.varmap, snippet_varname)
       
   745         if key in self.rewritten:
       
   746             stmt = self.current_statement()
       
   747             insertedvar = stmt.defined_vars.pop(self.rewritten[key])
       
   748             for inserted_vref in insertedvar.references():
       
   749                 inserted_vref.parent.replace(inserted_vref, term.copy(stmt))
       
   750         self.rewritten[key] = term.name
       
   751 
       
   752     def _get_varname_or_term(self, vname):
       
   753         stmt = self.current_statement()
       
   754         if vname == 'U':
       
   755             stmt = self.select
       
   756             if self.u_varname is None:
       
   757                 self.u_varname = stmt.allocate_varname()
       
   758                 # generate an identifier for the substitution
       
   759                 argname = stmt.allocate_varname()
       
   760                 while argname in self.kwargs:
       
   761                     argname = stmt.allocate_varname()
       
   762                 # insert "U eid %(u)s"
       
   763                 stmt.add_constant_restriction(
       
   764                     stmt.get_variable(self.u_varname),
       
   765                     'eid', text_type(argname), 'Substitute')
       
   766                 self.kwargs[argname] = self.session.user.eid
       
   767             return self.u_varname
       
   768         key = (self.current_expr, self.varmap, vname)
       
   769         try:
       
   770             return self.rewritten[key]
       
   771         except KeyError:
       
   772             self.rewritten[key] = newvname = stmt.allocate_varname()
       
   773             return newvname
       
   774 
       
   775     # visitor methods ##########################################################
       
   776 
       
   777     def _visit_binary(self, node, cls):
       
   778         newnode = cls()
       
   779         for c in node.children:
       
   780             new = c.accept(self)
       
   781             if new is None:
       
   782                 continue
       
   783             newnode.append(new)
       
   784         if len(newnode.children) == 0:
       
   785             return None
       
   786         if len(newnode.children) == 1:
       
   787             return newnode.children[0]
       
   788         return newnode
       
   789 
       
   790     def _visit_unary(self, node, cls):
       
   791         newc = node.children[0].accept(self)
       
   792         if newc is None:
       
   793             return None
       
   794         newnode = cls()
       
   795         newnode.append(newc)
       
   796         return newnode
       
   797 
       
   798     def visit_and(self, node):
       
   799         return self._visit_binary(node, n.And)
       
   800 
       
   801     def visit_or(self, node):
       
   802         return self._visit_binary(node, n.Or)
       
   803 
       
   804     def visit_not(self, node):
       
   805         return self._visit_unary(node, n.Not)
       
   806 
       
   807     def visit_exists(self, node):
       
   808         return self._visit_unary(node, n.Exists)
       
   809 
       
   810     def keep_var(self, varname):
       
   811         if varname in 'SO':
       
   812             return varname in self.existingvars
       
   813         if varname == 'U':
       
   814             return True
       
   815         vargraph = self.current_expr.vargraph
       
   816         for existingvar in self.existingvars:
       
   817             #path = has_path(vargraph, varname, existingvar)
       
   818             if not varname in vargraph or has_path(vargraph, varname, existingvar):
       
   819                 return True
       
   820         # no path from this variable to an existing variable
       
   821         return False
       
   822 
       
   823     def visit_relation(self, node):
       
   824         lhs, rhs = node.get_variable_parts()
       
   825         # remove relations where an unexistant variable and or a variable linked
       
   826         # to an unexistant variable is used.
       
   827         if self.existingvars:
       
   828             if not self.keep_var(lhs.name):
       
   829                 return
       
   830         if node.r_type in ('has_add_permission', 'has_update_permission',
       
   831                            'has_delete_permission', 'has_read_permission'):
       
   832             assert lhs.name == 'U'
       
   833             action = node.r_type.split('_')[1]
       
   834             key = (self.current_expr, self.varmap, rhs.name)
       
   835             self.pending_keys.append( (key, action) )
       
   836             return
       
   837         if isinstance(rhs, n.VariableRef):
       
   838             if self.existingvars and not self.keep_var(rhs.name):
       
   839                 return
       
   840             if lhs.name in self.revvarmap and rhs.name != 'U':
       
   841                 orel = self._may_be_shared_with(node, 'object')
       
   842                 if orel is not None:
       
   843                     self._use_orig_term(rhs.name, orel.children[1].children[0])
       
   844                     return
       
   845             elif rhs.name in self.revvarmap and lhs.name != 'U':
       
   846                 orel = self._may_be_shared_with(node, 'subject')
       
   847                 if orel is not None:
       
   848                     self._use_orig_term(lhs.name, orel.children[0])
       
   849                     return
       
   850         rel = n.Relation(node.r_type, node.optional)
       
   851         for c in node.children:
       
   852             rel.append(c.accept(self))
       
   853         return rel
       
   854 
       
   855     def visit_comparison(self, node):
       
   856         cmp_ = n.Comparison(node.operator)
       
   857         for c in node.children:
       
   858             cmp_.append(c.accept(self))
       
   859         return cmp_
       
   860 
       
   861     def visit_mathexpression(self, node):
       
   862         cmp_ = n.MathExpression(node.operator)
       
   863         for c in node.children:
       
   864             cmp_.append(c.accept(self))
       
   865         return cmp_
       
   866 
       
   867     def visit_function(self, node):
       
   868         """generate filter name for a function"""
       
   869         function_ = n.Function(node.name)
       
   870         for c in node.children:
       
   871             function_.append(c.accept(self))
       
   872         return function_
       
   873 
       
   874     def visit_constant(self, node):
       
   875         """generate filter name for a constant"""
       
   876         return n.Constant(node.value, node.type)
       
   877 
       
   878     def visit_variableref(self, node):
       
   879         """get the sql name for a variable reference"""
       
   880         stmt = self.current_statement()
       
   881         if node.name in self.revvarmap:
       
   882             selectvar, index = self.revvarmap[node.name]
       
   883             vi = self.varinfos[index]
       
   884             if vi.get('const') is not None:
       
   885                 return n.Constant(vi['const'], 'Int')
       
   886             return n.VariableRef(stmt.get_variable(selectvar))
       
   887         vname_or_term = self._get_varname_or_term(node.name)
       
   888         if isinstance(vname_or_term, string_types):
       
   889             return n.VariableRef(stmt.get_variable(vname_or_term))
       
   890         # shared term
       
   891         return vname_or_term.copy(stmt)
       
   892 
       
   893     def current_statement(self):
       
   894         if self._insert_scope is None:
       
   895             return self.select
       
   896         return self._insert_scope.stmt
       
   897 
       
   898 
       
   899 class RQLRelationRewriter(RQLRewriter):
       
   900     """Insert some rql snippets into another rql syntax tree, replacing computed
       
   901     relations by their associated rule.
       
   902 
       
   903     This class *isn't thread safe*.
       
   904     """
       
   905     def __init__(self, session):
       
   906         super(RQLRelationRewriter, self).__init__(session)
       
   907         self.rules = {}
       
   908         for rschema in self.schema.iter_computed_relations():
       
   909             self.rules[rschema.type] = RRQLExpression(rschema.rule)
       
   910 
       
   911     def rewrite(self, union, kwargs=None):
       
   912         self.kwargs = kwargs
       
   913         self.removing_ambiguity = False
       
   914         self.existingvars = None
       
   915         self.pending_keys = None
       
   916         for relation in union.iget_nodes(n.Relation):
       
   917             if relation.r_type in self.rules:
       
   918                 self.select = relation.stmt
       
   919                 self.solutions = solutions = self.select.solutions[:]
       
   920                 self.current_expr = self.rules[relation.r_type]
       
   921                 self._insert_scope = relation.scope
       
   922                 self.rewritten = {}
       
   923                 lhs, rhs = relation.get_variable_parts()
       
   924                 varmap = {lhs.name: 'S', rhs.name: 'O'}
       
   925                 self.init_from_varmap(tuple(sorted(varmap.items())))
       
   926                 self.insert_snippet(varmap, self.current_expr.snippet_rqlst)
       
   927                 self.select.remove_node(relation)
       
   928 
       
   929     def _subquery_variable(self, selectvar):
       
   930         return self.select.aliases[selectvar].stinfo
       
   931 
       
   932     def _inserted_root(self, new):
       
   933         return new