etwist/server.py
changeset 5155 1dea6e0fdfc1
parent 5062 5691fd8697cd
child 5216 4f4369e63f5e
--- a/etwist/server.py	Wed Mar 31 17:02:51 2010 +0200
+++ b/etwist/server.py	Tue Apr 06 16:50:53 2010 +0200
@@ -14,19 +14,25 @@
 from time import mktime
 from datetime import date, timedelta
 from urlparse import urlsplit, urlunsplit
+from cgi import FieldStorage, parse_header
 
 from twisted.internet import reactor, task, threads
 from twisted.internet.defer import maybeDeferred
-from twisted.web2 import channel, http, server, iweb
-from twisted.web2 import static, resource, responsecode
+from twisted.web import http, server
+from twisted.web import static, resource
+from twisted.web.server import NOT_DONE_YET
+
+from logilab.common.decorators import monkeypatch
 
 from cubicweb import ConfigurationError, CW_EVENT_MANAGER
 from cubicweb.web import (AuthenticationError, NotFound, Redirect,
                           RemoteCallFailed, DirectResponse, StatusResponse,
                           ExplicitLogin)
+
 from cubicweb.web.application import CubicWebPublisher
 
 from cubicweb.etwist.request import CubicWebTwistedRequestAdapter
+from cubicweb.etwist.http import HTTPResponse
 
 def daemonize():
     # XXX unix specific
@@ -67,8 +73,20 @@
     return baseurl
 
 
-class LongTimeExpiringFile(static.File):
-    """overrides static.File and sets a far futre ``Expires`` date
+class ForbiddenDirectoryLister(resource.Resource):
+    def render(self, request):
+        return HTTPResponse(twisted_request=request,
+                            code=http.FORBIDDEN,
+                            stream='Access forbidden')
+
+class File(static.File):
+    """Prevent from listing directories"""
+    def directoryListing(self):
+        return ForbiddenDirectoryLister()
+
+
+class LongTimeExpiringFile(File):
+    """overrides static.File and sets a far future ``Expires`` date
     on the resouce.
 
     versions handling is done by serving static files by different
@@ -79,22 +97,19 @@
       etc.
 
     """
-    def renderHTTP(self, request):
+    def render(self, request):
         def setExpireHeader(response):
-            response = iweb.IResponse(response)
             # Don't provide additional resource information to error responses
             if response.code < 400:
                 # the HTTP RFC recommands not going further than 1 year ahead
                 expires = date.today() + timedelta(days=6*30)
                 response.headers.setHeader('Expires', mktime(expires.timetuple()))
             return response
-        d = maybeDeferred(super(LongTimeExpiringFile, self).renderHTTP, request)
+        d = maybeDeferred(super(LongTimeExpiringFile, self).render, request)
         return d.addCallback(setExpireHeader)
 
 
-class CubicWebRootResource(resource.PostableResource):
-    addSlash = False
-
+class CubicWebRootResource(resource.Resource):
     def __init__(self, config, debug=None):
         self.debugmode = debug
         self.config = config
@@ -104,6 +119,7 @@
         self.base_url = config['base-url']
         self.https_url = config['https-url']
         self.versioned_datadir = 'data%s' % config.instance_md5_version()
+        self.children = {}
 
     def init_publisher(self):
         config = self.config
@@ -145,35 +161,35 @@
         except select.error:
             return
 
-    def locateChild(self, request, segments):
+    def getChild(self, path, request):
         """Indicate which resource to use to process down the URL's path"""
-        if segments:
-            if segments[0] == 'https':
-                segments = segments[1:]
-            if len(segments) >= 2:
-                if segments[0] in (self.versioned_datadir, 'data', 'static'):
-                    # Anything in data/, static/ is treated as static files
-                    if segments[0] == 'static':
-                        # instance static directory
-                        datadir = self.config.static_directory
-                    elif segments[1] == 'fckeditor':
-                        fckeditordir = self.config.ext_resources['FCKEDITOR_PATH']
-                        return static.File(fckeditordir), segments[2:]
-                    else:
-                        # cube static data file
-                        datadir = self.config.locate_resource(segments[1])
-                        if datadir is None:
-                            return None, []
-                    self.debug('static file %s from %s', segments[-1], datadir)
-                    if segments[0] == 'data':
-                        return static.File(str(datadir)), segments[1:]
-                    else:
-                        return LongTimeExpiringFile(datadir), segments[1:]
-                elif segments[0] == 'fckeditor':
-                    fckeditordir = self.config.ext_resources['FCKEDITOR_PATH']
-                    return static.File(fckeditordir), segments[1:]
+        pre_path = request.prePathURL()
+        # XXX testing pre_path[0] not enough?
+        if any(s in pre_path
+               for s in (self.versioned_datadir, 'data', 'static')):
+            # Anything in data/, static/ is treated as static files
+
+            if 'static' in pre_path:
+                # instance static directory
+                datadir = self.config.static_directory
+            elif 'fckeditor' in pre_path:
+                fckeditordir = self.config.ext_resources['FCKEDITOR_PATH']
+                return File(fckeditordir)
+            else:
+                # cube static data file
+                datadir = self.config.locate_resource(path)
+                if datadir is None:
+                    return self
+            self.info('static file %s from %s', path, datadir)
+            if 'data' in pre_path:
+                return File(os.path.join(datadir, path))
+            else:
+                return LongTimeExpiringFile(datadir)
+        elif path == 'fckeditor':
+            fckeditordir = self.config.ext_resources['FCKEDITOR_PATH']
+            return File(fckeditordir)
         # Otherwise we use this single resource
