9 |
9 |
10 import os |
10 import os |
11 import subprocess |
11 import subprocess |
12 from datetime import datetime, date |
12 from datetime import datetime, date |
13 |
13 |
14 import logilab.common as lgc |
14 from logilab import db, common as lgc |
15 from logilab.common import db |
|
16 from logilab.common.shellutils import ProgressBar |
15 from logilab.common.shellutils import ProgressBar |
17 from logilab.common.adbh import get_adv_func_helper |
|
18 from logilab.common.sqlgen import SQLGenerator |
|
19 from logilab.common.date import todate, todatetime |
16 from logilab.common.date import todate, todatetime |
20 |
17 from logilab.db.sqlgen import SQLGenerator |
21 from indexer import get_indexer |
|
22 |
18 |
23 from cubicweb import Binary, ConfigurationError |
19 from cubicweb import Binary, ConfigurationError |
24 from cubicweb.uilib import remove_html_tags |
20 from cubicweb.uilib import remove_html_tags |
25 from cubicweb.schema import PURE_VIRTUAL_RTYPES |
21 from cubicweb.schema import PURE_VIRTUAL_RTYPES |
26 from cubicweb.server import SQL_CONNECT_HOOKS |
22 from cubicweb.server import SQL_CONNECT_HOOKS |
27 from cubicweb.server.utils import crypt_password |
23 from cubicweb.server.utils import crypt_password |
28 |
24 from rql.utils import RQL_FUNCTIONS_REGISTRY |
29 |
25 |
30 lgc.USE_MX_DATETIME = False |
26 lgc.USE_MX_DATETIME = False |
31 SQL_PREFIX = 'cw_' |
27 SQL_PREFIX = 'cw_' |
32 |
28 |
33 def _run_command(cmd): |
29 def _run_command(cmd): |
94 assert user, 'user is argument required when set_owner is true' |
90 assert user, 'user is argument required when set_owner is true' |
95 output = [] |
91 output = [] |
96 w = output.append |
92 w = output.append |
97 w(native.sql_schema(driver)) |
93 w(native.sql_schema(driver)) |
98 w('') |
94 w('') |
|
95 dbhelper = db.get_db_helper(driver) |
99 if text_index: |
96 if text_index: |
100 indexer = get_indexer(driver) |
97 w(dbhelper.sql_init_fti()) |
101 w(indexer.sql_init_fti()) |
|
102 w('') |
98 w('') |
103 dbhelper = get_adv_func_helper(driver) |
|
104 w(schema2sql(dbhelper, schema, prefix=SQL_PREFIX, |
99 w(schema2sql(dbhelper, schema, prefix=SQL_PREFIX, |
105 skip_entities=skip_entities, skip_relations=skip_relations)) |
100 skip_entities=skip_entities, skip_relations=skip_relations)) |
106 if dbhelper.users_support and user: |
101 if dbhelper.users_support and user: |
107 w('') |
102 w('') |
108 w(sqlgrants(schema, driver, user, text_index, set_owner, |
103 w(sqlgrants(schema, driver, user, text_index, set_owner, |
135 """ |
130 """ |
136 |
131 |
137 def __init__(self, source_config): |
132 def __init__(self, source_config): |
138 try: |
133 try: |
139 self.dbdriver = source_config['db-driver'].lower() |
134 self.dbdriver = source_config['db-driver'].lower() |
140 self.dbname = source_config['db-name'] |
135 dbname = source_config['db-name'] |
141 except KeyError: |
136 except KeyError: |
142 raise ConfigurationError('missing some expected entries in sources file') |
137 raise ConfigurationError('missing some expected entries in sources file') |
143 self.dbhost = source_config.get('db-host') |
138 dbhost = source_config.get('db-host') |
144 port = source_config.get('db-port') |
139 port = source_config.get('db-port') |
145 self.dbport = port and int(port) or None |
140 dbport = port and int(port) or None |
146 self.dbuser = source_config.get('db-user') |
141 dbuser = source_config.get('db-user') |
147 self.dbpasswd = source_config.get('db-password') |
142 dbpassword = source_config.get('db-password') |
148 self.encoding = source_config.get('db-encoding', 'UTF-8') |
143 dbencoding = source_config.get('db-encoding', 'UTF-8') |
149 self.dbapi_module = db.get_dbapi_compliant_module(self.dbdriver) |
144 dbextraargs = source_config.get('db-extra-arguments') |
150 self.dbdriver_extra_args = source_config.get('db-extra-arguments') |
145 self.dbhelper = db.get_db_helper(self.dbdriver) |
151 self.binary = self.dbapi_module.Binary |
146 self.dbhelper.record_connection_info(dbname, dbhost, dbport, dbuser, |
152 self.dbhelper = self.dbapi_module.adv_func_helper |
147 dbpassword, dbextraargs, |
|
148 dbencoding) |
153 self.sqlgen = SQLGenerator() |
149 self.sqlgen = SQLGenerator() |
154 |
150 # copy back some commonly accessed attributes |
155 def get_connection(self, user=None, password=None): |
151 dbapi_module = self.dbhelper.dbapi_module |
|
152 self.OperationalError = dbapi_module.OperationalError |
|
153 self.InterfaceError = dbapi_module.InterfaceError |
|
154 self._binary = dbapi_module.Binary |
|
155 self._process_value = dbapi_module.process_value |
|
156 self._dbencoding = dbencoding |
|
157 |
|
158 def get_connection(self): |
156 """open and return a connection to the database""" |
159 """open and return a connection to the database""" |
157 if user or self.dbuser: |
160 return self.dbhelper.get_connection() |
158 self.info('connecting to %s@%s for user %s', self.dbname, |
|
159 self.dbhost or 'localhost', user or self.dbuser) |
|
160 else: |
|
161 self.info('connecting to %s@%s', self.dbname, |
|
162 self.dbhost or 'localhost') |
|
163 extra = {} |
|
164 if self.dbdriver_extra_args: |
|
165 extra = {'extra_args': self.dbdriver_extra_args} |
|
166 cnx = self.dbapi_module.connect(self.dbhost, self.dbname, |
|
167 user or self.dbuser, |
|
168 password or self.dbpasswd, |
|
169 port=self.dbport, |
|
170 **extra) |
|
171 init_cnx(self.dbdriver, cnx) |
|
172 #self.dbapi_module.type_code_test(cnx.cursor()) |
|
173 return cnx |
|
174 |
161 |
175 def backup_to_file(self, backupfile): |
162 def backup_to_file(self, backupfile): |
176 for cmd in self.dbhelper.backup_commands(self.dbname, self.dbhost, |
163 for cmd in self.dbhelper.backup_commands(backupfile, |
177 self.dbuser, backupfile, |
|
178 dbport=self.dbport, |
|
179 keepownership=False): |
164 keepownership=False): |
180 if _run_command(cmd): |
165 if _run_command(cmd): |
181 if not confirm(' [Failed] Continue anyway?', default='n'): |
166 if not confirm(' [Failed] Continue anyway?', default='n'): |
182 raise Exception('Failed command: %s' % cmd) |
167 raise Exception('Failed command: %s' % cmd) |
183 |
168 |
184 def restore_from_file(self, backupfile, confirm, drop=True): |
169 def restore_from_file(self, backupfile, confirm, drop=True): |
185 for cmd in self.dbhelper.restore_commands(self.dbname, self.dbhost, |
170 for cmd in self.dbhelper.restore_commands(backupfile, |
186 self.dbuser, backupfile, |
|
187 self.encoding, |
|
188 dbport=self.dbport, |
|
189 keepownership=False, |
171 keepownership=False, |
190 drop=drop): |
172 drop=drop): |
191 if _run_command(cmd): |
173 if _run_command(cmd): |
192 if not confirm(' [Failed] Continue anyway?', default='n'): |
174 if not confirm(' [Failed] Continue anyway?', default='n'): |
193 raise Exception('Failed command: %s' % cmd) |
175 raise Exception('Failed command: %s' % cmd) |
196 if args is not None: |
178 if args is not None: |
197 newargs = {} |
179 newargs = {} |
198 for key, val in args.iteritems(): |
180 for key, val in args.iteritems(): |
199 # convert cubicweb binary into db binary |
181 # convert cubicweb binary into db binary |
200 if isinstance(val, Binary): |
182 if isinstance(val, Binary): |
201 val = self.binary(val.getvalue()) |
183 val = self._binary(val.getvalue()) |
202 newargs[key] = val |
184 newargs[key] = val |
203 # should not collide |
185 # should not collide |
204 newargs.update(query_args) |
186 newargs.update(query_args) |
205 return newargs |
187 return newargs |
206 return query_args |
188 return query_args |
207 |
189 |
208 def process_result(self, cursor): |
190 def process_result(self, cursor): |
209 """return a list of CubicWeb compliant values from data in the given cursor |
191 """return a list of CubicWeb compliant values from data in the given cursor |
210 """ |
192 """ |
|
193 # begin bind to locals for optimization |
211 descr = cursor.description |
194 descr = cursor.description |
212 encoding = self.encoding |
195 encoding = self._dbencoding |
213 process_value = self.dbapi_module.process_value |
196 process_value = self._process_value |
214 binary = Binary |
197 binary = Binary |
|
198 # /end |
215 results = cursor.fetchall() |
199 results = cursor.fetchall() |
216 for i, line in enumerate(results): |
200 for i, line in enumerate(results): |
217 result = [] |
201 result = [] |
218 for col, value in enumerate(line): |
202 for col, value in enumerate(line): |
219 if value is None: |
203 if value is None: |
240 # from a query result and so it is already encrypted |
224 # from a query result and so it is already encrypted |
241 if isinstance(value, Binary): |
225 if isinstance(value, Binary): |
242 value = value.getvalue() |
226 value = value.getvalue() |
243 else: |
227 else: |
244 value = crypt_password(value) |
228 value = crypt_password(value) |
245 value = self.binary(value) |
229 value = self._binary(value) |
246 # XXX needed for sqlite but I don't think it is for other backends |
230 # XXX needed for sqlite but I don't think it is for other backends |
247 elif atype == 'Datetime' and isinstance(value, date): |
231 elif atype == 'Datetime' and isinstance(value, date): |
248 value = todatetime(value) |
232 value = todatetime(value) |
249 elif atype == 'Date' and isinstance(value, datetime): |
233 elif atype == 'Date' and isinstance(value, datetime): |
250 value = todate(value) |
234 value = todate(value) |
251 elif isinstance(value, Binary): |
235 elif isinstance(value, Binary): |
252 value = self.binary(value.getvalue()) |
236 value = self._binary(value.getvalue()) |
253 attrs[SQL_PREFIX+str(attr)] = value |
237 attrs[SQL_PREFIX+str(attr)] = value |
254 return attrs |
238 return attrs |
255 |
239 |
256 |
240 |
257 from logging import getLogger |
241 from logging import getLogger |
258 from cubicweb import set_log_methods |
242 from cubicweb import set_log_methods |
259 set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter')) |
243 set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter')) |
260 |
244 |
261 def init_sqlite_connexion(cnx): |
245 def init_sqlite_connexion(cnx): |
262 # XXX should not be publicly exposed |
246 |
263 #def comma_join(strings): |
247 class group_concat(object): |
264 # return ', '.join(strings) |
|
265 #cnx.create_function("COMMA_JOIN", 1, comma_join) |
|
266 |
|
267 class concat_strings(object): |
|
268 def __init__(self): |
248 def __init__(self): |
269 self.values = [] |
249 self.values = [] |
270 def step(self, value): |
250 def step(self, value): |
271 if value is not None: |
251 if value is not None: |
272 self.values.append(value) |
252 self.values.append(value) |
273 def finalize(self): |
253 def finalize(self): |
274 return ', '.join(self.values) |
254 return ', '.join(self.values) |
275 # renamed to GROUP_CONCAT in cubicweb 2.45, keep old name for bw compat for |
255 cnx.create_aggregate("GROUP_CONCAT", 1, group_concat) |
276 # some time |
|
277 cnx.create_aggregate("CONCAT_STRINGS", 1, concat_strings) |
|
278 cnx.create_aggregate("GROUP_CONCAT", 1, concat_strings) |
|
279 |
256 |
280 def _limit_size(text, maxsize, format='text/plain'): |
257 def _limit_size(text, maxsize, format='text/plain'): |
281 if len(text) < maxsize: |
258 if len(text) < maxsize: |
282 return text |
259 return text |
283 if format in ('text/html', 'text/xhtml', 'text/xml'): |
260 if format in ('text/html', 'text/xhtml', 'text/xml'): |
291 cnx.create_function("LIMIT_SIZE", 3, limit_size3) |
268 cnx.create_function("LIMIT_SIZE", 3, limit_size3) |
292 |
269 |
293 def limit_size2(text, maxsize): |
270 def limit_size2(text, maxsize): |
294 return _limit_size(text, maxsize) |
271 return _limit_size(text, maxsize) |
295 cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2) |
272 cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2) |
|
273 |
296 import yams.constraints |
274 import yams.constraints |
297 if hasattr(yams.constraints, 'patch_sqlite_decimal'): |
275 yams.constraints.patch_sqlite_decimal() |
298 yams.constraints.patch_sqlite_decimal() |
|
299 |
276 |
300 def fspath(eid, etype, attr): |
277 def fspath(eid, etype, attr): |
301 try: |
278 try: |
302 cu = cnx.cursor() |
279 cu = cnx.cursor() |
303 cu.execute('SELECT X.cw_%s FROM cw_%s as X ' |
280 cu.execute('SELECT X.cw_%s FROM cw_%s as X ' |