[wsgi] avoid reading the entire request body in memory
authorJulien Cristau <julien.cristau@logilab.fr>
Fri, 16 May 2014 16:49:29 +0200
changeset 9735 b71158815bc8
parent 9729 1fe9dad662e5
child 9738 bcb90cc8edde
[wsgi] avoid reading the entire request body in memory Import POST form handling code from https://raw.github.com/defnull/multipart/master/multipart.py to avoid reading arbitrary amounts of data from the network in memory. NOTES: - In the twisted case we limit the max request content-length to 100MB (by default), which seems kind of arbitrary, but avoids this issue - werkzeug.formparser has suitable code as well, but I don't know if we want to add it as a dependency
multipart.py
wsgi/__init__.py
wsgi/request.py
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/multipart.py	Fri May 16 16:49:29 2014 +0200
@@ -0,0 +1,413 @@
+# -*- coding: utf-8 -*-
+'''
+Parser for multipart/form-data
+==============================
+
+This module provides a parser for the multipart/form-data format. It can read
+from a file, a socket or a WSGI environment. The parser can be used to replace
+cgi.FieldStorage (without the bugs) and works with Python 2.5+ and 3.x (2to3).
+
+Licence (MIT)
+-------------
+
+    Copyright (c) 2010, Marcel Hellkamp.
+    Inspired by the Werkzeug library: http://werkzeug.pocoo.org/
+
+    Permission is hereby granted, free of charge, to any person obtaining a copy
+    of this software and associated documentation files (the "Software"), to deal
+    in the Software without restriction, including without limitation the rights
+    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+    copies of the Software, and to permit persons to whom the Software is
+    furnished to do so, subject to the following conditions:
+
+    The above copyright notice and this permission notice shall be included in
+    all copies or substantial portions of the Software.
+
+    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+    THE SOFTWARE.
+
+'''
+
+__author__ = 'Marcel Hellkamp'
+__version__ = '0.1'
+__license__ = 'MIT'
+
+from tempfile import TemporaryFile
+from wsgiref.headers import Headers
+import re, sys
+try:
+    from urlparse import parse_qs
+except ImportError: # pragma: no cover (fallback for Python 2.5)
+    from cgi import parse_qs
+try:
+    from io import BytesIO
+except ImportError: # pragma: no cover (fallback for Python 2.5)
+    from StringIO import StringIO as BytesIO
+
+##############################################################################
+################################ Helper & Misc ################################
+##############################################################################
+# Some of these were copied from bottle: http://bottle.paws.de/
+
+try:
+    from collections import MutableMapping as DictMixin
+except ImportError: # pragma: no cover (fallback for Python 2.5)
+    from UserDict import DictMixin
+
+class MultiDict(DictMixin):
+    """ A dict that remembers old values for each key """
+    def __init__(self, *a, **k):
+        self.dict = dict()
+        for k, v in dict(*a, **k).iteritems():
+            self[k] = v
+
+    def __len__(self): return len(self.dict)
+    def __iter__(self): return iter(self.dict)
+    def __contains__(self, key): return key in self.dict
+    def __delitem__(self, key): del self.dict[key]
+    def keys(self): return self.dict.keys()
+    def __getitem__(self, key): return self.get(key, KeyError, -1)
+    def __setitem__(self, key, value): self.append(key, value)
+
+    def append(self, key, value): self.dict.setdefault(key, []).append(value)
+    def replace(self, key, value): self.dict[key] = [value]
+    def getall(self, key): return self.dict.get(key) or []
+
+    def get(self, key, default=None, index=-1):
+        if key not in self.dict and default != KeyError:
+            return [default][index]
+        return self.dict[key][index]
+
+    def iterallitems(self):
+        for key, values in self.dict.iteritems():
+            for value in values:
+                yield key, value
+
+def tob(data, enc='utf8'): # Convert strings to bytes (py2 and py3)
+    return data.encode(enc) if isinstance(data, unicode) else data
+
+def copy_file(stream, target, maxread=-1, buffer_size=2*16):
+    ''' Read from :stream and write to :target until :maxread or EOF. '''
+    size, read = 0, stream.read
+    while 1:
+        to_read = buffer_size if maxread < 0 else min(buffer_size, maxread-size)
+        part = read(to_read)
+        if not part: return size
+        target.write(part)
+        size += len(part)
+
+##############################################################################
+################################ Header Parser ################################
+##############################################################################
+
+_special = re.escape('()<>@,;:\\"/[]?={} \t')
+_re_special = re.compile('[%s]' % _special)
+_qstr = '"(?:\\\\.|[^"])*"' # Quoted string
+_value = '(?:[^%s]+|%s)' % (_special, _qstr) # Save or quoted string
+_option = '(?:;|^)\s*([^%s]+)\s*=\s*(%s)' % (_special, _value)
+_re_option = re.compile(_option) # key=value part of an Content-Type like header
+
+def header_quote(val):
+    if not _re_special.search(val):
+        return val
+    return '"' + val.replace('\\','\\\\').replace('"','\\"') + '"'
+
+def header_unquote(val, filename=False):
+    if val[0] == val[-1] == '"':
+        val = val[1:-1]
+        if val[1:3] == ':\\' or val[:2] == '\\\\': 
+            val = val.split('\\')[-1] # fix ie6 bug: full path --> filename
+        return val.replace('\\\\','\\').replace('\\"','"')
+    return val
+
+def parse_options_header(header, options=None):
+    if ';' not in header:
+        return header.lower().strip(), {}
+    ctype, tail = header.split(';', 1)
+    options = options or {}
+    for match in _re_option.finditer(tail):
+        key = match.group(1).lower()
+        value = header_unquote(match.group(2), key=='filename')
+        options[key] = value
+    return ctype, options
+
+##############################################################################
+################################## Multipart ##################################
+##############################################################################
+
+
+class MultipartError(ValueError): pass
+
+
+class MultipartParser(object):
+    
+    def __init__(self, stream, boundary, content_length=-1,
+                 disk_limit=2**30, mem_limit=2**20, memfile_limit=2**18,
+                 buffer_size=2**16, charset='latin1'):
+        ''' Parse a multipart/form-data byte stream. This object is an iterator
+            over the parts of the message.
+            
+            :param stream: A file-like stream. Must implement ``.read(size)``.
+            :param boundary: The multipart boundary as a byte string.
+            :param content_length: The maximum number of bytes to read.
+        '''
+        self.stream, self.boundary = stream, boundary
+        self.content_length = content_length
+        self.disk_limit = disk_limit
+        self.memfile_limit = memfile_limit
+        self.mem_limit = min(mem_limit, self.disk_limit)
+        self.buffer_size = min(buffer_size, self.mem_limit)
+        self.charset = charset
+        if self.buffer_size - 6 < len(boundary): # "--boundary--\r\n"
+            raise MultipartError('Boundary does not fit into buffer_size.')
+        self._done = []
+        self._part_iter = None
+    
+    def __iter__(self):
+        ''' Iterate over the parts of the multipart message. '''
+        if not self._part_iter:
+            self._part_iter = self._iterparse()
+        for part in self._done:
+            yield part
+        for part in self._part_iter:
+            self._done.append(part)
+            yield part
+    
+    def parts(self):
+        ''' Returns a list with all parts of the multipart message. '''
+        return list(iter(self))
+    
+    def get(self, name, default=None):
+        ''' Return the first part with that name or a default value (None). '''
+        for part in self:
+            if name == part.name:
+                return part
+        return default
+
+    def get_all(self, name):
+        ''' Return a list of parts with that name. '''
+        return [p for p in self if p.name == name]
+
+    def _lineiter(self):
+        ''' Iterate over a binary file-like object line by line. Each line is
+            returned as a (line, line_ending) tuple. If the line does not fit
+            into self.buffer_size, line_ending is empty and the rest of the line
+            is returned with the next iteration.
+        '''
+        read = self.stream.read
+        maxread, maxbuf = self.content_length, self.buffer_size
+        _bcrnl = tob('\r\n')
+        _bcr = _bcrnl[:1]
+        _bnl = _bcrnl[1:]
+        _bempty = _bcrnl[:0] # b'rn'[:0] -> b''
+        buffer = _bempty # buffer for the last (partial) line
+        while 1:
+            data = read(maxbuf if maxread < 0 else min(maxbuf, maxread))
+            maxread -= len(data)
+            lines = (buffer+data).splitlines(True)
+            len_first_line = len(lines[0])
+            # be sure that the first line does not become too big
+            if len_first_line > self.buffer_size:
+                # at the same time don't split a '\r\n' accidentally
+                if (len_first_line == self.buffer_size+1 and
+                    lines[0].endswith(_bcrnl)):
+                    splitpos = self.buffer_size - 1
+                else:
+                    splitpos = self.buffer_size
+                lines[:1] = [lines[0][:splitpos],
+                             lines[0][splitpos:]]
+            if data:
+                buffer = lines[-1]
+                lines = lines[:-1]
+            for line in lines:
+                if line.endswith(_bcrnl): yield line[:-2], _bcrnl
+                elif line.endswith(_bnl): yield line[:-1], _bnl
+                elif line.endswith(_bcr): yield line[:-1], _bcr
+                else:                     yield line, _bempty
+            if not data:
+                break
+    
+    def _iterparse(self):
+        lines, line = self._lineiter(), ''
+        separator = tob('--') + tob(self.boundary)
+        terminator = tob('--') + tob(self.boundary) + tob('--')
+        # Consume first boundary. Ignore leading blank lines
+        for line, nl in lines:
+            if line: break
+        if line != separator:
+            raise MultipartError("Stream does not start with boundary")
+        # For each part in stream...
+        mem_used, disk_used = 0, 0 # Track used resources to prevent DoS
+        is_tail = False # True if the last line was incomplete (cutted)
+        opts = {'buffer_size': self.buffer_size,
+                'memfile_limit': self.memfile_limit,
+                'charset': self.charset}
+        part = MultipartPart(**opts)
+        for line, nl in lines:
+            if line == terminator and not is_tail:
+                part.file.seek(0)
+                yield part
+                break
+            elif line == separator and not is_tail:
+                if part.is_buffered(): mem_used  += part.size
+                else:                  disk_used += part.size
+                part.file.seek(0)
+                yield part
+                part = MultipartPart(**opts)
+            else:
+                is_tail = not nl # The next line continues this one
+                part.feed(line, nl)
+                if part.is_buffered():
+                    if part.size + mem_used > self.mem_limit:
+                        raise MultipartError("Memory limit reached.")
+                elif part.size + disk_used > self.disk_limit:
+                    raise MultipartError("Disk limit reached.")
+        if line != terminator:
+            raise MultipartError("Unexpected end of multipart stream.")
+            
+
+class MultipartPart(object):
+    
+    def __init__(self, buffer_size=2**16, memfile_limit=2**18, charset='latin1'):
+        self.headerlist = []
+        self.headers = None
+        self.file = False
+        self.size = 0
+        self._buf = tob('')
+        self.disposition, self.name, self.filename = None, None, None
+        self.content_type, self.charset = None, charset
+        self.memfile_limit = memfile_limit
+        self.buffer_size = buffer_size
+
+    def feed(self, line, nl=''):
+        if self.file:
+            return self.write_body(line, nl)
+        return self.write_header(line, nl)
+
+    def write_header(self, line, nl):
+        line = line.decode(self.charset or 'latin1')
+        if not nl: raise MultipartError('Unexpected end of line in header.')
+        if not line.strip(): # blank line -> end of header segment
+            self.finish_header()
+        elif line[0] in ' \t' and self.headerlist:
+            name, value = self.headerlist.pop()
+            self.headerlist.append((name, value+line.strip()))
+        else:
+            if ':' not in line:
+                raise MultipartError("Syntax error in header: No colon.")
+            name, value = line.split(':', 1)
+            self.headerlist.append((name.strip(), value.strip()))
+
+    def write_body(self, line, nl):
+        if not line and not nl: return # This does not even flush the buffer
+        self.size += len(line) + len(self._buf)
+        self.file.write(self._buf + line)
+        self._buf = nl
+        if self.content_length > 0 and self.size > self.content_length:
+            raise MultipartError('Size of body exceeds Content-Length header.')
+        if self.size > self.memfile_limit and isinstance(self.file, BytesIO):
+            # TODO: What about non-file uploads that exceed the memfile_limit?
+            self.file, old = TemporaryFile(mode='w+b'), self.file
+            old.seek(0)
+            copy_file(old, self.file, self.size, self.buffer_size)
+
+    def finish_header(self):
+        self.file = BytesIO()
+        self.headers = Headers(self.headerlist)
+        cdis = self.headers.get('Content-Disposition','')
+        ctype = self.headers.get('Content-Type','')
+        clen = self.headers.get('Content-Length','-1')
+        if not cdis:
+            raise MultipartError('Content-Disposition header is missing.')
+        self.disposition, self.options = parse_options_header(cdis)
+        self.name = self.options.get('name')
+        self.filename = self.options.get('filename')
+        self.content_type, options = parse_options_header(ctype)
+        self.charset = options.get('charset') or self.charset
+        self.content_length = int(self.headers.get('Content-Length','-1'))
+
+    def is_buffered(self):
+        ''' Return true if the data is fully buffered in memory.'''
+        return isinstance(self.file, BytesIO)
+
+    @property
+    def value(self):
+        ''' Data decoded with the specified charset '''
+        pos = self.file.tell()
+        self.file.seek(0)
+        val = self.file.read()
+        self.file.seek(pos)
+        return val.decode(self.charset)
+    
+    def save_as(self, path):
+        fp = open(path, 'wb')
+        pos = self.file.tell()
+        try:
+            self.file.seek(0)
+            size = copy_file(self.file, fp)
+        finally:
+            self.file.seek(pos)
+        return size
+
+##############################################################################
+#################################### WSGI ####################################
+##############################################################################
+
+def parse_form_data(environ, charset='utf8', strict=False, **kw):
+    ''' Parse form data from an environ dict and return a (forms, files) tuple.
+        Both tuple values are dictionaries with the form-field name as a key
+        (unicode) and lists as values (multiple values per key are possible).
+        The forms-dictionary contains form-field values as unicode strings.
+        The files-dictionary contains :class:`MultipartPart` instances, either
+        because the form-field was a file-upload or the value is to big to fit
+        into memory limits.
+        
+        :param environ: An WSGI environment dict.
+        :param charset: The charset to use if unsure. (default: utf8)
+        :param strict: If True, raise :exc:`MultipartError` on any parsing
+                       errors. These are silently ignored by default.
+    '''
+        
+    forms, files = MultiDict(), MultiDict()
+    try:
+        if environ.get('REQUEST_METHOD','GET').upper() not in ('POST', 'PUT'):
+            raise MultipartError("Request method other than POST or PUT.")
+        content_length = int(environ.get('CONTENT_LENGTH', '-1'))
+        content_type = environ.get('CONTENT_TYPE', '')
+        if not content_type:
+            raise MultipartError("Missing Content-Type header.")
+        content_type, options = parse_options_header(content_type)
+        stream = environ.get('wsgi.input') or BytesIO()
+        kw['charset'] = charset = options.get('charset', charset)
+        if content_type == 'multipart/form-data':
+            boundary = options.get('boundary','')
+            if not boundary:
+                raise MultipartError("No boundary for multipart/form-data.")
+            for part in MultipartParser(stream, boundary, content_length, **kw):
+                if part.filename or not part.is_buffered():
+                    files[part.name] = part
+                else: # TODO: Big form-fields are in the files dict. really?
+                    forms[part.name] = part.value
+        elif content_type in ('application/x-www-form-urlencoded',
+                              'application/x-url-encoded'):
+            mem_limit = kw.get('mem_limit', 2**20)
+            if content_length > mem_limit:
+                raise MultipartError("Request to big. Increase MAXMEM.")
+            data = stream.read(mem_limit).decode(charset)
+            if stream.read(1): # These is more that does not fit mem_limit
+                raise MultipartError("Request to big. Increase MAXMEM.")
+            data = parse_qs(data, keep_blank_values=True)
+            for key, values in data.iteritems():
+                for value in values:
+                    forms[key] = value
+        else:
+            raise MultipartError("Unsupported content type.")
+    except MultipartError:
+        if strict: raise
+    return forms, files
+
--- a/wsgi/__init__.py	Tue May 06 14:11:17 2014 +0200
+++ b/wsgi/__init__.py	Fri May 16 16:49:29 2014 +0200
@@ -29,7 +29,7 @@
 from email import message, message_from_string
 from Cookie import SimpleCookie
 from StringIO import StringIO
-from cgi import parse_header, parse_qsl
+from cgi import parse_header
 from pprint import pformat as _pformat
 
 
@@ -40,13 +40,6 @@
     except Exception:
         return u'<could not parse>'
 
-def qs2dict(qs):
-    """transforms a query string into a regular python dict"""
-    result = {}
-    for key, value in parse_qsl(qs, True):
-        result.setdefault(key, []).append(value)
-    return result
-
 def normalize_header(header):
     """returns a normalized header name
 
@@ -70,31 +63,3 @@
             break
         fdst.write(buf)
         size -= len(buf)
-
-def parse_file_upload(header_dict, post_data):
-    """This is adapted FROM DJANGO"""
-    raw_message = '\r\n'.join('%s:%s' % pair for pair in header_dict.iteritems())
-    raw_message += '\r\n\r\n' + post_data
-    msg = message_from_string(raw_message)
-    post, files = {}, {}
-    for submessage in msg.get_payload():
-        name_dict = parse_header(submessage['Content-Disposition'])[1]
-        key = name_dict['name']
-        # name_dict is something like {'name': 'file', 'filename': 'test.txt'} for file uploads
-        # or {'name': 'blah'} for POST fields
-        # We assume all uploaded files have a 'filename' set.
-        if 'filename' in name_dict:
-            assert type([]) != type(submessage.get_payload()), "Nested MIME messages are not supported"
-            if not name_dict['filename'].strip():
-                continue
-            # IE submits the full path, so trim everything but the basename.
-            # (We can't use os.path.basename because that uses the server's
-            # directory separator, which may not be the same as the
-            # client's one.)
-            filename = name_dict['filename'][name_dict['filename'].rfind("\\")+1:]
-            mimetype = 'Content-Type' in submessage and submessage['Content-Type'] or None
-            content = StringIO(submessage.get_payload())
-            files[key] = [filename, mimetype, content]
-        else:
-            post.setdefault(key, []).append(submessage.get_payload())
-    return post, files
--- a/wsgi/request.py	Tue May 06 14:11:17 2014 +0200
+++ b/wsgi/request.py	Fri May 16 16:49:29 2014 +0200
@@ -27,14 +27,11 @@
 
 from StringIO import StringIO
 from urllib import quote
+from urlparse import parse_qs
 
-from logilab.common.decorators import cached
-
+from cubicweb.multipart import copy_file, parse_form_data
 from cubicweb.web.request import CubicWebRequestBase
-from cubicweb.wsgi import (pformat, qs2dict, safe_copyfileobj, parse_file_upload,
-                           normalize_header)
-from cubicweb.web.http_headers import Headers
-
+from cubicweb.wsgi import pformat, normalize_header
 
 
 class CubicWebWsgiRequest(CubicWebRequestBase):
@@ -45,6 +42,8 @@
         self.environ = environ
         self.path = environ['PATH_INFO']
         self.method = environ['REQUEST_METHOD'].upper()
+
+        # content_length "may be empty or absent"
         try:
             length = int(environ['CONTENT_LENGTH'])
         except (KeyError, ValueError):
@@ -54,8 +53,9 @@
             self.content = StringIO()
         else:
             self.content = tempfile.TemporaryFile()
-        safe_copyfileobj(environ['wsgi.input'], self.content, size=length)
+        copy_file(environ['wsgi.input'], self.content, maxread=length)
         self.content.seek(0, 0)
+        environ['wsgi.input'] = self.content
 
         headers_in = dict((normalize_header(k[5:]), v) for k, v in self.environ.items()
                           if k.startswith('HTTP_'))
@@ -65,10 +65,11 @@
         super(CubicWebWsgiRequest, self).__init__(vreg, https, post,
                                                   headers= headers_in)
         if files is not None:
-            for key, (name, _, stream) in files.iteritems():
-                if name is not None:
-                    name = unicode(name, self.encoding)
-                self.form[key] = (name, stream)
+            for key, part in files.iteritems():
+                name = None
+                if part.filename is not None:
+                    name = unicode(part.filename, self.encoding)
+                self.form[key] = (name, part.file)
 
     def __repr__(self):
         # Since this is called as part of error handling, we need to be very
@@ -132,23 +133,11 @@
 
     def get_posted_data(self):
         # The WSGI spec says 'QUERY_STRING' may be absent.
-        post = qs2dict(self.environ.get('QUERY_STRING', ''))
+        post = parse_qs(self.environ.get('QUERY_STRING', ''))
         files = None
         if self.method == 'POST':
-            if self.environ.get('CONTENT_TYPE', '').startswith('multipart'):
-                header_dict = dict((normalize_header(k[5:]), v)
-                                   for k, v in self.environ.items()
-                                   if k.startswith('HTTP_'))
-                header_dict['Content-Type'] = self.environ.get('CONTENT_TYPE', '')
-                post_, files = parse_file_upload(header_dict, self.raw_post_data)
-                post.update(post_)
-            else:
-                post.update(qs2dict(self.raw_post_data))
+            forms, files = parse_form_data(self.environ, strict=True,
+                                           mem_limit=self.vreg.config['max-post-length'])
+            post.update(forms)
+        self.content.seek(0, 0)
         return post, files
-
-    @property
-    @cached
-    def raw_post_data(self):
-        postdata = self.content.read()
-        self.content.seek(0, 0)
-        return postdata