[Web-Request] Use rich header (closes #2204164)
authorPierre-Yves David <pierre-yves.david@logilab.fr>
Thu, 15 Mar 2012 18:34:59 +0100
changeset 8314 cfd6ab461849
parent 8313 386b6313de28
child 8315 166e6d5d8e17
[Web-Request] Use rich header (closes #2204164) Unify header management. All web request use the Headers class now (imported from twisted). Code dedicated to header management have been merged into the base WebRequest class.
devtools/fake.py
etwist/request.py
web/request.py
wsgi/handler.py
wsgi/request.py
--- a/devtools/fake.py	Thu Mar 15 17:54:40 2012 +0100
+++ b/devtools/fake.py	Thu Mar 15 18:34:59 2012 +0100
@@ -64,7 +64,6 @@
         self._url = kwargs.pop('url', None) or 'view?rql=Blop&vid=blop'
         super(FakeRequest, self).__init__(*args, **kwargs)
         self._session_data = {}
-        self._headers_in = Headers()
 
     def set_cookie(self, name, value, maxage=300, expires=None, secure=False):
         super(FakeRequest, self).set_cookie(name, value, maxage, expires, secure)
@@ -92,32 +91,23 @@
             return url
         return url.split('?', 1)[0]
 
-    def get_header(self, header, default=None, raw=True):
-        """return the value associated with the given input header, raise
-        KeyError if the header is not set
-        """
-        if raw:
-            return self._headers_in.getRawHeaders(header, [default])[0]
-        return self._headers_in.getHeader(header, default)
-
-    ## extend request API to control headers in / out values
     def set_request_header(self, header, value, raw=False):
-        """set an input HTTP header"""
+        """set an incoming HTTP header (For test purpose only)"""
         if isinstance(value, basestring):
             value = [value]
-        if raw:
+        if raw: #
+            # adding encoded header is important, else page content
+            # will be reconverted back to unicode and apart unefficiency, this
+            # may cause decoding problem (e.g. when downloading a file)
             self._headers_in.setRawHeaders(header, value)
-        else:
-            self._headers_in.setHeader(header, value)
+        else: #
+            self._headers_in.setHeader(header, value) #
 
     def get_response_header(self, header, default=None, raw=False):
-        """return the value associated with the given input header,
-        raise KeyError if the header is not set
-        """
-        if raw:
-            return self.headers_out.getRawHeaders(header, default)[0]
-        else:
-            return self.headers_out.getHeader(header, default)
+        """return output header (For test purpose only"""
+        if raw: #
+            return self.headers_out.getRawHeaders(header, [default])[0]
+        return self.headers_out.getHeader(header, default)
 
     def validate_cache(self):
         pass
--- a/etwist/request.py	Thu Mar 15 17:54:40 2012 +0100
+++ b/etwist/request.py	Thu Mar 15 18:34:59 2012 +0100
@@ -33,16 +33,13 @@
 class CubicWebTwistedRequestAdapter(CubicWebRequestBase):
     def __init__(self, req, vreg, https):
         self._twreq = req
-        super(CubicWebTwistedRequestAdapter, self).__init__(vreg, https, req.args)
+        super(CubicWebTwistedRequestAdapter, self).__init__(
+            vreg, https, req.args, headers=req.received_headers)
         for key, (name, stream) in req.files.iteritems():
             if name is None:
                 self.form[key] = (name, stream)
             else:
                 self.form[key] = (unicode(name, self.encoding), stream)
-        # XXX can't we keep received_headers?
-        self._headers_in = Headers()
-        for k, v in req.received_headers.iteritems():
-            self._headers_in.addRawHeader(k, v)
 
     def http_method(self):
         """returns 'POST', 'GET', 'HEAD', etc."""
@@ -61,14 +58,6 @@
             path = path.split('?', 1)[0]
         return path
 
-    def get_header(self, header, default=None, raw=True):
-        """return the value associated with the given input header, raise
-        KeyError if the header is not set
-        """
-        if raw:
-            return self._headers_in.getRawHeaders(header, [default])[0]
-        return self._headers_in.getHeader(header, default)
-
     def _validate_cache(self):
         """raise a `DirectResponse` exception if a cached page along the way
         exists and is still usable
@@ -95,21 +84,3 @@
                 raise DirectResponse(response)
         # Expires header seems to be required by IE7
         self.add_header('Expires', 'Sat, 01 Jan 2000 00:00:00 GMT')
-
-    def header_accept_language(self):
-        """returns an ordered list of preferred languages"""
-        acceptedlangs = self.get_header('Accept-Language', raw=False) or {}
-        for lang, _ in sorted(acceptedlangs.iteritems(), key=lambda x: x[1],
-                              reverse=True):
-            lang = lang.split('-')[0]
-            yield lang
-
-    def header_if_modified_since(self):
-        """If the HTTP header If-modified-since is set, return the equivalent
-        date time value (GMT), else return None
-        """
-        mtime = self.get_header('If-modified-since', raw=False)
-        if mtime:
-            # :/ twisted is returned a localized time stamp
-            return datetime.fromtimestamp(mtime) + GMTOFFSET
-        return None
--- a/web/request.py	Thu Mar 15 17:54:40 2012 +0100
+++ b/web/request.py	Thu Mar 15 18:34:59 2012 +0100
@@ -25,7 +25,7 @@
 from hashlib import sha1 # pylint: disable=E0611
 from Cookie import SimpleCookie
 from calendar import timegm
-from datetime import date
+from datetime import date, datetime
 from urlparse import urlsplit
 from itertools import count
 from warnings import warn
@@ -86,7 +86,7 @@
     """
     ajax_request = False # to be set to True by ajax controllers
 
-    def __init__(self, vreg, https=False, form=None):
+    def __init__(self, vreg, https=False, form=None, headers={}):
         """
         :vreg: Vregistry,
         :https: boolean, s this a https request
@@ -107,6 +107,10 @@
             self.datadir_url = vreg.config.datadir_url
         #: raw html headers that can be added from any view
         self.html_headers = HTMLHead(self)
+        #: received headers
+        self._headers_in = Headers()
+        for k, v in headers.iteritems():
+            self._headers_in.addRawHeader(k, v)
         #: form parameters
         self.setup_params(form)
         #: dictionary that may be used to store request data that has to be
@@ -305,7 +309,6 @@
             form = self.form
         return list_form_param(form, param, pop)
 
-
     def reset_headers(self):
         """used by AutomaticWebTest to clear html headers between tests on
         the same resultset
@@ -778,12 +781,37 @@
         """
         raise NotImplementedError()
 
-    def get_header(self, header, default=None):
-        """return the value associated with the given input HTTP header,
-        raise KeyError if the header is not set
+    # http headers ############################################################
+
+    ### incoming headers
+
+    def get_header(self, header, default=None, raw=True):
+        """return the value associated with the given input header, raise
+        KeyError if the header is not set
         """
-        raise NotImplementedError()
+        if raw:
+            return self._headers_in.getRawHeaders(header, [default])[0]
+        return self._headers_in.getHeader(header, default)
+
+    def header_accept_language(self):
+        """returns an ordered list of preferred languages"""
+        acceptedlangs = self.get_header('Accept-Language', raw=False) or {}
+        for lang, _ in sorted(acceptedlangs.iteritems(), key=lambda x: x[1],
+                              reverse=True):
+            lang = lang.split('-')[0]
+            yield lang
 
+    def header_if_modified_since(self):
+        """If the HTTP header If-modified-since is set, return the equivalent
+        date time value (GMT), else return None
+        """
+        mtime = self.get_header('If-modified-since', raw=False)
+        if mtime:
+            # :/ twisted is returned a localized time stamp
+            return datetime.fromtimestamp(mtime) + GMTOFFSET
+        return None
+
+    ### outcoming headers
     def set_header(self, header, value, raw=True):
         """set an output HTTP header"""
         if raw:
@@ -831,12 +859,6 @@
         values = _parse_accept_header(accepteds, value_parser, value_sort_key)
         return (raw_value for (raw_value, parsed_value, score) in values)
 
-    def header_if_modified_since(self):
-        """If the HTTP header If-modified-since is set, return the equivalent
-        mx date time value (GMT), else return None
-        """
-        raise NotImplementedError()
-
     def demote_to_html(self):
         """helper method to dynamically set request content type to text/html
 
@@ -851,6 +873,8 @@
             self.set_content_type('text/html')
             self.main_stream.set_doctype(TRANSITIONAL_DOCTYPE_NOEXT)
 
+    # xml doctype #############################################################
+
     def set_doctype(self, doctype, reset_xmldecl=True):
         """helper method to dynamically change page doctype
 
--- a/wsgi/handler.py	Thu Mar 15 17:54:40 2012 +0100
+++ b/wsgi/handler.py	Thu Mar 15 18:34:59 2012 +0100
@@ -19,8 +19,12 @@
 
 """
 
+
+
 __docformat__ = "restructuredtext en"
 
+from itertools import chain, repeat, izip
+
 from cubicweb import AuthenticationError
 from cubicweb.web import Redirect, DirectResponse, StatusResponse, LogOut
 from cubicweb.web.application import CubicWebPublisher
@@ -71,7 +75,6 @@
     505: 'HTTP VERSION NOT SUPPORTED',
 }
 
-
 class WSGIResponse(object):
     """encapsulates the wsgi response parameters
     (code, headers and body if there is one)
@@ -79,7 +82,9 @@
     def __init__(self, code, req, body=None):
         text = STATUS_CODE_TEXT.get(code, 'UNKNOWN STATUS CODE')
         self.status =  '%s %s' % (code, text)
-        self.headers = [(str(k), str(v)) for k, v in req.headers_out.items()]
+        self.headers = list(chain(*[izip(repeat(k), v)
+                                    for k, v in req.headers_out.getAllRawHeaders()]))
+        self.headers = [(str(k), str(v)) for k, v in self.headers]
         if body:
             self.body = [body]
         else:
@@ -103,11 +108,8 @@
     def __init__(self, config, vreg=None):
         self.appli = CubicWebPublisher(config, vreg=vreg)
         self.config = config
-        self.base_url = None
-#         self.base_url = config['base-url'] or config.default_base_url()
-#         assert self.base_url[-1] == '/'
-#         self.https_url = config['https-url']
-#         assert not self.https_url or self.https_url[-1] == '/'
+        self.base_url = config['base-url']
+        self.https_url = config['https-url']
         self.url_rewriter = self.appli.vreg['components'].select_or_none('urlrewriter')
 
     def _render(self, req):
--- a/wsgi/request.py	Thu Mar 15 17:54:40 2012 +0100
+++ b/wsgi/request.py	Thu Mar 15 18:34:59 2012 +0100
@@ -32,7 +32,8 @@
 
 from cubicweb.web.request import CubicWebRequestBase
 from cubicweb.wsgi import (pformat, qs2dict, safe_copyfileobj, parse_file_upload,
-                        normalize_header)
+                           normalize_header)
+from cubicweb.web.http_headers import Headers
 
 
 
@@ -44,17 +45,19 @@
         self.environ = environ
         self.path = environ['PATH_INFO']
         self.method = environ['REQUEST_METHOD'].upper()
-        self._headers = dict([(normalize_header(k[5:]), v) for k, v in self.environ.items()
-                              if k.startswith('HTTP_')])
+
+        headers_in = dict((normalize_header(k[5:]), v) for k, v in self.environ.items()
+                          if k.startswith('HTTP_'))
         https = environ.get("HTTPS") in ('yes', 'on', '1')
         post, files = self.get_posted_data()
-        super(CubicWebWsgiRequest, self).__init__(vreg, https, post)
+
+        super(CubicWebWsgiRequest, self).__init__(vreg, https, post,
+                                                  headers= headers_in)
         if files is not None:
             for key, (name, _, stream) in files.iteritems():
-                name = unicode(name, self.encoding)
+                if name is not None:
+                    name = unicode(name, self.encoding)
                 self.form[key] = (name, stream)
-        # prepare output headers
-        self.headers_out = {}
 
     def __repr__(self):
         # Since this is called as part of error handling, we need to be very
@@ -87,31 +90,6 @@
 
         return path
 
-    def get_header(self, header, default=None):
-        """return the value associated with the given input HTTP header,
-        raise KeyError if the header is not set
-        """
-        return self._headers.get(normalize_header(header), default)
-
-    def set_header(self, header, value, raw=True):
-        """set an output HTTP header"""
-        assert raw, "don't know anything about non-raw headers for wsgi requests"
-        self.headers_out[header] = value
-
-    def add_header(self, header, value):
-        """add an output HTTP header"""
-        self.headers_out[header] = value
-
-    def remove_header(self, header):
-        """remove an output HTTP header"""
-        self.headers_out.pop(header, None)
-
-    def header_if_modified_since(self):
-        """If the HTTP header If-modified-since is set, return the equivalent
-        mx date time value (GMT), else return None
-        """
-        return None
-
     ## wsgi request helpers ###################################################
 
     def instance_uri(self):
@@ -142,6 +120,8 @@
             and self.environ['wsgi.url_scheme'] == 'https'
 
     def get_posted_data(self):
+        # The WSGI spec says 'QUERY_STRING' may be absent.
+        post = qs2dict(self.environ.get('QUERY_STRING', ''))
         files = None
         if self.method == 'POST':
             if self.environ.get('CONTENT_TYPE', '').startswith('multipart'):
@@ -149,12 +129,10 @@
                                    for k, v in self.environ.items()
                                    if k.startswith('HTTP_'))
                 header_dict['Content-Type'] = self.environ.get('CONTENT_TYPE', '')
-                post, files = parse_file_upload(header_dict, self.raw_post_data)
+                post_, files = parse_file_upload(header_dict, self.raw_post_data)
+                post.update(post_)
             else:
-                post = qs2dict(self.raw_post_data)
-        else:
-            # The WSGI spec says 'QUERY_STRING' may be absent.
-            post = qs2dict(self.environ.get('QUERY_STRING', ''))
+                post.update(qs2dict(self.raw_post_data))
         return post, files
 
     @property
@@ -177,11 +155,10 @@
         """raise a `DirectResponse` exception if a cached page along the way
         exists and is still usable
         """
-        # XXX
-#         if self.get_header('Cache-Control') in ('max-age=0', 'no-cache'):
-#             # Expires header seems to be required by IE7
-#             self.add_header('Expires', 'Sat, 01 Jan 2000 00:00:00 GMT')
-#             return
+        if self.get_header('Cache-Control') in ('max-age=0', 'no-cache'):
+            # Expires header seems to be required by IE7
+            self.add_header('Expires', 'Sat, 01 Jan 2000 00:00:00 GMT')
+            return
 #         try:
 #             http.checkPreconditions(self._twreq, _PreResponse(self))
 #         except http.HTTPError, ex: