server/sources/rql2sql.py
branchstable
changeset 5582 3e133b29a1a4
parent 5426 0d4853a6e5ee
child 5593 f6c55bec9326
--- a/server/sources/rql2sql.py	Tue May 25 12:21:17 2010 +0200
+++ b/server/sources/rql2sql.py	Wed May 26 10:28:48 2010 +0200
@@ -44,7 +44,7 @@
 by Troels Arvin. Features SQL ISO Standard, PG, mysql, Oracle, MS SQL, DB2
 and Informix.
 
-.. _Comparison of different SQL implementations: http://www.troels.arvin.dk/db/rdbms 
+.. _Comparison of different SQL implementations: http://www.troels.arvin.dk/db/rdbms
 
 
 """
@@ -112,7 +112,7 @@
         unstable.remove(varname)
         torewrite.add(var)
         newselect = Select()
-        newselect.need_distinct = newselect.need_intersect = False
+        newselect.need_distinct = False
         myunion = Union()
         myunion.append(newselect)
         # extract aliases / selection
@@ -316,13 +316,15 @@
 # IGenerator implementation for RQL->SQL #######################################
 
 class StateInfo(object):
-    def __init__(self, existssols, unstablevars):
+    def __init__(self, select, existssols, unstablevars):
         self.existssols = existssols
         self.unstablevars = unstablevars
         self.subtables = {}
         self.needs_source_cb = None
         self.subquery_source_cb = None
         self.source_cb_funcs = set()
+        self.scopes = {select: 0}
+        self.scope_nodes = []
 
     def reset(self, solution):
         """reset some visit variables"""
@@ -381,12 +383,16 @@
         self.solution = origsol
         self.tables = origtables
 
-    def push_scope(self):
+    def push_scope(self, scope_node):
+        self.scope_nodes.append(scope_node)
+        self.scopes[scope_node] = len(self.actual_tables)
         self.actual_tables.append([])
         self._restr_stack.append(self.restrictions)
         self.restrictions = []
 
     def pop_scope(self):
+        del self.scopes[self.scope_nodes[-1]]
+        self.scope_nodes.pop()
         restrictions = self.restrictions
         self.restrictions = self._restr_stack.pop()
         return restrictions, self.actual_tables.pop()
@@ -442,7 +448,7 @@
         self._varmap = varmap
         self._query_attrs = {}
         self._state = None
-        self._not_scope_offset = 0
+        # self._not_scope_offset = 0
         try:
             # union query for each rqlst / solution
             sql = self.union_sql(union)
@@ -509,7 +515,7 @@
                     needwrap = True
         else:
             existssols, unstable = {}, ()
-        state = StateInfo(existssols, unstable)
+        state = StateInfo(select, existssols, unstable)
         if self._state is not None:
             # state from a previous unioned select
             state.merge_source_cbs(self._state.needs_source_cb)
@@ -622,12 +628,7 @@
             elif self._state.restrictions and self.dbhelper.needs_from_clause:
                 sql.insert(1, 'FROM (SELECT 1) AS _T')
             sqls.append('\n'.join(sql))
-        if select.need_intersect:
-            #if distinct or not self.dbhelper.intersect_all_support:
-            return '\nINTERSECT\n'.join(sqls)
-            #else:
-            #    return '\nINTERSECT ALL\n'.join(sqls)
-        elif distinct:
+        if distinct:
             return '\nUNION\n'.join(sqls)
         else:
             return '\nUNION ALL\n'.join(sqls)
@@ -682,32 +683,11 @@
         return ''
 
     def visit_not(self, node):
-        self._state.push_scope()
-        if isinstance(node.children[0], Relation):
-            self._not_scope_offset += 1
         csql = node.children[0].accept(self)
-        if isinstance(node.children[0], Relation):
-            self._not_scope_offset -= 1
-        sqls, tables = self._state.pop_scope()
         if node in self._state.done or not csql:
             # already processed or no sql generated by children
-            self._state.actual_tables[-1] += tables
-            self._state.restrictions += sqls
             return csql
-        if isinstance(node.children[0], Exists):
-            assert not sqls, (sqls, str(node.stmt))
-            assert not tables, (tables, str(node.stmt))
-            return 'NOT %s' % csql
-        sqls.append(csql)
-        if tables:
-            select = 'SELECT 1 FROM %s' % ','.join(tables)
-        else:
-            select = 'SELECT 1'
-        if sqls:
-            sql = 'NOT EXISTS(%s WHERE %s)' % (select, ' AND '.join(sqls))
-        else:
-            sql = 'NOT EXISTS(%s)' % select
-        return sql
+        return 'NOT (%s)' % csql
 
     def visit_exists(self, exists):
         """generate SQL name for a exists subquery"""
@@ -721,7 +701,7 @@
         return 'EXISTS(%s)' % ' UNION '.join(sqls)
 
     def _visit_exists(self, exists):
-        self._state.push_scope()
+        self._state.push_scope(exists)
         restriction = exists.children[0].accept(self)
         restrictions, tables = self._state.pop_scope()
         if restriction:
@@ -762,9 +742,6 @@
                 else:
                     # no variables in the RHS
                     sql = self._visit_attribute_relation(relation)
-                if relation.neged(strict=True):
-                    self._state.done.add(relation.parent)
-                    sql = 'NOT (%s)' % sql
         else:
             if rtype == 'is' and rhs.operator == 'IS':
                 # special case "C is NULL"
@@ -833,9 +810,6 @@
         if relation.r_type == 'identity':
             # special case "X identity Y"
             lhs, rhs = relation.get_parts()
-            if isinstance(relation.parent, Not):
-                self._state.done.add(relation.parent)
-                return 'NOT %s%s' % (lhs.accept(self), rhs.accept(self))
             return '%s%s' % (lhs.accept(self), rhs.accept(self))
         lhsvar, lhsconst, rhsvar, rhsconst = relation_info(relation)
         rid = self._relation_table(relation)
@@ -1041,7 +1015,7 @@
         else:
             not_ = False
         return self.dbhelper.fti_restriction_sql(alias, const.eval(self._args),
-                                                    jointo, not_) + restriction
+                                                 jointo, not_) + restriction
 
     def visit_comparison(self, cmp):
         """generate SQL for a comparison"""
@@ -1204,22 +1178,10 @@
         return ''
 
     def _var_info(self, var):
-        # if current var or one of its attribute is selected , it *must*
-        # appear in the toplevel's FROM even if we're currently visiting
-        # a EXISTS node
-        if var.sqlscope is var.stmt:
-            scope = 0
-        # don't consider not_scope_offset if the variable is only used in one
-        # relation
-        elif len(var.stinfo['relations']) > 1:
-            scope = -1 - self._not_scope_offset
-        else:
-            scope = -1
+        scope = self._state.scopes[var.scope]
         try:
             sql = self._varmap[var.name]
             tablealias = sql.split('.', 1)[0]
-            if scope < 0:
-                scope = self._varmap_table_scope(var.stmt, tablealias)
             self.add_table(tablealias, scope=scope)
         except KeyError:
             etype = self._state.solution[var.name]
@@ -1235,7 +1197,7 @@
     def _inlined_var_sql(self, var, rtype):
         try:
             sql = self._varmap['%s.%s' % (var.name, rtype)]
-            scope = var.sqlscope is var.stmt and 0 or -1
+            scope = self._state.scopes[var.scope]
             self.add_table(sql.split('.', 1)[0], scope=scope)
         except KeyError:
             sql = '%s.%s%s' % (self._var_table(var), SQL_PREFIX, rtype)
@@ -1358,7 +1320,7 @@
                 break
             # XXX may have a principal without being invariant for this generation,
             #     not sure this is a pb or not
-            if var.stinfo.get('principal') is relation and var.sqlscope is var.stmt:
+            if var.stinfo.get('principal') is relation and var.scope is var.stmt:
                 scope = 0
                 break
         else:
@@ -1379,15 +1341,3 @@
         alias = self.alias_and_add_table(self.dbhelper.fti_table)
         relation._q_sqltable = alias
         return alias
-
-    def _varmap_table_scope(self, select, table):
-        """since a varmap table may be used for multiple variable, its scope is
-        the most outer scope of each variables
-        """
-        scope = -1
-        for varname, alias in self._varmap.iteritems():
-            # check '.' in varname since there are 'X.attribute' keys in varmap
-            if not '.' in varname and alias.split('.', 1)[0] == table:
-                if select.defined_vars[varname].sqlscope is select:
-                    return 0
-        return scope