server/sources/pyrorql.py
author Sylvain Thénault <sylvain.thenault@logilab.fr>
Mon, 25 Oct 2010 17:15:51 +0200
changeset 6632 78878f5a8166
parent 6626 5c20a005bddc
child 6652 592c88c8f018
permissions -rw-r--r--
cleanup

# copyright 2003-2010 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of CubicWeb.
#
# CubicWeb is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# CubicWeb is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
"""Source to query another RQL repository using pyro"""

__docformat__ = "restructuredtext en"
_ = unicode

import threading
from os.path import join
from time import mktime
from datetime import datetime
from base64 import b64decode

from Pyro.errors import PyroError, ConnectionClosedError

from logilab.common.configuration import REQUIRED
from logilab.common.optik_ext import check_yn

from rql.nodes import Constant
from rql.utils import rqlvar_maker

from cubicweb import dbapi, server
from cubicweb import BadConnectionId, UnknownEid, ConnectionError
from cubicweb.cwconfig import register_persistent_options
from cubicweb.server.sources import (AbstractSource, ConnectionWrapper,
                                     TimedCache, dbg_st_search, dbg_results)
from cubicweb.server.msplanner import neged_relation

def uidtype(union, col, etype, args):
    select, col = union.locate_subquery(col, etype, args)
    return getattr(select.selection[col], 'uidtype', None)

def load_mapping_file(mappingfile):
    mapping = {}
    execfile(mappingfile, mapping)
    for junk in ('__builtins__', '__doc__'):
        mapping.pop(junk, None)
    mapping.setdefault('support_relations', {})
    mapping.setdefault('dont_cross_relations', set())
    mapping.setdefault('cross_relations', set())

    # do some basic checks of the mapping content
    assert 'support_entities' in mapping, \
           'mapping file should at least define support_entities'
    assert isinstance(mapping['support_entities'], dict)
    assert isinstance(mapping['support_relations'], dict)
    assert isinstance(mapping['dont_cross_relations'], set)
    assert isinstance(mapping['cross_relations'], set)
    unknown = set(mapping) - set( ('support_entities', 'support_relations',
                                   'dont_cross_relations', 'cross_relations') )
    assert not unknown, 'unknown mapping attribute(s): %s' % unknown
    # relations that are necessarily not crossed
    mapping['dont_cross_relations'] |= set(('owned_by', 'created_by'))
    for rtype in ('is', 'is_instance_of', 'cw_source'):
        assert rtype not in mapping['dont_cross_relations'], \
               '%s relation should not be in dont_cross_relations' % rtype
        assert rtype not in mapping['support_relations'], \
               '%s relation should not be in support_relations' % rtype
    return mapping


class ReplaceByInOperator(Exception):
    def __init__(self, eids):
        self.eids = eids

