cubicweb/rqlrewrite.py
branch3.25
changeset 12087 9f668acfa6c3
parent 12086 39c9e548f0ce
child 12089 54b518367617
equal deleted inserted replaced
12086:39c9e548f0ce 12087:9f668acfa6c3
    19 tree.
    19 tree.
    20 
    20 
    21 This is used for instance for read security checking in the repository.
    21 This is used for instance for read security checking in the repository.
    22 """
    22 """
    23 
    23 
    24 
       
    25 from six import text_type, string_types
    24 from six import text_type, string_types
    26 
    25 
    27 from rql import nodes as n, stmts, TypeResolverException
    26 from rql import nodes as n, stmts, TypeResolverException
    28 from rql.utils import common_parent
    27 from rql.utils import common_parent
    29 
    28 
    31 
    30 
    32 from logilab.common import tempattr
    31 from logilab.common import tempattr
    33 from logilab.common.graph import has_path
    32 from logilab.common.graph import has_path
    34 
    33 
    35 from cubicweb import Unauthorized
    34 from cubicweb import Unauthorized
    36 from cubicweb.schema import RRQLExpression
    35 
    37 
    36 
    38 def cleanup_solutions(rqlst, solutions):
    37 def cleanup_solutions(rqlst, solutions):
    39     for sol in solutions:
    38     for sol in solutions:
    40         for vname in list(sol):
    39         for vname in list(sol):
    41             if not (vname in rqlst.defined_vars or vname in rqlst.aliases):
    40             if not (vname in rqlst.defined_vars or vname in rqlst.aliases):
    64     # XXX could be factorized with add_etypes_restriction from rql 0.31
    63     # XXX could be factorized with add_etypes_restriction from rql 0.31
    65     for varname in sorted(allpossibletypes):
    64     for varname in sorted(allpossibletypes):
    66         var = newroot.defined_vars[varname]
    65         var = newroot.defined_vars[varname]
    67         stinfo = var.stinfo
    66         stinfo = var.stinfo
    68         if stinfo.get('uidrel') is not None:
    67         if stinfo.get('uidrel') is not None:
    69             continue # eid specified, no need for additional type specification
    68             continue  # eid specified, no need for additional type specification
    70         try:
    69         try:
    71             typerel = rqlst.defined_vars[varname].stinfo.get('typerel')
    70             typerel = rqlst.defined_vars[varname].stinfo.get('typerel')
    72         except KeyError:
    71         except KeyError:
    73             assert varname in rqlst.aliases
    72             assert varname in rqlst.aliases
    74             continue
    73             continue
   101             else:
   100             else:
   102                 # variable has already some strict types restriction. new
   101                 # variable has already some strict types restriction. new
   103                 # possible types can only be a subset of existing ones, so only
   102                 # possible types can only be a subset of existing ones, so only
   104                 # remove no more possible types
   103                 # remove no more possible types
   105                 for cst in mytyperel.get_nodes(n.Constant):
   104                 for cst in mytyperel.get_nodes(n.Constant):
   106                     if not cst.value in possibletypes:
   105                     if cst.value not in possibletypes:
   107                         cst.parent.remove(cst)
   106                         cst.parent.remove(cst)
   108         else:
   107         else:
   109             # we have to add types restriction
   108             # we have to add types restriction
   110             if stinfo.get('scope') is not None:
   109             if stinfo.get('scope') is not None:
   111                 rel = var.scope.add_type_restriction(var, possibletypes)
   110                 rel = var.scope.add_type_restriction(var, possibletypes)
   157 
   156 
   158 
   157 
   159 def _expand_selection(terms, selected, aliases, select, newselect):
   158 def _expand_selection(terms, selected, aliases, select, newselect):
   160     for term in terms:
   159     for term in terms:
   161         for vref in term.iget_nodes(n.VariableRef):
   160         for vref in term.iget_nodes(n.VariableRef):
   162             if not vref.name in selected:
   161             if vref.name not in selected:
   163                 select.append_selected(vref)
   162                 select.append_selected(vref)
   164                 colalias = newselect.get_variable(vref.name, len(aliases))
   163                 colalias = newselect.get_variable(vref.name, len(aliases))
   165                 aliases.append(n.VariableRef(colalias))
   164                 aliases.append(n.VariableRef(colalias))
   166                 selected.add(vref.name)
   165                 selected.add(vref.name)
   167 
   166 
   173     for etype in etypes:
   172     for etype in etypes:
   174         for ttype in ttypes_func(etype):
   173         for ttype in ttypes_func(etype):
   175             if rdef(etype, ttype).cardinality[cardindex] in '+*':
   174             if rdef(etype, ttype).cardinality[cardindex] in '+*':
   176                 return True
   175                 return True
   177     return False
   176     return False
       
   177 
   178 
   178 
   179 def _compatible_relation(relations, stmt, sniprel):
   179 def _compatible_relation(relations, stmt, sniprel):
   180     """Search among given rql relation nodes if there is one 'compatible' with the
   180     """Search among given rql relation nodes if there is one 'compatible' with the
   181     snippet relation, and return it if any, else None.
   181     snippet relation, and return it if any, else None.
   182 
   182 
   207 
   207 
   208 class Unsupported(Exception):
   208 class Unsupported(Exception):
   209     """raised when an rql expression can't be inserted in some rql query
   209     """raised when an rql expression can't be inserted in some rql query
   210     because it create an unresolvable query (eg no solutions found)
   210     because it create an unresolvable query (eg no solutions found)
   211     """
   211     """
       
   212 
   212 
   213 
   213 class VariableFromSubQuery(Exception):
   214 class VariableFromSubQuery(Exception):
   214     """flow control exception to indicate that a variable is coming from a
   215     """flow control exception to indicate that a variable is coming from a
   215     subquery, and let parent act accordingly
   216     subquery, and let parent act accordingly
   216     """
   217     """
   300                             newvar = newselect.get_variable(var.name)
   301                             newvar = newselect.get_variable(var.name)
   301                             newvar.stinfo.setdefault('ftirels', set()).add(rel)
   302                             newvar.stinfo.setdefault('ftirels', set()).add(rel)
   302                             newvar.stinfo.setdefault('relations', set()).add(rel)
   303                             newvar.stinfo.setdefault('relations', set()).add(rel)
   303                 newselect.set_orderby(sortterms)
   304                 newselect.set_orderby(sortterms)
   304                 _expand_selection(select.orderby, selected, aliases, select, newselect)
   305                 _expand_selection(select.orderby, selected, aliases, select, newselect)
   305                 select.orderby = () # XXX dereference?
   306                 select.orderby = ()  # XXX dereference?
   306             if select.groupby:
   307             if select.groupby:
   307                 newselect.set_groupby([g.copy(newselect) for g in select.groupby])
   308                 newselect.set_groupby([g.copy(newselect) for g in select.groupby])
   308                 _expand_selection(select.groupby, selected, aliases, select, newselect)
   309                 _expand_selection(select.groupby, selected, aliases, select, newselect)
   309                 select.groupby = () # XXX dereference?
   310                 select.groupby = ()  # XXX dereference?
   310             if select.having:
   311             if select.having:
   311                 newselect.set_having([g.copy(newselect) for g in select.having])
   312                 newselect.set_having([g.copy(newselect) for g in select.having])
   312                 _expand_selection(select.having, selected, aliases, select, newselect)
   313                 _expand_selection(select.having, selected, aliases, select, newselect)
   313                 select.having = () # XXX dereference?
   314                 select.having = ()  # XXX dereference?
   314             if select.limit:
   315             if select.limit:
   315                 newselect.limit = select.limit
   316                 newselect.limit = select.limit
   316                 select.limit = None
   317                 select.limit = None
   317             if select.offset:
   318             if select.offset:
   318                 newselect.offset = select.offset
   319                 newselect.offset = select.offset
   374         assert len(newsolutions) >= len(solutions), (
   375         assert len(newsolutions) >= len(solutions), (
   375             'rewritten rql %s has lost some solutions, there is probably '
   376             'rewritten rql %s has lost some solutions, there is probably '
   376             'something wrong in your schema permission (for instance using a '
   377             'something wrong in your schema permission (for instance using a '
   377             'RQLExpression which inserts a relation which doesn\'t exist in '
   378             'RQLExpression which inserts a relation which doesn\'t exist in '
   378             'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
   379             'the schema)\nOrig solutions: %s\nnew solutions: %s' % (
   379             select, solutions, newsolutions))
   380                 select, solutions, newsolutions))
   380         if len(newsolutions) > len(solutions):
   381         if len(newsolutions) > len(solutions):
   381             newsolutions = self.remove_ambiguities(snippets, newsolutions)
   382             newsolutions = self.remove_ambiguities(snippets, newsolutions)
   382             assert newsolutions
   383             assert newsolutions
   383         select.solutions = newsolutions
   384         select.solutions = newsolutions
   384         add_types_restriction(self.schema, select)
   385         add_types_restriction(self.schema, select)
   388         for varmap, rqlexprs in snippets:
   389         for varmap, rqlexprs in snippets:
   389             if isinstance(varmap, dict):
   390             if isinstance(varmap, dict):
   390                 varmap = tuple(sorted(varmap.items()))
   391                 varmap = tuple(sorted(varmap.items()))
   391             else:
   392             else:
   392                 assert isinstance(varmap, tuple), varmap
   393                 assert isinstance(varmap, tuple), varmap
   393             if varexistsmap is not None and not varmap in varexistsmap:
   394             if varexistsmap is not None and varmap not in varexistsmap:
   394                 continue
   395                 continue
   395             self.insert_varmap_snippets(varmap, rqlexprs, varexistsmap)
   396             self.insert_varmap_snippets(varmap, rqlexprs, varexistsmap)
   396 
   397 
   397     def init_from_varmap(self, varmap, varexistsmap=None):
   398     def init_from_varmap(self, varmap, varexistsmap=None):
   398         self.varmap = varmap
   399         self.varmap = varmap
   416                     vi['rhs_rels'] = {}
   417                     vi['rhs_rels'] = {}
   417                     for rel in sti.get('rhsrelations', []):
   418                     for rel in sti.get('rhsrelations', []):
   418                         vi['rhs_rels'].setdefault(rel.r_type, []).append(rel)
   419                         vi['rhs_rels'].setdefault(rel.r_type, []).append(rel)
   419                     vi['lhs_rels'] = {}
   420                     vi['lhs_rels'] = {}
   420                     for rel in sti.get('relations', []):
   421                     for rel in sti.get('relations', []):
   421                         if not rel in sti.get('rhsrelations', []):
   422                         if rel not in sti.get('rhsrelations', []):
   422                             vi['lhs_rels'].setdefault(rel.r_type, []).append(rel)
   423                             vi['lhs_rels'].setdefault(rel.r_type, []).append(rel)
   423                 else:
   424                 else:
   424                     vi['rhs_rels'] = vi['lhs_rels'] = {}
   425                     vi['rhs_rels'] = vi['lhs_rels'] = {}
   425 
   426 
   426     def _subquery_variable(self, selectvar):
   427     def _subquery_variable(self, selectvar):
   458                 exists = varexistsmap[varmap]
   459                 exists = varexistsmap[varmap]
   459                 if self.exists_snippet.get(rqlexpr) is exists:
   460                 if self.exists_snippet.get(rqlexpr) is exists:
   460                     self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists)
   461                     self.insert_snippet(varmap, rqlexpr.snippet_rqlst, exists)
   461         if varexistsmap is None and not inserted:
   462         if varexistsmap is None and not inserted:
   462             # no rql expression found matching rql solutions. User has no access right
   463             # no rql expression found matching rql solutions. User has no access right
   463             raise Unauthorized() # XXX may also be because of bad constraints in schema definition
   464             raise Unauthorized()  # XXX may also be because of bad constraints in schema definition
   464 
   465 
   465     def insert_snippet(self, varmap, snippetrqlst, previous=None):
   466     def insert_snippet(self, varmap, snippetrqlst, previous=None):
   466         new = snippetrqlst.where.accept(self)
   467         new = snippetrqlst.where.accept(self)
   467         existing = self.existingvars
   468         existing = self.existingvars
   468         self.existingvars = None
   469         self.existingvars = None
   496             if self._insert_scope is None and any(vi.get('stinfo', {}).get('optrelations')
   497             if self._insert_scope is None and any(vi.get('stinfo', {}).get('optrelations')
   497                                                   for vi in self.varinfos):
   498                                                   for vi in self.varinfos):
   498                 assert previous is None
   499                 assert previous is None
   499                 self._insert_scope, new = self.snippet_subquery(varmap, new)
   500                 self._insert_scope, new = self.snippet_subquery(varmap, new)
   500                 self.insert_pending()
   501                 self.insert_pending()
   501                 #self._insert_scope = None
       
   502                 return new
   502                 return new
   503             new = self._inserted_root(new)
   503             new = self._inserted_root(new)
   504             if previous is None:
   504             if previous is None:
   505                 insert_scope.add_restriction(new)
   505                 insert_scope.add_restriction(new)
   506             else:
   506             else:
   546             ptypes = stmt.defined_vars[varname].stinfo['possibletypes']
   546             ptypes = stmt.defined_vars[varname].stinfo['possibletypes']
   547             if len(ptypes) > 1:
   547             if len(ptypes) > 1:
   548                 # XXX dunno how to handle this
   548                 # XXX dunno how to handle this
   549                 self.session.error(
   549                 self.session.error(
   550                     'cant check security of %s, ambigous type for %s in %s',
   550                     'cant check security of %s, ambigous type for %s in %s',
   551                     stmt, varname, key[0]) # key[0] == the rql expression
   551                     stmt, varname, key[0])  # key[0] == the rql expression
   552                 raise Unauthorized()
   552                 raise Unauthorized()
   553             etype = next(iter(ptypes))
   553             etype = next(iter(ptypes))
   554             eschema = self.schema.eschema(etype)
   554             eschema = self.schema.eschema(etype)
   555             if not eschema.has_perm(self.session, action):
   555             if not eschema.has_perm(self.session, action):
   556                 rqlexprs = eschema.get_rqlexprs(action)
   556                 rqlexprs = eschema.get_rqlexprs(action)
   579                         continue
   579                         continue
   580                     done.add(rel)
   580                     done.add(rel)
   581                     rschema = get_rschema(rel.r_type)
   581                     rschema = get_rschema(rel.r_type)
   582                     if rschema.final or rschema.inlined:
   582                     if rschema.final or rschema.inlined:
   583                         subselect_vrefs = []
   583                         subselect_vrefs = []
   584                         rel.children[0].name = varname # XXX explain why
   584                         rel.children[0].name = varname  # XXX explain why
   585                         subselect.add_restriction(rel.copy(subselect))
   585                         subselect.add_restriction(rel.copy(subselect))
   586                         for vref in rel.children[1].iget_nodes(n.VariableRef):
   586                         for vref in rel.children[1].iget_nodes(n.VariableRef):
   587                             if isinstance(vref.variable, n.ColumnAlias):
   587                             if isinstance(vref.variable, n.ColumnAlias):
   588                                 # XXX could probably be handled by generating the
   588                                 # XXX could probably be handled by generating the
   589                                 # subquery into the detected subquery
   589                                 # subquery into the detected subquery
   609                                 # we can use vref here define in above for loop
   609                                 # we can use vref here define in above for loop
   610                                 ostinfo = vref.variable.stinfo
   610                                 ostinfo = vref.variable.stinfo
   611                                 for orel in iter_relations(ostinfo):
   611                                 for orel in iter_relations(ostinfo):
   612                                     orschema = get_rschema(orel.r_type)
   612                                     orschema = get_rschema(orel.r_type)
   613                                     if orschema.final or orschema.inlined:
   613                                     if orschema.final or orschema.inlined:
   614                                         todo.append( (vref.name, ostinfo) )
   614                                         todo.append((vref.name, ostinfo))
   615                                         break
   615                                         break
   616             if need_null_test:
   616             if need_null_test:
   617                 snippetrqlst = n.Or(
   617                 snippetrqlst = n.Or(
   618                     n.make_relation(subselect.get_variable(selectvar), 'is',
   618                     n.make_relation(subselect.get_variable(selectvar), 'is',
   619                                     (None, None), n.Constant,
   619                                     (None, None), n.Constant,
   678     def build_variantes(self, newsolutions):
   678     def build_variantes(self, newsolutions):
   679         variantes = set()
   679         variantes = set()
   680         for sol in newsolutions:
   680         for sol in newsolutions:
   681             variante = []
   681             variante = []
   682             for key, newvar in self.rewritten.items():
   682             for key, newvar in self.rewritten.items():
   683                 variante.append( (key, sol[newvar]) )
   683                 variante.append((key, sol[newvar]))
   684             variantes.add(tuple(variante))
   684             variantes.add(tuple(variante))
   685         # rebuild variantes as dict
   685         # rebuild variantes as dict
   686         variantes = [dict(variante) for variante in variantes]
   686         variantes = [dict(variante) for variante in variantes]
   687         # remove variable which have always the same type
   687         # remove variable which have always the same type
   688         for key in self.rewritten:
   688         for key in self.rewritten:
   707                 removed.add(vref.name)
   707                 removed.add(vref.name)
   708         for key, newvar in list(self.rewritten.items()):
   708         for key, newvar in list(self.rewritten.items()):
   709             if newvar in removed:
   709             if newvar in removed:
   710                 del self.rewritten[key]
   710                 del self.rewritten[key]
   711 
   711 
   712 
       
   713     def _may_be_shared_with(self, sniprel, target):
   712     def _may_be_shared_with(self, sniprel, target):
   714         """if the snippet relation can be skipped to use a relation from the
   713         """if the snippet relation can be skipped to use a relation from the
   715         original query, return that relation node
   714         original query, return that relation node
   716         """
   715         """
   717         if sniprel.neged(strict=True):
   716         if sniprel.neged(strict=True):
   718             return None # no way
   717             return None  # no way
   719         rschema = self.schema.rschema(sniprel.r_type)
   718         rschema = self.schema.rschema(sniprel.r_type)
   720         stmt = self.current_statement()
   719         stmt = self.current_statement()
   721         for vi in self.varinfos:
   720         for vi in self.varinfos:
   722             try:
   721             try:
   723                 if target == 'object':
   722                 if target == 'object':
   724                     orels = vi['lhs_rels'][sniprel.r_type]
   723                     orels = vi['lhs_rels'][sniprel.r_type]
   725                     cardindex = 0
   724                     cardindex = 0
   726                     ttypes_func = rschema.objects
   725                     ttypes_func = rschema.objects
   727                     rdef = rschema.rdef
   726                     rdef = rschema.rdef
   728                 else: # target == 'subject':
   727                 else:  # target == 'subject':
   729                     orels = vi['rhs_rels'][sniprel.r_type]
   728                     orels = vi['rhs_rels'][sniprel.r_type]
   730                     cardindex = 1
   729                     cardindex = 1
   731                     ttypes_func = rschema.subjects
   730                     ttypes_func = rschema.subjects
   732                     rdef = lambda x, y: rschema.rdef(y, x)
   731 
       
   732                     def rdef(x, y):
       
   733                         return rschema.rdef(y, x)
   733             except KeyError:
   734             except KeyError:
   734                 # may be raised by vi['xhs_rels'][sniprel.r_type]
   735                 # may be raised by vi['xhs_rels'][sniprel.r_type]
   735                 continue
   736                 continue
   736             # if cardinality isn't in '?1', we can't ignore the snippet relation
   737             # if cardinality isn't in '?1', we can't ignore the snippet relation
   737             # and use variable from the original query
   738             # and use variable from the original query
   815             return varname in self.existingvars
   816             return varname in self.existingvars
   816         if varname == 'U':
   817         if varname == 'U':
   817             return True
   818             return True
   818         vargraph = self.current_expr.vargraph
   819         vargraph = self.current_expr.vargraph
   819         for existingvar in self.existingvars:
   820         for existingvar in self.existingvars:
   820             #path = has_path(vargraph, varname, existingvar)
   821             if varname not in vargraph or has_path(vargraph, varname, existingvar):
   821             if not varname in vargraph or has_path(vargraph, varname, existingvar):
       
   822                 return True
   822                 return True
   823         # no path from this variable to an existing variable
   823         # no path from this variable to an existing variable
   824         return False
   824         return False
   825 
   825 
   826     def visit_relation(self, node):
   826     def visit_relation(self, node):
   833         if node.r_type in ('has_add_permission', 'has_update_permission',
   833         if node.r_type in ('has_add_permission', 'has_update_permission',
   834                            'has_delete_permission', 'has_read_permission'):
   834                            'has_delete_permission', 'has_read_permission'):
   835             assert lhs.name == 'U'
   835             assert lhs.name == 'U'
   836             action = node.r_type.split('_')[1]
   836             action = node.r_type.split('_')[1]
   837             key = (self.current_expr, self.varmap, rhs.name)
   837             key = (self.current_expr, self.varmap, rhs.name)
   838             self.pending_keys.append( (key, action) )
   838             self.pending_keys.append((key, action))
   839             return
   839             return
   840         if isinstance(rhs, n.VariableRef):
   840         if isinstance(rhs, n.VariableRef):
   841             if self.existingvars and not self.keep_var(rhs.name):
   841             if self.existingvars and not self.keep_var(rhs.name):
   842                 return
   842                 return
   843             if lhs.name in self.revvarmap and rhs.name != 'U':
   843             if lhs.name in self.revvarmap and rhs.name != 'U':
   913         self.pending_keys = None
   913         self.pending_keys = None
   914         rules = self.schema.rules_rqlexpr_mapping
   914         rules = self.schema.rules_rqlexpr_mapping
   915         for relation in union.iget_nodes(n.Relation):
   915         for relation in union.iget_nodes(n.Relation):
   916             if relation.r_type in rules:
   916             if relation.r_type in rules:
   917                 self.select = relation.stmt
   917                 self.select = relation.stmt
   918                 self.solutions = solutions = self.select.solutions[:]
   918                 self.solutions = self.select.solutions[:]
   919                 self.current_expr = rules[relation.r_type]
   919                 self.current_expr = rules[relation.r_type]
   920                 self._insert_scope = relation.scope
   920                 self._insert_scope = relation.scope
   921                 self.rewritten = {}
   921                 self.rewritten = {}
   922                 lhs, rhs = relation.get_variable_parts()
   922                 lhs, rhs = relation.get_variable_parts()
   923                 varmap = {lhs.name: 'S', rhs.name: 'O'}
   923                 varmap = {lhs.name: 'S', rhs.name: 'O'}