-        return self, ()
+        return self
 
     def render(self, request):
         """Render a page from the root resource"""
@@ -183,7 +199,8 @@
         if self.config['profile']: # default profiler don't trace threads
             return self.render_request(request)
         else:
-            return threads.deferToThread(self.render_request, request)
+            deferred = threads.deferToThread(self.render_request, request)
+            return NOT_DONE_YET
 
     def render_request(self, request):
         origpath = request.path
@@ -209,12 +226,12 @@
         try:
             self.appli.connect(req)
         except AuthenticationError:
-            return self.request_auth(req)
+            return self.request_auth(request=req)
         except Redirect, ex:
-            return self.redirect(req, ex.location)
+            return self.redirect(request=req, location=ex.location)
         if https and req.cnx.anonymous_connection:
             # don't allow anonymous on https connection
-            return self.request_auth(req)
+            return self.request_auth(request=req)
         if self.url_rewriter is not None:
             # XXX should occur before authentication?
             try:
@@ -231,234 +248,115 @@
         except DirectResponse, ex:
             return ex.response
         except StatusResponse, ex:
-            return http.Response(stream=ex.content, code=ex.status,
-                                 headers=req.headers_out or None)
+            return HTTPResponse(stream=ex.content, code=ex.status,
+                                twisted_request=req._twreq,
+                                headers=req.headers_out)
         except RemoteCallFailed, ex:
             req.set_header('content-type', 'application/json')
-            return http.Response(stream=ex.dumps(),
-                                 code=responsecode.INTERNAL_SERVER_ERROR)
+            return HTTPResponse(twisted_request=req._twreq, code=http.INTERNAL_SERVER_ERROR,
+                                stream=ex.dumps(), headers=req.headers_out)
         except NotFound:
             result = self.appli.notfound_content(req)
-            return http.Response(stream=result, code=responsecode.NOT_FOUND,
-                                 headers=req.headers_out or None)
+            return HTTPResponse(twisted_request=req._twreq, code=http.NOT_FOUND,
+                                stream=result, headers=req.headers_out)
+
         except ExplicitLogin:  # must be before AuthenticationError
-            return self.request_auth(req)
+            return self.request_auth(request=req)
         except AuthenticationError, ex:
             if self.config['auth-mode'] == 'cookie' and getattr(ex, 'url', None):
-                return self.redirect(req, ex.url)
+                return self.redirect(request=req, location=ex.url)
             # in http we have to request auth to flush current http auth
             # information
-            return self.request_auth(req, loggedout=True)
+            return self.request_auth(request=req, loggedout=True)
         except Redirect, ex:
-            return self.redirect(req, ex.location)
+            return self.redirect(request=req, location=ex.location)
         # request may be referenced by "onetime callback", so clear its entity
         # cache to avoid memory usage
         req.drop_entity_cache()
-        return http.Response(stream=result, code=responsecode.OK,
-                             headers=req.headers_out or None)
 
-    def redirect(self, req, location):
-        req.headers_out.setHeader('location', str(location))
-        self.debug('redirecting to %s', location)
-        # 303 See other
-        return http.Response(code=303, headers=req.headers_out)
+        return HTTPResponse(twisted_request=req._twreq, code=http.OK,
+                            stream=result, headers=req.headers_out)
 
-    def request_auth(self, req, loggedout=False):
-        if self.https_url and req.base_url() != self.https_url:
-            req.headers_out.setHeader('location', self.https_url + 'login')
-            return http.Response(code=303, headers=req.headers_out)
+    def redirect(self, request, location):
+        self.debug('redirecting to %s', str(location))
+        request.headers_out.setHeader('location', str(location))
+        # 303 See other
+        return HTTPResponse(twisted_request=request._twreq, code=303,
+                            headers=request.headers_out)
+
+    def request_auth(self, request, loggedout=False):
+        if self.https_url and request.base_url() != self.https_url:
+            return self.redirect(request, self.https_url + 'login')
         if self.config['auth-mode'] == 'http':