class PyroRQLSource(AbstractSource):
    """External repository source, using Pyro connection"""

    # boolean telling if modification hooks should be called when something is
    # modified in this source
    should_call_hooks = False
    # boolean telling if the repository should connect to this source during
    # migration
    connect_for_migration = False

    options = (
        # XXX pyro-ns host/port
        ('pyro-ns-id',
         {'type' : 'string',
          'default': REQUIRED,
          'help': 'identifier of the repository in the pyro name server',
          'group': 'pyro-source', 'level': 0,
          }),
        ('mapping-file',
         {'type' : 'string',
          'default': REQUIRED,
          'help': 'path to a python file with the schema mapping definition',
          'group': 'pyro-source', 'level': 1,
          }),
        ('cubicweb-user',
         {'type' : 'string',
          'default': REQUIRED,
          'help': 'user to use for connection on the distant repository',
          'group': 'pyro-source', 'level': 0,
          }),
        ('cubicweb-password',
         {'type' : 'password',
          'default': '',
          'help': 'user to use for connection on the distant repository',
          'group': 'pyro-source', 'level': 0,
          }),
        ('base-url',
         {'type' : 'string',
          'default': '',
          'help': 'url of the web site for the distant repository, if you want '
          'to generate external link to entities from this repository',
          'group': 'pyro-source', 'level': 1,
          }),
        ('skip-external-entities',
         {'type' : 'yn',
          'default': False,
          'help': 'should entities not local to the source be considered or not',
          'group': 'pyro-source', 'level': 0,
          }),
        ('pyro-ns-host',
         {'type' : 'string',
          'default': None,
          'help': 'Pyro name server\'s host. If not set, default to the value \
from all_in_one.conf. It may contains port information using <host>:<port> notation.',
          'group': 'pyro-source', 'level': 1,
          }),
        ('pyro-ns-group',
         {'type' : 'string',
          'default': None,
          'help': 'Pyro name server\'s group where the repository will be \
registered. If not set, default to the value from all_in_one.conf.',
          'group': 'pyro-source', 'level': 2,
          }),
        ('synchronization-interval',
         {'type' : 'int',
          'default': 5*60,
          'help': 'interval between synchronization with the external \
repository (default to 5 minutes).',
          'group': 'pyro-source', 'level': 2,
          }),

    )

    PUBLIC_KEYS = AbstractSource.PUBLIC_KEYS + ('base-url',)
    _conn = None

    def __init__(self, repo, source_config, *args, **kwargs):
        AbstractSource.__init__(self, repo, source_config, *args, **kwargs)
        mappingfile = source_config['mapping-file']
        if not mappingfile[0] == '/':
            mappingfile = join(repo.config.apphome, mappingfile)
        try:
            mapping = load_mapping_file(mappingfile)
        except IOError:
            self.disabled = True
            self.error('cant read mapping file %s, source disabled',
                       mappingfile)
            self.support_entities = {}
            self.support_relations = {}
            self.dont_cross_relations = set()
            self.cross_relations = set()
        else:
            self.support_entities = mapping['support_entities']
            self.support_relations = mapping['support_relations']
            self.dont_cross_relations = mapping['dont_cross_relations']
            self.cross_relations = mapping['cross_relations']
        baseurl = source_config.get('base-url')
        if baseurl and not baseurl.endswith('/'):
            source_config['base-url'] += '/'
        self.config = source_config
        myoptions = (('%s.latest-update-time' % self.uri,
                      {'type' : 'int', 'sitewide': True,
                       'default': 0,
                       'help': _('timestamp of the latest source synchronization.'),
                       'group': 'sources',
                       }),)
        register_persistent_options(myoptions)
        self._query_cache = TimedCache(1800)
        self._skip_externals = check_yn(None, 'skip-external-entities',
                                        source_config.get('skip-external-entities', False))

    def reset_caches(self):
        """method called during test to reset potential source caches"""
        self._query_cache = TimedCache(1800)

    def last_update_time(self):
        pkey = u'sources.%s.latest-update-time' % self.uri
        session = self.repo.internal_session()
        try:
            rset = session.execute('Any V WHERE X is CWProperty, X value V, X pkey %(k)s',
                                   {'k': pkey})
            if not rset:
                # insert it
                session.execute('INSERT CWProperty X: X pkey %(k)s, X value %(v)s',
                                {'k': pkey, 'v': u'0'})
                session.commit()
                timestamp = 0
            else:
                assert len(rset) == 1
                timestamp = int(rset[0][0])
            return datetime.fromtimestamp(timestamp)
        finally:
            session.close()

    def init(self):
        """method called by the repository once ready to handle request"""
        interval = int(self.config.get('synchronization-interval', 5*60))
        self.repo.looping_task(interval, self.synchronize)
        self.repo.looping_task(self._query_cache.ttl.seconds/10,
                               self._query_cache.clear_expired)

    def map_entity_source(self, exturi):
        return (exturi == 'system' or
                not (exturi in self.repo.sources_by_uri or self._skip_externals))

    def synchronize(self, mtime=None):
        """synchronize content known by this repository with content in the
        external repository
        """
        self.info('synchronizing pyro source %s', self.uri)
        cnx = self.get_connection()
        try:
            extrepo = cnx._repo
        except AttributeError:
            # fake connection wrapper returned when we can't connect to the
            # external source (hence we've no chance to synchronize...)
            return
        etypes = self.support_entities.keys()
        if mtime is None:
            mtime = self.last_update_time()
        updatetime, modified, deleted = extrepo.entities_modified_since(etypes,
                                                                        mtime)
        self._query_cache.clear()
        repo = self.repo
        session = repo.internal_session()
        source = repo.system_source
        try:
            for etype, extid in modified:
                try:
                    exturi = cnx.describe(extid)[1]
                    if self.map_entity_source(exturi):
                        eid = self.extid2eid(str(extid), etype, session)
                        rset = session.eid_rset(eid, etype)
                        entity = rset.get_entity(0, 0)
                        entity.complete(entity.e_schema.indexable_attributes())
                        source.index_entity(session, entity)
                except:
                    self.exception('while updating %s with external id %s of source %s',
                                   etype, extid, self.uri)
                    continue
            for etype, extid in deleted:
                try:
                    eid = self.extid2eid(str(extid), etype, session,
                                         insert=False)
                    # entity has been deleted from external repository but is not known here
                    if eid is not None:
                        entity = session.entity_from_eid(eid, etype)
                        repo.delete_info(session, entity, self.uri, extid,
                                         scleanup=True)
                except:
                    self.exception('while updating %s with external id %s of source %s',
                                   etype, extid, self.uri)
                    continue
            session.execute('SET X value %(v)s WHERE X pkey %(k)s',
                            {'k': u'sources.%s.latest-update-time' % self.uri,
                             'v': unicode(int(mktime(updatetime.timetuple())))})
            session.commit()
        finally:
            session.close()

    def _get_connection(self):
        """open and return a connection to the source"""
        nshost = self.config.get('pyro-ns-host') or self.repo.config['pyro-ns-host']
        nsgroup = self.config.get('pyro-ns-group') or self.repo.config['pyro-ns-group']
        self.info('connecting to instance :%s.%s for user %s',
                  nsgroup, self.config['pyro-ns-id'], self.config['cubicweb-user'])
        #cnxprops = ConnectionProperties(cnxtype=self.config['cnx-type'])
        return dbapi.connect(database=self.config['pyro-ns-id'],
                             login=self.config['cubicweb-user'],
                             password=self.config['cubicweb-password'],
                             host=nshost, group=nsgroup,
                             setvreg=False) #cnxprops=cnxprops)

    def get_connection(self):
        try:
            return self._get_connection()
        except (ConnectionError, PyroError):
            self.critical("can't get connection to source %s", self.uri,
                          exc_info=1)
            return ConnectionWrapper()

    def check_connection(self, cnx):
        """check connection validity, return None if the connection is still valid
        else a new connection
        """
        # we have to transfer manually thread ownership. This can be done safely
        # since the pool to which belong the connection is affected to one
        # session/thread and can't be called simultaneously
        try:
            cnx._repo._transferThread(threading.currentThread())
        except AttributeError:
            # inmemory connection
            pass
        if not isinstance(cnx, ConnectionWrapper):
            try:
                cnx.check()
                return # ok
            except (BadConnectionId, ConnectionClosedError):
                pass
        # try to reconnect
        return self.get_connection()

    def syntax_tree_search(self, session, union, args=None, cachekey=None,
                           varmap=None):
        assert dbg_st_search(self.uri, union, varmap, args, cachekey)
        rqlkey = union.as_string(kwargs=args)
        try:
            results = self._query_cache[rqlkey]
        except KeyError:
            results = self._syntax_tree_search(session, union, args)
            self._query_cache[rqlkey] = results
        assert dbg_results(results)
        return results

    def _syntax_tree_search(self, session, union, args):
        """return result from this source for a rql query (actually from a rql
        syntax tree and a solution dictionary mapping each used variable to a
        possible type). If cachekey is given, the query necessary to fetch the
        results (but not the results themselves) may be cached using this key.
        """
        if not args is None:
            args = args.copy()
        # get cached cursor anyway
        cu = session.pool[self.uri]
        if cu is None:
            # this is a ConnectionWrapper instance
            msg = session._("can't connect to source %s, some data may be missing")
            session.set_shared_data('sources_error', msg % self.uri)
            return []
        try:
            rql = RQL2RQL(self).generate(session, union, args)
        except UnknownEid, ex:
            if server.DEBUG:
                print '  unknown eid', ex, 'no results'
            return []
        if server.DEBUG & server.DBG_RQL:
            print '  translated rql', rql
        try:
            rset = cu.execute(rql, args)
        except Exception, ex:
            self.exception(str(ex))
            msg = session._("error while querying source %s, some data may be missing")
            session.set_shared_data('sources_error', msg % self.uri)
            return []
        descr = rset.description
        if rset:
            needtranslation = []
            rows = rset.rows
            for i, etype in enumerate(descr[0]):
                if (etype is None or not self.schema.eschema(etype).final
                    or uidtype(union, i, etype, args)):
                    needtranslation.append(i)
            if needtranslation:
                cnx = session.pool.connection(self.uri)
                for rowindex in xrange(rset.rowcount - 1, -1, -1):
                    row = rows[rowindex]
                    for colindex in needtranslation:
                        if row[colindex] is not None: # optional variable
                            etype = descr[rowindex][colindex]
                            exttype, exturi, extid = cnx.describe(row[colindex])
                            if self.map_entity_source(exturi):
                                eid = self.extid2eid(str(row[colindex]), etype,
                                                     session)
                                row[colindex] = eid
                            else:
                                # skip this row
                                del rows[rowindex]
                                del descr[rowindex]
                                break
            results = rows
        else:
            results = []
        return results

    def _entity_relations_and_kwargs(self, session, entity):
        relations = []
        kwargs = {'x': self.eid2extid(entity.eid, session)}
        for key, val in entity.iteritems():
            relations.append('X %s %%(%s)s' % (key, key))
            kwargs[key] = val
        return relations, kwargs

    def add_entity(self, session, entity):
        """add a new entity to the source"""
        raise NotImplementedError()

    def update_entity(self, session, entity):
        """update an entity in the source"""
        relations, kwargs = self._entity_relations_and_kwargs(session, entity)
        cu = session.pool[self.uri]
        cu.execute('SET %s WHERE X eid %%(x)s' % ','.join(relations), kwargs)
        self._query_cache.clear()
        entity.clear_all_caches()

    def delete_entity(self, session, entity):
        """delete an entity from the source"""
        cu = session.pool[self.uri]
        cu.execute('DELETE %s X WHERE X eid %%(x)s' % entity.__regid__,
                   {'x': self.eid2extid(entity.eid, session)})
        self._query_cache.clear()

    def add_relation(self, session, subject, rtype, object):
        """add a relation to the source"""
        cu = session.pool[self.uri]
        cu.execute('SET X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
                   {'x': self.eid2extid(subject, session),
                    'y': self.eid2extid(object, session)})
        self._query_cache.clear()
        session.entity_from_eid(subject).clear_all_caches()
        session.entity_from_eid(object).clear_all_caches()

    def delete_relation(self, session, subject, rtype, object):
        """delete a relation from the source"""
        cu = session.pool[self.uri]
        cu.execute('DELETE X %s Y WHERE X eid %%(x)s, Y eid %%(y)s' % rtype,
                   {'x': self.eid2extid(subject, session),
                    'y': self.eid2extid(object, session)})
        self._query_cache.clear()
        session.entity_from_eid(subject).clear_all_caches()
        session.entity_from_eid(object).clear_all_caches()


