server/sources/rql2sql.py
branchstable
changeset 3815 50b87f759b5d
parent 3787 82bb2c7f083b
child 3852 03121ca1f85e
--- a/server/sources/rql2sql.py	Tue Nov 10 15:46:34 2009 +0100
+++ b/server/sources/rql2sql.py	Tue Nov 10 18:06:47 2009 +0100
@@ -252,10 +252,10 @@
             self.actual_tables[-1].append(tsql)
         self.outer_tables = {}
         self.duplicate_switches = []
-        self.attr_vars = {}
         self.aliases = {}
         self.restrictions = []
         self._restr_stack = []
+        self.ignore_varmap = False
 
     def add_restriction(self, restr):
         if restr:
@@ -848,23 +848,23 @@
         nothing to do here.
         """
         contextrels = {}
-        attrvars = self._state.attr_vars
         for var in rhs_vars:
-            try:
-                contextrels[var.name] = attrvars[var.name]
-            except KeyError:
-                attrvars[var.name] = relation
             if var.name in self._varmap:
                 # ensure table is added
                 self._var_info(var.variable)
+            principal = var.variable.stinfo.get('principal')
+            if principal is not None and principal is not relation:
+                contextrels[var.name] = relation
         if not contextrels:
-            relation.children[1].accept(self, contextrels)
             return ''
-        # at least one variable is already in attr_vars, this means we have to
-        # generate unification expression
+        # we have to generate unification expression
         lhssql = self._inlined_var_sql(relation.children[0].variable,
                                        relation.r_type)
-        return '%s%s' % (lhssql, relation.children[1].accept(self, contextrels))
+        try:
+            self._state.ignore_varmap = True
+            return '%s%s' % (lhssql, relation.children[1].accept(self))
+        finally:
+            self._state.ignore_varmap = False
 
     def _visit_attribute_relation(self, rel):
         """generate SQL for an attribute relation"""
@@ -932,7 +932,7 @@
         return self.dbms_helper.fti_restriction_sql(alias, const.eval(self._args),
                                                     jointo, not_) + restriction
 
-    def visit_comparison(self, cmp, contextrels=None):
+    def visit_comparison(self, cmp):
         """generate SQL for a comparison"""
         if len(cmp.children) == 2:
             # XXX occurs ?
@@ -950,16 +950,15 @@
               and rhs.eval(self._args) is None):
             if lhs is None:
                 return ' IS NULL'
-            return '%s IS NULL' % lhs.accept(self, contextrels)
+            return '%s IS NULL' % lhs.accept(self)
         elif isinstance(rhs, Function) and rhs.name == 'IN':
             assert operator == '='
             operator = ' '
         if lhs is None:
-            return '%s%s'% (operator, rhs.accept(self, contextrels))
-        return '%s%s%s'% (lhs.accept(self, contextrels), operator,
-                          rhs.accept(self, contextrels))
+            return '%s%s'% (operator, rhs.accept(self))
+        return '%s%s%s'% (lhs.accept(self), operator, rhs.accept(self))
 
-    def visit_mathexpression(self, mexpr, contextrels=None):
+    def visit_mathexpression(self, mexpr):
         """generate SQL for a mathematic expression"""
         lhs, rhs = mexpr.get_parts()
         # check for string concatenation
@@ -969,17 +968,16 @@
                 operator = '||'
         except CoercionError:
             pass
-        return '(%s %s %s)'% (lhs.accept(self, contextrels), operator,
-                              rhs.accept(self, contextrels))
+        return '(%s %s %s)'% (lhs.accept(self), operator, rhs.accept(self))
 
-    def visit_function(self, func, contextrels=None):
+    def visit_function(self, func):
         """generate SQL name for a function"""
         # function_description will check function is supported by the backend
         sqlname = self.dbms_helper.func_sqlname(func.name)
-        return '%s(%s)' % (sqlname, ', '.join(c.accept(self, contextrels)
+        return '%s(%s)' % (sqlname, ', '.join(c.accept(self)
                                               for c in func.children))
 
-    def visit_constant(self, constant, contextrels=None):
+    def visit_constant(self, constant):
         """generate SQL name for a constant"""
         value = constant.value
         if constant.type is None:
@@ -1004,12 +1002,12 @@
             self._query_attrs[_id] = value
         return '%%(%s)s' % _id
 
-    def visit_variableref(self, variableref, contextrels=None):
+    def visit_variableref(self, variableref):
         """get the sql name for a variable reference"""
         # use accept, .variable may be a variable or a columnalias
-        return variableref.variable.accept(self, contextrels)
+        return variableref.variable.accept(self)
 
-    def visit_columnalias(self, colalias, contextrels=None):
+    def visit_columnalias(self, colalias):
         """get the sql name for a subquery column alias"""
         if colalias.name in self._varmap:
             sql = self._varmap[colalias.name]
@@ -1020,20 +1018,21 @@
             return sql
         return colalias._q_sql
 
-    def visit_variable(self, variable, contextrels=None):
+    def visit_variable(self, variable):
         """get the table name and sql string for a variable"""
-        if contextrels is None and variable.name in self._state.done:
+        #if contextrels is None and variable.name in self._state.done:
+        if variable.name in self._state.done:
             if self._in_wrapping_query:
                 return 'T1.%s' % self._state.aliases[variable.name]
             return variable._q_sql
         self._state.done.add(variable.name)
         vtablename = None
-        if contextrels is None and variable.name in self._varmap:
+        if not self._state.ignore_varmap and variable.name in self._varmap:
             sql, vtablename = self._var_info(variable)
         elif variable.stinfo['attrvar']:
             # attribute variable (systematically used in rhs of final
             # relation(s)), get table name and sql from any rhs relation
-            sql = self._linked_var_sql(variable, contextrels)
+            sql = self._linked_var_sql(variable)
         elif variable._q_invariant:
             # since variable is invariant, we know we won't found final relation
             principal = variable.stinfo['principal']
@@ -1056,7 +1055,7 @@
                                  self.dbms_helper.fti_uid_attr)
             elif principal in variable.stinfo['rhsrelations']:
                 if self.schema.rschema(principal.r_type).inlined:
-                    sql = self._linked_var_sql(variable, contextrels)
+                    sql = self._linked_var_sql(variable)
                 else:
                     sql = '%s.eid_to' % self._relation_table(principal)
             else:
@@ -1076,8 +1075,14 @@
         # generate extra join
         try:
             if not var.stinfo['principal'] is relation:
-                # need a predicable result for tests
-                return '%s=%s' % tuple(sorted((sql, var.accept(self))))
+                op = relation.operator()
+                if op == '=':
+                    # need a predicable result for tests
+                    args = sorted( (sql, var.accept(self)) )
+                    args.insert(1, op)
+                else:
+                    args = (sql, op, var.accept(self))
+                return '%s%s%s' % tuple(args)
         except KeyError:
             # no principal defined, relation is necessarily the principal and
             # so nothing to return here
@@ -1123,14 +1128,13 @@
             #self._state.done.add(var.name)
         return sql
 
-    def _linked_var_sql(self, variable, contextrels=None):
-        if contextrels is None:
+    def _linked_var_sql(self, variable):
+        if not self._state.ignore_varmap:
             try:
                 return self._varmap[variable.name]
             except KeyError:
                 pass
-        rel = (contextrels and contextrels.get(variable.name) or
-               variable.stinfo.get('principal') or
+        rel = (variable.stinfo.get('principal') or
                iter(variable.stinfo['rhsrelations']).next())
         linkedvar = rel.children[0].variable
         if rel.r_type == 'eid':