server/rqlrewrite.py
branchtls-sprint
changeset 1802 d628defebc17
parent 1138 22f634977c95
child 1977 606923dff11b
equal deleted inserted replaced
1801:672acc730ce5 1802:d628defebc17
    32                 newsolutions.append(newsol)
    32                 newsolutions.append(newsol)
    33                 solutions.remove(newsol)
    33                 solutions.remove(newsol)
    34     return newsolutions
    34     return newsolutions
    35 
    35 
    36 class Unsupported(Exception): pass
    36 class Unsupported(Exception): pass
    37         
    37 
    38 class RQLRewriter(object):
    38 class RQLRewriter(object):
    39     """insert some rql snippets into another rql syntax tree"""
    39     """insert some rql snippets into another rql syntax tree"""
    40     def __init__(self, querier, session):
    40     def __init__(self, querier, session):
    41         self.session = session
    41         self.session = session
    42         self.annotate = querier._rqlhelper.annotate
    42         self.annotate = querier._rqlhelper.annotate
    49             self._compute_solutions(self.session, self.select, self.kwargs)
    49             self._compute_solutions(self.session, self.select, self.kwargs)
    50         except TypeResolverException:
    50         except TypeResolverException:
    51             raise Unsupported()
    51             raise Unsupported()
    52         if len(self.select.solutions) < len(self.solutions):
    52         if len(self.select.solutions) < len(self.solutions):
    53             raise Unsupported()
    53             raise Unsupported()
    54         
    54 
    55     def rewrite(self, select, snippets, solutions, kwargs):
    55     def rewrite(self, select, snippets, solutions, kwargs):
    56         if server.DEBUG:
    56         if server.DEBUG:
    57             print '---- rewrite', select, snippets, solutions
    57             print '---- rewrite', select, snippets, solutions
    58         self.select = select
    58         self.select = select
    59         self.solutions = solutions
    59         self.solutions = solutions
   110                                             select.defined_vars)
   110                                             select.defined_vars)
   111         select.solutions = newsolutions
   111         select.solutions = newsolutions
   112         add_types_restriction(self.schema, select)
   112         add_types_restriction(self.schema, select)
   113         if server.DEBUG:
   113         if server.DEBUG:
   114             print '---- rewriten', select
   114             print '---- rewriten', select
   115             
   115 
   116     def build_variantes(self, newsolutions):
   116     def build_variantes(self, newsolutions):
   117         variantes = set()
   117         variantes = set()
   118         for sol in newsolutions:
   118         for sol in newsolutions:
   119             variante = []
   119             variante = []
   120             for (erqlexpr, mainvar, oldvar), newvar in self.rewritten.iteritems():
   120             for (erqlexpr, mainvar, oldvar), newvar in self.rewritten.iteritems():
   131                     break
   131                     break
   132             else:
   132             else:
   133                 for variante in variantes:
   133                 for variante in variantes:
   134                     del variante[(erqlexpr, mainvar, oldvar)]
   134                     del variante[(erqlexpr, mainvar, oldvar)]
   135         return variantes
   135         return variantes
   136     
   136 
   137     def insert_snippets(self, snippets, varexistsmap=None):
   137     def insert_snippets(self, snippets, varexistsmap=None):
   138         self.rewritten = {}
   138         self.rewritten = {}
   139         for varname, erqlexprs in snippets:
   139         for varname, erqlexprs in snippets:
   140             if varexistsmap is not None and not varname in varexistsmap:
   140             if varexistsmap is not None and not varname in varexistsmap:
   141                 continue
   141                 continue
   173                     if self.exists_snippet[erqlexpr] is exists:
   173                     if self.exists_snippet[erqlexpr] is exists:
   174                         self.insert_snippet(varname, erqlexpr.snippet_rqlst, exists)
   174                         self.insert_snippet(varname, erqlexpr.snippet_rqlst, exists)
   175             if varexistsmap is None and not inserted:
   175             if varexistsmap is None and not inserted:
   176                 # no rql expression found matching rql solutions. User has no access right
   176                 # no rql expression found matching rql solutions. User has no access right
   177                 raise Unauthorized()
   177                 raise Unauthorized()
   178             
   178 
   179     def insert_snippet(self, varname, snippetrqlst, parent=None):
   179     def insert_snippet(self, varname, snippetrqlst, parent=None):
   180         new = snippetrqlst.where.accept(self)
   180         new = snippetrqlst.where.accept(self)
   181         if new is not None:
   181         if new is not None:
   182             try:
   182             try:
   183                 var = self.select.defined_vars[varname]
   183                 var = self.select.defined_vars[varname]
   238                     if parent is None:
   238                     if parent is None:
   239                         self.select.remove_node(new, undefine=True)
   239                         self.select.remove_node(new, undefine=True)
   240                     else:
   240                     else:
   241                         parent.parent.replace(or_, or_.children[0])
   241                         parent.parent.replace(or_, or_.children[0])
   242                         self._cleanup_inserted(new)
   242                         self._cleanup_inserted(new)
   243                     raise 
   243                     raise
   244             return new
   244             return new
   245 
   245 
   246     def _cleanup_inserted(self, node):
   246     def _cleanup_inserted(self, node):
   247         # cleanup inserted variable references
   247         # cleanup inserted variable references
   248         for vref in node.iget_nodes(nodes.VariableRef):
   248         for vref in node.iget_nodes(nodes.VariableRef):
   249             vref.unregister_reference()
   249             vref.unregister_reference()
   250             if not vref.variable.stinfo['references']:
   250             if not vref.variable.stinfo['references']:
   251                 # no more references, undefine the variable
   251                 # no more references, undefine the variable
   252                 del self.select.defined_vars[vref.name]
   252                 del self.select.defined_vars[vref.name]
   253         
   253 
   254     def _visit_binary(self, node, cls):
   254     def _visit_binary(self, node, cls):
   255         newnode = cls()
   255         newnode = cls()
   256         for c in node.children:
   256         for c in node.children:
   257             new = c.accept(self)
   257             new = c.accept(self)
   258             if new is None:
   258             if new is None:
   268         newc = node.children[0].accept(self)
   268         newc = node.children[0].accept(self)
   269         if newc is None:
   269         if newc is None:
   270             return None
   270             return None
   271         newnode = cls()
   271         newnode = cls()
   272         newnode.append(newc)
   272         newnode.append(newc)
   273         return newnode 
   273         return newnode
   274         
   274 
   275     def visit_and(self, et):
   275     def visit_and(self, et):
   276         return self._visit_binary(et, nodes.And)
   276         return self._visit_binary(et, nodes.And)
   277 
   277 
   278     def visit_or(self, ou):
   278     def visit_or(self, ou):
   279         return self._visit_binary(ou, nodes.Or)
   279         return self._visit_binary(ou, nodes.Or)
   280         
   280 
   281     def visit_not(self, node):
   281     def visit_not(self, node):
   282         return self._visit_unary(node, nodes.Not)
   282         return self._visit_unary(node, nodes.Not)
   283 
   283 
   284     def visit_exists(self, node):
   284     def visit_exists(self, node):
   285         return self._visit_unary(node, nodes.Exists)
   285         return self._visit_unary(node, nodes.Exists)
   286    
   286 
   287     def visit_relation(self, relation):
   287     def visit_relation(self, relation):
   288         lhs, rhs = relation.get_variable_parts()
   288         lhs, rhs = relation.get_variable_parts()
   289         if lhs.name == 'X':
   289         if lhs.name == 'X':
   290             # on lhs
   290             # on lhs
   291             # see if we can reuse this relation
   291             # see if we can reuse this relation
   299             # on rhs
   299             # on rhs
   300             # see if we can reuse this relation
   300             # see if we can reuse this relation
   301             if relation.r_type in self.rhs_rels and self._may_be_shared(relation, 'subject'):
   301             if relation.r_type in self.rhs_rels and self._may_be_shared(relation, 'subject'):
   302                 # ok, can share variable
   302                 # ok, can share variable
   303                 term = self.rhs_rels[relation.r_type].children[0]
   303                 term = self.rhs_rels[relation.r_type].children[0]
   304                 self._use_outer_term(lhs.name, term)            
   304                 self._use_outer_term(lhs.name, term)
   305                 return
   305                 return
   306         rel = nodes.Relation(relation.r_type, relation.optional)
   306         rel = nodes.Relation(relation.r_type, relation.optional)
   307         for c in relation.children:
   307         for c in relation.children:
   308             rel.append(c.accept(self))
   308             rel.append(c.accept(self))
   309         return rel
   309         return rel
   317     def visit_mathexpression(self, mexpr):
   317     def visit_mathexpression(self, mexpr):
   318         cmp_ = nodes.MathExpression(mexpr.operator)
   318         cmp_ = nodes.MathExpression(mexpr.operator)
   319         for c in cmp.children:
   319         for c in cmp.children:
   320             cmp_.append(c.accept(self))
   320             cmp_.append(c.accept(self))
   321         return cmp_
   321         return cmp_
   322         
   322 
   323     def visit_function(self, function):
   323     def visit_function(self, function):
   324         """generate filter name for a function"""
   324         """generate filter name for a function"""
   325         function_ = nodes.Function(function.name)
   325         function_ = nodes.Function(function.name)
   326         for c in function.children:
   326         for c in function.children:
   327             function_.append(c.accept(self))
   327             function_.append(c.accept(self))
   369         if key in self.rewritten:
   369         if key in self.rewritten:
   370             insertedvar = self.select.defined_vars.pop(self.rewritten[key])
   370             insertedvar = self.select.defined_vars.pop(self.rewritten[key])
   371             for inserted_vref in insertedvar.references():
   371             for inserted_vref in insertedvar.references():
   372                 inserted_vref.parent.replace(inserted_vref, term.copy(self.select))
   372                 inserted_vref.parent.replace(inserted_vref, term.copy(self.select))
   373         self.rewritten[key] = term
   373         self.rewritten[key] = term
   374         
   374 
   375     def _get_varname_or_term(self, vname):
   375     def _get_varname_or_term(self, vname):
   376         if vname == 'U':
   376         if vname == 'U':
   377             if self.u_varname is None:
   377             if self.u_varname is None:
   378                 select = self.select
   378                 select = self.select
   379                 self.u_varname = select.allocate_varname()
   379                 self.u_varname = select.allocate_varname()