class RQL2RQL(object):
    """translate a local rql query to be executed on a distant repository"""
    def __init__(self, source):
        self.source = source
        self.current_operator = None

    def _accept_children(self, node):
        res = []
        for child in node.children:
            rql = child.accept(self)
            if rql is not None:
                res.append(rql)
        return res

    def generate(self, session, rqlst, args):
        self._session = session
        self.kwargs = args
        self.need_translation = False
        return self.visit_union(rqlst)

    def visit_union(self, node):
        s = self._accept_children(node)
        if len(s) > 1:
            return ' UNION '.join('(%s)' % q for q in s)
        return s[0]

    def visit_select(self, node):
        """return the tree as an encoded rql string"""
        self._varmaker = rqlvar_maker(defined=node.defined_vars.copy())
        self._const_var = {}
        if node.distinct:
            base = 'DISTINCT Any'
        else:
            base = 'Any'
        s = ['%s %s' % (base, ','.join(v.accept(self) for v in node.selection))]
        if node.groupby:
            s.append('GROUPBY %s' % ', '.join(group.accept(self)
                                              for group in node.groupby))
        if node.orderby:
            s.append('ORDERBY %s' % ', '.join(self.visit_sortterm(term)
                                              for term in node.orderby))
        if node.limit is not None:
            s.append('LIMIT %s' % node.limit)
        if node.offset:
            s.append('OFFSET %s' % node.offset)
        restrictions = []
        if node.where is not None:
            nr = node.where.accept(self)
            if nr is not None:
                restrictions.append(nr)
        if restrictions:
            s.append('WHERE %s' % ','.join(restrictions))

        if node.having:
            s.append('HAVING %s' % ', '.join(term.accept(self)
                                             for term in node.having))
        subqueries = []
        for subquery in node.with_:
            subqueries.append('%s BEING (%s)' % (','.join(ca.name for ca in subquery.aliases),
                                                 self.visit_union(subquery.query)))
        if subqueries:
            s.append('WITH %s' % (','.join(subqueries)))
        return ' '.join(s)

    def visit_and(self, node):
        res = self._accept_children(node)
        if res:
            return ', '.join(res)
        return

    def visit_or(self, node):
        res = self._accept_children(node)
        if len(res) > 1:
            return ' OR '.join('(%s)' % rql for rql in res)
        elif res:
            return res[0]
        return

    def visit_not(self, node):
        rql = node.children[0].accept(self)
        if rql:
            return 'NOT (%s)' % rql
        return

    def visit_exists(self, node):
        rql = node.children[0].accept(self)
        if rql:
            return 'EXISTS(%s)' % rql
        return

    def visit_relation(self, node):
        try:
            if isinstance(node.children[0], Constant):
                # simplified rqlst, reintroduce eid relation
                try:
                    restr, lhs = self.process_eid_const(node.children[0])
                except UnknownEid:
                    # can safely skip not relation with an unsupported eid
                    if neged_relation(node):
                        return
                    raise
            else:
                lhs = node.children[0].accept(self)
                restr = None
        except UnknownEid:
            # can safely skip not relation with an unsupported eid
            if neged_relation(node):
                return
            # XXX what about optional relation or outer NOT EXISTS()
            raise
        if node.optional in ('left', 'both'):
            lhs += '?'
        if node.r_type == 'eid' or not self.source.schema.rschema(node.r_type).final:
            self.need_translation = True
            self.current_operator = node.operator()
            if isinstance(node.children[0], Constant):
                self.current_etypes = (node.children[0].uidtype,)
            else:
                self.current_etypes = node.children[0].variable.stinfo['possibletypes']
        try:
            rhs = node.children[1].accept(self)
        except UnknownEid:
            # can safely skip not relation with an unsupported eid
            if neged_relation(node):
                return
            # XXX what about optional relation or outer NOT EXISTS()
            raise
        except ReplaceByInOperator, ex:
            rhs = 'IN (%s)' % ','.join(eid for eid in ex.eids)
        self.need_translation = False
        self.current_operator = None
        if node.optional in ('right', 'both'):
            rhs += '?'
        if restr is not None:
            return '%s %s %s, %s' % (lhs, node.r_type, rhs, restr)
        return '%s %s %s' % (lhs, node.r_type, rhs)

    def visit_comparison(self, node):
        if node.operator in ('=', 'IS'):
            return node.children[0].accept(self)
        return '%s %s' % (node.operator.encode(),
                          node.children[0].accept(self))

    def visit_mathexpression(self, node):
        return '(%s %s %s)' % (node.children[0].accept(self),
                               node.operator.encode(),
                               node.children[1].accept(self))

    def visit_function(self, node):
        #if node.name == 'IN':
        res = []
        for child in node.children:
            try:
                rql = child.accept(self)
            except UnknownEid, ex:
                continue
            res.append(rql)
        if not res:
            raise ex
        return '%s(%s)' % (node.name, ', '.join(res))

    def visit_constant(self, node):
        if self.need_translation or node.uidtype:
            if node.type == 'Int':
                return str(self.eid2extid(node.value))
            if node.type == 'Substitute':
                key = node.value
                # ensure we have not yet translated the value...
                if not key in self._const_var:
                    self.kwargs[key] = self.eid2extid(self.kwargs[key])
                    self._const_var[key] = None
        return node.as_string()

    def visit_variableref(self, node):
        """get the sql name for a variable reference"""
        return node.name

    def visit_sortterm(self, node):
        if node.asc:
            return node.term.accept(self)
        return '%s DESC' % node.term.accept(self)

    def process_eid_const(self, const):
        value = const.eval(self.kwargs)
        try:
            return None, self._const_var[value]
        except:
            var = self._varmaker.next()
            self.need_translation = True
            restr = '%s eid %s' % (var, self.visit_constant(const))
            self.need_translation = False
            self._const_var[value] = var
            return restr, var

    def eid2extid(self, eid):
        try:
            return self.source.eid2extid(eid, self._session)
        except UnknownEid:
            operator = self.current_operator
            if operator is not None and operator != '=':
                # deal with query like "X eid > 12"
                #
                # The problem is that eid order in the external source may
                # differ from the local source
                #
                # So search for all eids from this source matching the condition
                # locally and then to replace the "> 12" branch by "IN (eids)"
                #
                # XXX we may have to insert a huge number of eids...)
                sql = "SELECT extid FROM entities WHERE source='%s' AND type IN (%s) AND eid%s%s"
                etypes = ','.join("'%s'" % etype for etype in self.current_etypes)
                cu = self._session.system_sql(sql % (self.source.uri, etypes,
                                                      operator, eid))
                # XXX buggy cu.rowcount which may be zero while there are some
                # results
                rows = cu.fetchall()
                if rows:
                    raise ReplaceByInOperator((b64decode(r[0]) for r in rows))
            raise