server/mssteps.py
changeset 9448 3e7cad3967c5
parent 9447 0636c4960259
child 9449 287a05ec7ab1
equal deleted inserted replaced
9447:0636c4960259 9448:3e7cad3967c5
     1 # copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
       
     2 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
       
     3 #
       
     4 # This file is part of CubicWeb.
       
     5 #
       
     6 # CubicWeb is free software: you can redistribute it and/or modify it under the
       
     7 # terms of the GNU Lesser General Public License as published by the Free
       
     8 # Software Foundation, either version 2.1 of the License, or (at your option)
       
     9 # any later version.
       
    10 #
       
    11 # CubicWeb is distributed in the hope that it will be useful, but WITHOUT
       
    12 # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
       
    13 # FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
       
    14 # details.
       
    15 #
       
    16 # You should have received a copy of the GNU Lesser General Public License along
       
    17 # with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
       
    18 """Defines the diferent querier steps usable in plans.
       
    19 
       
    20 FIXME : this code needs refactoring. Some problems :
       
    21 * get data from the parent plan, the latest step, temporary table...
       
    22 * each step has is own members (this is not necessarily bad, but a bit messy
       
    23   for now)
       
    24 """
       
    25 __docformat__ = "restructuredtext en"
       
    26 
       
    27 from rql.nodes import VariableRef, Variable, Function
       
    28 
       
    29 from cubicweb.server.ssplanner import (LimitOffsetMixIn, Step, OneFetchStep,
       
    30                                     varmap_test_repr, offset_result)
       
    31 
       
    32 AGGR_TRANSFORMS = {'COUNT':'SUM', 'MIN':'MIN', 'MAX':'MAX', 'SUM': 'SUM'}
       
    33 
       
    34 class remove_and_restore_clauses(object):
       
    35     def __init__(self, union, keepgroup):
       
    36         self.union = union
       
    37         self.keepgroup = keepgroup
       
    38         self.clauses = None
       
    39 
       
    40     def __enter__(self):
       
    41         self.clauses = clauses = []
       
    42         for select in self.union.children:
       
    43             if self.keepgroup:
       
    44                 having, orderby = select.having, select.orderby
       
    45                 select.having, select.orderby = (), ()
       
    46                 clauses.append( (having, orderby) )
       
    47             else:
       
    48                 groupby, having, orderby = select.groupby, select.having, select.orderby
       
    49                 select.groupby, select.having, select.orderby = (), (), ()
       
    50                 clauses.append( (groupby, having, orderby) )
       
    51 
       
    52     def __exit__(self, exctype, exc, traceback):
       
    53         for i, select in enumerate(self.union.children):
       
    54             if self.keepgroup:
       
    55                 select.having, select.orderby = self.clauses[i]
       
    56             else:
       
    57                 select.groupby, select.having, select.orderby = self.clauses[i]
       
    58 
       
    59 
       
    60 class FetchStep(OneFetchStep):
       
    61     """step consisting in fetching data from sources, and storing result in
       
    62     a temporary table
       
    63     """
       
    64     def __init__(self, plan, union, sources, table, keepgroup, inputmap=None):
       
    65         OneFetchStep.__init__(self, plan, union, sources)
       
    66         # temporary table to store step result
       
    67         self.table = table
       
    68         # should groupby clause be kept or not
       
    69         self.keepgroup = keepgroup
       
    70         # variables mapping to use as input
       
    71         self.inputmap = inputmap
       
    72         # output variable mapping
       
    73         srqlst = union.children[0] # sample select node
       
    74         # add additional information to the output mapping
       
    75         self.outputmap = plan.init_temp_table(table, srqlst.selection,
       
    76                                               srqlst.solutions[0])
       
    77         for vref in srqlst.selection:
       
    78             if not isinstance(vref, VariableRef):
       
    79                 continue
       
    80             var = vref.variable
       
    81             if var.stinfo.get('attrvars'):
       
    82                 for lhsvar, rtype in var.stinfo['attrvars']:
       
    83                     if lhsvar.name in srqlst.defined_vars:
       
    84                         key = '%s.%s' % (lhsvar.name, rtype)
       
    85                         self.outputmap[key] = self.outputmap[var.name]
       
    86             else:
       
    87                 rschema = self.plan.schema.rschema
       
    88                 for rel in var.stinfo['rhsrelations']:
       
    89                     if rschema(rel.r_type).inlined:
       
    90                         lhsvar = rel.children[0]
       
    91                         if lhsvar.name in srqlst.defined_vars:
       
    92                             key = '%s.%s' % (lhsvar.name, rel.r_type)
       
    93                             self.outputmap[key] = self.outputmap[var.name]
       
    94 
       
    95     def execute(self):
       
    96         """execute this step"""
       
    97         self.execute_children()
       
    98         plan = self.plan
       
    99         plan.create_temp_table(self.table)
       
   100         union = self.union
       
   101         with remove_and_restore_clauses(union, self.keepgroup):
       
   102             for source in self.sources:
       
   103                 source.flying_insert(self.table, plan.session, union, plan.args,
       
   104                                      self.inputmap)
       
   105 
       
   106     def mytest_repr(self):
       
   107         """return a representation of this step suitable for test"""
       
   108         with remove_and_restore_clauses(self.union, self.keepgroup):
       
   109             try:
       
   110                 inputmap = varmap_test_repr(self.inputmap, self.plan.tablesinorder)
       
   111                 outputmap = varmap_test_repr(self.outputmap, self.plan.tablesinorder)
       
   112             except AttributeError:
       
   113                 inputmap = self.inputmap
       
   114                 outputmap = self.outputmap
       
   115             return (self.__class__.__name__,
       
   116                     sorted((r.as_string(kwargs=self.plan.args), r.solutions)
       
   117                            for r in self.union.children),
       
   118                     sorted(self.sources), inputmap, outputmap)
       
   119 
       
   120 
       
   121 class AggrStep(LimitOffsetMixIn, Step):
       
   122     """step consisting in making aggregat from temporary data in the system
       
   123     source
       
   124     """
       
   125     def __init__(self, plan, selection, select, table, outputtable=None):
       
   126         Step.__init__(self, plan)
       
   127         # original selection
       
   128         self.selection = selection
       
   129         # original Select RQL tree
       
   130         self.select = select
       
   131         # table where are located temporary results
       
   132         self.table = table
       
   133         # optional table where to write results
       
   134         self.outputtable = outputtable
       
   135         if outputtable is not None:
       
   136             plan.init_temp_table(outputtable, selection, select.solutions[0])
       
   137 
       
   138         #self.inputmap = inputmap
       
   139 
       
   140     def mytest_repr(self):
       
   141         """return a representation of this step suitable for test"""
       
   142         try:
       
   143             # rely on a monkey patch (cf unittest_querier)
       
   144             table = self.plan.tablesinorder[self.table]
       
   145             outputtable = self.outputtable and self.plan.tablesinorder[self.outputtable]
       
   146         except AttributeError:
       
   147             # not monkey patched
       
   148             table = self.table
       
   149             outputtable = self.outputtable
       
   150         sql = self.get_sql().replace(self.table, table)
       
   151         return (self.__class__.__name__, sql, outputtable)
       
   152 
       
   153     def execute(self):
       
   154         """execute this step"""
       
   155         self.execute_children()
       
   156         sql = self.get_sql()
       
   157         if self.outputtable:
       
   158             self.plan.create_temp_table(self.outputtable)
       
   159             sql = 'INSERT INTO %s %s' % (self.outputtable, sql)
       
   160             self.plan.syssource.doexec(self.plan.session, sql, self.plan.args)
       
   161         else:
       
   162             return self.plan.sqlexec(sql, self.plan.args)
       
   163 
       
   164     def get_sql(self):
       
   165         self.inputmap = inputmap = self.children[-1].outputmap
       
   166         dbhelper=self.plan.syssource.dbhelper
       
   167         # get the select clause
       
   168         clause = []
       
   169         for i, term in enumerate(self.selection):
       
   170             try:
       
   171                 var_name = inputmap[term.as_string()]
       
   172             except KeyError:
       
   173                 var_name = 'C%s' % i
       
   174             if isinstance(term, Function):
       
   175                 # we have to translate some aggregat function
       
   176                 # (for instance COUNT -> SUM)
       
   177                 orig_name = term.name
       
   178                 try:
       
   179                     term.name = AGGR_TRANSFORMS[term.name]
       
   180                     # backup and reduce children
       
   181                     orig_children = term.children
       
   182                     term.children = [VariableRef(Variable(var_name))]
       
   183                     clause.append(term.accept(self))
       
   184                     # restaure the tree XXX necessary?
       
   185                     term.name = orig_name
       
   186                     term.children = orig_children
       
   187                 except KeyError:
       
   188                     clause.append(var_name)
       
   189             else:
       
   190                 clause.append(var_name)
       
   191                 for vref in term.iget_nodes(VariableRef):
       
   192                     inputmap[vref.name] = var_name
       
   193         # XXX handle distinct with non selected sort term
       
   194         if self.select.distinct:
       
   195             sql = ['SELECT DISTINCT %s' % ', '.join(clause)]
       
   196         else:
       
   197             sql = ['SELECT %s' % ', '.join(clause)]
       
   198         sql.append("FROM %s" % self.table)
       
   199         # get the group/having clauses
       
   200         if self.select.groupby:
       
   201             clause = [inputmap[var.name] for var in self.select.groupby]
       
   202             grouped = set(var.name for var in self.select.groupby)
       
   203             sql.append('GROUP BY %s' % ', '.join(clause))
       
   204         else:
       
   205             grouped = None
       
   206         if self.select.having:
       
   207             clause = [term.accept(self) for term in self.select.having]
       
   208             sql.append('HAVING %s' % ', '.join(clause))
       
   209         # get the orderby clause
       
   210         if self.select.orderby:
       
   211             clause = []
       
   212             for sortterm in self.select.orderby:
       
   213                 sqlterm = sortterm.term.accept(self)
       
   214                 if sortterm.asc:
       
   215                     clause.append(sqlterm)
       
   216                 else:
       
   217                     clause.append('%s DESC' % sqlterm)
       
   218                 if grouped is not None:
       
   219                     for vref in sortterm.iget_nodes(VariableRef):
       
   220                         if not vref.name in grouped:
       
   221                             sql[-1] += ', ' + self.inputmap[vref.name]
       
   222                             grouped.add(vref.name)
       
   223             sql = dbhelper.sql_add_order_by(' '.join(sql),
       
   224                                             clause,
       
   225                                             None, False,
       
   226                                             self.limit or self.offset)
       
   227         else:
       
   228             sql = ' '.join(sql)
       
   229             clause = None
       
   230 
       
   231         sql = dbhelper.sql_add_limit_offset(sql, self.limit, self.offset, clause)
       
   232         return sql
       
   233 
       
   234     def visit_function(self, function):
       
   235         """generate SQL name for a function"""
       
   236         try:
       
   237             return self.children[0].outputmap[str(function)]
       
   238         except KeyError:
       
   239             return '%s(%s)' % (function.name,
       
   240                                ','.join(c.accept(self) for c in function.children))
       
   241 
       
   242     def visit_variableref(self, variableref):
       
   243         """get the sql name for a variable reference"""
       
   244         try:
       
   245             return self.inputmap[variableref.name]
       
   246         except KeyError: # XXX duh? explain
       
   247             return variableref.variable.name
       
   248 
       
   249     def visit_constant(self, constant):
       
   250         """generate SQL name for a constant"""
       
   251         assert constant.type == 'Int'
       
   252         return str(constant.value)
       
   253 
       
   254 
       
   255 class UnionStep(LimitOffsetMixIn, Step):
       
   256     """union results of child in-memory steps (e.g. OneFetchStep / AggrStep)"""
       
   257 
       
   258     def execute(self):
       
   259         """execute this step"""
       
   260         result = []
       
   261         limit = olimit = self.limit
       
   262         offset = self.offset
       
   263         assert offset != 0
       
   264         if offset is not None:
       
   265             limit = limit + offset
       
   266         for step in self.children:
       
   267             if limit is not None:
       
   268                 if offset is None:
       
   269                     limit = olimit - len(result)
       
   270                 step.set_limit_offset(limit, None)
       
   271             result_ = step.execute()
       
   272             if offset is not None:
       
   273                 offset, result_ = offset_result(offset, result_)
       
   274             result += result_
       
   275             if limit is not None:
       
   276                 if len(result) >= olimit:
       
   277                     return result[:olimit]
       
   278         return result
       
   279 
       
   280     def mytest_repr(self):
       
   281         """return a representation of this step suitable for test"""
       
   282         return (self.__class__.__name__, self.limit, self.offset)
       
   283 
       
   284 
       
   285 class IntersectStep(UnionStep):
       
   286     """return intersection of results of child in-memory steps (e.g. OneFetchStep / AggrStep)"""
       
   287 
       
   288     def execute(self):
       
   289         """execute this step"""
       
   290         result = set()
       
   291         for step in self.children:
       
   292             result &= frozenset(step.execute())
       
   293         result = list(result)
       
   294         if self.offset:
       
   295             result = result[self.offset:]
       
   296         if self.limit:
       
   297             result = result[:self.limit]
       
   298         return result
       
   299 
       
   300 
       
   301 class UnionFetchStep(Step):
       
   302     """union results of child steps using temporary tables (e.g. FetchStep)"""
       
   303 
       
   304     def execute(self):
       
   305         """execute this step"""
       
   306         self.execute_children()
       
   307 
       
   308 
       
   309 __all__ = ('FetchStep', 'AggrStep', 'UnionStep', 'UnionFetchStep', 'IntersectStep')