--- a/server/sources/rql2sql.py Tue Nov 10 15:46:34 2009 +0100
+++ b/server/sources/rql2sql.py Tue Nov 10 18:06:47 2009 +0100
@@ -252,10 +252,10 @@
self.actual_tables[-1].append(tsql)
self.outer_tables = {}
self.duplicate_switches = []
- self.attr_vars = {}
self.aliases = {}
self.restrictions = []
self._restr_stack = []
+ self.ignore_varmap = False
def add_restriction(self, restr):
if restr:
@@ -848,23 +848,23 @@
nothing to do here.
"""
contextrels = {}
- attrvars = self._state.attr_vars
for var in rhs_vars:
- try:
- contextrels[var.name] = attrvars[var.name]
- except KeyError:
- attrvars[var.name] = relation
if var.name in self._varmap:
# ensure table is added
self._var_info(var.variable)
+ principal = var.variable.stinfo.get('principal')
+ if principal is not None and principal is not relation:
+ contextrels[var.name] = relation
if not contextrels:
- relation.children[1].accept(self, contextrels)
return ''
- # at least one variable is already in attr_vars, this means we have to
- # generate unification expression
+ # we have to generate unification expression
lhssql = self._inlined_var_sql(relation.children[0].variable,
relation.r_type)
- return '%s%s' % (lhssql, relation.children[1].accept(self, contextrels))
+ try:
+ self._state.ignore_varmap = True
+ return '%s%s' % (lhssql, relation.children[1].accept(self))
+ finally:
+ self._state.ignore_varmap = False
def _visit_attribute_relation(self, rel):
"""generate SQL for an attribute relation"""
@@ -932,7 +932,7 @@
return self.dbms_helper.fti_restriction_sql(alias, const.eval(self._args),
jointo, not_) + restriction
- def visit_comparison(self, cmp, contextrels=None):
+ def visit_comparison(self, cmp):
"""generate SQL for a comparison"""
if len(cmp.children) == 2:
# XXX occurs ?
@@ -950,16 +950,15 @@
and rhs.eval(self._args) is None):
if lhs is None:
return ' IS NULL'
- return '%s IS NULL' % lhs.accept(self, contextrels)
+ return '%s IS NULL' % lhs.accept(self)
elif isinstance(rhs, Function) and rhs.name == 'IN':
assert operator == '='
operator = ' '
if lhs is None:
- return '%s%s'% (operator, rhs.accept(self, contextrels))
- return '%s%s%s'% (lhs.accept(self, contextrels), operator,
- rhs.accept(self, contextrels))
+ return '%s%s'% (operator, rhs.accept(self))
+ return '%s%s%s'% (lhs.accept(self), operator, rhs.accept(self))
- def visit_mathexpression(self, mexpr, contextrels=None):
+ def visit_mathexpression(self, mexpr):
"""generate SQL for a mathematic expression"""
lhs, rhs = mexpr.get_parts()
# check for string concatenation
@@ -969,17 +968,16 @@
operator = '||'
except CoercionError:
pass
- return '(%s %s %s)'% (lhs.accept(self, contextrels), operator,
- rhs.accept(self, contextrels))
+ return '(%s %s %s)'% (lhs.accept(self), operator, rhs.accept(self))
- def visit_function(self, func, contextrels=None):
+ def visit_function(self, func):
"""generate SQL name for a function"""
# function_description will check function is supported by the backend
sqlname = self.dbms_helper.func_sqlname(func.name)
- return '%s(%s)' % (sqlname, ', '.join(c.accept(self, contextrels)
+ return '%s(%s)' % (sqlname, ', '.join(c.accept(self)
for c in func.children))
- def visit_constant(self, constant, contextrels=None):
+ def visit_constant(self, constant):
"""generate SQL name for a constant"""
value = constant.value
if constant.type is None:
@@ -1004,12 +1002,12 @@
self._query_attrs[_id] = value
return '%%(%s)s' % _id
- def visit_variableref(self, variableref, contextrels=None):
+ def visit_variableref(self, variableref):
"""get the sql name for a variable reference"""
# use accept, .variable may be a variable or a columnalias
- return variableref.variable.accept(self, contextrels)
+ return variableref.variable.accept(self)
- def visit_columnalias(self, colalias, contextrels=None):
+ def visit_columnalias(self, colalias):
"""get the sql name for a subquery column alias"""
if colalias.name in self._varmap:
sql = self._varmap[colalias.name]
@@ -1020,20 +1018,21 @@
return sql
return colalias._q_sql
- def visit_variable(self, variable, contextrels=None):
+ def visit_variable(self, variable):
"""get the table name and sql string for a variable"""
- if contextrels is None and variable.name in self._state.done:
+ #if contextrels is None and variable.name in self._state.done:
+ if variable.name in self._state.done:
if self._in_wrapping_query:
return 'T1.%s' % self._state.aliases[variable.name]
return variable._q_sql
self._state.done.add(variable.name)
vtablename = None
- if contextrels is None and variable.name in self._varmap:
+ if not self._state.ignore_varmap and variable.name in self._varmap:
sql, vtablename = self._var_info(variable)
elif variable.stinfo['attrvar']:
# attribute variable (systematically used in rhs of final
# relation(s)), get table name and sql from any rhs relation
- sql = self._linked_var_sql(variable, contextrels)
+ sql = self._linked_var_sql(variable)
elif variable._q_invariant:
# since variable is invariant, we know we won't found final relation
principal = variable.stinfo['principal']
@@ -1056,7 +1055,7 @@
self.dbms_helper.fti_uid_attr)
elif principal in variable.stinfo['rhsrelations']:
if self.schema.rschema(principal.r_type).inlined:
- sql = self._linked_var_sql(variable, contextrels)
+ sql = self._linked_var_sql(variable)
else:
sql = '%s.eid_to' % self._relation_table(principal)
else:
@@ -1076,8 +1075,14 @@
# generate extra join
try:
if not var.stinfo['principal'] is relation:
- # need a predicable result for tests
- return '%s=%s' % tuple(sorted((sql, var.accept(self))))
+ op = relation.operator()
+ if op == '=':
+ # need a predicable result for tests
+ args = sorted( (sql, var.accept(self)) )
+ args.insert(1, op)
+ else:
+ args = (sql, op, var.accept(self))
+ return '%s%s%s' % tuple(args)
except KeyError:
# no principal defined, relation is necessarily the principal and
# so nothing to return here
@@ -1123,14 +1128,13 @@
#self._state.done.add(var.name)
return sql
- def _linked_var_sql(self, variable, contextrels=None):
- if contextrels is None:
+ def _linked_var_sql(self, variable):
+ if not self._state.ignore_varmap:
try:
return self._varmap[variable.name]
except KeyError:
pass
- rel = (contextrels and contextrels.get(variable.name) or
- variable.stinfo.get('principal') or
+ rel = (variable.stinfo.get('principal') or
iter(variable.stinfo['rhsrelations']).next())
linkedvar = rel.children[0].variable
if rel.r_type == 'eid':