multipart.py
changeset 9735 b71158815bc8
child 9946 ec88c1a1904a
equal deleted inserted replaced
9729:1fe9dad662e5 9735:b71158815bc8
       
     1 # -*- coding: utf-8 -*-
       
     2 '''
       
     3 Parser for multipart/form-data
       
     4 ==============================
       
     5 
       
     6 This module provides a parser for the multipart/form-data format. It can read
       
     7 from a file, a socket or a WSGI environment. The parser can be used to replace
       
     8 cgi.FieldStorage (without the bugs) and works with Python 2.5+ and 3.x (2to3).
       
     9 
       
    10 Licence (MIT)
       
    11 -------------
       
    12 
       
    13     Copyright (c) 2010, Marcel Hellkamp.
       
    14     Inspired by the Werkzeug library: http://werkzeug.pocoo.org/
       
    15 
       
    16     Permission is hereby granted, free of charge, to any person obtaining a copy
       
    17     of this software and associated documentation files (the "Software"), to deal
       
    18     in the Software without restriction, including without limitation the rights
       
    19     to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       
    20     copies of the Software, and to permit persons to whom the Software is
       
    21     furnished to do so, subject to the following conditions:
       
    22 
       
    23     The above copyright notice and this permission notice shall be included in
       
    24     all copies or substantial portions of the Software.
       
    25 
       
    26     THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
       
    27     IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
       
    28     FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
       
    29     AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
       
    30     LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
       
    31     OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
       
    32     THE SOFTWARE.
       
    33 
       
    34 '''
       
    35 
       
    36 __author__ = 'Marcel Hellkamp'
       
    37 __version__ = '0.1'
       
    38 __license__ = 'MIT'
       
    39 
       
    40 from tempfile import TemporaryFile
       
    41 from wsgiref.headers import Headers
       
    42 import re, sys
       
    43 try:
       
    44     from urlparse import parse_qs
       
    45 except ImportError: # pragma: no cover (fallback for Python 2.5)
       
    46     from cgi import parse_qs
       
    47 try:
       
    48     from io import BytesIO
       
    49 except ImportError: # pragma: no cover (fallback for Python 2.5)
       
    50     from StringIO import StringIO as BytesIO
       
    51 
       
    52 ##############################################################################
       
    53 ################################ Helper & Misc ################################
       
    54 ##############################################################################
       
    55 # Some of these were copied from bottle: http://bottle.paws.de/
       
    56 
       
    57 try:
       
    58     from collections import MutableMapping as DictMixin
       
    59 except ImportError: # pragma: no cover (fallback for Python 2.5)
       
    60     from UserDict import DictMixin
       
    61 
       
    62 class MultiDict(DictMixin):
       
    63     """ A dict that remembers old values for each key """
       
    64     def __init__(self, *a, **k):
       
    65         self.dict = dict()
       
    66         for k, v in dict(*a, **k).iteritems():
       
    67             self[k] = v
       
    68 
       
    69     def __len__(self): return len(self.dict)
       
    70     def __iter__(self): return iter(self.dict)
       
    71     def __contains__(self, key): return key in self.dict
       
    72     def __delitem__(self, key): del self.dict[key]
       
    73     def keys(self): return self.dict.keys()
       
    74     def __getitem__(self, key): return self.get(key, KeyError, -1)
       
    75     def __setitem__(self, key, value): self.append(key, value)
       
    76 
       
    77     def append(self, key, value): self.dict.setdefault(key, []).append(value)
       
    78     def replace(self, key, value): self.dict[key] = [value]
       
    79     def getall(self, key): return self.dict.get(key) or []
       
    80 
       
    81     def get(self, key, default=None, index=-1):
       
    82         if key not in self.dict and default != KeyError:
       
    83             return [default][index]
       
    84         return self.dict[key][index]
       
    85 
       
    86     def iterallitems(self):
       
    87         for key, values in self.dict.iteritems():
       
    88             for value in values:
       
    89                 yield key, value
       
    90 
       
    91 def tob(data, enc='utf8'): # Convert strings to bytes (py2 and py3)
       
    92     return data.encode(enc) if isinstance(data, unicode) else data
       
    93 
       
    94 def copy_file(stream, target, maxread=-1, buffer_size=2*16):
       
    95     ''' Read from :stream and write to :target until :maxread or EOF. '''
       
    96     size, read = 0, stream.read
       
    97     while 1:
       
    98         to_read = buffer_size if maxread < 0 else min(buffer_size, maxread-size)
       
    99         part = read(to_read)
       
   100         if not part: return size
       
   101         target.write(part)
       
   102         size += len(part)
       
   103 
       
   104 ##############################################################################
       
   105 ################################ Header Parser ################################
       
   106 ##############################################################################
       
   107 
       
   108 _special = re.escape('()<>@,;:\\"/[]?={} \t')
       
   109 _re_special = re.compile('[%s]' % _special)
       
   110 _qstr = '"(?:\\\\.|[^"])*"' # Quoted string
       
   111 _value = '(?:[^%s]+|%s)' % (_special, _qstr) # Save or quoted string
       
   112 _option = '(?:;|^)\s*([^%s]+)\s*=\s*(%s)' % (_special, _value)
       
   113 _re_option = re.compile(_option) # key=value part of an Content-Type like header
       
   114 
       
   115 def header_quote(val):
       
   116     if not _re_special.search(val):
       
   117         return val
       
   118     return '"' + val.replace('\\','\\\\').replace('"','\\"') + '"'
       
   119 
       
   120 def header_unquote(val, filename=False):
       
   121     if val[0] == val[-1] == '"':
       
   122         val = val[1:-1]
       
   123         if val[1:3] == ':\\' or val[:2] == '\\\\': 
       
   124             val = val.split('\\')[-1] # fix ie6 bug: full path --> filename
       
   125         return val.replace('\\\\','\\').replace('\\"','"')
       
   126     return val
       
   127 
       
   128 def parse_options_header(header, options=None):
       
   129     if ';' not in header:
       
   130         return header.lower().strip(), {}
       
   131     ctype, tail = header.split(';', 1)
       
   132     options = options or {}
       
   133     for match in _re_option.finditer(tail):
       
   134         key = match.group(1).lower()
       
   135         value = header_unquote(match.group(2), key=='filename')
       
   136         options[key] = value
       
   137     return ctype, options
       
   138 
       
   139 ##############################################################################
       
   140 ################################## Multipart ##################################
       
   141 ##############################################################################
       
   142 
       
   143 
       
   144 class MultipartError(ValueError): pass
       
   145 
       
   146 
       
   147 class MultipartParser(object):
       
   148     
       
   149     def __init__(self, stream, boundary, content_length=-1,
       
   150                  disk_limit=2**30, mem_limit=2**20, memfile_limit=2**18,
       
   151                  buffer_size=2**16, charset='latin1'):
       
   152         ''' Parse a multipart/form-data byte stream. This object is an iterator
       
   153             over the parts of the message.
       
   154             
       
   155             :param stream: A file-like stream. Must implement ``.read(size)``.
       
   156             :param boundary: The multipart boundary as a byte string.
       
   157             :param content_length: The maximum number of bytes to read.
       
   158         '''
       
   159         self.stream, self.boundary = stream, boundary
       
   160         self.content_length = content_length
       
   161         self.disk_limit = disk_limit
       
   162         self.memfile_limit = memfile_limit
       
   163         self.mem_limit = min(mem_limit, self.disk_limit)
       
   164         self.buffer_size = min(buffer_size, self.mem_limit)
       
   165         self.charset = charset
       
   166         if self.buffer_size - 6 < len(boundary): # "--boundary--\r\n"
       
   167             raise MultipartError('Boundary does not fit into buffer_size.')
       
   168         self._done = []
       
   169         self._part_iter = None
       
   170     
       
   171     def __iter__(self):
       
   172         ''' Iterate over the parts of the multipart message. '''
       
   173         if not self._part_iter:
       
   174             self._part_iter = self._iterparse()
       
   175         for part in self._done:
       
   176             yield part
       
   177         for part in self._part_iter:
       
   178             self._done.append(part)
       
   179             yield part
       
   180     
       
   181     def parts(self):
       
   182         ''' Returns a list with all parts of the multipart message. '''
       
   183         return list(iter(self))
       
   184     
       
   185     def get(self, name, default=None):
       
   186         ''' Return the first part with that name or a default value (None). '''
       
   187         for part in self:
       
   188             if name == part.name:
       
   189                 return part
       
   190         return default
       
   191 
       
   192     def get_all(self, name):
       
   193         ''' Return a list of parts with that name. '''
       
   194         return [p for p in self if p.name == name]
       
   195 
       
   196     def _lineiter(self):
       
   197         ''' Iterate over a binary file-like object line by line. Each line is
       
   198             returned as a (line, line_ending) tuple. If the line does not fit
       
   199             into self.buffer_size, line_ending is empty and the rest of the line
       
   200             is returned with the next iteration.
       
   201         '''
       
   202         read = self.stream.read
       
   203         maxread, maxbuf = self.content_length, self.buffer_size
       
   204         _bcrnl = tob('\r\n')
       
   205         _bcr = _bcrnl[:1]
       
   206         _bnl = _bcrnl[1:]
       
   207         _bempty = _bcrnl[:0] # b'rn'[:0] -> b''
       
   208         buffer = _bempty # buffer for the last (partial) line
       
   209         while 1:
       
   210             data = read(maxbuf if maxread < 0 else min(maxbuf, maxread))
       
   211             maxread -= len(data)
       
   212             lines = (buffer+data).splitlines(True)
       
   213             len_first_line = len(lines[0])
       
   214             # be sure that the first line does not become too big
       
   215             if len_first_line > self.buffer_size:
       
   216                 # at the same time don't split a '\r\n' accidentally
       
   217                 if (len_first_line == self.buffer_size+1 and
       
   218                     lines[0].endswith(_bcrnl)):
       
   219                     splitpos = self.buffer_size - 1
       
   220                 else:
       
   221                     splitpos = self.buffer_size
       
   222                 lines[:1] = [lines[0][:splitpos],
       
   223                              lines[0][splitpos:]]
       
   224             if data:
       
   225                 buffer = lines[-1]
       
   226                 lines = lines[:-1]
       
   227             for line in lines:
       
   228                 if line.endswith(_bcrnl): yield line[:-2], _bcrnl
       
   229                 elif line.endswith(_bnl): yield line[:-1], _bnl
       
   230                 elif line.endswith(_bcr): yield line[:-1], _bcr
       
   231                 else:                     yield line, _bempty
       
   232             if not data:
       
   233                 break
       
   234     
       
   235     def _iterparse(self):
       
   236         lines, line = self._lineiter(), ''
       
   237         separator = tob('--') + tob(self.boundary)
       
   238         terminator = tob('--') + tob(self.boundary) + tob('--')
       
   239         # Consume first boundary. Ignore leading blank lines
       
   240         for line, nl in lines:
       
   241             if line: break
       
   242         if line != separator:
       
   243             raise MultipartError("Stream does not start with boundary")
       
   244         # For each part in stream...
       
   245         mem_used, disk_used = 0, 0 # Track used resources to prevent DoS
       
   246         is_tail = False # True if the last line was incomplete (cutted)
       
   247         opts = {'buffer_size': self.buffer_size,
       
   248                 'memfile_limit': self.memfile_limit,
       
   249                 'charset': self.charset}
       
   250         part = MultipartPart(**opts)
       
   251         for line, nl in lines:
       
   252             if line == terminator and not is_tail:
       
   253                 part.file.seek(0)
       
   254                 yield part
       
   255                 break
       
   256             elif line == separator and not is_tail:
       
   257                 if part.is_buffered(): mem_used  += part.size
       
   258                 else:                  disk_used += part.size
       
   259                 part.file.seek(0)
       
   260                 yield part
       
   261                 part = MultipartPart(**opts)
       
   262             else:
       
   263                 is_tail = not nl # The next line continues this one
       
   264                 part.feed(line, nl)
       
   265                 if part.is_buffered():
       
   266                     if part.size + mem_used > self.mem_limit:
       
   267                         raise MultipartError("Memory limit reached.")
       
   268                 elif part.size + disk_used > self.disk_limit:
       
   269                     raise MultipartError("Disk limit reached.")
       
   270         if line != terminator:
       
   271             raise MultipartError("Unexpected end of multipart stream.")
       
   272             
       
   273 
       
   274 class MultipartPart(object):
       
   275     
       
   276     def __init__(self, buffer_size=2**16, memfile_limit=2**18, charset='latin1'):
       
   277         self.headerlist = []
       
   278         self.headers = None
       
   279         self.file = False
       
   280         self.size = 0
       
   281         self._buf = tob('')
       
   282         self.disposition, self.name, self.filename = None, None, None
       
   283         self.content_type, self.charset = None, charset
       
   284         self.memfile_limit = memfile_limit
       
   285         self.buffer_size = buffer_size
       
   286 
       
   287     def feed(self, line, nl=''):
       
   288         if self.file:
       
   289             return self.write_body(line, nl)
       
   290         return self.write_header(line, nl)
       
   291 
       
   292     def write_header(self, line, nl):
       
   293         line = line.decode(self.charset or 'latin1')
       
   294         if not nl: raise MultipartError('Unexpected end of line in header.')
       
   295         if not line.strip(): # blank line -> end of header segment
       
   296             self.finish_header()
       
   297         elif line[0] in ' \t' and self.headerlist:
       
   298             name, value = self.headerlist.pop()
       
   299             self.headerlist.append((name, value+line.strip()))
       
   300         else:
       
   301             if ':' not in line:
       
   302                 raise MultipartError("Syntax error in header: No colon.")
       
   303             name, value = line.split(':', 1)
       
   304             self.headerlist.append((name.strip(), value.strip()))
       
   305 
       
   306     def write_body(self, line, nl):
       
   307         if not line and not nl: return # This does not even flush the buffer
       
   308         self.size += len(line) + len(self._buf)
       
   309         self.file.write(self._buf + line)
       
   310         self._buf = nl
       
   311         if self.content_length > 0 and self.size > self.content_length:
       
   312             raise MultipartError('Size of body exceeds Content-Length header.')
       
   313         if self.size > self.memfile_limit and isinstance(self.file, BytesIO):
       
   314             # TODO: What about non-file uploads that exceed the memfile_limit?
       
   315             self.file, old = TemporaryFile(mode='w+b'), self.file
       
   316             old.seek(0)
       
   317             copy_file(old, self.file, self.size, self.buffer_size)
       
   318 
       
   319     def finish_header(self):
       
   320         self.file = BytesIO()
       
   321         self.headers = Headers(self.headerlist)
       
   322         cdis = self.headers.get('Content-Disposition','')
       
   323         ctype = self.headers.get('Content-Type','')
       
   324         clen = self.headers.get('Content-Length','-1')
       
   325         if not cdis:
       
   326             raise MultipartError('Content-Disposition header is missing.')
       
   327         self.disposition, self.options = parse_options_header(cdis)
       
   328         self.name = self.options.get('name')
       
   329         self.filename = self.options.get('filename')
       
   330         self.content_type, options = parse_options_header(ctype)
       
   331         self.charset = options.get('charset') or self.charset
       
   332         self.content_length = int(self.headers.get('Content-Length','-1'))
       
   333 
       
   334     def is_buffered(self):
       
   335         ''' Return true if the data is fully buffered in memory.'''
       
   336         return isinstance(self.file, BytesIO)
       
   337 
       
   338     @property
       
   339     def value(self):
       
   340         ''' Data decoded with the specified charset '''
       
   341         pos = self.file.tell()
       
   342         self.file.seek(0)
       
   343         val = self.file.read()
       
   344         self.file.seek(pos)
       
   345         return val.decode(self.charset)
       
   346     
       
   347     def save_as(self, path):
       
   348         fp = open(path, 'wb')
       
   349         pos = self.file.tell()
       
   350         try:
       
   351             self.file.seek(0)
       
   352             size = copy_file(self.file, fp)
       
   353         finally:
       
   354             self.file.seek(pos)
       
   355         return size
       
   356 
       
   357 ##############################################################################
       
   358 #################################### WSGI ####################################
       
   359 ##############################################################################
       
   360 
       
   361 def parse_form_data(environ, charset='utf8', strict=False, **kw):
       
   362     ''' Parse form data from an environ dict and return a (forms, files) tuple.
       
   363         Both tuple values are dictionaries with the form-field name as a key
       
   364         (unicode) and lists as values (multiple values per key are possible).
       
   365         The forms-dictionary contains form-field values as unicode strings.
       
   366         The files-dictionary contains :class:`MultipartPart` instances, either
       
   367         because the form-field was a file-upload or the value is to big to fit
       
   368         into memory limits.
       
   369         
       
   370         :param environ: An WSGI environment dict.
       
   371         :param charset: The charset to use if unsure. (default: utf8)
       
   372         :param strict: If True, raise :exc:`MultipartError` on any parsing
       
   373                        errors. These are silently ignored by default.
       
   374     '''
       
   375         
       
   376     forms, files = MultiDict(), MultiDict()
       
   377     try:
       
   378         if environ.get('REQUEST_METHOD','GET').upper() not in ('POST', 'PUT'):
       
   379             raise MultipartError("Request method other than POST or PUT.")
       
   380         content_length = int(environ.get('CONTENT_LENGTH', '-1'))
       
   381         content_type = environ.get('CONTENT_TYPE', '')
       
   382         if not content_type:
       
   383             raise MultipartError("Missing Content-Type header.")
       
   384         content_type, options = parse_options_header(content_type)
       
   385         stream = environ.get('wsgi.input') or BytesIO()
       
   386         kw['charset'] = charset = options.get('charset', charset)
       
   387         if content_type == 'multipart/form-data':
       
   388             boundary = options.get('boundary','')
       
   389             if not boundary:
       
   390                 raise MultipartError("No boundary for multipart/form-data.")
       
   391             for part in MultipartParser(stream, boundary, content_length, **kw):
       
   392                 if part.filename or not part.is_buffered():
       
   393                     files[part.name] = part
       
   394                 else: # TODO: Big form-fields are in the files dict. really?
       
   395                     forms[part.name] = part.value
       
   396         elif content_type in ('application/x-www-form-urlencoded',
       
   397                               'application/x-url-encoded'):
       
   398             mem_limit = kw.get('mem_limit', 2**20)
       
   399             if content_length > mem_limit:
       
   400                 raise MultipartError("Request to big. Increase MAXMEM.")
       
   401             data = stream.read(mem_limit).decode(charset)
       
   402             if stream.read(1): # These is more that does not fit mem_limit
       
   403                 raise MultipartError("Request to big. Increase MAXMEM.")
       
   404             data = parse_qs(data, keep_blank_values=True)
       
   405             for key, values in data.iteritems():
       
   406                 for value in values:
       
   407                     forms[key] = value
       
   408         else:
       
   409             raise MultipartError("Unsupported content type.")
       
   410     except MultipartError:
       
   411         if strict: raise
       
   412     return forms, files
       
   413