[dataimport] _create_copyfrom_buffer: put converters into a list
authorAlain Leufroy <alain.leufroy@logilab.fr>
Mon, 02 Jun 2014 13:50:15 +0200
changeset 9898 70056633085c
parent 9897 fa44db7da2dc
child 9899 2918ef1e3199
[dataimport] _create_copyfrom_buffer: put converters into a list Cleans up the code to avoid a succession of ifs. Related to #3845572
dataimport.py
--- a/dataimport.py	Fri Jul 18 17:35:25 2014 +0200
+++ b/dataimport.py	Mon Jun 02 13:50:15 2014 +0200
@@ -425,10 +425,67 @@
         cnx.commit()
         cu.close()
 
-def _create_copyfrom_buffer(data, columns, encoding='utf-8', replace_sep=None):
+
+def _copyfrom_buffer_convert_None(value, **opts):
+    '''Convert None value to "NULL"'''
+    return 'NULL'
+
+def _copyfrom_buffer_convert_number(value, **opts):
+    '''Convert a number into its string representation'''
+    return str(value)
+
+def _copyfrom_buffer_convert_string(value, **opts):
+    '''Convert string value.
+
+    Recognized keywords:
+    :encoding: resulting string encoding (default: utf-8)
+    :replace_sep: character used when input contains characters
+                  that conflict with the column separator.
+    '''
+    encoding = opts.get('encoding','utf-8')
+    replace_sep = opts.get('replace_sep', None)
+    # Remove separators used in string formatting
+    for _char in (u'\t', u'\r', u'\n'):
+        if _char in value:
+            # If a replace_sep is given, replace
+            # the separator
+            # (and thus avoid empty buffer)
+            if replace_sep:
+                value = value.replace(_char, replace_sep)
+            else:
+                return
+        value = value.replace('\\', r'\\')
+    if isinstance(value, unicode):
+        value = value.encode(encoding)
+    return value
+
+def _copyfrom_buffer_convert_datetime(value, **opts):
+    '''Convert date into "YYYY-MM-DD"'''
+    # Do not use strftime, as it yields issue
+    # with date < 1900
+    value = '%04d-%02d-%02d' % (value.year, value.month, value.day)
+    if isinstance(value, datetime):
+        value += ' %02d:%02d:%02d' % (value.hour,
+                                      value.minutes,
+                                      value.second)
+    return value
+
+# (types, converter) list.
+_COPYFROM_BUFFER_CONVERTERS = [
+    (type(None), _copyfrom_buffer_convert_None),
+    ((long, int, float), _copyfrom_buffer_convert_number),
+    (basestring, _copyfrom_buffer_convert_string),
+    ((date, datetime), _copyfrom_buffer_convert_datetime)
+]
+
+def _create_copyfrom_buffer(data, columns, **convert_opts):
     """
     Create a StringIO buffer for 'COPY FROM' command.
-    Deals with Unicode, Int, Float, Date...
+    Deals with Unicode, Int, Float, Date... (see ``converters``)
+
+    :data: a sequence/dict of tuples
+    :columns: list of columns to consider (default to all columns)
+    :converter_opts: keyword arguements given to converters
     """
     # Create a list rather than directly create a StringIO
     # to correctly write lines separated by '\n' in a single step
@@ -444,41 +501,19 @@
             try:
                 value = row[col]
             except KeyError:
-                warnings.warn(u"Column %s is not accessible in row %s" 
+                warnings.warn(u"Column %s is not accessible in row %s"
                               % (col, row), RuntimeWarning)
-                # XXX 'value' set to None so that the import does not end in 
-                # error. 
-                # Instead, the extra keys are set to NULL from the 
+                # XXX 'value' set to None so that the import does not end in
+                # error.
+                # Instead, the extra keys are set to NULL from the
                 # database point of view.
                 value = None
-            if value is None:
-                value = 'NULL'
-            elif isinstance(value, (long, int, float)):
-                value = str(value)
-            elif isinstance(value, (str, unicode)):
-                # Remove separators used in string formatting
-                for _char in (u'\t', u'\r', u'\n'):
-                    if _char in value:
-                        # If a replace_sep is given, replace
-                        # the separator instead of returning None
-                        # (and thus avoid empty buffer)
-                        if replace_sep:
-                            value = value.replace(_char, replace_sep)
-                        else:
-                            return
-                value = value.replace('\\', r'\\')
-                if value is None:
-                    return
-                if isinstance(value, unicode):
-                    value = value.encode(encoding)
-            elif isinstance(value, (date, datetime)):
-                value = '%04d-%02d-%02d' % (value.year,
-                                            value.month,
-                                            value.day)
-                if isinstance(value, datetime):
-                    value += ' %02d:%02d:%02d' % (value.hour,
-                                                  value.minutes,
-                                                  value.second)
+            for types, converter in _COPYFROM_BUFFER_CONVERTERS:
+                if isinstance(value, types):
+                    value = converter(value, **convert_opts)
+                    if value is None:
+                        return None
+                    break
             else:
                 return None
             # We push the value to the new formatted row