test and fix http://www.logilab.org/ticket/499838 stable
authorSylvain Thénault <sylvain.thenault@logilab.fr>
Tue, 10 Nov 2009 18:06:47 +0100
branchstable
changeset 3815 50b87f759b5d
parent 3814 a4659adf4eee
child 3817 9fcf048e14b7
test and fix http://www.logilab.org/ticket/499838 refactor nicely on the way
server/msplanner.py
server/rqlannotation.py
server/sources/rql2sql.py
server/test/unittest_msplanner.py
server/test/unittest_multisources.py
server/test/unittest_rql2sql.py
--- a/server/msplanner.py	Tue Nov 10 15:46:34 2009 +0100
+++ b/server/msplanner.py	Tue Nov 10 18:06:47 2009 +0100
@@ -1470,11 +1470,19 @@
     def visit_constant(self, node, newroot, terms):
         return copy_node(newroot, node), node
 
+    def visit_comparison(self, node, newroot, terms):
+        subparts, node = self._visit_children(node, newroot, terms)
+        copy = copy_node(newroot, node, subparts)
+        # ignore comparison operator when fetching non final query
+        if not self.final and isinstance(node.children[0], VariableRef):
+            copy.operator = '='
+        return copy, node
+
     def visit_default(self, node, newroot, terms):
         subparts, node = self._visit_children(node, newroot, terms)
         return copy_node(newroot, node, subparts), node
 
-    visit_comparison = visit_mathexpression = visit_constant = visit_function = visit_default
+    visit_mathexpression = visit_constant = visit_function = visit_default
     visit_sort = visit_sortterm = visit_default
 
     def _visit_children(self, node, newroot, terms):
--- a/server/rqlannotation.py	Tue Nov 10 15:46:34 2009 +0100
+++ b/server/rqlannotation.py	Tue Nov 10 18:06:47 2009 +0100
@@ -10,6 +10,7 @@
 
 from logilab.common.compat import any
 
+from rql import BadRQLQuery
 from rql.nodes import Relation, VariableRef, Constant, Variable, Or
 from rql.utils import common_parent
 
@@ -177,10 +178,20 @@
     """given a list of rqlst relations, select one which will be used as main
     relation for the rhs variable
     """
-    for rel in relations:
+    principal = None
+    # sort for test predictability
+    for rel in sorted(relations, key=lambda x: (x.children[0].name, x.r_type)):
+        # only equality relation with a variable as rhs may be principal
+        if rel.operator() not in ('=', 'IS') \
+               or not isinstance(rel.children[1].children[0], VariableRef):
+            continue
         if rel.sqlscope is rel.stmt:
             return rel
         principal = rel
+    if principal is None:
+        print iter(relations).next().root
+        raise BadRQLQuery('unable to find principal in %s' % ', '.join(
+            r.as_string() for r in relations))
     return principal
 
 
--- a/server/sources/rql2sql.py	Tue Nov 10 15:46:34 2009 +0100
+++ b/server/sources/rql2sql.py	Tue Nov 10 18:06:47 2009 +0100
@@ -252,10 +252,10 @@
             self.actual_tables[-1].append(tsql)
         self.outer_tables = {}
         self.duplicate_switches = []
-        self.attr_vars = {}
         self.aliases = {}
         self.restrictions = []
         self._restr_stack = []
+        self.ignore_varmap = False
 
     def add_restriction(self, restr):
         if restr:
