goa/rqlinterpreter.py
branchtls-sprint
changeset 1802 d628defebc17
parent 1133 8a409ea0c9ec
child 1977 606923dff11b
--- a/goa/rqlinterpreter.py	Thu May 14 10:24:56 2009 +0200
+++ b/goa/rqlinterpreter.py	Thu May 14 11:38:40 2009 +0200
@@ -23,7 +23,7 @@
     return Key(key).kind()
 
 def poss_var_types(myvar, ovar, kind, solutions):
-    return frozenset(etypes[myvar] for etypes in solutions 
+    return frozenset(etypes[myvar] for etypes in solutions
                      if etypes[ovar] == kind)
 
 def expand_result(results, result, myvar, values, dsget=None):
@@ -84,7 +84,7 @@
             string.append('%s: %s' % (k, v))
     return '{%s}' % ', '.join(string)
 
-                         
+
 class EidMismatch(Exception):
     def __init__(self, varname, value):
         self.varname = varname
@@ -101,45 +101,45 @@
         self.operator = operator
         self.rtype = rel.r_type
         self.var = rel.children[0]
-        
+
     def __repr__(self):
         return '<%s for %s>' % (self.__class__.__name__, self.rel)
-    
+
     @property
     def rhs(self):
         return self.rel.children[1].children[0]
 
-        
+
 class MultipleRestriction(object):
     def __init__(self, restrictions):
         self.restrictions = restrictions
-        
+
     def resolve(self, solutions, fixed):
         return _resolve(self.restrictions, solutions, fixed)
 
-    
+
 class VariableSelection(Restriction):
     def __init__(self, rel, dsget, prefix='s'):
         Restriction.__init__(self, rel)
         self._dsget = dsget
         self._not = self.rel.neged(strict=True)
         self._prefix = prefix + '_'
-        
+
     def __repr__(self):
         return '<%s%s for %s>' % (self._prefix[0], self.__class__.__name__, self.rel)
-        
+
     @property
     def searched_var(self):
         if self._prefix == 's_':
             return self.var.name
         return self.rhs.name
-        
+
     @property
     def constraint_var(self):
         if self._prefix == 's_':
             return self.rhs.name
         return self.var.name
-        
+
     def _possible_values(self, myvar, ovar, entity, solutions, dsprefix):
         if self.rtype == 'identity':
             return (entity.key(),)
@@ -150,7 +150,7 @@
             value = [value]
         vartypes = poss_var_types(myvar, ovar, entity.kind(), solutions)
         return (v for v in value if v.kind() in vartypes)
-        
+
     def complete_and_filter(self, solutions, results):
         myvar = self.rhs.name
         ovar = self.var.name
@@ -173,8 +173,8 @@
                     expand_result(results, result, myvar, values, self._dsget)
         else:
             assert self.rhs.name in results[0]
-            self.object_complete_and_filter(solutions, results)           
-            
+            self.object_complete_and_filter(solutions, results)
+
     def filter(self, solutions, results):
         myvar = self.rhs.name
         ovar = self.var.name
@@ -187,10 +187,10 @@
                 newsols[key] = frozenset(v for v in values)
             if self._not:
                 if result[myvar].key() in newsols[key]:
-                    results.remove(result)                
+                    results.remove(result)
             elif not result[myvar].key() in newsols[key]:
                 results.remove(result)
-    
+
     def object_complete_and_filter(self, solutions, results):
         if self._not:
             raise NotImplementedError()
@@ -201,7 +201,7 @@
                                            solutions, 'o_')
             expand_result(results, result, myvar, values, self._dsget)
 
-    
+
 class EidRestriction(Restriction):
     def __init__(self, rel, dsget):
         Restriction.__init__(self, rel)
@@ -216,7 +216,7 @@
 
     def _get_value(self, fixed):
         return fixed[self.constraint_var].key()
-    
+
     def fill_query(self, fixed, query, operator=None):
         restr = '%s%s %s' % (self._prefix, self.rtype, operator or self.operator)
         query[restr] = self._get_value(fixed)
@@ -235,7 +235,7 @@
 
     def _get_value(self, fixed):
         return None
-    
+
     def resolve(self, solutions, fixed):
         if self.rtype == 'identity':
             raise NotImplementedError()
@@ -255,7 +255,7 @@
                 raise NotImplementedError('LIKE is only supported for prefix search')
             self.operator = '>'
             self.value = value[:-1]
-            
+
     def complete_and_filter(self, solutions, results):
         # check lhs var first in case this is a restriction
         assert self._not
@@ -263,7 +263,7 @@
         for result in results[:]:
             if result[myvar].get('s_'+rtype) == value:
                 results.remove(result)
-            
+
     def _get_value(self, fixed):
         return self.value
 
@@ -294,7 +294,7 @@
     @property
     def operator(self):
         return 'in'
-            
+
 
 class TypeRestriction(AttributeRestriction):
     def __init__(self, var):
@@ -302,7 +302,7 @@
 
     def __repr__(self):
         return '<%s for %s>' % (self.__class__.__name__, self.var)
-    
+
     def resolve(self, solutions, fixed):
         objs = []
         for etype in frozenset(etypes[self.var.name] for etypes in solutions):
@@ -330,7 +330,7 @@
         self.args = args
         self.term = term
         self._solution = self.term.stmt.solutions[0]
-        
+
     def compute(self, result):
         """return (entity type, value) to which self.term is evaluated according
         to the given result dictionnary and to query arguments (self.args)
@@ -341,7 +341,7 @@
         args = tuple(n.accept(self, result)[1] for n in node.children)
         value = self.functions[node.name](*args)
         return node.get_type(self._solution, self.args), value
-    
+
     def visit_variableref(self, node, result):
         value = result[node.name]
         try:
@@ -350,11 +350,11 @@
         except AttributeError:
             etype = self._solution[node.name]
         return etype, value
-    
+
     def visit_constant(self, node, result):
         return node.get_type(kwargs=self.args), node.eval(self.args)
-    
-        
+
+
 class RQLInterpreter(object):
     """algorithm:
     1. visit the restriction clauses and collect restriction for each subject
