--- a/server/sources/rql2sql.py Wed Jun 29 18:27:01 2011 +0200
+++ b/server/sources/rql2sql.py Wed Jun 29 18:28:36 2011 +0200
@@ -247,28 +247,32 @@
table + '.eid_from')
return switchedsql.replace('__eid_from__', table + '.eid_to')
-def sort_term_selection(sorts, selectedidx, rqlst, groups):
+def sort_term_selection(sorts, rqlst, groups):
# XXX beurk
if isinstance(rqlst, list):
def append(term):
rqlst.append(term)
+ selectionidx = set(str(term) for term in rqlst)
else:
def append(term):
rqlst.selection.append(term.copy(rqlst))
+ selectionidx = set(str(term) for term in rqlst.selection)
+
for sortterm in sorts:
term = sortterm.term
- if not isinstance(term, Constant) and not str(term) in selectedidx:
- selectedidx.append(str(term))
+ if not isinstance(term, Constant) and not str(term) in selectionidx:
+ selectionidx.add(str(term))
append(term)
if groups:
for vref in term.iget_nodes(VariableRef):
if not vref in groups:
groups.append(vref)
-def fix_selection_and_group(rqlst, selectedidx, needwrap, selectsortterms,
+def fix_selection_and_group(rqlst, needwrap, selectsortterms,
sorts, groups, having):
if selectsortterms and sorts:
- sort_term_selection(sorts, selectedidx, rqlst, not needwrap and groups)
+ sort_term_selection(sorts, rqlst, not needwrap and groups)
+ groupvrefs = [vref for term in groups for vref in term.iget_nodes(VariableRef)]
if sorts and groups:
# when a query is grouped, ensure sort terms are grouped as well
for sortterm in sorts:
@@ -277,19 +281,22 @@
(isinstance(term, Function) and
get_func_descr(term.name).aggregat)):
for vref in term.iget_nodes(VariableRef):
- if not vref in groups:
+ if not vref in groupvrefs:
groups.append(vref)
- if needwrap:
+ groupvrefs.append(vref)
+ if needwrap and (groups or having):
+ selectedidx = set(vref.name for term in rqlst.selection
+ for vref in term.get_nodes(VariableRef))
if groups:
- for vref in groups:
- if not vref.name in selectedidx:
- selectedidx.append(vref.name)
+ for vref in groupvrefs:
+ if vref.name not in selectedidx:
+ selectedidx.add(vref.name)
rqlst.selection.append(vref)
if having:
for term in having:
for vref in term.iget_nodes(VariableRef):
- if not vref.name in selectedidx:
- selectedidx.append(vref.name)
+ if vref.name not in selectedidx:
+ selectedidx.add(vref.name)
rqlst.selection.append(vref)
def iter_mapped_var_sels(stmt, variable):
@@ -806,23 +813,16 @@
# treat subqueries
self._subqueries_sql(select, state)
# generate sql for this select node
- selectidx = [str(term) for term in select.selection]
if needwrap:
outerselection = origselection[:]
if sorts and selectsortterms:
- outerselectidx = [str(term) for term in outerselection]
if distinct:
- sort_term_selection(sorts, outerselectidx,
- outerselection, groups)
- else:
- outerselectidx = selectidx[:]
- fix_selection_and_group(select, selectidx, needwrap,
- selectsortterms, sorts, groups, having)
+ sort_term_selection(sorts, outerselection, groups)
+ fix_selection_and_group(select, needwrap, selectsortterms,
+ sorts, groups, having)
if needwrap:
- fselectidx = outerselectidx
fneedwrap = len(outerselection) != len(origselection)
else:
- fselectidx = selectidx
fneedwrap = len(select.selection) != len(origselection)
if fneedwrap:
needalias = True
@@ -854,8 +854,12 @@
# sort
if sorts:
sqlsortterms = []
+ if needwrap:
+ selectidx = [str(term) for term in outerselection]
+ else:
+ selectidx = [str(term) for term in select.selection]
for sortterm in sorts:
- _term = self._sortterm_sql(sortterm, fselectidx)
+ _term = self._sortterm_sql(sortterm, selectidx)
if _term is not None:
sqlsortterms.append(_term)
if sqlsortterms: