[server] some pep8 in rql2sql
authorSylvain Thénault <sylvain.thenault@logilab.fr>
Mon, 06 Jun 2016 21:17:33 +0200
changeset 11347 b4dcfd734686
parent 11346 69c17d011f74
child 11348 70337ad23145
[server] some pep8 in rql2sql
cubicweb/server/sources/rql2sql.py
--- a/cubicweb/server/sources/rql2sql.py	Tue Jun 21 17:51:11 2016 +0200
+++ b/cubicweb/server/sources/rql2sql.py	Mon Jun 06 21:17:33 2016 +0200
@@ -1,4 +1,4 @@
-# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# copyright 2003-2016 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
 #
 # This file is part of CubicWeb.
@@ -66,24 +66,27 @@
 from cubicweb.rqlrewrite import cleanup_solutions
 from cubicweb.server.sqlutils import SQL_PREFIX
 
-ColumnAlias._q_invariant = False # avoid to check for ColumnAlias / Variable
+get_func_descr = SQL_FUNCTIONS_REGISTRY.get_function
 
-FunctionDescr.source_execute = None
+ColumnAlias._q_invariant = False  # avoid to check for ColumnAlias / Variable
+
 
 def default_update_cb_stack(self, stack):
     stack.append(self.source_execute)
 FunctionDescr.update_cb_stack = default_update_cb_stack
+FunctionDescr.source_execute = None
 
-get_func_descr = SQL_FUNCTIONS_REGISTRY.get_function
+
+def length_source_execute(source, session, value):
+    return len(value.getvalue())
 
 LENGTH = get_func_descr('LENGTH')
-def length_source_execute(source, session, value):
-    return len(value.getvalue())
 LENGTH.source_execute = length_source_execute
 
+
 def _new_var(select, varname):
     newvar = select.get_variable(varname)
-    if not 'relations' in newvar.stinfo:
+    if 'relations' not in newvar.stinfo:
         # not yet initialized
         newvar.prepare_annotation()
         newvar.stinfo['scope'] = select
@@ -91,20 +94,22 @@
         select.selection.append(VariableRef(newvar))
     return newvar
 
+
 def _fill_to_wrap_rel(var, newselect, towrap, schema):
     for rel in var.stinfo['relations'] - var.stinfo['rhsrelations']:
         rschema = schema.rschema(rel.r_type)
         if rschema.inlined:
-            towrap.add( (var, rel) )
+            towrap.add((var, rel))
             for vref in rel.children[1].iget_nodes(VariableRef):
                 newivar = _new_var(newselect, vref.name)
                 _fill_to_wrap_rel(vref.variable, newselect, towrap, schema)
         elif rschema.final:
-            towrap.add( (var, rel) )
+            towrap.add((var, rel))
             for vref in rel.children[1].iget_nodes(VariableRef):
                 newivar = _new_var(newselect, vref.name)
                 newivar.stinfo['attrvar'] = (var, rel.r_type)
 
+
 def rewrite_unstable_outer_join(select, solutions, unstable, schema):
     """if some optional variables are unstable, they should be selected in a
     subquery. This function check this and rewrite the rql syntax tree if
@@ -156,6 +161,7 @@
         select.add_subquery(SubQuery(aliases, myunion), check=False)
     return modified
 
+
 def _new_solutions(rqlst, solutions):
     """first filter out subqueries variables from solutions"""
     newsolutions = []
@@ -163,10 +169,11 @@
         asol = {}
         for vname in rqlst.defined_vars:
             asol[vname] = origsol[vname]
-        if not asol in newsolutions:
+        if asol not in newsolutions:
             newsolutions.append(asol)
     return newsolutions
 
+
 def remove_unused_solutions(rqlst, solutions, schema):
     """cleanup solutions: remove solutions where invariant variables are taking
     different types
@@ -191,7 +198,7 @@
                 thisexistssols = [newsols[0].copy()]
                 thisexistsvars = set()
                 existssols[var.scope] = thisexistssols, thisexistsvars
-            for i in range(len(newsols)-1, 0, -1):
+            for i in range(len(newsols) - 1, 0, -1):
                 if vtype != newsols[i][vname]:
                     thisexistssols.append(newsols.pop(i))
                     thisexistsvars.add(vname)
@@ -200,8 +207,8 @@
             for i in range(1, len(newsols)):
                 if vtype != newsols[i][vname]:
                     unstable.add(vname)
-    # remove unstable variables from exists solutions: the possible types of these variables are not
-    # properly represented in exists solutions, so we have to remove and reinject them later
+    # remove unstable variables from exists solutions: the possible types of these variables are
+    # not properly represented in exists solutions, so we have to remove and reinject them later
     # according to the outer solution (see `iter_exists_sols`)
     for sols, _ in existssols.values():
         for vname in unstable:
@@ -211,7 +218,7 @@
         # filter out duplicates
         newsols_ = []
         for sol in newsols:
-            if not sol in newsols_:
+            if sol not in newsols_:
                 newsols_.append(sol)
         newsols = newsols_
         # reinsert solutions for invariants
@@ -230,6 +237,7 @@
             newsols = _new_solutions(rqlst, newsols)
     return newsols, existssols, unstable
 
+
 def relation_info(relation):
     lhs, rhs = relation.get_variable_parts()
     try:
@@ -239,7 +247,7 @@
         lhsconst = lhs
         lhs = None
     except KeyError:
-        lhsconst = None # ColumnAlias
+        lhsconst = None  # ColumnAlias
     try:
         rhs = rhs.variable
         rhsconst = rhs.stinfo['constnode']
@@ -247,9 +255,10 @@
         rhsconst = rhs
         rhs = None
     except KeyError:
-        rhsconst = None # ColumnAlias
+        rhsconst = None  # ColumnAlias
     return lhs, lhsconst, rhs, rhsconst
 
+
 def sort_term_selection(sorts, rqlst, groups):
     # XXX beurk
     if isinstance(rqlst, list):
@@ -271,6 +280,7 @@
                     if not any(vref.is_equivalent(g) for g in groups):
                         groups.append(vref)
 
+
 def fix_selection_and_group(rqlst, needwrap, selectsortterms,
                             sorts, groups, having):
     if selectsortterms and sorts:
@@ -280,7 +290,7 @@
         # when a query is grouped, ensure sort terms are grouped as well
         for sortterm in sorts:
             term = sortterm.term
-            if not (isinstance(term, Constant) or \
+            if not (isinstance(term, Constant) or
                     (isinstance(term, Function) and
                      get_func_descr(term.name).aggregat)):
                 for vref in term.iget_nodes(VariableRef):
@@ -302,10 +312,11 @@
                         selectedidx.add(vref.name)
                         rqlst.selection.append(vref)
 
+
 def iter_mapped_var_sels(stmt, variable):
     # variable is a Variable or ColumnAlias node mapped to a source side
     # callback
-    if not (len(variable.stinfo['rhsrelations']) <= 1 and # < 1 on column alias
+    if not (len(variable.stinfo['rhsrelations']) <= 1 and  # < 1 on column alias
             variable.stinfo['selected']):
         raise QueryError("can't use %s as a restriction variable"
                          % variable.name)
@@ -315,6 +326,7 @@
             raise QueryError()
         yield selectidx, vrefs[0]
 
+
 def update_source_cb_stack(state, stmt, node, stack):
     while True:
         node = node.parent
@@ -401,7 +413,7 @@
             self.restrictions.append(restr)
 
     def iter_exists_sols(self, exists):
-        if not exists in self.existssols:
+        if exists not in self.existssols:
             yield 1
             return
         thisexistssols, thisexistsvars = self.existssols[exists]
@@ -552,7 +564,7 @@
             else:
                 self.outer_tables[tablealias][1].extend(pending_conditions)
         else:
-            assert not tablealias in self.outer_pending
+            assert tablealias not in self.outer_pending
 
     def add_outer_join_condition(self, tablealias, condition):
         try:
@@ -568,9 +580,9 @@
         assert leftalias != rightalias, leftalias
         outer_tables = self.outer_tables
         louter, lconditions, lchain = outer_tables.get(leftalias,
-                                                      (None, None, None))
+                                                       (None, None, None))
         router, rconditions, rchain = outer_tables.get(rightalias,
-                                                      (None, None, None))
+                                                       (None, None, None))
         if lchain is None and rchain is None:
             # create a new outer chaine
             chain = [leftalias, rightalias]
@@ -603,9 +615,9 @@
             # the condition if it's ok
             lidx = lchain.index(leftalias)
             ridx = lchain.index(rightalias)
-            if (outertype == 'FULL' and router != 'FULL') \
-                   or (lidx < ridx and router != 'LEFT') \
-                   or (ridx < lidx and louter != 'RIGHT'):
+            if ((outertype == 'FULL' and router != 'FULL')
+                    or (lidx < ridx and router != 'LEFT')
+                    or (ridx < lidx and louter != 'RIGHT')):
                 raise BadRQLQuery()
             # merge conditions
             if lidx < ridx:
@@ -685,11 +697,10 @@
         # put them in fakehaving if they don't share an Or node as ancestor
         # with another comparison containing an aggregat function
         for compnode in tocheck:
-            parents = set()
             p = compnode.parent
             oor = None
             while not isinstance(p, Select):
-                if p in ors or p is None: # p is None for nodes already in fakehaving
+                if p in ors or p is None:  # p is None for nodes already in fakehaving
                     break
                 if isinstance(p, (Or, Not)):
                     oor = p
@@ -719,9 +730,10 @@
         self.schema = schema
         self.dbhelper = dbhelper
         self.dbencoding = dbhelper.dbencoding
-        self.keyword_map = {'NOW' : self.dbhelper.sql_current_timestamp,
-                            'TODAY': self.dbhelper.sql_current_date,
-                            }
+        self.keyword_map = {
+            'NOW': self.dbhelper.sql_current_timestamp,
+            'TODAY': self.dbhelper.sql_current_date,
+        }
         if not self.dbhelper.union_parentheses_support:
             self.union_sql = self.noparen_union_sql
         self._lock = threading.Lock()
@@ -752,7 +764,7 @@
         finally:
             self._lock.release()
 
-    def union_sql(self, union, needalias=False): # pylint: disable=E0202
+    def union_sql(self, union, needalias=False):  # pylint: disable=E0202
         if len(union.children) == 1:
             return self.select_sql(union.children[0], needalias)
         sqls = ('(%s)' % self.select_sql(select, needalias)
@@ -793,8 +805,8 @@
             for vref in restr.get_nodes(VariableRef):
                 vscope = vref.variable.scope
                 if vscope is select:
-                    continue # ignore select scope, so restriction is added to
-                             # the inner most scope possible
+                    # ignore select scope, so restriction is added to the innermost possible scope
+                    continue
                 if scope is None:
                     scope = vscope
                 elif vscope is not scope:
@@ -863,7 +875,7 @@
                                       if not isinstance(term, Constant))
             if needwrap:
                 sql = '%s FROM (%s) AS T1' % (
-                    self._selection_sql(outerselection, distinct,needalias),
+                    self._selection_sql(outerselection, distinct, needalias),
                     sql)
             if groups:
                 sql += '\nGROUP BY %s' % groups
@@ -899,7 +911,7 @@
     def _subqueries_sql(self, select, state):
         for i, subquery in enumerate(select.with_):
             sql = self.union_sql(subquery.query, needalias=True)
-            tablealias = '_T%s' % i # XXX nested subqueries
+            tablealias = '_T%s' % i  # XXX nested subqueries
             sql = '(%s) AS %s' % (sql, tablealias)
             state.subtables[tablealias] = (0, sql)
             latest_state = self._state
@@ -1029,7 +1041,6 @@
             sql = 'SELECT 1 FROM %s WHERE %s' % (tables, restriction)
         return sql
 
-
     def visit_relation(self, relation):
         """generate SQL for a relation"""
         rtype = relation.r_type
@@ -1039,8 +1050,8 @@
         lhs, rhs = relation.get_parts()
         rschema = self.schema.rschema(rtype)
         if rschema.final:
-            if rtype == 'eid' and lhs.variable._q_invariant and \
-                   lhs.variable.stinfo['constnode']:
+            if (rtype == 'eid' and lhs.variable._q_invariant
+                    and lhs.variable.stinfo['constnode']):
                 # special case where this restriction is already generated by
                 # some other relation
                 return ''
@@ -1176,15 +1187,15 @@
                 leftalias = leftvar.stinfo['principal']._q_sqltable
             else:
                 # search for relation on which we should join
+                rschema = self.schema.rschema
                 for orelation in leftvar.stinfo['relations']:
-                    if (orelation is not relation and
-                        not self.schema.rschema(orelation.r_type).final):
+                    if orelation is not relation and not rschema(orelation.r_type).final:
                         break
                 else:
                     for orelation in rightvar.stinfo['relations']:
-                        if (orelation is not relation and
-                            not self.schema.rschema(orelation.r_type).final
-                            and orelation.optional):
+                        if (orelation is not relation
+                                and not rschema(orelation.r_type).final
+                                and orelation.optional):
                             break
                     else:
                         # unexpected
@@ -1211,7 +1222,7 @@
             # relation.eid_from) join, now we've to do (relation.eid_to /
             # cw_Y.eid)
             leftalias = rightalias
-            rightsql = rightvar.accept(self) # accept before using var_table
+            rightvar.accept(self)  # accept before using var_table
             rightalias = self._var_table(rightvar)
             if rightalias is None:
                 if rightvar.stinfo['principal'] is not relation:
@@ -1226,7 +1237,6 @@
         # here
         return ''
 
-
     def _visit_outer_join_inlined_relation(self, relation, rschema):
         lhsvar, lhsconst, rhsvar, rhsconst = relation_info(relation)
         assert not (lhsconst and rhsconst), "doesn't make sense"
@@ -1288,13 +1298,13 @@
             # 1. visited relation is ored
             # 2. variable's principal is not this relation and not 1.
             if ored or (principal is not None and principal is not relation
-                        and not getattr(principal, 'ored', lambda : 0)()):
+                        and not getattr(principal, 'ored', lambda: 0)()):
                 # we have to generate unification expression
                 if principal is relation:
                     # take care if ored case and principal is the relation to
                     # use the right relation in the unification term
                     _rel = [rel for rel in var.stinfo['rhsrelations']
-                            if not rel is principal][0]
+                            if rel is not principal][0]
                 else:
                     _rel = relation
                 lhssql = self._inlined_var_sql(_rel.children[0].variable,
@@ -1335,7 +1345,7 @@
             if rel._q_needcast == 'TODAY':
                 sql = 'DATE(%s)%s' % (lhssql, rhssql)
             # XXX which cast function should be used
-            #elif rel._q_needcast == 'NOW':
+            # elif rel._q_needcast == 'NOW':
             #    sql = 'TIMESTAMP(%s)%s' % (lhssql, rhssql)
             else:
                 sql = '%s%s' % (lhssql, rhssql)
@@ -1381,7 +1391,7 @@
 
     def visit_comparison(self, cmp):
         """generate SQL for a comparison"""
-        optional = getattr(cmp, 'optional', None) # rql < 0.30
+        optional = cmp.optional
         if len(cmp.children) == 2:
             # simplified expression from HAVING clause
             lhs, rhs = cmp.children
@@ -1409,9 +1419,9 @@
             operator = ' '
         if sql is None:
             if lhs is None:
-                sql = '%s%s'% (operator, rhs.accept(self))
+                sql = '%s%s' % (operator, rhs.accept(self))
             else:
-                sql = '%s%s%s'% (lhs.accept(self), operator, rhs.accept(self))
+                sql = '%s%s%s' % (lhs.accept(self), operator, rhs.accept(self))
         if optional is None:
             return sql
         leftvars = cmp.children[0].get_nodes(VariableRef)
@@ -1447,16 +1457,17 @@
         if operator == '%':
             operator = '%%'
         try:
-            if mexpr.operator == '+' and mexpr.get_type(self._state.solution, self._args) == 'String':
+            if (mexpr.operator == '+'
+                    and mexpr.get_type(self._state.solution, self._args) == 'String'):
                 return '(%s)' % self.dbhelper.sql_concat_string(lhs.accept(self),
                                                                 rhs.accept(self))
         except CoercionError:
             pass
-        return '(%s %s %s)'% (lhs.accept(self), operator, rhs.accept(self))
+        return '(%s %s %s)' % (lhs.accept(self), operator, rhs.accept(self))
 
     def visit_unaryexpression(self, uexpr):
         """generate SQL for a unary expression"""
-        return '%s%s'% (uexpr.operator, uexpr.children[0].accept(self))
+        return '%s%s' % (uexpr.operator, uexpr.children[0].accept(self))
 
     def visit_function(self, func):
         """generate SQL name for a function"""
@@ -1513,7 +1524,6 @@
 
     def visit_variable(self, variable):
         """get the table name and sql string for a variable"""
-        #if contextrels is None and variable.name in self._state.done:
         if variable.name in self._state.done:
             if self._in_wrapping_query:
                 return 'T1.%s' % self._state.aliases[variable.name]
@@ -1570,7 +1580,7 @@
                 op = relation.operator()
                 if op == '=':
                     # need a predicable result for tests
-                    args = sorted( (sql, var.accept(self)) )
+                    args = sorted((sql, var.accept(self)))
                     args.insert(1, op)
                 else:
                     args = (sql, op, var.accept(self))
@@ -1630,5 +1640,5 @@
     # tables handling #########################################################
 
     def _var_table(self, var):
-        var.accept(self)#.visit_variable(var)
+        var.accept(self)
         return var._q_sqltable