cubicweb/server/utils.py
changeset 11057 0b59724cb3f2
parent 10907 9ae707db5265
child 11129 97095348b3ee
equal deleted inserted replaced
11052:058bb3dc685f 11057:0b59724cb3f2
       
     1 # copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
       
     2 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
       
     3 #
       
     4 # This file is part of CubicWeb.
       
     5 #
       
     6 # CubicWeb is free software: you can redistribute it and/or modify it under the
       
     7 # terms of the GNU Lesser General Public License as published by the Free
       
     8 # Software Foundation, either version 2.1 of the License, or (at your option)
       
     9 # any later version.
       
    10 #
       
    11 # CubicWeb is distributed in the hope that it will be useful, but WITHOUT
       
    12 # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
       
    13 # FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
       
    14 # details.
       
    15 #
       
    16 # You should have received a copy of the GNU Lesser General Public License along
       
    17 # with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
       
    18 """Some utilities for the CubicWeb server."""
       
    19 from __future__ import print_function
       
    20 
       
    21 __docformat__ = "restructuredtext en"
       
    22 
       
    23 import sys
       
    24 import logging
       
    25 from threading import Timer, Thread
       
    26 from getpass import getpass
       
    27 
       
    28 from six import PY2, text_type
       
    29 from six.moves import input
       
    30 
       
    31 from passlib.utils import handlers as uh, to_hash_str
       
    32 from passlib.context import CryptContext
       
    33 
       
    34 from cubicweb.md5crypt import crypt as md5crypt
       
    35 
       
    36 
       
    37 class CustomMD5Crypt(uh.HasSalt, uh.GenericHandler):
       
    38     name = 'cubicwebmd5crypt'
       
    39     setting_kwds = ('salt',)
       
    40     min_salt_size = 0
       
    41     max_salt_size = 8
       
    42     salt_chars = uh.H64_CHARS
       
    43 
       
    44     @classmethod
       
    45     def from_string(cls, hash):
       
    46         salt, chk = uh.parse_mc2(hash, u'')
       
    47         if chk is None:
       
    48             raise ValueError('missing checksum')
       
    49         return cls(salt=salt, checksum=chk)
       
    50 
       
    51     def to_string(self):
       
    52         return to_hash_str(u'%s$%s' % (self.salt, self.checksum or u''))
       
    53 
       
    54     # passlib 1.5 wants calc_checksum, 1.6 wants _calc_checksum
       
    55     def calc_checksum(self, secret):
       
    56         return md5crypt(secret, self.salt.encode('ascii')).decode('utf-8')
       
    57     _calc_checksum = calc_checksum
       
    58 
       
    59 _CRYPTO_CTX = CryptContext(['sha512_crypt', CustomMD5Crypt, 'des_crypt', 'ldap_salted_sha1'],
       
    60                            deprecated=['cubicwebmd5crypt', 'des_crypt'])
       
    61 verify_and_update = _CRYPTO_CTX.verify_and_update
       
    62 
       
    63 def crypt_password(passwd, salt=None):
       
    64     """return the encrypted password using the given salt or a generated one
       
    65     """
       
    66     if salt is None:
       
    67         return _CRYPTO_CTX.encrypt(passwd).encode('ascii')
       
    68     # empty hash, accept any password for backwards compat
       
    69     if salt == '':
       
    70         return salt
       
    71     try:
       
    72         if _CRYPTO_CTX.verify(passwd, salt):
       
    73             return salt
       
    74     except ValueError: # e.g. couldn't identify hash
       
    75         pass
       
    76     # wrong password
       
    77     return b''
       
    78 
       
    79 
       
    80 def eschema_eid(cnx, eschema):
       
    81     """get eid of the CWEType entity for the given yams type. You should use
       
    82     this because when schema has been loaded from the file-system, not from the
       
    83     database, (e.g. during tests), eschema.eid is not set.
       
    84     """
       
    85     if eschema.eid is None:
       
    86         eschema.eid = cnx.execute(
       
    87             'Any X WHERE X is CWEType, X name %(name)s',
       
    88             {'name': text_type(eschema)})[0][0]
       
    89     return eschema.eid
       
    90 
       
    91 
       
    92 DEFAULT_MSG = 'we need a manager connection on the repository \
       
    93 (the server doesn\'t have to run, even should better not)'
       
    94 
       
    95 def manager_userpasswd(user=None, msg=DEFAULT_MSG, confirm=False,
       
    96                        passwdmsg='password'):
       
    97     if not user:
       
    98         if msg:
       
    99             print(msg)
       
   100         while not user:
       
   101             user = input('login: ')
       
   102         if PY2:
       
   103             user = unicode(user, sys.stdin.encoding)
       
   104     passwd = getpass('%s: ' % passwdmsg)
       
   105     if confirm:
       
   106         while True:
       
   107             passwd2 = getpass('confirm password: ')
       
   108             if passwd == passwd2:
       
   109                 break
       
   110             print('password doesn\'t match')
       
   111             passwd = getpass('password: ')
       
   112     # XXX decode password using stdin encoding then encode it using appl'encoding
       
   113     return user, passwd
       
   114 
       
   115 
       
   116 _MARKER = object()
       
   117 def func_name(func):
       
   118     name = getattr(func, '__name__', _MARKER)
       
   119     if name is _MARKER:
       
   120         name = getattr(func, 'func_name', _MARKER)
       
   121     if name is _MARKER:
       
   122         name = repr(func)
       
   123     return name
       
   124 
       
   125 class LoopTask(object):
       
   126     """threaded task restarting itself once executed"""
       
   127     def __init__(self, tasks_manager, interval, func, args):
       
   128         if interval < 0:
       
   129             raise ValueError('Loop task interval must be >= 0 '
       
   130                              '(current value: %f for %s)' % \
       
   131                              (interval, func_name(func)))
       
   132         self._tasks_manager = tasks_manager
       
   133         self.interval = interval
       
   134         def auto_restart_func(self=self, func=func, args=args):
       
   135             restart = True
       
   136             try:
       
   137                 func(*args)
       
   138             except Exception:
       
   139                 logger = logging.getLogger('cubicweb.repository')
       
   140                 logger.exception('Unhandled exception in LoopTask %s', self.name)
       
   141                 raise
       
   142             except BaseException:
       
   143                 restart = False
       
   144             finally:
       
   145                 if restart and tasks_manager.running:
       
   146                     self.start()
       
   147         self.func = auto_restart_func
       
   148         self.name = func_name(func)
       
   149 
       
   150     def __str__(self):
       
   151         return '%s (%s seconds)' % (self.name, self.interval)
       
   152 
       
   153     def start(self):
       
   154         self._t = Timer(self.interval, self.func)
       
   155         self._t.setName('%s-%s[%d]' % (self._t.getName(), self.name, self.interval))
       
   156         self._t.start()
       
   157 
       
   158     def cancel(self):
       
   159         self._t.cancel()
       
   160 
       
   161     def join(self):
       
   162         if self._t.isAlive():
       
   163             self._t.join()
       
   164 
       
   165 
       
   166 class RepoThread(Thread):
       
   167     """subclass of thread so it auto remove itself from a given list once
       
   168     executed
       
   169     """
       
   170     def __init__(self, target, running_threads):
       
   171         def auto_remove_func(self=self, func=target):
       
   172             try:
       
   173                 func()
       
   174             except Exception:
       
   175                 logger = logging.getLogger('cubicweb.repository')
       
   176                 logger.exception('Unhandled exception in RepoThread %s', self._name)
       
   177                 raise
       
   178             finally:
       
   179                 self.running_threads.remove(self)
       
   180         Thread.__init__(self, target=auto_remove_func)
       
   181         self.running_threads = running_threads
       
   182         self._name = func_name(target)
       
   183 
       
   184     def start(self):
       
   185         self.running_threads.append(self)
       
   186         self.daemon = True
       
   187         Thread.start(self)
       
   188 
       
   189     def getName(self):
       
   190         return '%s(%s)' % (self._name, Thread.getName(self))
       
   191 
       
   192 class TasksManager(object):
       
   193     """Object dedicated manage background task"""
       
   194 
       
   195     def __init__(self):
       
   196         self.running = False
       
   197         self._tasks = []
       
   198         self._looping_tasks = []
       
   199 
       
   200     def add_looping_task(self, interval, func, *args):
       
   201         """register a function to be called every `interval` seconds.
       
   202 
       
   203         If interval is negative, no looping task is registered.
       
   204         """
       
   205         if interval < 0:
       
   206             self.debug('looping task %s ignored due to interval %f < 0',
       
   207                        func_name(func), interval)
       
   208             return
       
   209         task = LoopTask(self, interval, func, args)
       
   210         if self.running:
       
   211             self._start_task(task)
       
   212         else:
       
   213             self._tasks.append(task)
       
   214 
       
   215     def _start_task(self, task):
       
   216         self._looping_tasks.append(task)
       
   217         self.info('starting task %s with interval %.2fs', task.name,
       
   218                   task.interval)
       
   219         task.start()
       
   220 
       
   221     def start(self):
       
   222         """Start running looping task"""
       
   223         assert self.running == False # bw compat purpose maintly
       
   224         while self._tasks:
       
   225             task = self._tasks.pop()
       
   226             self._start_task(task)
       
   227         self.running = True
       
   228 
       
   229     def stop(self):
       
   230         """Stop all running task.
       
   231 
       
   232         returns when all task have been cancel and none are running anymore"""
       
   233         if self.running:
       
   234             while self._looping_tasks:
       
   235                 looptask = self._looping_tasks.pop()
       
   236                 self.info('canceling task %s...', looptask.name)
       
   237                 looptask.cancel()
       
   238                 looptask.join()
       
   239                 self.info('task %s finished', looptask.name)
       
   240 
       
   241 from logging import getLogger
       
   242 from cubicweb import set_log_methods
       
   243 set_log_methods(TasksManager, getLogger('cubicweb.repository'))