server/cwzmq.py
changeset 8350 e1c05bf6fdeb
parent 8211 543e1579ba0d
child 8388 c6c624cea870
--- a/server/cwzmq.py	Tue Apr 10 17:03:19 2012 +0200
+++ b/server/cwzmq.py	Wed Apr 04 16:51:09 2012 +0200
@@ -18,12 +18,16 @@
 # with CubicWeb.  If not, see <http://www.gnu.org/licenses/>.
 
 from threading import Thread
+import cPickle
+import traceback
+
 import zmq
 from zmq.eventloop import ioloop
 import zmq.eventloop.zmqstream
 
 from logging import getLogger
 from cubicweb import set_log_methods
+from cubicweb.server.server import QuitEvent
 
 ctx = zmq.Context()
 
@@ -105,5 +109,133 @@
         self.ioloop.add_callback(lambda: self.stream.setsockopt(zmq.SUBSCRIBE, topic))
 
 
+class ZMQRepositoryServer(object):
+
+    def __init__(self, repository):
+        """make the repository available as a PyRO object"""
+        self.address = None
+        self.repo = repository
+        self.socket = None
+        self.stream = None
+        self.loop = None
+
+        # event queue
+        self.events = []
+
+    def connect(self, address):
+        self.address = address
+
+    def run(self):
+        """enter the service loop"""
+        # start repository looping tasks
+        self.socket = ctx.socket(zmq.REP)
+        self.loop = ioloop.IOLoop()
+        self.stream = zmq.eventloop.zmqstream.ZMQStream(self.socket, io_loop=self.loop)
+        self.stream.bind(self.address)
+        self.info('ZMQ server bound on: %s', self.address)
+
+        self.stream.on_recv(self.process_cmds)
+
+        try:
+            self.loop.start()
+        except zmq.ZMQError:
+            self.warning('ZMQ event loop killed')
+        self.quit()
+
+    def trigger_events(self):
+        """trigger ready events"""
+        for event in self.events[:]:
+            if event.is_ready():
+                self.info('starting event %s', event)
+                event.fire(self)
+                try:
+                    event.update()
+                except Finished:
+                    self.events.remove(event)
+
+    def process_cmd(self, cmd):
+        """Delegate the given command to the repository.
+
+        ``cmd`` is a list of (method_name, args, kwargs)
+        where ``args`` is a list of positional arguments
+        and ``kwargs`` is a dictionnary of named arguments.
+
+        >>> rset = delegate_to_repo(["execute", [sessionid], {'rql': rql}])
+
+        :note1: ``kwargs`` may be ommited
+
+            >>> rset = delegate_to_repo(["execute", [sessionid, rql]])
+
+        :note2: both ``args`` and ``kwargs`` may be omitted
+
+            >>> schema = delegate_to_repo(["get_schema"])
+            >>> schema = delegate_to_repo("get_schema") # also allowed
+
+        """
+        cmd = cPickle.loads(cmd)
+        if not cmd:
+            raise AttributeError('function name required')
+        if isinstance(cmd, basestring):
+            cmd = [cmd]
+        if len(cmd) < 2:
+            cmd.append(())
+        if len(cmd) < 3:
+            cmd.append({})
+        cmd  = list(cmd) + [(), {}]
+        funcname, args, kwargs = cmd[:3]
+        result = getattr(self.repo, funcname)(*args, **kwargs)
+        return result
+
+    def process_cmds(self, cmds):
+        """Callback intended to be used with ``on_recv``.
+
+        Call ``delegate_to_repo`` on each command and send a pickled of
+        each result recursively.
+
+        Any exception are catched, pickled and sent.
+        """
+        try:
+            for cmd in cmds:
+                result = self.process_cmd(cmd)
+                self.send_data(result)
+        except Exception, exc:
+            traceback.print_exc()
+            self.send_data(exc)
+
+    def send_data(self, data):
+        self.socket.send_pyobj(data)
+
+    def quit(self, shutdown_repo=False):
+        """stop the server"""
+        self.info('Quitting ZMQ server')
+        try:
+            self.loop.stop()
+            self.stream.on_recv(None)
+            self.stream.close()
+        except Exception, e:
+            print e
+            pass
+        if shutdown_repo and not self.repo.shutting_down:
+            event = QuitEvent()
+            event.fire(self)
+
+    # server utilitities ######################################################
+
+    def install_sig_handlers(self):
+        """install signal handlers"""
+        import signal
+        self.info('installing signal handlers')
+        signal.signal(signal.SIGINT, lambda x, y, s=self: s.quit(shutdown_repo=True))
+        signal.signal(signal.SIGTERM, lambda x, y, s=self: s.quit(shutdown_repo=True))
+
+
+    # these are overridden by set_log_methods below
+    # only defining here to prevent pylint from complaining
+    @classmethod
+    def info(cls, msg, *a, **kw):
+        pass
+
+
 set_log_methods(Publisher, getLogger('cubicweb.zmq.pub'))
 set_log_methods(Subscriber, getLogger('cubicweb.zmq.sub'))
+set_log_methods(ZMQRepositoryServer, getLogger('cubicweb.zmq.repo'))