rqlrewrite.py
branchstable
changeset 9167 c05652b108ce
parent 8748 f5027f8d2478
child 9169 544b22a3485b
equal deleted inserted replaced
9166:e47e192ea0d9 9167:c05652b108ce
     1 # copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
     1 # copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
     2 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
     2 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
     3 #
     3 #
     4 # This file is part of CubicWeb.
     4 # This file is part of CubicWeb.
     5 #
     5 #
     6 # CubicWeb is free software: you can redistribute it and/or modify it under the
     6 # CubicWeb is free software: you can redistribute it and/or modify it under the
    29 
    29 
    30 from logilab.common import tempattr
    30 from logilab.common import tempattr
    31 from logilab.common.graph import has_path
    31 from logilab.common.graph import has_path
    32 
    32 
    33 from cubicweb import Unauthorized
    33 from cubicweb import Unauthorized
       
    34 
       
    35 
       
    36 def cleanup_solutions(rqlst, solutions):
       
    37     for sol in solutions:
       
    38         for vname in list(sol):
       
    39             if not (vname in rqlst.defined_vars or vname in rqlst.aliases):
       
    40                 del sol[vname]
    34 
    41 
    35 
    42 
    36 def add_types_restriction(schema, rqlst, newroot=None, solutions=None):
    43 def add_types_restriction(schema, rqlst, newroot=None, solutions=None):
    37     if newroot is None:
    44     if newroot is None:
    38         assert solutions is None
    45         assert solutions is None
   130                 newsolutions.append(newsol)
   137                 newsolutions.append(newsol)
   131                 solutions.remove(newsol)
   138                 solutions.remove(newsol)
   132     return newsolutions
   139     return newsolutions
   133 
   140 
   134 
   141 
       
   142 def _add_noinvariant(noinvariant, restricted, select, nbtrees):
       
   143     # a variable can actually be invariant if it has not been restricted for
       
   144     # security reason or if security assertion hasn't modified the possible
       
   145     # solutions for the query
       
   146     for vname in restricted:
       
   147         try:
       
   148             var = select.defined_vars[vname]
       
   149         except KeyError:
       
   150             # this is an alias
       
   151             continue
       
   152         if nbtrees != 1 or len(var.stinfo['possibletypes']) != 1:
       
   153             noinvariant.add(var)
       
   154 
       
   155 
       
   156 def _expand_selection(terms, selected, aliases, select, newselect):
       
   157     for term in terms:
       
   158         for vref in term.iget_nodes(n.VariableRef):
       
   159             if not vref.name in selected:
       
   160                 select.append_selected(vref)
       
   161                 colalias = newselect.get_variable(vref.name, len(aliases))
       
   162                 aliases.append(n.VariableRef(colalias))
       
   163                 selected.add(vref.name)
       
   164 
       
   165 
   135 def iter_relations(stinfo):
   166 def iter_relations(stinfo):
   136     # this is a function so that test may return relation in a predictable order
   167     # this is a function so that test may return relation in a predictable order
   137     return stinfo['relations'] - stinfo['rhsrelations']
   168     return stinfo['relations'] - stinfo['rhsrelations']
       
   169 
   138 
   170 
   139 class Unsupported(Exception):
   171 class Unsupported(Exception):
   140     """raised when an rql expression can't be inserted in some rql query
   172     """raised when an rql expression can't be inserted in some rql query
   141     because it create an unresolvable query (eg no solutions found)
   173     because it create an unresolvable query (eg no solutions found)
   142     """
   174     """
   161             self._compute_solutions(self.session, self.select, self.kwargs)
   193             self._compute_solutions(self.session, self.select, self.kwargs)
   162         except TypeResolverException:
   194         except TypeResolverException:
   163             raise Unsupported(str(self.select))
   195             raise Unsupported(str(self.select))
   164         if len(self.select.solutions) < len(self.solutions):
   196         if len(self.select.solutions) < len(self.solutions):
   165             raise Unsupported()
   197             raise Unsupported()
       
   198 
       
   199     def insert_local_checks(self, select, kwargs,
       
   200                             localchecks, restricted, noinvariant):
       
   201         """
       
   202         select: the rql syntax tree Select node
       
   203         kwargs: query arguments
       
   204 
       
   205         localchecks: {(('Var name', (rqlexpr1, rqlexpr2)),
       
   206                        ('Var name1', (rqlexpr1, rqlexpr23))): [solution]}
       
   207 
       
   208               (see querier._check_permissions docstring for more information)
       
   209 
       
   210         restricted: set of variable names to which an rql expression has to be
       
   211               applied
       
   212 
       
   213         noinvariant: set of variable names that can't be considered has
       
   214               invariant due to security reason (will be filed by this method)
       
   215         """
       
   216         nbtrees = len(localchecks)
       
   217         myunion = union = select.parent
       
   218         # transform in subquery when len(localchecks)>1 and groups
       
   219         if nbtrees > 1 and (select.orderby or select.groupby or
       
   220                             select.having or select.has_aggregat or
       
   221                             select.distinct or
       
   222                             select.limit or select.offset):
       
   223             newselect = stmts.Select()
       
   224             # only select variables in subqueries
       
   225             origselection = select.selection
       
   226             select.select_only_variables()
       
   227             select.has_aggregat = False
       
   228             # create subquery first so correct node are used on copy
       
   229             # (eg ColumnAlias instead of Variable)
       
   230             aliases = [n.VariableRef(newselect.get_variable(vref.name, i))
       
   231                        for i, vref in enumerate(select.selection)]
       
   232             selected = set(vref.name for vref in aliases)
       
   233             # now copy original selection and groups
       
   234             for term in origselection:
       
   235                 newselect.append_selected(term.copy(newselect))
       
   236             if select.orderby:
       
   237                 sortterms = []
       
   238                 for sortterm in select.orderby:
       
   239                     sortterms.append(sortterm.copy(newselect))
       
   240                     for fnode in sortterm.get_nodes(n.Function):
       
   241                         if fnode.name == 'FTIRANK':
       
   242                             # we've to fetch the has_text relation as well
       
   243                             var = fnode.children[0].variable
       
   244                             rel = iter(var.stinfo['ftirels']).next()
       
   245                             assert not rel.ored(), 'unsupported'
       
   246                             newselect.add_restriction(rel.copy(newselect))
       
   247                             # remove relation from the orig select and
       
   248                             # cleanup variable stinfo
       
   249                             rel.parent.remove(rel)
       
   250                             var.stinfo['ftirels'].remove(rel)
       
   251                             var.stinfo['relations'].remove(rel)
       
   252                             # XXX not properly re-annotated after security insertion?
       
   253                             newvar = newselect.get_variable(var.name)
       
   254                             newvar.stinfo.setdefault('ftirels', set()).add(rel)
       
   255                             newvar.stinfo.setdefault('relations', set()).add(rel)
       
   256                 newselect.set_orderby(sortterms)
       
   257                 _expand_selection(select.orderby, selected, aliases, select, newselect)
       
   258                 select.orderby = () # XXX dereference?
       
   259             if select.groupby:
       
   260                 newselect.set_groupby([g.copy(newselect) for g in select.groupby])
       
   261                 _expand_selection(select.groupby, selected, aliases, select, newselect)
       
   262                 select.groupby = () # XXX dereference?
       
   263             if select.having:
       
   264                 newselect.set_having([g.copy(newselect) for g in select.having])
       
   265                 _expand_selection(select.having, selected, aliases, select, newselect)
       
   266                 select.having = () # XXX dereference?
       
   267             if select.limit:
       
   268                 newselect.limit = select.limit
       
   269                 select.limit = None
       
   270             if select.offset:
       
   271                 newselect.offset = select.offset
       
   272                 select.offset = 0
       
   273             myunion = stmts.Union()
       
   274             newselect.set_with([n.SubQuery(aliases, myunion)], check=False)
       
   275             newselect.distinct = select.distinct
       
   276             solutions = [sol.copy() for sol in select.solutions]
       
   277             cleanup_solutions(newselect, solutions)
       
   278             newselect.set_possible_types(solutions)
       
   279             # if some solutions doesn't need rewriting, insert original
       
   280             # select as first union subquery
       
   281             if () in localchecks:
       
   282                 myunion.append(select)
       
   283             # we're done, replace original select by the new select with
       
   284             # subqueries (more added in the loop below)
       
   285             union.replace(select, newselect)
       
   286         elif not () in localchecks:
       
   287             union.remove(select)
       
   288         for lcheckdef, lchecksolutions in localchecks.iteritems():
       
   289             if not lcheckdef:
       
   290                 continue
       
   291             myrqlst = select.copy(solutions=lchecksolutions)
       
   292             myunion.append(myrqlst)
       
   293             # in-place rewrite + annotation / simplification
       
   294             lcheckdef = [({var: 'X'}, rqlexprs) for var, rqlexprs in lcheckdef]
       
   295             self.rewrite(myrqlst, lcheckdef, lchecksolutions, kwargs)
       
   296             _add_noinvariant(noinvariant, restricted, myrqlst, nbtrees)
       
   297         if () in localchecks:
       
   298             select.set_possible_types(localchecks[()])
       
   299             add_types_restriction(self.schema, select)
       
   300             _add_noinvariant(noinvariant, restricted, select, nbtrees)
       
   301         self.annotate(union)
   166 
   302 
   167     def rewrite(self, select, snippets, solutions, kwargs, existingvars=None):
   303     def rewrite(self, select, snippets, solutions, kwargs, existingvars=None):
   168         """
   304         """
   169         snippets: (varmap, list of rql expression)
   305         snippets: (varmap, list of rql expression)
   170                   with varmap a *tuple* (select var, snippet var)
   306                   with varmap a *tuple* (select var, snippet var)