@@ -848,23 +848,23 @@
         nothing to do here.
         """
         contextrels = {}
-        attrvars = self._state.attr_vars
         for var in rhs_vars:
-            try:
-                contextrels[var.name] = attrvars[var.name]
-            except KeyError:
-                attrvars[var.name] = relation
             if var.name in self._varmap:
                 # ensure table is added
                 self._var_info(var.variable)
+            principal = var.variable.stinfo.get('principal')
+            if principal is not None and principal is not relation:
+                contextrels[var.name] = relation
         if not contextrels:
-            relation.children[1].accept(self, contextrels)
             return ''
-        # at least one variable is already in attr_vars, this means we have to
-        # generate unification expression
+        # we have to generate unification expression
         lhssql = self._inlined_var_sql(relation.children[0].variable,
                                        relation.r_type)
-        return '%s%s' % (lhssql, relation.children[1].accept(self, contextrels))
+        try:
+            self._state.ignore_varmap = True
+            return '%s%s' % (lhssql, relation.children[1].accept(self))
+        finally:
+            self._state.ignore_varmap = False
 
     def _visit_attribute_relation(self, rel):
         """generate SQL for an attribute relation"""
@@ -932,7 +932,7 @@
         return self.dbms_helper.fti_restriction_sql(alias, const.eval(self._args),
                                                     jointo, not_) + restriction
 
-    def visit_comparison(self, cmp, contextrels=None):
+    def visit_comparison(self, cmp):
         """generate SQL for a comparison"""
         if len(cmp.children) == 2:
             # XXX occurs ?
@@ -950,16 +950,15 @@
               and rhs.eval(self._args) is None):
             if lhs is None:
                 return ' IS NULL'
-            return '%s IS NULL' % lhs.accept(self, contextrels)
+            return '%s IS NULL' % lhs.accept(self)
         elif isinstance(rhs, Function) and rhs.name == 'IN':
             assert operator == '='
             operator = ' '
         if lhs is None:
-            return '%s%s'% (operator, rhs.accept(self, contextrels))
-        return '%s%s%s'% (lhs.accept(self, contextrels), operator,
-                          rhs.accept(self, contextrels))
+            return '%s%s'% (operator, rhs.accept(self))
+        return '%s%s%s'% (lhs.accept(self), operator, rhs.accept(self))
 
-    def visit_mathexpression(self, mexpr, contextrels=None):
+    def visit_mathexpression(self, mexpr):
         """generate SQL for a mathematic expression"""
         lhs, rhs = mexpr.get_parts()
         # check for string concatenation
@@ -969,17 +968,16 @@
                 operator = '||'
         except CoercionError:
             pass
-        return '(%s %s %s)'% (lhs.accept(self, contextrels), operator,
-                              rhs.accept(self, contextrels))
+        return '(%s %s %s)'% (lhs.accept(self), operator, rhs.accept(self))
 
-    def visit_function(self, func, contextrels=None):
+    def visit_function(self, func):
         """generate SQL name for a function"""
         # function_description will check function is supported by the backend
         sqlname = self.dbms_helper.func_sqlname(func.name)
-        return '%s(%s)' % (sqlname, ', '.join(c.accept(self, contextrels)
+        return '%s(%s)' % (sqlname, ', '.join(c.accept(self)
                                               for c in func.children))
 
-    def visit_constant(self, constant, contextrels=None):
+    def visit_constant(self, constant):
         """generate SQL name for a constant"""
         value = constant.value
         if constant.type is None:
@@ -1004,12 +1002,12 @@
             self._query_attrs[_id] = value
         return '%%(%s)s' % _id
 
-    def visit_variableref(self, variableref, contextrels=None):
+    def visit_variableref(self, variableref):
         """get the sql name for a variable reference"""
         # use accept, .variable may be a variable or a columnalias
-        return variableref.variable.accept(self, contextrels)
+        return variableref.variable.accept(self)
 
-    def visit_columnalias(self, colalias, contextrels=None):
+    def visit_columnalias(self, colalias):
         """get the sql name for a subquery column alias"""
         if colalias.name in self._varmap:
             sql = self._varmap[colalias.name]
@@ -1020,20 +1018,21 @@
             return sql
         return colalias._q_sql
 
-    def visit_variable(self, variable, contextrels=None):
+    def visit_variable(self, variable):
         """get the table name and sql string for a variable"""
-        if contextrels is None and variable.name in self._state.done:
+        #if contextrels is None and variable.name in self._state.done:
+        if variable.name in self._state.done:
             if self._in_wrapping_query:
                 return 'T1.%s' % self._state.aliases[variable.name]
             return variable._q_sql
         self._state.done.add(variable.name)
         vtablename = None
-        if contextrels is None and variable.name in self._varmap:
+        if not self._state.ignore_varmap and variable.name in self._varmap:
             sql, vtablename = self._var_info(variable)
         elif variable.stinfo['attrvar']:
             # attribute variable (systematically used in rhs of final
             # relation(s)), get table name and sql from any rhs relation
-            sql = self._linked_var_sql(variable, contextrels)
+            sql = self._linked_var_sql(variable)
         elif variable._q_invariant:
             # since variable is invariant, we know we won't found final relation
             principal = variable.stinfo['principal']
@@ -1056,7 +1055,7 @@
                                  self.dbms_helper.fti_uid_attr)
             elif principal in variable.stinfo['rhsrelations']:
                 if self.schema.rschema(principal.r_type).inlined:
-                    sql = self._linked_var_sql(variable, contextrels)
+                    sql = self._linked_var_sql(variable)
                 else:
                     sql = '%s.eid_to' % self._relation_table(principal)
             else:
@@ -1076,8 +1075,14 @@
         # generate extra join
         try:
             if not var.stinfo['principal'] is relation:
-                # need a predicable result for tests
-                return '%s=%s' % tuple(sorted((sql, var.accept(self))))
+                op = relation.operator()
+                if op == '=':
+                    # need a predicable result for tests
+                    args = sorted( (sql, var.accept(self)) )
+                    args.insert(1, op)
+                else:
+                    args = (sql, op, var.accept(self))
+                return '%s%s%s' % tuple(args)
         except KeyError:
             # no principal defined, relation is necessarily the principal and
             # so nothing to return here
@@ -1123,14 +1128,13 @@
             #self._state.done.add(var.name)
         return sql
 
-    def _linked_var_sql(self, variable, contextrels=None):
-        if contextrels is None:
+    def _linked_var_sql(self, variable):
+        if not self._state.ignore_varmap:
             try:
                 return self._varmap[variable.name]
             except KeyError:
                 pass
-        rel = (contextrels and contextrels.get(variable.name) or
-               variable.stinfo.get('principal') or
+        rel = (variable.stinfo.get('principal') or
                iter(variable.stinfo['rhsrelations']).next())
         linkedvar = rel.children[0].variable
         if rel.r_type == 'eid':
--- a/server/test/unittest_msplanner.py	Tue Nov 10 15:46:34 2009 +0100
+++ b/server/test/unittest_msplanner.py	Tue Nov 10 18:06:47 2009 +0100
@@ -706,7 +706,7 @@
         self._test('Any V, MAX(VR) WHERE V is Card, V creation_date VR, '
                    '(V creation_date TODAY OR (V creation_date < TODAY AND NOT EXISTS('
                    'X is Card, X creation_date < TODAY, X creation_date >= VR)))',
-                   [('FetchStep', [('Any VR WHERE X creation_date < TODAY, X creation_date >= VR, X is Card',
+                   [('FetchStep', [('Any VR WHERE X creation_date < TODAY, X creation_date VR, X is Card',
                                     [{'X': 'Card', 'VR': 'Datetime'}])],
                      [self.cards, self.system], None,
                      {'VR': 'table0.C0', 'X.creation_date': 'table0.C0'}, []),
@@ -1349,7 +1349,7 @@
     def test_attr_unification_neq_1(self):
         self._test('Any X,Y WHERE X is Bookmark, Y is Card, X creation_date D, Y creation_date > D',
                    [('FetchStep',
-                     [('Any Y,D WHERE Y creation_date > D, Y is Card',
+                     [('Any Y,D WHERE Y creation_date D, Y is Card',
                        [{'D': 'Datetime', 'Y': 'Card'}])],
                      [self.cards,self.system], None,
                      {'D': 'table0.C1', 'Y': 'table0.C0', 'Y.creation_date': 'table0.C1'}, []),
--- a/server/test/unittest_multisources.py	Tue Nov 10 15:46:34 2009 +0100
+++ b/server/test/unittest_multisources.py	Tue Nov 10 18:06:47 2009 +0100
@@ -184,7 +184,9 @@
     def test_attr_unification_1(self):
         n1 = self.execute('INSERT Note X: X type "AFFREF"')[0][0]
         n2 = self.execute('INSERT Note X: X type "AFFREU"')[0][0]
+        self.set_debug('DBG_SQL|DBG_MS')
         rset = self.execute('Any X,Y WHERE X is Note, Y is Affaire, X type T, Y ref T')
+        self.set_debug(None)
         self.assertEquals(len(rset), 1, rset.rows)
 
     def test_attr_unification_2(self):
--- a/server/test/unittest_rql2sql.py	Tue Nov 10 15:46:34 2009 +0100
+++ b/server/test/unittest_rql2sql.py	Tue Nov 10 18:06:47 2009 +0100
@@ -164,6 +164,9 @@
      '''SELECT _X.cw_eid
 FROM cw_Personne AS _X
 WHERE _X.cw_prenom=lulu AND NOT EXISTS(SELECT 1 FROM owned_by_relation AS rel_owned_by0, in_group_relation AS rel_in_group1, cw_CWGroup AS _G WHERE rel_owned_by0.eid_from=_X.cw_eid AND rel_in_group1.eid_from=rel_owned_by0.eid_to AND rel_in_group1.eid_to=_G.cw_eid AND ((_G.cw_name=lulufanclub) OR (_G.cw_name=managers)))'''),
+
+
+
 ]
 
 ADVANCED= [
@@ -867,10 +870,21 @@
     ]
 
 VIRTUAL_VARS = [
-    ("Personne P WHERE P travaille S, S tel T, S fax T, S is Societe;",
+
+    ('Any X WHERE X is CWUser, X creation_date > D1, Y creation_date D1, Y login "SWEB09"',
+     '''SELECT _X.cw_eid
+FROM cw_CWUser AS _X, cw_CWUser AS _Y
+WHERE _X.cw_creation_date>_Y.cw_creation_date AND _Y.cw_login=SWEB09'''),
+
+    ('Any X WHERE X is CWUser, Y creation_date D1, Y login "SWEB09", X creation_date > D1',
+     '''SELECT _X.cw_eid
+FROM cw_CWUser AS _X, cw_CWUser AS _Y
+WHERE _Y.cw_login=SWEB09 AND _X.cw_creation_date>_Y.cw_creation_date'''),
+
+    ('Personne P WHERE P travaille S, S tel T, S fax T, S is Societe',
      '''SELECT rel_travaille0.eid_from
 FROM cw_Societe AS _S, travaille_relation AS rel_travaille0
-WHERE rel_travaille0.eid_to=_S.cw_eid AND _S.cw_fax=_S.cw_tel'''),
+WHERE rel_travaille0.eid_to=_S.cw_eid AND _S.cw_tel=_S.cw_fax'''),
 
     ("Personne P where X eid 0, X creation_date D, P datenaiss < D, X is Affaire",
      '''SELECT _P.cw_eid