server/utils.py
author Aurelien Campeas <aurelien.campeas@logilab.fr>
Wed, 06 Jun 2012 10:26:34 +0200
changeset 8433 ff9d6d269877
parent 8425 b86bdc343c18
child 8446 cae198371548
permissions -rw-r--r--
[server/session,repo] turn InternalSession, hence repo.internal_session, into a context manager (closes #2393651)

# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of CubicWeb.
#
# CubicWeb is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# CubicWeb is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
"""Some utilities for the CubicWeb server."""

__docformat__ = "restructuredtext en"

import sys
import logging
from threading import Timer, Thread
from getpass import getpass

from passlib.utils import handlers as uh, to_hash_str
from passlib.context import CryptContext

from cubicweb.md5crypt import crypt as md5crypt


class CustomMD5Crypt(uh.HasSalt, uh.GenericHandler):
    name = 'cubicwebmd5crypt'
    setting_kwds = ('salt',)
    min_salt_size = 0
    max_salt_size = 8
    salt_chars = uh.H64_CHARS

    @classmethod
    def from_string(cls, hash):
        salt, chk = uh.parse_mc2(hash, u'')
        if chk is None:
            raise ValueError('missing checksum')
        return cls(salt=salt, checksum=chk)

    def to_string(self):
        return to_hash_str(u'%s$%s' % (self.salt, self.checksum or u''))

    # passlib 1.5 wants calc_checksum, 1.6 wants _calc_checksum
    def calc_checksum(self, secret):
        return md5crypt(secret, self.salt.encode('ascii')).decode('utf-8')
    _calc_checksum = calc_checksum

_CRYPTO_CTX = CryptContext(['sha512_crypt', CustomMD5Crypt, 'des_crypt', 'ldap_salted_sha1'])

def crypt_password(passwd, salt=None):
    """return the encrypted password using the given salt or a generated one
    """
    if salt is None:
        return _CRYPTO_CTX.encrypt(passwd)
    # empty hash, accept any password for backwards compat
    if salt == '':
        return salt
    if _CRYPTO_CTX.verify(passwd, salt):
        return salt
    # wrong password
    return ''

def cartesian_product(seqin):
    """returns a generator which returns the cartesian product of `seqin`

    for more details, see :
    http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/302478
    """
    def rloop(seqin, comb):
        """recursive looping function"""
        if seqin:                   # any more sequences to process?
            for item in seqin[0]:
                newcomb = comb + [item] # add next item to current combination
                # call rloop w/ remaining seqs, newcomb
                for item in rloop(seqin[1:], newcomb):
                    yield item          # seqs and newcomb
        else:                           # processing last sequence
            yield comb                  # comb finished, add to list
    return rloop(seqin, [])


def cleanup_solutions(rqlst, solutions):
    for sol in solutions:
        for vname in sol.keys():
            if not (vname in rqlst.defined_vars or vname in rqlst.aliases):
                del sol[vname]


def eschema_eid(session, eschema):
    """get eid of the CWEType entity for the given yams type. You should use
    this because when schema has been loaded from the file-system, not from the
    database, (e.g. during tests), eschema.eid is not set.
    """
    if eschema.eid is None:
        eschema.eid = session.execute(
            'Any X WHERE X is CWEType, X name %(name)s',
            {'name': str(eschema)})[0][0]
    return eschema.eid


DEFAULT_MSG = 'we need a manager connection on the repository \
(the server doesn\'t have to run, even should better not)'

def manager_userpasswd(user=None, msg=DEFAULT_MSG, confirm=False,
                       passwdmsg='password'):
    if not user:
        if msg:
            print msg
        while not user:
            user = raw_input('login: ')
        user = unicode(user, sys.stdin.encoding)
    passwd = getpass('%s: ' % passwdmsg)
    if confirm:
        while True:
            passwd2 = getpass('confirm password: ')
            if passwd == passwd2:
                break
            print 'password doesn\'t match'
            passwd = getpass('password: ')
    # XXX decode password using stdin encoding then encode it using appl'encoding
    return user, passwd


_MARKER=object()
def func_name(func):
    name = getattr(func, '__name__', _MARKER)
    if name is _MARKER:
        name = getattr(func, 'func_name', _MARKER)
    if name is _MARKER:
        name = repr(func)
    return name

class LoopTask(object):
    """threaded task restarting itself once executed"""
    def __init__(self, tasks_manager, interval, func, args):
        if interval <= 0:
            raise ValueError('Loop task interval must be > 0 '
                             '(current value: %f for %s)' % \
                             (interval, func_name(func)))
        self._tasks_manager = tasks_manager
        self.interval = interval
        def auto_restart_func(self=self, func=func, args=args):
            restart = True
            try:
                func(*args)
            except Exception:
                logger = logging.getLogger('cubicweb.repository')
                logger.exception('Unhandled exception in LoopTask %s', self.name)
                raise
            except BaseException:
                restart = False
            finally:
                if restart and tasks_manager.running:
                    self.start()
        self.func = auto_restart_func
        self.name = func_name(func)

    def __str__(self):
        return '%s (%s seconds)' % (self.name, self.interval)

    def start(self):
        self._t = Timer(self.interval, self.func)
        self._t.setName('%s-%s[%d]' % (self._t.getName(), self.name, self.interval))
        self._t.start()

    def cancel(self):
        self._t.cancel()

    def join(self):
        if self._t.isAlive():
            self._t.join()


class RepoThread(Thread):
    """subclass of thread so it auto remove itself from a given list once
    executed
    """
    def __init__(self, target, running_threads):
        def auto_remove_func(self=self, func=target):
            try:
                func()
            except Exception:
                logger = logging.getLogger('cubicweb.repository')
                logger.exception('Unhandled exception in RepoThread %s', self._name)
                raise
            finally:
                self.running_threads.remove(self)
        Thread.__init__(self, target=auto_remove_func)
        self.running_threads = running_threads
        self._name = func_name(target)

    def start(self):
        self.running_threads.append(self)
        self.daemon = True
        Thread.start(self)

    def getName(self):
        return '%s(%s)' % (self._name, Thread.getName(self))

class TasksManager(object):
    """Object dedicated manage background task"""

    def __init__(self):
        self.running = False
        self._tasks = []
        self._looping_tasks = []

    def add_looping_task(self, interval, func, *args):
        """register a function to be called every `interval` seconds.

        looping tasks can only be registered during repository initialization,
        once done this method will fail.
        """
        if self.running:
            raise RuntimeError("can't add looping task once the repository is started")
        self._tasks.append( (interval, func, args) )

    def start(self):
        """Start running looping task"""
        assert self.running == False # bw compat purpose maintly

        while self._tasks:
            interval, func, args = self._tasks.pop()
            task = LoopTask(self, interval, func, args)
            self._looping_tasks.append(task)
            self.info('starting task %s with interval %.2fs', task.name,
                      interval)
            task.start()

        self.running = True

    def stop(self):
        """Stop all running task.

        returns when all task have been cancel and none are running anymore"""
        if self.running:
            while self._looping_tasks:
                looptask = self._looping_tasks.pop()
                self.info('canceling task %s...', looptask.name)
                looptask.cancel()
                looptask.join()
                self.info('task %s finished', looptask.name)

from logging import getLogger
from cubicweb import set_log_methods
set_log_methods(TasksManager, getLogger('cubicweb.repository'))