-            code = responsecode.UNAUTHORIZED
+            code = http.UNAUTHORIZED
         else:
-            code = responsecode.FORBIDDEN
+            code = http.FORBIDDEN
         if loggedout:
-            if req.https:
-                req._base_url =  self.base_url
-                req.https = False
-            content = self.appli.loggedout_content(req)
+            if request.https:
+                request._base_url =  self.base_url
+                request.https = False
+            content = self.appli.loggedout_content(request)
         else:
-            content = self.appli.need_login_content(req)
-        return http.Response(code, req.headers_out, content)
+            content = self.appli.need_login_content(request)
+        return HTTPResponse(twisted_request=request._twreq,
+                            stream=content, code=code,
+                            headers=request.headers_out)
 
-from twisted.internet import defer
-from twisted.web2 import fileupload
+#TODO
+# # XXX max upload size in the configuration
 
-# XXX set max file size to 100Mo: put max upload size in the configuration
-# line below for twisted >= 8.0, default param value for earlier version
-resource.PostableResource.maxSize = 100*1024*1024
-def parsePOSTData(request, maxMem=100*1024, maxFields=1024,
-                  maxSize=100*1024*1024):
-    if request.stream.length == 0:
-        return defer.succeed(None)
+@monkeypatch(http.Request)
+def requestReceived(self, command, path, version):
+    """Called by channel when all data has been received.
 
-    ctype = request.headers.getHeader('content-type')
-
-    if ctype is None:
-        return defer.succeed(None)
-
-    def updateArgs(data):
-        args = data
-        request.args.update(args)
-
-    def updateArgsAndFiles(data):
-        args, files = data
-        request.args.update(args)
-        request.files.update(files)
-
-    def error(f):
-        f.trap(fileupload.MimeFormatError)
-        raise http.HTTPError(responsecode.BAD_REQUEST)
-
-    if ctype.mediaType == 'application' and ctype.mediaSubtype == 'x-www-form-urlencoded':
-        d = fileupload.parse_urlencoded(request.stream, keep_blank_values=True)
-        d.addCallbacks(updateArgs, error)
-        return d
-    elif ctype.mediaType == 'multipart' and ctype.mediaSubtype == 'form-data':
-        boundary = ctype.params.get('boundary')
-        if boundary is None:
-            return defer.fail(http.HTTPError(
-                http.StatusResponse(responsecode.BAD_REQUEST,
-                                    "Boundary not specified in Content-Type.")))
-        d = fileupload.parseMultipartFormData(request.stream, boundary,
-                                              maxMem, maxFields, maxSize)
-        d.addCallbacks(updateArgsAndFiles, error)
-        return d
+    This method is not intended for users.
+    """
+    self.content.seek(0,0)
+    self.args = {}
+    self.files = {}
+    self.stack = []
+    self.method, self.uri = command, path
+    self.clientproto = version
+    x = self.uri.split('?', 1)
+    if len(x) == 1:
+        self.path = self.uri
     else:
-        raise http.HTTPError(responsecode.BAD_REQUEST)
-
-server.parsePOSTData = parsePOSTData
+        self.path, argstring = x
+        self.args = http.parse_qs(argstring, 1)
+    # cache the client and server information, we'll need this later to be
+    # serialized and sent with the request so CGIs will work remotely
+    self.client = self.channel.transport.getPeer()
+    self.host = self.channel.transport.getHost()
+    # Argument processing
+    ctype = self.getHeader('content-type')
+    if self.method == "POST" and ctype:
+        key, pdict = parse_header(ctype)
+        if key == 'application/x-www-form-urlencoded':
+            self.args.update(http.parse_qs(self.content.read(), 1))
+        elif key == 'multipart/form-data':
+            self.content.seek(0,0)
+            form = FieldStorage(self.content, self.received_headers,
+                                environ={'REQUEST_METHOD': 'POST'},
+                                keep_blank_values=1,
+                                strict_parsing=1)
+            for key in form:
+                value = form[key]
+                if isinstance(value, list):
+                    self.args[key] = [v.value for v in value]
+                elif value.filename:
+                    if value.done != -1: # -1 is transfer has been interrupted
+                        self.files[key] = (value.filename, value.file)
+                    else:
+                        self.files[key] = (None, None)
+                else:
+                    self.args[key] = value.value
+    self.process()
 
 
 from logging import getLogger
 from cubicweb import set_log_methods
-set_log_methods(CubicWebRootResource, getLogger('cubicweb.twisted'))
-
-
-listiterator = type(iter([]))
-
-def _gc_debug(all=True):
-    import gc
-    from pprint import pprint
-    from cubicweb.appobject import AppObject
-    gc.collect()
-    count = 0
-    acount = 0
-    fcount = 0
-    rcount = 0
-    ccount = 0
-    scount = 0
-    ocount = {}
-    from rql.stmts import Union
-    from cubicweb.schema import CubicWebSchema
-    from cubicweb.rset import ResultSet
-    from cubicweb.dbapi import Connection, Cursor
-    from cubicweb.req import RequestSessionBase
-    from cubicweb.server.repository import Repository
-    from cubicweb.server.sources.native import NativeSQLSource
-    from cubicweb.server.session import Session
-    from cubicweb.devtools.testlib import CubicWebTC
-    from logilab.common.testlib import TestSuite
-    from optparse import Values
-    import types, weakref
-    for obj in gc.get_objects():
-        if isinstance(obj, RequestSessionBase):
-            count += 1
-            if isinstance(obj, Session):
-                print '   session', obj, referrers(obj, True)
-        elif isinstance(obj, AppObject):
-            acount += 1
-        elif isinstance(obj, ResultSet):
-            rcount += 1
-            #print '   rset', obj, referrers(obj)
-        elif isinstance(obj, Repository):
-            print '   REPO', obj, referrers(obj, True)
-        #elif isinstance(obj, NativeSQLSource):
-        #    print '   SOURCe', obj, referrers(obj)
-        elif isinstance(obj, CubicWebTC):
-            print '   TC', obj, referrers(obj)
-        elif isinstance(obj, TestSuite):
-            print '   SUITE', obj, referrers(obj)
-        #elif isinstance(obj, Values):
-        #    print '   values', '%#x' % id(obj), referrers(obj, True)
-        elif isinstance(obj, Connection):
-            ccount += 1
-            #print '   cnx', obj, referrers(obj)
-        #elif isinstance(obj, Cursor):
-        #    ccount += 1
-        #    print '   cursor', obj, referrers(obj)
-        elif isinstance(obj, file):
-            fcount += 1
-        #    print '   open file', file.name, file.fileno
-        elif isinstance(obj, CubicWebSchema):
-            scount += 1
-            print '   schema', obj, referrers(obj)
-        elif not isinstance(obj, (type, tuple, dict, list, set, frozenset,
-                                  weakref.ref, weakref.WeakKeyDictionary,
-                                  listiterator,
-                                  property, classmethod,
-                                  types.ModuleType, types.MemberDescriptorType,
-                                  types.FunctionType, types.MethodType)):
-            try:
-                ocount[obj.__class__] += 1
-            except KeyError:
-                ocount[obj.__class__] = 1
-            except AttributeError:
-                pass
-    if count:
-        print ' NB REQUESTS/SESSIONS', count
-    if acount:
-        print ' NB APPOBJECTS', acount
-    if ccount:
-        print ' NB CONNECTIONS', ccount
-    if rcount:
-        print ' NB RSETS', rcount
-    if scount:
-        print ' NB SCHEMAS', scount
-    if fcount:
-        print ' NB FILES', fcount
-    if all:
-        ocount = sorted(ocount.items(), key=lambda x: x[1], reverse=True)[:20]
-        pprint(ocount)
-    if gc.garbage:
-        print 'UNREACHABLE', gc.garbage
-
-def referrers(obj, showobj=False):
-    try:
-        return sorted(set((type(x), showobj and x or getattr(x, '__name__', '%#x' % id(x)))
-                          for x in _referrers(obj)))
-    except TypeError:
-        s = set()
-        unhashable = []
-        for x in _referrers(obj):
-            try:
-                s.add(x)
-            except TypeError:
-                unhashable.append(x)
-        return sorted(s) + unhashable
-
-def _referrers(obj, seen=None, level=0):
-    import gc, types
-    from cubicweb.schema import CubicWebRelationSchema, CubicWebEntitySchema
-    interesting = []
-    if seen is None:
-        seen = set()
-    for x in gc.get_referrers(obj):
-        if id(x) in seen:
-            continue
-        seen.add(id(x))
-        if isinstance(x, types.FrameType):
-            continue
-        if isinstance(x, (CubicWebRelationSchema, CubicWebEntitySchema)):
-            continue
-        if isinstance(x, (list, tuple, set, dict, listiterator)):
-            if level >= 5:
-                pass
-                #interesting.append(x)
-            else:
-                interesting += _referrers(x, seen, level+1)
-        else:
-            interesting.append(x)
-    return interesting
+LOGGER = getLogger('cubicweb.twisted')
+set_log_methods(CubicWebRootResource, LOGGER)
 
 def run(config, debug):
     # create the site
@@ -466,7 +364,7 @@
     website = server.Site(root_resource)
     # serve it via standard HTTP on port set in the configuration
     port = config['port'] or 8080
-    reactor.listenTCP(port, channel.HTTPFactory(website))
+    reactor.listenTCP(port, website)
     logger = getLogger('cubicweb.twisted')
     if not debug:
         if sys.platform == 'win32':