1 # copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved. |
|
2 # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr |
|
3 # |
|
4 # This file is part of CubicWeb. |
|
5 # |
|
6 # CubicWeb is free software: you can redistribute it and/or modify it under the |
|
7 # terms of the GNU Lesser General Public License as published by the Free |
|
8 # Software Foundation, either version 2.1 of the License, or (at your option) |
|
9 # any later version. |
|
10 # |
|
11 # CubicWeb is distributed in the hope that it will be useful, but WITHOUT |
|
12 # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS |
|
13 # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more |
|
14 # details. |
|
15 # |
|
16 # You should have received a copy of the GNU Lesser General Public License along |
|
17 # with CubicWeb. If not, see <http://www.gnu.org/licenses/>. |
|
18 """Some utilities for the CubicWeb server.""" |
|
19 from __future__ import print_function |
|
20 |
|
21 __docformat__ = "restructuredtext en" |
|
22 |
|
23 import sys |
|
24 import logging |
|
25 from threading import Timer, Thread |
|
26 from getpass import getpass |
|
27 |
|
28 from six import PY2, text_type |
|
29 from six.moves import input |
|
30 |
|
31 from passlib.utils import handlers as uh, to_hash_str |
|
32 from passlib.context import CryptContext |
|
33 |
|
34 from cubicweb.md5crypt import crypt as md5crypt |
|
35 |
|
36 |
|
37 class CustomMD5Crypt(uh.HasSalt, uh.GenericHandler): |
|
38 name = 'cubicwebmd5crypt' |
|
39 setting_kwds = ('salt',) |
|
40 min_salt_size = 0 |
|
41 max_salt_size = 8 |
|
42 salt_chars = uh.H64_CHARS |
|
43 |
|
44 @classmethod |
|
45 def from_string(cls, hash): |
|
46 salt, chk = uh.parse_mc2(hash, u'') |
|
47 if chk is None: |
|
48 raise ValueError('missing checksum') |
|
49 return cls(salt=salt, checksum=chk) |
|
50 |
|
51 def to_string(self): |
|
52 return to_hash_str(u'%s$%s' % (self.salt, self.checksum or u'')) |
|
53 |
|
54 # passlib 1.5 wants calc_checksum, 1.6 wants _calc_checksum |
|
55 def calc_checksum(self, secret): |
|
56 return md5crypt(secret, self.salt.encode('ascii')).decode('utf-8') |
|
57 _calc_checksum = calc_checksum |
|
58 |
|
59 _CRYPTO_CTX = CryptContext(['sha512_crypt', CustomMD5Crypt, 'des_crypt', 'ldap_salted_sha1'], |
|
60 deprecated=['cubicwebmd5crypt', 'des_crypt']) |
|
61 verify_and_update = _CRYPTO_CTX.verify_and_update |
|
62 |
|
63 def crypt_password(passwd, salt=None): |
|
64 """return the encrypted password using the given salt or a generated one |
|
65 """ |
|
66 if salt is None: |
|
67 return _CRYPTO_CTX.encrypt(passwd).encode('ascii') |
|
68 # empty hash, accept any password for backwards compat |
|
69 if salt == '': |
|
70 return salt |
|
71 try: |
|
72 if _CRYPTO_CTX.verify(passwd, salt): |
|
73 return salt |
|
74 except ValueError: # e.g. couldn't identify hash |
|
75 pass |
|
76 # wrong password |
|
77 return b'' |
|
78 |
|
79 |
|
80 def eschema_eid(cnx, eschema): |
|
81 """get eid of the CWEType entity for the given yams type. You should use |
|
82 this because when schema has been loaded from the file-system, not from the |
|
83 database, (e.g. during tests), eschema.eid is not set. |
|
84 """ |
|
85 if eschema.eid is None: |
|
86 eschema.eid = cnx.execute( |
|
87 'Any X WHERE X is CWEType, X name %(name)s', |
|
88 {'name': text_type(eschema)})[0][0] |
|
89 return eschema.eid |
|
90 |
|
91 |
|
92 DEFAULT_MSG = 'we need a manager connection on the repository \ |
|
93 (the server doesn\'t have to run, even should better not)' |
|
94 |
|
95 def manager_userpasswd(user=None, msg=DEFAULT_MSG, confirm=False, |
|
96 passwdmsg='password'): |
|
97 if not user: |
|
98 if msg: |
|
99 print(msg) |
|
100 while not user: |
|
101 user = input('login: ') |
|
102 if PY2: |
|
103 user = unicode(user, sys.stdin.encoding) |
|
104 passwd = getpass('%s: ' % passwdmsg) |
|
105 if confirm: |
|
106 while True: |
|
107 passwd2 = getpass('confirm password: ') |
|
108 if passwd == passwd2: |
|
109 break |
|
110 print('password doesn\'t match') |
|
111 passwd = getpass('password: ') |
|
112 # XXX decode password using stdin encoding then encode it using appl'encoding |
|
113 return user, passwd |
|
114 |
|
115 |
|
116 _MARKER = object() |
|
117 def func_name(func): |
|
118 name = getattr(func, '__name__', _MARKER) |
|
119 if name is _MARKER: |
|
120 name = getattr(func, 'func_name', _MARKER) |
|
121 if name is _MARKER: |
|
122 name = repr(func) |
|
123 return name |
|
124 |
|
125 class LoopTask(object): |
|
126 """threaded task restarting itself once executed""" |
|
127 def __init__(self, tasks_manager, interval, func, args): |
|
128 if interval < 0: |
|
129 raise ValueError('Loop task interval must be >= 0 ' |
|
130 '(current value: %f for %s)' % \ |
|
131 (interval, func_name(func))) |
|
132 self._tasks_manager = tasks_manager |
|
133 self.interval = interval |
|
134 def auto_restart_func(self=self, func=func, args=args): |
|
135 restart = True |
|
136 try: |
|
137 func(*args) |
|
138 except Exception: |
|
139 logger = logging.getLogger('cubicweb.repository') |
|
140 logger.exception('Unhandled exception in LoopTask %s', self.name) |
|
141 raise |
|
142 except BaseException: |
|
143 restart = False |
|
144 finally: |
|
145 if restart and tasks_manager.running: |
|
146 self.start() |
|
147 self.func = auto_restart_func |
|
148 self.name = func_name(func) |
|
149 |
|
150 def __str__(self): |
|
151 return '%s (%s seconds)' % (self.name, self.interval) |
|
152 |
|
153 def start(self): |
|
154 self._t = Timer(self.interval, self.func) |
|
155 self._t.setName('%s-%s[%d]' % (self._t.getName(), self.name, self.interval)) |
|
156 self._t.start() |
|
157 |
|
158 def cancel(self): |
|
159 self._t.cancel() |
|
160 |
|
161 def join(self): |
|
162 if self._t.isAlive(): |
|
163 self._t.join() |
|
164 |
|
165 |
|
166 class RepoThread(Thread): |
|
167 """subclass of thread so it auto remove itself from a given list once |
|
168 executed |
|
169 """ |
|
170 def __init__(self, target, running_threads): |
|
171 def auto_remove_func(self=self, func=target): |
|
172 try: |
|
173 func() |
|
174 except Exception: |
|
175 logger = logging.getLogger('cubicweb.repository') |
|
176 logger.exception('Unhandled exception in RepoThread %s', self._name) |
|
177 raise |
|
178 finally: |
|
179 self.running_threads.remove(self) |
|
180 Thread.__init__(self, target=auto_remove_func) |
|
181 self.running_threads = running_threads |
|
182 self._name = func_name(target) |
|
183 |
|
184 def start(self): |
|
185 self.running_threads.append(self) |
|
186 self.daemon = True |
|
187 Thread.start(self) |
|
188 |
|
189 def getName(self): |
|
190 return '%s(%s)' % (self._name, Thread.getName(self)) |
|
191 |
|
192 class TasksManager(object): |
|
193 """Object dedicated manage background task""" |
|
194 |
|
195 def __init__(self): |
|
196 self.running = False |
|
197 self._tasks = [] |
|
198 self._looping_tasks = [] |
|
199 |
|
200 def add_looping_task(self, interval, func, *args): |
|
201 """register a function to be called every `interval` seconds. |
|
202 |
|
203 If interval is negative, no looping task is registered. |
|
204 """ |
|
205 if interval < 0: |
|
206 self.debug('looping task %s ignored due to interval %f < 0', |
|
207 func_name(func), interval) |
|
208 return |
|
209 task = LoopTask(self, interval, func, args) |
|
210 if self.running: |
|
211 self._start_task(task) |
|
212 else: |
|
213 self._tasks.append(task) |
|
214 |
|
215 def _start_task(self, task): |
|
216 self._looping_tasks.append(task) |
|
217 self.info('starting task %s with interval %.2fs', task.name, |
|
218 task.interval) |
|
219 task.start() |
|
220 |
|
221 def start(self): |
|
222 """Start running looping task""" |
|
223 assert self.running == False # bw compat purpose maintly |
|
224 while self._tasks: |
|
225 task = self._tasks.pop() |
|
226 self._start_task(task) |
|
227 self.running = True |
|
228 |
|
229 def stop(self): |
|
230 """Stop all running task. |
|
231 |
|
232 returns when all task have been cancel and none are running anymore""" |
|
233 if self.running: |
|
234 while self._looping_tasks: |
|
235 looptask = self._looping_tasks.pop() |
|
236 self.info('canceling task %s...', looptask.name) |
|
237 looptask.cancel() |
|
238 looptask.join() |
|
239 self.info('task %s finished', looptask.name) |
|
240 |
|
241 from logging import getLogger |
|
242 from cubicweb import set_log_methods |
|
243 set_log_methods(TasksManager, getLogger('cubicweb.repository')) |
|