[server/session,repo] turn InternalSession, hence repo.internal_session, into a context manager (closes #2393651)
# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of CubicWeb.
#
# CubicWeb is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option)
# any later version.
#
# CubicWeb is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with CubicWeb. If not, see <http://www.gnu.org/licenses/>.
"""Some utilities for the CubicWeb server."""
__docformat__ = "restructuredtext en"
import sys
import logging
from threading import Timer, Thread
from getpass import getpass
from passlib.utils import handlers as uh, to_hash_str
from passlib.context import CryptContext
from cubicweb.md5crypt import crypt as md5crypt
class CustomMD5Crypt(uh.HasSalt, uh.GenericHandler):
name = 'cubicwebmd5crypt'
setting_kwds = ('salt',)
min_salt_size = 0
max_salt_size = 8
salt_chars = uh.H64_CHARS
@classmethod
def from_string(cls, hash):
salt, chk = uh.parse_mc2(hash, u'')
if chk is None:
raise ValueError('missing checksum')
return cls(salt=salt, checksum=chk)
def to_string(self):
return to_hash_str(u'%s$%s' % (self.salt, self.checksum or u''))
# passlib 1.5 wants calc_checksum, 1.6 wants _calc_checksum
def calc_checksum(self, secret):
return md5crypt(secret, self.salt.encode('ascii')).decode('utf-8')
_calc_checksum = calc_checksum
_CRYPTO_CTX = CryptContext(['sha512_crypt', CustomMD5Crypt, 'des_crypt', 'ldap_salted_sha1'])
def crypt_password(passwd, salt=None):
"""return the encrypted password using the given salt or a generated one
"""
if salt is None:
return _CRYPTO_CTX.encrypt(passwd)
# empty hash, accept any password for backwards compat
if salt == '':
return salt
if _CRYPTO_CTX.verify(passwd, salt):
return salt
# wrong password
return ''
def cartesian_product(seqin):
"""returns a generator which returns the cartesian product of `seqin`
for more details, see :
http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/302478
"""
def rloop(seqin, comb):
"""recursive looping function"""
if seqin: # any more sequences to process?
for item in seqin[0]:
newcomb = comb + [item] # add next item to current combination
# call rloop w/ remaining seqs, newcomb
for item in rloop(seqin[1:], newcomb):
yield item # seqs and newcomb
else: # processing last sequence
yield comb # comb finished, add to list
return rloop(seqin, [])
def cleanup_solutions(rqlst, solutions):
for sol in solutions:
for vname in sol.keys():
if not (vname in rqlst.defined_vars or vname in rqlst.aliases):
del sol[vname]
def eschema_eid(session, eschema):
"""get eid of the CWEType entity for the given yams type. You should use
this because when schema has been loaded from the file-system, not from the
database, (e.g. during tests), eschema.eid is not set.
"""
if eschema.eid is None:
eschema.eid = session.execute(
'Any X WHERE X is CWEType, X name %(name)s',
{'name': str(eschema)})[0][0]
return eschema.eid
DEFAULT_MSG = 'we need a manager connection on the repository \
(the server doesn\'t have to run, even should better not)'
def manager_userpasswd(user=None, msg=DEFAULT_MSG, confirm=False,
passwdmsg='password'):
if not user:
if msg:
print msg
while not user:
user = raw_input('login: ')
user = unicode(user, sys.stdin.encoding)
passwd = getpass('%s: ' % passwdmsg)
if confirm:
while True:
passwd2 = getpass('confirm password: ')
if passwd == passwd2:
break
print 'password doesn\'t match'
passwd = getpass('password: ')
# XXX decode password using stdin encoding then encode it using appl'encoding
return user, passwd
_MARKER=object()
def func_name(func):
name = getattr(func, '__name__', _MARKER)
if name is _MARKER:
name = getattr(func, 'func_name', _MARKER)
if name is _MARKER:
name = repr(func)
return name
class LoopTask(object):
"""threaded task restarting itself once executed"""
def __init__(self, tasks_manager, interval, func, args):
if interval <= 0:
raise ValueError('Loop task interval must be > 0 '
'(current value: %f for %s)' % \
(interval, func_name(func)))
self._tasks_manager = tasks_manager
self.interval = interval
def auto_restart_func(self=self, func=func, args=args):
restart = True
try:
func(*args)
except Exception:
logger = logging.getLogger('cubicweb.repository')
logger.exception('Unhandled exception in LoopTask %s', self.name)
raise
except BaseException:
restart = False
finally:
if restart and tasks_manager.running:
self.start()
self.func = auto_restart_func
self.name = func_name(func)
def __str__(self):
return '%s (%s seconds)' % (self.name, self.interval)
def start(self):
self._t = Timer(self.interval, self.func)
self._t.setName('%s-%s[%d]' % (self._t.getName(), self.name, self.interval))
self._t.start()
def cancel(self):
self._t.cancel()
def join(self):
if self._t.isAlive():
self._t.join()
class RepoThread(Thread):
"""subclass of thread so it auto remove itself from a given list once
executed
"""
def __init__(self, target, running_threads):
def auto_remove_func(self=self, func=target):
try:
func()
except Exception:
logger = logging.getLogger('cubicweb.repository')
logger.exception('Unhandled exception in RepoThread %s', self._name)
raise
finally:
self.running_threads.remove(self)
Thread.__init__(self, target=auto_remove_func)
self.running_threads = running_threads
self._name = func_name(target)
def start(self):
self.running_threads.append(self)
self.daemon = True
Thread.start(self)
def getName(self):
return '%s(%s)' % (self._name, Thread.getName(self))
class TasksManager(object):
"""Object dedicated manage background task"""
def __init__(self):
self.running = False
self._tasks = []
self._looping_tasks = []
def add_looping_task(self, interval, func, *args):
"""register a function to be called every `interval` seconds.
looping tasks can only be registered during repository initialization,
once done this method will fail.
"""
if self.running:
raise RuntimeError("can't add looping task once the repository is started")
self._tasks.append( (interval, func, args) )
def start(self):
"""Start running looping task"""
assert self.running == False # bw compat purpose maintly
while self._tasks:
interval, func, args = self._tasks.pop()
task = LoopTask(self, interval, func, args)
self._looping_tasks.append(task)
self.info('starting task %s with interval %.2fs', task.name,
interval)
task.start()
self.running = True
def stop(self):
"""Stop all running task.
returns when all task have been cancel and none are running anymore"""
if self.running:
while self._looping_tasks:
looptask = self._looping_tasks.pop()
self.info('canceling task %s...', looptask.name)
looptask.cancel()
looptask.join()
self.info('task %s finished', looptask.name)
from logging import getLogger
from cubicweb import set_log_methods
set_log_methods(TasksManager, getLogger('cubicweb.repository'))