rqlrewrite.py
branch3.5
changeset 3240 8604a15995d1
parent 1977 606923dff11b
child 3254 fe7ec595751c
equal deleted inserted replaced
3239:1ceac4cd4fb7 3240:8604a15995d1
       
     1 """RQL rewriting utilities : insert rql expression snippets into rql syntax
       
     2 tree.
       
     3 
       
     4 This is used for instance for read security checking in the repository.
       
     5 
       
     6 :organization: Logilab
       
     7 :copyright: 2007-2009 LOGILAB S.A. (Paris, FRANCE), license is LGPL v2.
       
     8 :contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
       
     9 :license: GNU Lesser General Public License, v2.1 - http://www.gnu.org/licenses
       
    10 """
       
    11 __docformat__ = "restructuredtext en"
       
    12 
       
    13 from rql import nodes as n, stmts, TypeResolverException
       
    14 
       
    15 from logilab.common.compat import any
       
    16 
       
    17 from cubicweb import Unauthorized, server, typed_eid
       
    18 from cubicweb.server.ssplanner import add_types_restriction
       
    19 
       
    20 
       
    21 def remove_solutions(origsolutions, solutions, defined):
       
    22     """when a rqlst has been generated from another by introducing security
       
    23     assertions, this method returns solutions which are contained in orig
       
    24     solutions
       
    25     """
       
    26     newsolutions = []
       
    27     for origsol in origsolutions:
       
    28         for newsol in solutions[:]:
       
    29             for var, etype in origsol.items():
       
    30                 try:
       
    31                     if newsol[var] != etype:
       
    32                         try:
       
    33                             defined[var].stinfo['possibletypes'].remove(newsol[var])
       
    34                         except KeyError:
       
    35                             pass
       
    36                         break
       
    37                 except KeyError:
       
    38                     # variable has been rewritten
       
    39                     continue
       
    40             else:
       
    41                 newsolutions.append(newsol)
       
    42                 solutions.remove(newsol)
       
    43     return newsolutions
       
    44 
       
    45 
       
    46 class Unsupported(Exception): pass
       
    47 
       
    48 
       
    49 class RQLRewriter(object):
       
    50     """insert some rql snippets into another rql syntax tree
       
    51 
       
    52     this class *isn't thread safe*
       
    53     """
       
    54 
       
    55     def __init__(self, session):
       
    56         self.session = session
       
    57         vreg = session.vreg
       
    58         self.schema = vreg.schema
       
    59         self.annotate = vreg.rqlhelper.annotate
       
    60         self._compute_solutions = vreg.solutions
       
    61 
       
    62     def compute_solutions(self):
       
    63         self.annotate(self.select)
       
    64         try:
       
    65             self._compute_solutions(self.session, self.select, self.kwargs)
       
    66         except TypeResolverException:
       
    67             raise Unsupported(str(self.select))
       
    68         if len(self.select.solutions) < len(self.solutions):
       
    69             raise Unsupported()
       
    70 
       
    71     def rewrite(self, select, snippets, solutions, kwargs):
       
    72         """
       
    73         snippets: (varmap, list of rql expression)
       
    74                   with varmap a *tuple* (select var, snippet var)
       
    75         """
       
    76         if server.DEBUG:
       
    77             print '---- rewrite', select, snippets, solutions
       
    78         self.select = self.insert_scope = select
       
    79         self.solutions = solutions
       
    80         self.kwargs = kwargs
       
    81         self.u_varname = None
       
    82         self.removing_ambiguity = False
       
    83         self.exists_snippet = {}
       
    84         self.pending_keys = []
       
    85         # we have to annotate the rqlst before inserting snippets, even though
       
    86         # we'll have to redo it latter
       
    87         self.annotate(select)
       
    88         self.insert_snippets(snippets)
       
    89         if not self.exists_snippet and self.u_varname:
       
    90             # U has been inserted than cancelled, cleanup
       
    91             select.undefine_variable(select.defined_vars[self.u_varname])
       
    92         # clean solutions according to initial solutions
       
    93         newsolutions = remove_solutions(solutions, select.solutions,
       
    94                                         select.defined_vars)
       
    95         assert len(newsolutions) >= len(solutions), (
       
    96             'rewritten rql %s has lost some solutions, there is probably '
       
    97             'something wrong in your schema permission (for instance using a '
       
    98             'RQLExpression which insert a relation which doesn\'t exists in '
       
    99             'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
       
   100             select, solutions, newsolutions))
       
   101         if len(newsolutions) > len(solutions):
       
   102             newsolutions = self.remove_ambiguities(snippets, newsolutions)
       
   103         select.solutions = newsolutions
       
   104         add_types_restriction(self.schema, select)
       
   105         if server.DEBUG:
       
   106             print '---- rewriten', select
       
   107 
       
   108     def insert_snippets(self, snippets, varexistsmap=None):
       
   109         self.rewritten = {}
       
   110         for varmap, rqlexprs in snippets:
       
   111             if varexistsmap is not None and not varmap in varexistsmap:
       
   112                 continue
       
   113             self.varmap = varmap
       
   114             selectvar, snippetvar = varmap
       
   115             assert snippetvar in 'SOX'
       
   116             self.revvarmap = {snippetvar: selectvar}
       
   117             self.varinfo = vi = {}
       
   118             try:
       
   119                 vi['const'] = typed_eid(selectvar) # XXX gae
       
   120                 vi['rhs_rels'] = vi['lhs_rels'] = {}
       
   121             except ValueError:
       
   122                 vi['stinfo'] = sti = self.select.defined_vars[selectvar].stinfo
       
   123                 if varexistsmap is None:
       
   124                     vi['rhs_rels'] = dict( (r.r_type, r) for r in sti['rhsrelations'])
       
   125                     vi['lhs_rels'] = dict( (r.r_type, r) for r in sti['relations']
       
   126                                            if not r in sti['rhsrelations'])
       
   127                 else:
       
   128                     vi['rhs_rels'] = vi['lhs_rels'] = {}
       
   129             parent = None
       
   130             inserted = False
       
   131             for rqlexpr in rqlexprs:
       
   132                 self.current_expr = rqlexpr
       
   133                 if varexistsmap is None:
       
   134                     try:
       
   135                         new = self.insert_snippet(varmap, rqlexpr.snippet_rqlst, parent)
       
   136                     except Unsupported:
       
   137                         import traceback
       
   138                         traceback.print_exc()
       
   139                         continue
       
   140                     inserted = True
       
   141                     if new is not None:
       
   142                         self.exists_snippet[rqlexpr] = new
       
   143                     parent = parent or new
       
   144                 else:
       
   145                     # called to reintroduce snippet due to ambiguity creation,
       
   146                     # so skip snippets which are not introducing this ambiguity
       
   147                     exists = varexistsmap[varmap]
       
   148                     if self.exists_snippet[rqlexpr] is exists:
       
   149                         self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists)
       
   150             if varexistsmap is None and not inserted:
       
   151                 # no rql expression found matching rql solutions. User has no access right
       
   152                 raise Unauthorized(str((varmap, str(self.select), [expr.expression for expr in rqlexprs])))
       
   153 
       
   154     def insert_snippet(self, varmap, snippetrqlst, parent=None):
       
   155         new = snippetrqlst.where.accept(self)
       
   156         if new is not None:
       
   157             if self.varinfo.get('stinfo', {}).get('optrelations'):
       
   158                 assert parent is None
       
   159                 self.insert_scope = self.snippet_subquery(varmap, new)
       
   160                 self.insert_pending()
       
   161                 self.insert_scope = self.select
       
   162                 return
       
   163             new = n.Exists(new)
       
   164             if parent is None:
       
   165                 self.insert_scope.add_restriction(new)
       
   166             else:
       
   167                 grandpa = parent.parent
       
   168                 or_ = n.Or(parent, new)
       
   169                 grandpa.replace(parent, or_)
       
   170             if not self.removing_ambiguity:
       
   171                 try:
       
   172                     self.compute_solutions()
       
   173                 except Unsupported:
       
   174                     # some solutions have been lost, can't apply this rql expr
       
   175                     if parent is None:
       
   176                         self.select.remove_node(new, undefine=True)
       
   177                     else:
       
   178                         parent.parent.replace(or_, or_.children[0])
       
   179                         self._cleanup_inserted(new)
       
   180                     raise
       
   181                 else:
       
   182                     self.insert_scope = new
       
   183                     self.insert_pending()
       
   184                     self.insert_scope = self.select
       
   185             return new
       
   186         self.insert_pending()
       
   187 
       
   188     def insert_pending(self):
       
   189         """pending_keys hold variable referenced by U has_<action>_permission X
       
   190         relation.
       
   191 
       
   192         Once the snippet introducing this has been inserted and solutions
       
   193         recomputed, we have to insert snippet defined for <action> of entity
       
   194         types taken by X
       
   195         """
       
   196         while self.pending_keys:
       
   197             key, action = self.pending_keys.pop()
       
   198             try:
       
   199                 varname = self.rewritten[key]
       
   200             except KeyError:
       
   201                 try:
       
   202                     varname = self.revvarmap[key[-1]]
       
   203                 except KeyError:
       
   204                     # variable isn't used anywhere else, we can't insert security
       
   205                     raise Unauthorized()
       
   206             ptypes = self.select.defined_vars[varname].stinfo['possibletypes']
       
   207             if len(ptypes) > 1:
       
   208                 # XXX dunno how to handle this
       
   209                 self.session.error(
       
   210                     'cant check security of %s, ambigous type for %s in %s',
       
   211                     self.select, varname, key[0]) # key[0] == the rql expression
       
   212                 raise Unauthorized()
       
   213             etype = iter(ptypes).next()
       
   214             eschema = self.schema.eschema(etype)
       
   215             if not eschema.has_perm(self.session, action):
       
   216                 rqlexprs = eschema.get_rqlexprs(action)
       
   217                 if not rqlexprs:
       
   218                     raise Unauthorised()
       
   219                 self.insert_snippets([((varname, 'X'), rqlexprs)])
       
   220 
       
   221     def snippet_subquery(self, varmap, transformedsnippet):
       
   222         """introduce the given snippet in a subquery"""
       
   223         subselect = stmts.Select()
       
   224         selectvar, snippetvar = varmap
       
   225         subselect.append_selected(n.VariableRef(
       
   226             subselect.get_variable(selectvar)))
       
   227         aliases = [selectvar]
       
   228         subselect.add_restriction(transformedsnippet.copy(subselect))
       
   229         stinfo = self.varinfo['stinfo']
       
   230         for rel in stinfo['relations']:
       
   231             rschema = self.schema.rschema(rel.r_type)
       
   232             if rschema.is_final() or (rschema.inlined and
       
   233                                       not rel in stinfo['rhsrelations']):
       
   234                 self.select.remove_node(rel)
       
   235                 rel.children[0].name = selectvar
       
   236                 subselect.add_restriction(rel.copy(subselect))
       
   237                 for vref in rel.children[1].iget_nodes(n.VariableRef):
       
   238                     subselect.append_selected(vref.copy(subselect))
       
   239                     aliases.append(vref.name)
       
   240         if self.u_varname:
       
   241             # generate an identifier for the substitution
       
   242             argname = subselect.allocate_varname()
       
   243             while argname in self.kwargs:
       
   244                 argname = subselect.allocate_varname()
       
   245             subselect.add_constant_restriction(subselect.get_variable(self.u_varname),
       
   246                                                'eid', unicode(argname), 'Substitute')
       
   247             self.kwargs[argname] = self.session.user.eid
       
   248         add_types_restriction(self.schema, subselect, subselect,
       
   249                               solutions=self.solutions)
       
   250         myunion = stmts.Union()
       
   251         myunion.append(subselect)
       
   252         aliases = [n.VariableRef(self.select.get_variable(name, i))
       
   253                    for i, name in enumerate(aliases)]
       
   254         self.select.add_subquery(n.SubQuery(aliases, myunion), check=False)
       
   255         self._cleanup_inserted(transformedsnippet)
       
   256         try:
       
   257             self.compute_solutions()
       
   258         except Unsupported:
       
   259             # some solutions have been lost, can't apply this rql expr
       
   260             self.select.remove_subquery(new, undefine=True)
       
   261             raise
       
   262         return subselect
       
   263 
       
   264     def remove_ambiguities(self, snippets, newsolutions):
       
   265         # the snippet has introduced some ambiguities, we have to resolve them
       
   266         # "manually"
       
   267         variantes = self.build_variantes(newsolutions)
       
   268         # insert "is" where necessary
       
   269         varexistsmap = {}
       
   270         self.removing_ambiguity = True
       
   271         for (erqlexpr, varmap, oldvarname), etype in variantes[0].iteritems():
       
   272             varname = self.rewritten[(erqlexpr, varmap, oldvarname)]
       
   273             var = self.select.defined_vars[varname]
       
   274             exists = var.references()[0].scope
       
   275             exists.add_constant_restriction(var, 'is', etype, 'etype')
       
   276             varexistsmap[varmap] = exists
       
   277         # insert ORED exists where necessary
       
   278         for variante in variantes[1:]:
       
   279             self.insert_snippets(snippets, varexistsmap)
       
   280             for key, etype in variante.iteritems():
       
   281                 varname = self.rewritten[key]
       
   282                 try:
       
   283                     var = self.select.defined_vars[varname]
       
   284                 except KeyError:
       
   285                     # not a newly inserted variable
       
   286                     continue
       
   287                 exists = var.references()[0].scope
       
   288                 exists.add_constant_restriction(var, 'is', etype, 'etype')
       
   289         # recompute solutions
       
   290         #select.annotated = False # avoid assertion error
       
   291         self.compute_solutions()
       
   292         # clean solutions according to initial solutions
       
   293         return remove_solutions(self.solutions, self.select.solutions,
       
   294                                 self.select.defined_vars)
       
   295 
       
   296     def build_variantes(self, newsolutions):
       
   297         variantes = set()
       
   298         for sol in newsolutions:
       
   299             variante = []
       
   300             for key, newvar in self.rewritten.iteritems():
       
   301                 variante.append( (key, sol[newvar]) )
       
   302             variantes.add(tuple(variante))
       
   303         # rebuild variantes as dict
       
   304         variantes = [dict(variante) for variante in variantes]
       
   305         # remove variable which have always the same type
       
   306         for key in self.rewritten:
       
   307             it = iter(variantes)
       
   308             etype = it.next()[key]
       
   309             for variante in it:
       
   310                 if variante[key] != etype:
       
   311                     break
       
   312             else:
       
   313                 for variante in variantes:
       
   314                     del variante[key]
       
   315         return variantes
       
   316 
       
   317     def _cleanup_inserted(self, node):
       
   318         # cleanup inserted variable references
       
   319         for vref in node.iget_nodes(n.VariableRef):
       
   320             vref.unregister_reference()
       
   321             if not vref.variable.stinfo['references']:
       
   322                 # no more references, undefine the variable
       
   323                 del self.select.defined_vars[vref.name]
       
   324 
       
   325     def _may_be_shared(self, relation, target, searchedvarname):
       
   326         """return True if the snippet relation can be skipped to use a relation
       
   327         from the original query
       
   328         """
       
   329         # if cardinality is in '?1', we can ignore the relation and use variable
       
   330         # from the original query
       
   331         rschema = self.schema.rschema(relation.r_type)
       
   332         if target == 'object':
       
   333             cardindex = 0
       
   334             ttypes_func = rschema.objects
       
   335             rprop = rschema.rproperty
       
   336         else: # target == 'subject':
       
   337             cardindex = 1
       
   338             ttypes_func = rschema.subjects
       
   339             rprop = lambda x, y, z: rschema.rproperty(y, x, z)
       
   340         for etype in self.varinfo['stinfo']['possibletypes']:
       
   341             for ttype in ttypes_func(etype):
       
   342                 if rprop(etype, ttype, 'cardinality')[cardindex] in '+*':
       
   343                     return False
       
   344         return True
       
   345 
       
   346     def _use_outer_term(self, snippet_varname, term):
       
   347         key = (self.current_expr, self.varmap, snippet_varname)
       
   348         if key in self.rewritten:
       
   349             insertedvar = self.select.defined_vars.pop(self.rewritten[key])
       
   350             for inserted_vref in insertedvar.references():
       
   351                 inserted_vref.parent.replace(inserted_vref, term.copy(self.select))
       
   352         self.rewritten[key] = term.name
       
   353 
       
   354     def _get_varname_or_term(self, vname):
       
   355         if vname == 'U':
       
   356             if self.u_varname is None:
       
   357                 select = self.select
       
   358                 self.u_varname = select.allocate_varname()
       
   359                 # generate an identifier for the substitution
       
   360                 argname = select.allocate_varname()
       
   361                 while argname in self.kwargs:
       
   362                     argname = select.allocate_varname()
       
   363                 # insert "U eid %(u)s"
       
   364                 var = select.get_variable(self.u_varname)
       
   365                 select.add_constant_restriction(select.get_variable(self.u_varname),
       
   366                                                 'eid', unicode(argname), 'Substitute')
       
   367                 self.kwargs[argname] = self.session.user.eid
       
   368             return self.u_varname
       
   369         key = (self.current_expr, self.varmap, vname)
       
   370         try:
       
   371             return self.rewritten[key]
       
   372         except KeyError:
       
   373             self.rewritten[key] = newvname = self.select.allocate_varname()
       
   374             return newvname
       
   375 
       
   376     # visitor methods ##########################################################
       
   377 
       
   378     def _visit_binary(self, node, cls):
       
   379         newnode = cls()
       
   380         for c in node.children:
       
   381             new = c.accept(self)
       
   382             if new is None:
       
   383                 continue
       
   384             newnode.append(new)
       
   385         if len(newnode.children) == 0:
       
   386             return None
       
   387         if len(newnode.children) == 1:
       
   388             return newnode.children[0]
       
   389         return newnode
       
   390 
       
   391     def _visit_unary(self, node, cls):
       
   392         newc = node.children[0].accept(self)
       
   393         if newc is None:
       
   394             return None
       
   395         newnode = cls()
       
   396         newnode.append(newc)
       
   397         return newnode
       
   398 
       
   399     def visit_and(self, node):
       
   400         return self._visit_binary(node, n.And)
       
   401 
       
   402     def visit_or(self, node):
       
   403         return self._visit_binary(node, n.Or)
       
   404 
       
   405     def visit_not(self, node):
       
   406         return self._visit_unary(node, n.Not)
       
   407 
       
   408     def visit_exists(self, node):
       
   409         return self._visit_unary(node, n.Exists)
       
   410 
       
   411     def visit_relation(self, node):
       
   412         lhs, rhs = node.get_variable_parts()
       
   413         if node.r_type in ('has_add_permission', 'has_update_permission',
       
   414                            'has_delete_permission', 'has_read_permission'):
       
   415             assert lhs.name == 'U'
       
   416             action = node.r_type.split('_')[1]
       
   417             key = (self.current_expr, self.varmap, rhs.name)
       
   418             self.pending_keys.append( (key, action) )
       
   419             return
       
   420         if lhs.name in self.revvarmap:
       
   421             # on lhs
       
   422             # see if we can reuse this relation
       
   423             rels = self.varinfo['lhs_rels']
       
   424             if (node.r_type in rels and isinstance(rhs, n.VariableRef)
       
   425                 and rhs.name != 'U' and not rels[node.r_type].neged(strict=True)
       
   426                 and self._may_be_shared(node, 'object', lhs.name)):
       
   427                 # ok, can share variable
       
   428                 term = rels[node.r_type].children[1].children[0]
       
   429                 self._use_outer_term(rhs.name, term)
       
   430                 return
       
   431         elif isinstance(rhs, n.VariableRef) and rhs.name in self.revvarmap and lhs.name != 'U':
       
   432             # on rhs
       
   433             # see if we can reuse this relation
       
   434             rels = self.varinfo['rhs_rels']
       
   435             if (node.r_type in rels and not rels[node.r_type].neged(strict=True)
       
   436                 and self._may_be_shared(node, 'subject', rhs.name)):
       
   437                 # ok, can share variable
       
   438                 term = rels[node.r_type].children[0]
       
   439                 self._use_outer_term(lhs.name, term)
       
   440                 return
       
   441         rel = n.Relation(node.r_type, node.optional)
       
   442         for c in node.children:
       
   443             rel.append(c.accept(self))
       
   444         return rel
       
   445 
       
   446     def visit_comparison(self, node):
       
   447         cmp_ = n.Comparison(node.operator)
       
   448         for c in node.children:
       
   449             cmp_.append(c.accept(self))
       
   450         return cmp_
       
   451 
       
   452     def visit_mathexpression(self, node):
       
   453         cmp_ = n.MathExpression(node.operator)
       
   454         for c in cmp.children:
       
   455             cmp_.append(c.accept(self))
       
   456         return cmp_
       
   457 
       
   458     def visit_function(self, node):
       
   459         """generate filter name for a function"""
       
   460         function_ = n.Function(node.name)
       
   461         for c in node.children:
       
   462             function_.append(c.accept(self))
       
   463         return function_
       
   464 
       
   465     def visit_constant(self, node):
       
   466         """generate filter name for a constant"""
       
   467         return n.Constant(node.value, node.type)
       
   468 
       
   469     def visit_variableref(self, node):
       
   470         """get the sql name for a variable reference"""
       
   471         if node.name in self.revvarmap:
       
   472             if self.varinfo.get('const') is not None:
       
   473                 return n.Constant(self.varinfo['const'], 'Int') # XXX gae
       
   474             return n.VariableRef(self.select.get_variable(
       
   475                 self.revvarmap[node.name]))
       
   476         vname_or_term = self._get_varname_or_term(node.name)
       
   477         if isinstance(vname_or_term, basestring):
       
   478             return n.VariableRef(self.select.get_variable(vname_or_term))
       
   479         # shared term
       
   480         return vname_or_term.copy(self.select)