server/sources/rql2sql.py
changeset 5811 e77cea9721e7
parent 5768 1e73a466aa69
parent 5793 1faff41593df
child 5821 656c974961c4
--- a/server/sources/rql2sql.py	Fri Jun 18 18:31:22 2010 +0200
+++ b/server/sources/rql2sql.py	Mon Jun 21 13:23:11 2010 +0200
@@ -45,9 +45,8 @@
 and Informix.
 
 .. _Comparison of different SQL implementations: http://www.troels.arvin.dk/db/rdbms
-
+"""
 
-"""
 __docformat__ = "restructuredtext en"
 
 import threading
@@ -56,8 +55,8 @@
 
 from rql import BadRQLQuery, CoercionError
 from rql.stmts import Union, Select
-from rql.nodes import (SortTerm, VariableRef, Constant, Function, Not,
-                       Variable, ColumnAlias, Relation, SubQuery, Exists)
+from rql.nodes import (SortTerm, VariableRef, Constant, Function, Variable, Or,
+                       Not, Comparison, ColumnAlias, Relation, SubQuery, Exists)
 
 from cubicweb import QueryError
 from cubicweb.server.sqlutils import SQL_PREFIX
@@ -397,6 +396,49 @@
         self.restrictions = self._restr_stack.pop()
         return restrictions, self.actual_tables.pop()
 
+def extract_fake_having_terms(having):
+    """RQL's HAVING may be used to contains stuff that should go in the WHERE
+    clause of the SQL query, due to RQL grammar limitation. Split them...
+
+    Return a list nodes that can be ANDed with query's WHERE clause. Having
+    subtrees updated in place.
+    """
+    fakehaving = []
+    for subtree in having:
+        ors, tocheck = set(), []
+        for compnode in subtree.get_nodes(Comparison):
+            for fnode in compnode.get_nodes(Function):
+                if fnode.descr().aggregat:
+                    p = compnode.parent
+                    oor = None
+                    while not isinstance(p, Select):
+                        if isinstance(p, Or):
+                            oor = p
+                        p = p.parent
+                    if oor is not None:
+                        ors.add(oor)
+                    break
+            else:
+                tocheck.append(compnode)
+        # tocheck hold a set of comparison not implying an aggregat function
+        # put them in fakehaving if the don't share an Or node as ancestor
+        # with another comparison containing an aggregat function
+        for compnode in tocheck:
+            parents = set()
+            p = compnode.parent
+            oor = None
+            while not isinstance(p, Select):
+                if p in ors:
+                    break
+                if isinstance(p, Or):
+                    oor = p
+                p = p.parent
+            else:
+                node = oor or compnode
+                if not node in fakehaving:
+                    fakehaving.append(node)
+                    compnode.parent.remove(node)
+    return fakehaving
 
 class SQLGenerator(object):
     """
@@ -494,6 +536,7 @@
         sorts = select.orderby
         groups = select.groupby
         having = select.having
+        morerestr = extract_fake_having_terms(having)
         # remember selection, it may be changed and have to be restored
         origselection = select.selection[:]
         # check if the query will have union subquery, if it need sort term
@@ -545,7 +588,8 @@
         self._in_wrapping_query = False
         self._state = state
         try:
-            sql = self._solutions_sql(select, sols, distinct, needalias or needwrap)
+            sql = self._solutions_sql(select, morerestr, sols, distinct,
+                                      needalias or needwrap)
             # generate groups / having before wrapping query selection to
             # get correct column aliases
             self._in_wrapping_query = needwrap
@@ -610,13 +654,15 @@
                 except KeyError:
                     continue
 
-    def _solutions_sql(self, select, solutions, distinct, needalias):
+    def _solutions_sql(self, select, morerestr, solutions, distinct, needalias):
         sqls = []
         for solution in solutions:
             self._state.reset(solution)
             # visit restriction subtree
             if select.where is not None:
                 self._state.add_restriction(select.where.accept(self))
+            for restriction in morerestr:
+                self._state.add_restriction(restriction.accept(self))
             sql = [self._selection_sql(select.selection, distinct, needalias)]
             if self._state.restrictions:
                 sql.append('WHERE %s' % ' AND '.join(self._state.restrictions))
@@ -1055,7 +1101,8 @@
         operator = mexpr.operator
         try:
             if mexpr.operator == '+' and mexpr.get_type(self._state.solution, self._args) == 'String':
-                operator = '||'
+                return '(%s)' % self.dbhelper.sql_concat_string(lhs.accept(self),
+                                                                rhs.accept(self))
         except CoercionError:
             pass
         return '(%s %s %s)'% (lhs.accept(self), operator, rhs.accept(self))