author Alexandre Fayolle <>
Tue, 22 Mar 2011 15:11:38 +0100
changeset 7108 bcdf22734059
parent 6958 861251f125cf
child 7293 97505b798975
child 7499 96412cfc28e2
permissions -rw-r--r--
Abstract the support for ORDER BY and LIMIT/OFFSET SQL generation all DB engines do not support the same syntax for these features, MS SQLServer being the bad boy we try to support in CW. * Use two new methods of dbhelper to add LIMIT/OFFSET clauses and ORDER BY clauses * added unit tests for sqlserver backend * changed unittest_rql2sql to lauch the backend tests even if the driver module is not installed on the machine, so that we can run the sqlserver tests on linux (and the mysql tests too) * adapt to the new interface closes #1154756

# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact --
# This file is part of CubicWeb.
# CubicWeb is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
# CubicWeb is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
# You should have received a copy of the GNU Lesser General Public License along
# with CubicWeb.  If not, see <>.
"""Source to query another RQL repository using pyro"""

__docformat__ = "restructuredtext en"
_ = unicode

import threading
from os.path import join
from time import mktime
from datetime import datetime
from base64 import b64decode

from Pyro.errors import PyroError, ConnectionClosedError

from logilab.common.configuration import REQUIRED
from logilab.common.optik_ext import check_yn

from yams.schema import role_name

from rql.nodes import Constant
from rql.utils import rqlvar_maker

from cubicweb import dbapi, server
from cubicweb import ValidationError, BadConnectionId, UnknownEid, ConnectionError
from cubicweb.schema import VIRTUAL_RTYPES
from cubicweb.cwconfig import register_persistent_options
from cubicweb.server.sources import (AbstractSource, ConnectionWrapper,
                                     TimedCache, dbg_st_search, dbg_results)
from cubicweb.server.msplanner import neged_relation

def uidtype(union, col, etype, args):
    select, col = union.locate_subquery(col, etype, args)
    return getattr(select.selection[col], 'uidtype', None)

class ReplaceByInOperator(Exception):
    def __init__(self, eids):
        self.eids = eids

