server/sources/rql2sql.py
changeset 7580 328542c4fdc8
parent 7493 3c46b9390871
parent 7579 5a610b34d2d2
child 7642 64eee2a83bfa
--- a/server/sources/rql2sql.py	Wed Jun 29 18:27:01 2011 +0200
+++ b/server/sources/rql2sql.py	Wed Jun 29 18:28:36 2011 +0200
@@ -247,28 +247,32 @@
                                       table + '.eid_from')
     return switchedsql.replace('__eid_from__', table + '.eid_to')
 
-def sort_term_selection(sorts, selectedidx, rqlst, groups):
+def sort_term_selection(sorts, rqlst, groups):
     # XXX beurk
     if isinstance(rqlst, list):
         def append(term):
             rqlst.append(term)
+        selectionidx = set(str(term) for term in rqlst)
     else:
         def append(term):
             rqlst.selection.append(term.copy(rqlst))
+        selectionidx = set(str(term) for term in rqlst.selection)
+
     for sortterm in sorts:
         term = sortterm.term
-        if not isinstance(term, Constant) and not str(term) in selectedidx:
-            selectedidx.append(str(term))
+        if not isinstance(term, Constant) and not str(term) in selectionidx:
+            selectionidx.add(str(term))
             append(term)
             if groups:
                 for vref in term.iget_nodes(VariableRef):
                     if not vref in groups:
                         groups.append(vref)
 
-def fix_selection_and_group(rqlst, selectedidx, needwrap, selectsortterms,
+def fix_selection_and_group(rqlst, needwrap, selectsortterms,
                             sorts, groups, having):
     if selectsortterms and sorts:
-        sort_term_selection(sorts, selectedidx, rqlst, not needwrap and groups)
+        sort_term_selection(sorts, rqlst, not needwrap and groups)
+    groupvrefs = [vref for term in groups for vref in term.iget_nodes(VariableRef)]
     if sorts and groups:
         # when a query is grouped, ensure sort terms are grouped as well
         for sortterm in sorts:
@@ -277,19 +281,22 @@
                     (isinstance(term, Function) and
                      get_func_descr(term.name).aggregat)):
                 for vref in term.iget_nodes(VariableRef):
-                    if not vref in groups:
+                    if not vref in groupvrefs:
                         groups.append(vref)
-    if needwrap:
+                        groupvrefs.append(vref)
+    if needwrap and (groups or having):
+        selectedidx = set(vref.name for term in rqlst.selection
+                          for vref in term.get_nodes(VariableRef))
         if groups:
-            for vref in groups:
-                if not vref.name in selectedidx:
-                    selectedidx.append(vref.name)
+            for vref in groupvrefs:
+                if vref.name not in selectedidx:
+                    selectedidx.add(vref.name)
                     rqlst.selection.append(vref)
         if having:
             for term in having:
                 for vref in term.iget_nodes(VariableRef):
-                    if not vref.name in selectedidx:
-                        selectedidx.append(vref.name)
+                    if vref.name not in selectedidx:
+                        selectedidx.add(vref.name)
                         rqlst.selection.append(vref)
 
 def iter_mapped_var_sels(stmt, variable):
@@ -806,23 +813,16 @@
         # treat subqueries
         self._subqueries_sql(select, state)
         # generate sql for this select node
-        selectidx = [str(term) for term in select.selection]
         if needwrap:
             outerselection = origselection[:]
             if sorts and selectsortterms:
-                outerselectidx = [str(term) for term in outerselection]
                 if distinct:
-                    sort_term_selection(sorts, outerselectidx,
-                                        outerselection, groups)
-            else:
-                outerselectidx = selectidx[:]
-        fix_selection_and_group(select, selectidx, needwrap,
-                                selectsortterms, sorts, groups, having)
+                    sort_term_selection(sorts, outerselection, groups)
+        fix_selection_and_group(select, needwrap, selectsortterms,
+                                sorts, groups, having)
         if needwrap:
-            fselectidx = outerselectidx
             fneedwrap = len(outerselection) != len(origselection)
         else:
-            fselectidx = selectidx
             fneedwrap = len(select.selection) != len(origselection)
         if fneedwrap:
             needalias = True
@@ -854,8 +854,12 @@
             # sort
             if sorts:
                 sqlsortterms = []
+                if needwrap:
+                    selectidx = [str(term) for term in outerselection]
+                else:
+                    selectidx = [str(term) for term in select.selection]
                 for sortterm in sorts:
-                    _term = self._sortterm_sql(sortterm, fselectidx)
+                    _term = self._sortterm_sql(sortterm, selectidx)
                     if _term is not None:
                         sqlsortterms.append(_term)
                 if sqlsortterms: