multipart.py
changeset 11057 0b59724cb3f2
parent 11052 058bb3dc685f
child 11058 23eb30449fe5
equal deleted inserted replaced
11052:058bb3dc685f 11057:0b59724cb3f2
     1 # -*- coding: utf-8 -*-
       
     2 '''
       
     3 Parser for multipart/form-data
       
     4 ==============================
       
     5 
       
     6 This module provides a parser for the multipart/form-data format. It can read
       
     7 from a file, a socket or a WSGI environment. The parser can be used to replace
       
     8 cgi.FieldStorage (without the bugs) and works with Python 2.5+ and 3.x (2to3).
       
     9 
       
    10 Licence (MIT)
       
    11 -------------
       
    12 
       
    13     Copyright (c) 2010, Marcel Hellkamp.
       
    14     Inspired by the Werkzeug library: http://werkzeug.pocoo.org/
       
    15 
       
    16     Permission is hereby granted, free of charge, to any person obtaining a copy
       
    17     of this software and associated documentation files (the "Software"), to deal
       
    18     in the Software without restriction, including without limitation the rights
       
    19     to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       
    20     copies of the Software, and to permit persons to whom the Software is
       
    21     furnished to do so, subject to the following conditions:
       
    22 
       
    23     The above copyright notice and this permission notice shall be included in
       
    24     all copies or substantial portions of the Software.
       
    25 
       
    26     THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
       
    27     IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
       
    28     FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
       
    29     AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
       
    30     LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
       
    31     OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
       
    32     THE SOFTWARE.
       
    33 
       
    34 '''
       
    35 
       
    36 __author__ = 'Marcel Hellkamp'
       
    37 __version__ = '0.1'
       
    38 __license__ = 'MIT'
       
    39 
       
    40 from tempfile import TemporaryFile
       
    41 from wsgiref.headers import Headers
       
    42 import re, sys
       
    43 try:
       
    44     from io import BytesIO
       
    45 except ImportError: # pragma: no cover (fallback for Python 2.5)
       
    46     from StringIO import StringIO as BytesIO
       
    47 
       
    48 from six import PY3, text_type
       
    49 from six.moves.urllib.parse import parse_qs
       
    50 
       
    51 ##############################################################################
       
    52 ################################ Helper & Misc ################################
       
    53 ##############################################################################
       
    54 # Some of these were copied from bottle: http://bottle.paws.de/
       
    55 
       
    56 try:
       
    57     from collections import MutableMapping as DictMixin
       
    58 except ImportError: # pragma: no cover (fallback for Python 2.5)
       
    59     from UserDict import DictMixin
       
    60 
       
    61 class MultiDict(DictMixin):
       
    62     """ A dict that remembers old values for each key """
       
    63     def __init__(self, *a, **k):
       
    64         self.dict = dict()
       
    65         for k, v in dict(*a, **k).items():
       
    66             self[k] = v
       
    67 
       
    68     def __len__(self): return len(self.dict)
       
    69     def __iter__(self): return iter(self.dict)
       
    70     def __contains__(self, key): return key in self.dict
       
    71     def __delitem__(self, key): del self.dict[key]
       
    72     def keys(self): return self.dict.keys()
       
    73     def __getitem__(self, key): return self.get(key, KeyError, -1)
       
    74     def __setitem__(self, key, value): self.append(key, value)
       
    75 
       
    76     def append(self, key, value): self.dict.setdefault(key, []).append(value)
       
    77     def replace(self, key, value): self.dict[key] = [value]
       
    78     def getall(self, key): return self.dict.get(key) or []
       
    79 
       
    80     def get(self, key, default=None, index=-1):
       
    81         if key not in self.dict and default != KeyError:
       
    82             return [default][index]
       
    83         return self.dict[key][index]
       
    84 
       
    85     def iterallitems(self):
       
    86         for key, values in self.dict.items():
       
    87             for value in values:
       
    88                 yield key, value
       
    89 
       
    90 def tob(data, enc='utf8'): # Convert strings to bytes (py2 and py3)
       
    91     return data.encode(enc) if isinstance(data, text_type) else data
       
    92 
       
    93 def copy_file(stream, target, maxread=-1, buffer_size=2*16):
       
    94     ''' Read from :stream and write to :target until :maxread or EOF. '''
       
    95     size, read = 0, stream.read
       
    96     while 1:
       
    97         to_read = buffer_size if maxread < 0 else min(buffer_size, maxread-size)
       
    98         part = read(to_read)
       
    99         if not part: return size
       
   100         target.write(part)
       
   101         size += len(part)
       
   102 
       
   103 ##############################################################################
       
   104 ################################ Header Parser ################################
       
   105 ##############################################################################
       
   106 
       
   107 _special = re.escape('()<>@,;:\\"/[]?={} \t')
       
   108 _re_special = re.compile('[%s]' % _special)
       
   109 _qstr = '"(?:\\\\.|[^"])*"' # Quoted string
       
   110 _value = '(?:[^%s]+|%s)' % (_special, _qstr) # Save or quoted string
       
   111 _option = '(?:;|^)\s*([^%s]+)\s*=\s*(%s)' % (_special, _value)
       
   112 _re_option = re.compile(_option) # key=value part of an Content-Type like header
       
   113 
       
   114 def header_quote(val):
       
   115     if not _re_special.search(val):
       
   116         return val
       
   117     return '"' + val.replace('\\','\\\\').replace('"','\\"') + '"'
       
   118 
       
   119 def header_unquote(val, filename=False):
       
   120     if val[0] == val[-1] == '"':
       
   121         val = val[1:-1]
       
   122         if val[1:3] == ':\\' or val[:2] == '\\\\': 
       
   123             val = val.split('\\')[-1] # fix ie6 bug: full path --> filename
       
   124         return val.replace('\\\\','\\').replace('\\"','"')
       
   125     return val
       
   126 
       
   127 def parse_options_header(header, options=None):
       
   128     if ';' not in header:
       
   129         return header.lower().strip(), {}
       
   130     ctype, tail = header.split(';', 1)
       
   131     options = options or {}
       
   132     for match in _re_option.finditer(tail):
       
   133         key = match.group(1).lower()
       
   134         value = header_unquote(match.group(2), key=='filename')
       
   135         options[key] = value
       
   136     return ctype, options
       
   137 
       
   138 ##############################################################################
       
   139 ################################## Multipart ##################################
       
   140 ##############################################################################
       
   141 
       
   142 
       
   143 class MultipartError(ValueError): pass
       
   144 
       
   145 
       
   146 class MultipartParser(object):
       
   147     
       
   148     def __init__(self, stream, boundary, content_length=-1,
       
   149                  disk_limit=2**30, mem_limit=2**20, memfile_limit=2**18,
       
   150                  buffer_size=2**16, charset='latin1'):
       
   151         ''' Parse a multipart/form-data byte stream. This object is an iterator
       
   152             over the parts of the message.
       
   153             
       
   154             :param stream: A file-like stream. Must implement ``.read(size)``.
       
   155             :param boundary: The multipart boundary as a byte string.
       
   156             :param content_length: The maximum number of bytes to read.
       
   157         '''
       
   158         self.stream, self.boundary = stream, boundary
       
   159         self.content_length = content_length
       
   160         self.disk_limit = disk_limit
       
   161         self.memfile_limit = memfile_limit
       
   162         self.mem_limit = min(mem_limit, self.disk_limit)
       
   163         self.buffer_size = min(buffer_size, self.mem_limit)
       
   164         self.charset = charset
       
   165         if self.buffer_size - 6 < len(boundary): # "--boundary--\r\n"
       
   166             raise MultipartError('Boundary does not fit into buffer_size.')
       
   167         self._done = []
       
   168         self._part_iter = None
       
   169     
       
   170     def __iter__(self):
       
   171         ''' Iterate over the parts of the multipart message. '''
       
   172         if not self._part_iter:
       
   173             self._part_iter = self._iterparse()
       
   174         for part in self._done:
       
   175             yield part
       
   176         for part in self._part_iter:
       
   177             self._done.append(part)
       
   178             yield part
       
   179     
       
   180     def parts(self):
       
   181         ''' Returns a list with all parts of the multipart message. '''
       
   182         return list(iter(self))
       
   183     
       
   184     def get(self, name, default=None):
       
   185         ''' Return the first part with that name or a default value (None). '''
       
   186         for part in self:
       
   187             if name == part.name:
       
   188                 return part
       
   189         return default
       
   190 
       
   191     def get_all(self, name):
       
   192         ''' Return a list of parts with that name. '''
       
   193         return [p for p in self if p.name == name]
       
   194 
       
   195     def _lineiter(self):
       
   196         ''' Iterate over a binary file-like object line by line. Each line is
       
   197             returned as a (line, line_ending) tuple. If the line does not fit
       
   198             into self.buffer_size, line_ending is empty and the rest of the line
       
   199             is returned with the next iteration.
       
   200         '''
       
   201         read = self.stream.read
       
   202         maxread, maxbuf = self.content_length, self.buffer_size
       
   203         _bcrnl = tob('\r\n')
       
   204         _bcr = _bcrnl[:1]
       
   205         _bnl = _bcrnl[1:]
       
   206         _bempty = _bcrnl[:0] # b'rn'[:0] -> b''
       
   207         buffer = _bempty # buffer for the last (partial) line
       
   208         while 1:
       
   209             data = read(maxbuf if maxread < 0 else min(maxbuf, maxread))
       
   210             maxread -= len(data)
       
   211             lines = (buffer+data).splitlines(True)
       
   212             len_first_line = len(lines[0])
       
   213             # be sure that the first line does not become too big
       
   214             if len_first_line > self.buffer_size:
       
   215                 # at the same time don't split a '\r\n' accidentally
       
   216                 if (len_first_line == self.buffer_size+1 and
       
   217                     lines[0].endswith(_bcrnl)):
       
   218                     splitpos = self.buffer_size - 1
       
   219                 else:
       
   220                     splitpos = self.buffer_size
       
   221                 lines[:1] = [lines[0][:splitpos],
       
   222                              lines[0][splitpos:]]
       
   223             if data:
       
   224                 buffer = lines[-1]
       
   225                 lines = lines[:-1]
       
   226             for line in lines:
       
   227                 if line.endswith(_bcrnl): yield line[:-2], _bcrnl
       
   228                 elif line.endswith(_bnl): yield line[:-1], _bnl
       
   229                 elif line.endswith(_bcr): yield line[:-1], _bcr
       
   230                 else:                     yield line, _bempty
       
   231             if not data:
       
   232                 break
       
   233     
       
   234     def _iterparse(self):
       
   235         lines, line = self._lineiter(), ''
       
   236         separator = tob('--') + tob(self.boundary)
       
   237         terminator = tob('--') + tob(self.boundary) + tob('--')
       
   238         # Consume first boundary. Ignore leading blank lines
       
   239         for line, nl in lines:
       
   240             if line: break
       
   241         if line != separator:
       
   242             raise MultipartError("Stream does not start with boundary")
       
   243         # For each part in stream...
       
   244         mem_used, disk_used = 0, 0 # Track used resources to prevent DoS
       
   245         is_tail = False # True if the last line was incomplete (cutted)
       
   246         opts = {'buffer_size': self.buffer_size,
       
   247                 'memfile_limit': self.memfile_limit,
       
   248                 'charset': self.charset}
       
   249         part = MultipartPart(**opts)
       
   250         for line, nl in lines:
       
   251             if line == terminator and not is_tail:
       
   252                 part.file.seek(0)
       
   253                 yield part
       
   254                 break
       
   255             elif line == separator and not is_tail:
       
   256                 if part.is_buffered(): mem_used  += part.size
       
   257                 else:                  disk_used += part.size
       
   258                 part.file.seek(0)
       
   259                 yield part
       
   260                 part = MultipartPart(**opts)
       
   261             else:
       
   262                 is_tail = not nl # The next line continues this one
       
   263                 part.feed(line, nl)
       
   264                 if part.is_buffered():
       
   265                     if part.size + mem_used > self.mem_limit:
       
   266                         raise MultipartError("Memory limit reached.")
       
   267                 elif part.size + disk_used > self.disk_limit:
       
   268                     raise MultipartError("Disk limit reached.")
       
   269         if line != terminator:
       
   270             raise MultipartError("Unexpected end of multipart stream.")
       
   271             
       
   272 
       
   273 class MultipartPart(object):
       
   274     
       
   275     def __init__(self, buffer_size=2**16, memfile_limit=2**18, charset='latin1'):
       
   276         self.headerlist = []
       
   277         self.headers = None
       
   278         self.file = False
       
   279         self.size = 0
       
   280         self._buf = tob('')
       
   281         self.disposition, self.name, self.filename = None, None, None
       
   282         self.content_type, self.charset = None, charset
       
   283         self.memfile_limit = memfile_limit
       
   284         self.buffer_size = buffer_size
       
   285 
       
   286     def feed(self, line, nl=''):
       
   287         if self.file:
       
   288             return self.write_body(line, nl)
       
   289         return self.write_header(line, nl)
       
   290 
       
   291     def write_header(self, line, nl):
       
   292         line = line.decode(self.charset or 'latin1')
       
   293         if not nl: raise MultipartError('Unexpected end of line in header.')
       
   294         if not line.strip(): # blank line -> end of header segment
       
   295             self.finish_header()
       
   296         elif line[0] in ' \t' and self.headerlist:
       
   297             name, value = self.headerlist.pop()
       
   298             self.headerlist.append((name, value+line.strip()))
       
   299         else:
       
   300             if ':' not in line:
       
   301                 raise MultipartError("Syntax error in header: No colon.")
       
   302             name, value = line.split(':', 1)
       
   303             self.headerlist.append((name.strip(), value.strip()))
       
   304 
       
   305     def write_body(self, line, nl):
       
   306         if not line and not nl: return # This does not even flush the buffer
       
   307         self.size += len(line) + len(self._buf)
       
   308         self.file.write(self._buf + line)
       
   309         self._buf = nl
       
   310         if self.content_length > 0 and self.size > self.content_length:
       
   311             raise MultipartError('Size of body exceeds Content-Length header.')
       
   312         if self.size > self.memfile_limit and isinstance(self.file, BytesIO):
       
   313             # TODO: What about non-file uploads that exceed the memfile_limit?
       
   314             self.file, old = TemporaryFile(mode='w+b'), self.file
       
   315             old.seek(0)
       
   316             copy_file(old, self.file, self.size, self.buffer_size)
       
   317 
       
   318     def finish_header(self):
       
   319         self.file = BytesIO()
       
   320         self.headers = Headers(self.headerlist)
       
   321         cdis = self.headers.get('Content-Disposition','')
       
   322         ctype = self.headers.get('Content-Type','')
       
   323         clen = self.headers.get('Content-Length','-1')
       
   324         if not cdis:
       
   325             raise MultipartError('Content-Disposition header is missing.')
       
   326         self.disposition, self.options = parse_options_header(cdis)
       
   327         self.name = self.options.get('name')
       
   328         self.filename = self.options.get('filename')
       
   329         self.content_type, options = parse_options_header(ctype)
       
   330         self.charset = options.get('charset') or self.charset
       
   331         self.content_length = int(self.headers.get('Content-Length','-1'))
       
   332 
       
   333     def is_buffered(self):
       
   334         ''' Return true if the data is fully buffered in memory.'''
       
   335         return isinstance(self.file, BytesIO)
       
   336 
       
   337     @property
       
   338     def value(self):
       
   339         ''' Data decoded with the specified charset '''
       
   340         pos = self.file.tell()
       
   341         self.file.seek(0)
       
   342         val = self.file.read()
       
   343         self.file.seek(pos)
       
   344         return val.decode(self.charset)
       
   345     
       
   346     def save_as(self, path):
       
   347         fp = open(path, 'wb')
       
   348         pos = self.file.tell()
       
   349         try:
       
   350             self.file.seek(0)
       
   351             size = copy_file(self.file, fp)
       
   352         finally:
       
   353             self.file.seek(pos)
       
   354         return size
       
   355 
       
   356 ##############################################################################
       
   357 #################################### WSGI ####################################
       
   358 ##############################################################################
       
   359 
       
   360 def parse_form_data(environ, charset='utf8', strict=False, **kw):
       
   361     ''' Parse form data from an environ dict and return a (forms, files) tuple.
       
   362         Both tuple values are dictionaries with the form-field name as a key
       
   363         (unicode) and lists as values (multiple values per key are possible).
       
   364         The forms-dictionary contains form-field values as unicode strings.
       
   365         The files-dictionary contains :class:`MultipartPart` instances, either
       
   366         because the form-field was a file-upload or the value is to big to fit
       
   367         into memory limits.
       
   368         
       
   369         :param environ: An WSGI environment dict.
       
   370         :param charset: The charset to use if unsure. (default: utf8)
       
   371         :param strict: If True, raise :exc:`MultipartError` on any parsing
       
   372                        errors. These are silently ignored by default.
       
   373     '''
       
   374         
       
   375     forms, files = MultiDict(), MultiDict()
       
   376     try:
       
   377         if environ.get('REQUEST_METHOD','GET').upper() not in ('POST', 'PUT'):
       
   378             raise MultipartError("Request method other than POST or PUT.")
       
   379         content_length = int(environ.get('CONTENT_LENGTH', '-1'))
       
   380         content_type = environ.get('CONTENT_TYPE', '')
       
   381         if not content_type:
       
   382             raise MultipartError("Missing Content-Type header.")
       
   383         content_type, options = parse_options_header(content_type)
       
   384         stream = environ.get('wsgi.input') or BytesIO()
       
   385         kw['charset'] = charset = options.get('charset', charset)
       
   386         if content_type == 'multipart/form-data':
       
   387             boundary = options.get('boundary','')
       
   388             if not boundary:
       
   389                 raise MultipartError("No boundary for multipart/form-data.")
       
   390             for part in MultipartParser(stream, boundary, content_length, **kw):
       
   391                 if part.filename or not part.is_buffered():
       
   392                     files[part.name] = part
       
   393                 else: # TODO: Big form-fields are in the files dict. really?
       
   394                     forms[part.name] = part.value
       
   395         elif content_type in ('application/x-www-form-urlencoded',
       
   396                               'application/x-url-encoded'):
       
   397             mem_limit = kw.get('mem_limit', 2**20)
       
   398             if content_length > mem_limit:
       
   399                 raise MultipartError("Request too big. Increase MAXMEM.")
       
   400             data = stream.read(mem_limit)
       
   401             if stream.read(1): # These is more that does not fit mem_limit
       
   402                 raise MultipartError("Request too big. Increase MAXMEM.")
       
   403             if PY3:
       
   404                 data = data.decode('ascii')
       
   405             data = parse_qs(data, keep_blank_values=True)
       
   406             for key, values in data.items():
       
   407                 for value in values:
       
   408                     if PY3:
       
   409                         forms[key] = value
       
   410                     else:
       
   411                         forms[key.decode(charset)] = value.decode(charset)
       
   412         else:
       
   413             raise MultipartError("Unsupported content type.")
       
   414     except MultipartError:
       
   415         if strict: raise
       
   416     return forms, files