[dataimport] Properly escape strings sent to COPY FROM (closes #5278743)
authorRémi Cardona <remi.cardona@logilab.fr>
Fri, 07 Nov 2014 15:33:30 +0100
changeset 10349 efbbf1e93a04
parent 10347 52a976c5d27a
child 10350 31327bd26931
[dataimport] Properly escape strings sent to COPY FROM (closes #5278743) See http://www.postgresql.org/docs/9.1/static/sql-copy.html#AEN64296 for escaping codes.
dataimport.py
test/unittest_dataimport.py
--- a/dataimport.py	Tue Jul 01 13:19:35 2014 +0200
+++ b/dataimport.py	Fri Nov 07 15:33:30 2014 +0100
@@ -449,22 +449,12 @@
 
     Recognized keywords:
     :encoding: resulting string encoding (default: utf-8)
-    :replace_sep: character used when input contains characters
-                  that conflict with the column separator.
     '''
     encoding = opts.get('encoding','utf-8')
-    replace_sep = opts.get('replace_sep', None)
-    # Remove separators used in string formatting
-    for _char in (u'\t', u'\r', u'\n'):
-        if _char in value:
-            # If a replace_sep is given, replace
-            # the separator
-            # (and thus avoid empty buffer)
-            if replace_sep is None:
-                raise ValueError('conflicting separator: '
-                                 'you must provide the replace_sep option')
-            value = value.replace(_char, replace_sep)
-        value = value.replace('\\', r'\\')
+    escape_chars = ((u'\\', ur'\\'), (u'\t', u'\\t'), (u'\r', u'\\r'),
+                    (u'\n', u'\\n'))
+    for char, replace in escape_chars:
+        value = value.replace(char, replace)
     if isinstance(value, unicode):
         value = value.encode(encoding)
     return value
--- a/test/unittest_dataimport.py	Tue Jul 01 13:19:35 2014 +0200
+++ b/test/unittest_dataimport.py	Fri Nov 07 15:33:30 2014 +0100
@@ -49,8 +49,9 @@
         # unicode
         self.assertEqual('\xc3\xa9l\xc3\xa9phant', cnvt(u'éléphant'))
         self.assertEqual('\xe9l\xe9phant', cnvt(u'éléphant', encoding='latin1'))
-        self.assertEqual('babar#', cnvt('babar\t', replace_sep='#'))
-        self.assertRaises(ValueError, cnvt, 'babar\t')
+        # escaping
+        self.assertEqual('babar\\tceleste\\n', cnvt('babar\tceleste\n'))
+        self.assertEqual(r'C:\\new\tC:\\test', cnvt('C:\\new\tC:\\test'))
 
     def test_convert_date(self):
         cnvt = dataimport._copyfrom_buffer_convert_date