class PyroRQLSource(AbstractSource):
    """External repository source, using Pyro connection"""

    # boolean telling if modification hooks should be called when something is
    # modified in this source
    should_call_hooks = False
    # boolean telling if the repository should connect to this source during
    # migration
    connect_for_migration = False

    options = (
        # XXX pyro-ns host/port
         {'type' : 'string',
          'default': REQUIRED,
          'help': 'identifier of the repository in the pyro name server',
          'group': 'pyro-source', 'level': 0,
         {'type' : 'string',
          'default': REQUIRED,
          'help': 'user to use for connection on the distant repository',
          'group': 'pyro-source', 'level': 0,
         {'type' : 'password',
          'default': '',
          'help': 'user to use for connection on the distant repository',
          'group': 'pyro-source', 'level': 0,
         {'type' : 'string',
          'default': '',
          'help': 'url of the web site for the distant repository, if you want '
          'to generate external link to entities from this repository',
          'group': 'pyro-source', 'level': 1,
         {'type' : 'yn',
          'default': False,
          'help': 'should entities not local to the source be considered or not',
          'group': 'pyro-source', 'level': 0,
         {'type' : 'string',
          'default': None,
          'help': 'Pyro name server\'s host. If not set, default to the value \
from all_in_one.conf. It may contains port information using <host>:<port> notation.',
          'group': 'pyro-source', 'level': 1,
         {'type' : 'string',
          'default': None,
          'help': 'Pyro name server\'s group where the repository will be \
registered. If not set, default to the value from all_in_one.conf.',
          'group': 'pyro-source', 'level': 2,
         {'type' : 'time',
          'default': '5min',
          'help': 'interval between synchronization with the external \
repository (default to 5 minutes).',
          'group': 'pyro-source', 'level': 2,


    PUBLIC_KEYS = AbstractSource.PUBLIC_KEYS + ('base-url',)
    _conn = None

    def __init__(self, repo, source_config, eid=None):
        AbstractSource.__init__(self, repo, source_config, eid)
        self.update_config(None, self.check_conf_dict(eid, source_config,
        self._query_cache = TimedCache(1800)

    def update_config(self, source_entity, processed_config):
        """update configuration from source entity"""
        # XXX get it through pyro if unset
        baseurl = processed_config.get('base-url')
        if baseurl and not baseurl.endswith('/'):
            processed_config['base-url'] += '/'
        self.config = processed_config
        self._skip_externals = processed_config['skip-external-entities']
        if source_entity is not None:
            self.latest_retrieval = source_entity.latest_retrieval

    def reset_caches(self):
        """method called during test to reset potential source caches"""
        self._query_cache = TimedCache(1800)

    def init(self, activated, source_entity):
        """method called by the repository once ready to handle request"""
        if activated:
            interval = self.config['synchronization-interval']
            self.repo.looping_task(interval, self.synchronize)
            self.latest_retrieval = source_entity.latest_retrieval

    def load_mapping(self, session=None):
        self.support_entities = {}
        self.support_relations = {}
        self.dont_cross_relations = set(('owned_by', 'created_by'))
        self.cross_relations = set()
        assert self.eid is not None
        self._schemacfg_idx = {}

    etype_options = set(('write',))
    rtype_options = set(('maycross', 'dontcross', 'write',))

    def _check_options(self, schemacfg, allowedoptions):
        if schemacfg.options:
            options = set(w.strip() for w in schemacfg.options.split(':'))
            options = set()
        if options - allowedoptions:
            options = ', '.join(sorted(options - allowedoptions))
            msg = _('unknown option(s): %s' % options)
            raise ValidationError(schemacfg.eid, {role_name('options', 'subject'): msg})
        return options

    def add_schema_config(self, schemacfg, checkonly=False):
        """added CWSourceSchemaConfig, modify mapping accordingly"""
            ertype =
        except AttributeError:
            msg = schemacfg._cw._("attribute/relation can't be mapped, only "
                                  "entity and relation types")
            raise ValidationError(schemacfg.eid, {role_name('cw_for_schema', 'subject'): msg})
        if schemacfg.schema.__regid__ == 'CWEType':
            options = self._check_options(schemacfg, self.etype_options)
            if not checkonly:
                self.support_entities[ertype] = 'write' in options
        else: # CWRType
            if ertype in ('is', 'is_instance_of', 'cw_source') or ertype in VIRTUAL_RTYPES:
                msg = schemacfg._cw._('%s relation should not be in mapped') % rtype
                raise ValidationError(schemacfg.eid, {role_name('cw_for_schema', 'subject'): msg})
            options = self._check_options(schemacfg, self.rtype_options)
            if 'dontcross' in options:
                if 'maycross' in options:
                    msg = schemacfg._("can't mix dontcross and maycross options")
                    raise ValidationError(schemacfg.eid, {role_name('options', 'subject'): msg})
                if 'write' in options:
                    msg = schemacfg._("can't mix dontcross and write options")
                    raise ValidationError(schemacfg.eid, {role_name('options', 'subject'): msg})
                if not checkonly:
            elif not checkonly:
                self.support_relations[ertype] = 'write' in options
                if 'maycross' in options:
        if not checkonly:
            # add to an index to ease deletion handling
            self._schemacfg_idx[schemacfg.eid] = ertype

    def del_schema_config(self, schemacfg, checkonly=False):
        """deleted CWSourceSchemaConfig, modify mapping accordingly"""
        if checkonly:
            ertype = self._schemacfg_idx[schemacfg.eid]
            if ertype[0].isupper():
                del self.support_entities[ertype]
                if ertype in self.support_relations:
                    del self.support_relations[ertype]
                    if ertype in self.cross_relations:
            self.error('while updating mapping consequently to removal of %s',

    def local_eid(self, cnx, extid, session):
        etype, dexturi, dextid = cnx.describe(extid)
        if dexturi == 'system' or not (
            dexturi in self.repo.sources_by_uri or self._skip_externals):
            return self.repo.extid2eid(self, str(extid), etype, session), True
        if dexturi in self.repo.sources_by_uri:
            source = self.repo.sources_by_uri[dexturi]
            cnx = session.pool.connection(source.uri)
            eid = source.local_eid(cnx, dextid, session)[0]
            return eid, False
        return None, None

    def synchronize(self, mtime=None):
        """synchronize content known by this repository with content in the
        external repository
        """'synchronizing pyro source %s', self.uri)
        cnx = self.get_connection()
            extrepo = cnx._repo
        except AttributeError:
            # fake connection wrapper returned when we can't connect to the
            # external source (hence we've no chance to synchronize...)
        etypes = self.support_entities.keys()
        if mtime is None:
            mtime = self.latest_retrieval
        updatetime, modified, deleted = extrepo.entities_modified_since(
            etypes, mtime)
        repo = self.repo
        session = repo.internal_session()
        source = repo.system_source
            for etype, extid in modified:
                    eid = self.local_eid(cnx, extid, session)[0]
                    if eid is not None:
                        rset = session.eid_rset(eid, etype)
                        entity = rset.get_entity(0, 0)
                        source.index_entity(session, entity)
                    self.exception('while updating %s with external id %s of source %s',
                                   etype, extid, self.uri)
            for etype, extid in deleted:
                    eid = self.extid2eid(str(extid), etype, session,
                    # entity has been deleted from external repository but is not known here
                    if eid is not None:
                        entity = session.entity_from_eid(eid, etype)
                        repo.delete_info(session, entity, self.uri, extid,
                    self.exception('while updating %s with external id %s of source %s',
                                   etype, extid, self.uri)
            self.latest_retrieval = updatetime
            session.execute('SET X latest_retrieval %(date)s WHERE X eid %(x)s',
                            {'x': self.eid, 'date': self.latest_retrieval})

    def _get_connection(self):
        """open and return a connection to the source"""
        nshost = self.config.get('pyro-ns-host') or self.repo.config['pyro-ns-host']
        nsgroup = self.config.get('pyro-ns-group') or self.repo.config['pyro-ns-group']'connecting to instance :%s.%s for user %s',
                  nsgroup, self.config['pyro-ns-id'], self.config['cubicweb-user'])
        #cnxprops = ConnectionProperties(cnxtype=self.config['cnx-type'])
        return dbapi.connect(database=self.config['pyro-ns-id'],
                             host=nshost, group=nsgroup,
                             setvreg=False) #cnxprops=cnxprops)

    def get_connection(self):
            return self._get_connection()
        except (ConnectionError, PyroError):
            self.critical("can't get connection to source %s", self.uri,
            return ConnectionWrapper()

    def check_connection(self, cnx):
        """check connection validity, return None if the connection is still valid
        else a new connection
        # we have to transfer manually thread ownership. This can be done safely
        # since the pool to which belong the connection is affected to one
        # session/thread and can't be called simultaneously
        except AttributeError:
            # inmemory connection
        if not isinstance(cnx, ConnectionWrapper):
                return # ok
            except (BadConnectionId, ConnectionClosedError):
        # try to reconnect
        return self.get_connection()

    def syntax_tree_search(self, session, union, args=None, cachekey=None,
        assert dbg_st_search(self.uri, union, varmap, args, cachekey)
        rqlkey = union.as_string(kwargs=args)
            results = self._query_cache[rqlkey]
        except KeyError:
            results = self._syntax_tree_search(session, union, args)
            self._query_cache[rqlkey] = results
        assert dbg_results(results)
        return results

    def _syntax_tree_search(self, session, union, args):
        """return result from this source for a rql query (actually from a rql
        syntax tree and a solution dictionary mapping each used variable to a
        possible type). If cachekey is given, the query necessary to fetch the
        results (but not the results themselves) may be cached using this key.
        if not args is None:
            args = args.copy()
        # get cached cursor anyway
        cu = session.pool[self.uri]
        if cu is None:
            # this is a ConnectionWrapper instance
            msg = session._("can't connect to source %s, some data may be missing")
            session.set_shared_data('sources_error', msg % self.uri)
            return []
        translator = RQL2RQL(self)
            rql = translator.generate(session, union, args)
        except UnknownEid, ex:
            if server.DEBUG:
                print '  unknown eid', ex, 'no results'
            return []
        if server.DEBUG & server.DBG_RQL:
            print '  translated rql', rql
            rset = cu.execute(rql, args)
        except Exception, ex:
            msg = session._("error while querying source %s, some data may be missing")
            session.set_shared_data('sources_error', msg % self.uri)
            return []
        descr = rset.description
        if rset:
            needtranslation = []
            rows = rset.rows
            for i, etype in enumerate(descr[0]):
                if (etype is None or not self.schema.eschema(etype).final
                    or uidtype(union, i, etype, args)):
            if needtranslation:
                cnx = session.pool.connection(self.uri)
                for rowindex in xrange(rset.rowcount - 1, -1, -1):
                    row = rows[rowindex]
                    localrow = False
                    for colindex in needtranslation:
                        if row[colindex] is not None: # optional variable
                            eid, local = self.local_eid(cnx, row[colindex], session)
                            if local:
                                localrow = True
                            if eid is not None:
                                row[colindex] = eid
                                # skip this row
                                del rows[rowindex]
                                del descr[rowindex]
                        # skip row if it only contains eids of entities which
                        # are actually from a source we also know locally,
                        # except if some args specified (XXX should actually
                        # check if there are some args local to the source)
                        if not (translator.has_local_eid or localrow):
                            del rows[rowindex]
                            del descr[rowindex]
            results = rows
            results = []
        return results

    def _entity_relations_and_kwargs(self, session, entity):
        relations = []
        kwargs = {'x': self.eid2extid(entity.eid, session)}
        for key, val in entity.cw_attr_cache.iteritems():
            relations.append('X %s %%(%s)s' % (key, key))
            kwargs[key] = val
        return relations, kwargs

    def add_entity(self, session, entity):
        """add a new entity to the source"""
        raise NotImplementedError()

    def update_entity(self, session, entity):
        """update an entity in the source"""
        relations, kwargs = self._entity_relations_and_kwargs(session, entity)
        cu = session.pool[self.uri]
        cu.execute('SET %s WHERE X eid %%(x)s' % ','.join(relations), kwargs)

    def delete_entity(self, session, entity):
        """delete an entity from the source"""
        cu = session.pool[self.uri]
        cu.execute('DELETE %s X WHERE X eid %%(x)s' % entity.__regid__,
                   {'x': self.eid2extid(entity.eid, session)})

    def add_relation(self, session, subject, rtype, object):
        """add a relation to the source"""
        cu = session.pool[self.uri]
        cu.execute('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
                   {'x': self.eid2extid(subject, session),
                    'y': self.eid2extid(object, session)})

    def delete_relation(self, session, subject, rtype, object):
        """delete a relation from the source"""
        cu = session.pool[self.uri]
        cu.execute('DELETE X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
                   {'x': self.eid2extid(subject, session),
                    'y': self.eid2extid(object, session)})

class RQL2RQL(object):
    """translate a local rql query to be executed on a distant repository"""
    def __init__(self, source):
        self.source = source
        self.current_operator = None

    def _accept_children(self, node):
        res = []
        for child in node.children:
            rql = child.accept(self)
            if rql is not None:
        return res

    def generate(self, session, rqlst, args):
        self._session = session
        self.kwargs = args
        self.need_translation = False
        self.has_local_eid = False
        return self.visit_union(rqlst)

    def visit_union(self, node):
        s = self._accept_children(node)
        if len(s) > 1:
            return ' UNION '.join('(%s)' % q for q in s)
        return s[0]

    def visit_select(self, node):
        """return the tree as an encoded rql string"""
        self._varmaker = rqlvar_maker(defined=node.defined_vars.copy())
        self._const_var = {}
        if node.distinct:
            base = 'DISTINCT Any'
            base = 'Any'
        s = ['%s %s' % (base, ','.join(v.accept(self) for v in node.selection))]
        if node.groupby:
            s.append('GROUPBY %s' % ', '.join(group.accept(self)
                                              for group in node.groupby))
        if node.orderby:
            s.append('ORDERBY %s' % ', '.join(self.visit_sortterm(term)
                                              for term in node.orderby))
        if node.limit is not None:
            s.append('LIMIT %s' % node.limit)
        if node.offset:
            s.append('OFFSET %s' % node.offset)
        restrictions = []
        if node.where is not None:
            nr = node.where.accept(self)
            if nr is not None:
        if restrictions:
            s.append('WHERE %s' % ','.join(restrictions))

        if node.having:
            s.append('HAVING %s' % ', '.join(term.accept(self)
                                             for term in node.having))
        subqueries = []
        for subquery in node.with_:
            subqueries.append('%s BEING (%s)' % (','.join( for ca in subquery.aliases),
        if subqueries:
            s.append('WITH %s' % (','.join(subqueries)))
        return ' '.join(s)

    def visit_and(self, node):
        res = self._accept_children(node)
        if res:
            return ', '.join(res)

    def visit_or(self, node):
        res = self._accept_children(node)
        if len(res) > 1:
            return ' OR '.join('(%s)' % rql for rql in res)
        elif res:
            return res[0]

    def visit_not(self, node):
        rql = node.children[0].accept(self)
        if rql:
            return 'NOT (%s)' % rql

    def visit_exists(self, node):
        rql = node.children[0].accept(self)
        if rql:
            return 'EXISTS(%s)' % rql

    def visit_relation(self, node):
            if isinstance(node.children[0], Constant):
                # simplified rqlst, reintroduce eid relation
                    restr, lhs = self.process_eid_const(node.children[0])
                except UnknownEid:
                    # can safely skip not relation with an unsupported eid
                    if neged_relation(node):
                lhs = node.children[0].accept(self)
                restr = None
        except UnknownEid:
            # can safely skip not relation with an unsupported eid
            if neged_relation(node):
            # XXX what about optional relation or outer NOT EXISTS()
        if node.optional in ('left', 'both'):
            lhs += '?'
        if node.r_type == 'eid' or not self.source.schema.rschema(node.r_type).final:
            self.need_translation = True
            self.current_operator = node.operator()
            if isinstance(node.children[0], Constant):
                self.current_etypes = (node.children[0].uidtype,)
                self.current_etypes = node.children[0].variable.stinfo['possibletypes']
            rhs = node.children[1].accept(self)
        except UnknownEid:
            # can safely skip not relation with an unsupported eid
            if neged_relation(node):
            # XXX what about optional relation or outer NOT EXISTS()
        except ReplaceByInOperator, ex:
            rhs = 'IN (%s)' % ','.join(eid for eid in ex.eids)
        self.need_translation = False
        self.current_operator = None
        if node.optional in ('right', 'both'):
            rhs += '?'
        if restr is not None:
            return '%s %s %s, %s' % (lhs, node.r_type, rhs, restr)
        return '%s %s %s' % (lhs, node.r_type, rhs)

    def visit_comparison(self, node):
        if node.operator in ('=', 'IS'):
            return node.children[0].accept(self)
        return '%s %s' % (node.operator.encode(),

    def visit_mathexpression(self, node):
        return '(%s %s %s)' % (node.children[0].accept(self),

    def visit_function(self, node):
        #if == 'IN':
        res = []
        for child in node.children:
                rql = child.accept(self)
            except UnknownEid, ex:
        if not res:
            raise ex
        return '%s(%s)' % (, ', '.join(res))

    def visit_constant(self, node):
        if self.need_translation or node.uidtype:
            if node.type == 'Int':
                self.has_local_eid = True
                return str(self.eid2extid(node.value))
            if node.type == 'Substitute':
                key = node.value
                # ensure we have not yet translated the value...
                if not key in self._const_var:
                    self.kwargs[key] = self.eid2extid(self.kwargs[key])
                    self._const_var[key] = None
                    self.has_local_eid = True
        return node.as_string()

    def visit_variableref(self, node):
        """get the sql name for a variable reference"""

    def visit_sortterm(self, node):
        if node.asc:
            return node.term.accept(self)
        return '%s DESC' % node.term.accept(self)

    def process_eid_const(self, const):
        value = const.eval(self.kwargs)
            return None, self._const_var[value]
            var =
            self.need_translation = True
            restr = '%s eid %s' % (var, self.visit_constant(const))
            self.need_translation = False
            self._const_var[value] = var
            return restr, var

    def eid2extid(self, eid):
            return self.source.eid2extid(eid, self._session)
        except UnknownEid:
            operator = self.current_operator
            if operator is not None and operator != '=':
                # deal with query like "X eid > 12"
                # The problem is that eid order in the external source may
                # differ from the local source
                # So search for all eids from this source matching the condition
                # locally and then to replace the "> 12" branch by "IN (eids)"
                # XXX we may have to insert a huge number of eids...)
                sql = "SELECT extid FROM entities WHERE source='%s' AND type IN (%s) AND eid%s%s"
                etypes = ','.join("'%s'" % etype for etype in self.current_etypes)
                cu = self._session.system_sql(sql % (self.source.uri, etypes,
                                                      operator, eid))
                # XXX buggy cu.rowcount which may be zero while there are some
                # results
                rows = cu.fetchall()
                if rows:
                    raise ReplaceByInOperator((b64decode(r[0]) for r in rows))