@@ -369,7 +369,7 @@
            for each solution in select'solutions:
                1. resolve variables which have attribute restriction
                2. resolve relation restriction
-               3. resolve selection and add to global results 
+               3. resolve selection and add to global results
     """
     def __init__(self, schema):
         self.schema = schema
@@ -379,15 +379,15 @@
                              'UPPER': lambda x: x.upper()}
         for cb in SQL_CONNECT_HOOKS.get('sqlite', []):
             cb(self)
-            
+
     # emulate sqlite connection interface so we can reuse stored procedures
     def create_function(self, name, nbargs, func):
         self._stored_proc[name] = func
-        
+
     def create_aggregate(self, name, nbargs, func):
         self._stored_proc[name] = func
 
-        
+
     def execute(self, operation, parameters=None, eid_key=None, build_descr=True):
         rqlst = self.rqlhelper.parse(operation, annotate=True)
         try:
@@ -397,7 +397,7 @@
         else:
             results, description = self.interpret(rqlst, parameters)
         return ResultSet(results, operation, parameters, description, rqlst=rqlst)
-        
+
     def interpret(self, node, kwargs, dsget=None):
         if dsget is None:
             self._dsget = Get
@@ -417,7 +417,7 @@
             results += pres
             description += pdescr
         return results, description
-    
+
     def visit_select(self, node, extra):
         constraints = {}
         if node.where is not None:
@@ -441,7 +441,7 @@
         for varname, restrictions in constraints.iteritems():
             for restr in restrictions[:]:
                 if isinstance(restr, EidRestriction):
-                    assert not varname in fixed    
+                    assert not varname in fixed
                     try:
                         value = restr.resolve(kwargs)
                         fixed[varname] = value
@@ -455,7 +455,7 @@
                 if isinstance(restr, AttributeRestriction):
                     toresolve.setdefault(varname, []).append(restr)
                 elif isinstance(restr, NotRelationRestriction) or (
-                    isinstance(restr, RelationRestriction) and 
+                    isinstance(restr, RelationRestriction) and
                     not restr.searched_var in fixed and restr.constraint_var in fixed):
                     toresolve.setdefault(varname, []).append(restr)
                 else:
@@ -495,7 +495,7 @@
                     partres.append({varname: value})
             elif not varname in partres[0]:
                 # cartesian product
-                for res in partres:                    
+                for res in partres:
                     res[varname] = values[0]
                 for res in partres[:]:
                     for value in values[1:]:
@@ -503,14 +503,14 @@
                         res[varname] = value
                         partres.append(res)
             else:
-                # union 
+                # union
                 for res in varpartres:
                     for value in values:
                         res = res.copy()
                         res[varname] = value
                         partres.append(res)
         #print 'partres', len(partres)
-        #print partres                        
+        #print partres
         # Note: don't check for empty partres since constant selection may still
         # produce result at this point
         # sort to get RelationRestriction before AttributeSelection
@@ -569,14 +569,14 @@
                     append_result(res, descr, i, j, value, etype)
         #print '--------->', res
         return res, descr
-    
-    def visit_and(self, node, constraints, extra): 
+
+    def visit_and(self, node, constraints, extra):
         for child in node.children:
             child.accept(self, constraints, extra)
     def visit_exists(self, node, constraints, extra):
         extra['has_exists'] = True
         self.visit_and(node, constraints, extra)
-    
+
     def visit_not(self, node, constraints, extra):
         for child in node.children:
             child.accept(self, constraints, extra)
@@ -584,7 +584,7 @@
             extra.pop(node)
         except KeyError:
             raise NotImplementedError()
-        
+
     def visit_relation(self, node, constraints, extra):
         if node.is_types_restriction():
             return
@@ -600,7 +600,7 @@
             self._visit_non_final_neged_relation(rschema, node, constraints)
         else:
             self._visit_non_final_relation(rschema, node, constraints)
-                
+
     def _visit_non_final_relation(self, rschema, node, constraints, not_=False):
         lhs, rhs = node.get_variable_parts()
         for v1, v2, prefix in ((lhs, rhs, 's'), (rhs, lhs, 'o')):
@@ -611,14 +611,14 @@
             if nbrels > 1:
                 constraints.setdefault(v1.name, []).append(
                     RelationRestriction(node, self._dsget, prefix))
-                # just init an empty list for v2 variable to avoid a 
+                # just init an empty list for v2 variable to avoid a
                 # TypeRestriction being added for it
                 constraints.setdefault(v2.name, [])
                 break
         else:
             constraints.setdefault(rhs.name, []).append(
                 VariableSelection(node, self._dsget, 's'))
-                
+
     def _visit_non_final_neged_relation(self, rschema, node, constraints):
         lhs, rhs = node.get_variable_parts()
         for v1, v2, prefix in ((lhs, rhs, 's'), (rhs, lhs, 'o')):
@@ -653,16 +653,16 @@
                     AttributeInRestriction(node, extra['kwargs']))
             else:
                 raise NotImplementedError()
-        
+
     def _not_implemented(self, *args, **kwargs):
         raise NotImplementedError()
-    
+
     visit_or = _not_implemented
     # shouldn't occurs
     visit_set = _not_implemented
     visit_insert = _not_implemented
     visit_delete = _not_implemented
-        
+
 
 from logging import getLogger
 from cubicweb import set_log_methods