|
1 """SQL utilities functions and classes. |
|
2 |
|
3 :organization: Logilab |
|
4 :copyright: 2001-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved. |
|
5 :contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr |
|
6 """ |
|
7 __docformat__ = "restructuredtext en" |
|
8 |
|
9 from logilab.common.shellutils import ProgressBar |
|
10 from logilab.common.db import get_dbapi_compliant_module |
|
11 from logilab.common.adbh import get_adv_func_helper |
|
12 from logilab.common.sqlgen import SQLGenerator |
|
13 |
|
14 from indexer import get_indexer |
|
15 |
|
16 from cubicweb import Binary, ConfigurationError |
|
17 from cubicweb.common.uilib import remove_html_tags |
|
18 from cubicweb.server import SQL_CONNECT_HOOKS |
|
19 from cubicweb.server.utils import crypt_password, cartesian_product |
|
20 |
|
21 |
|
22 def sqlexec(sqlstmts, cursor_or_execute, withpb=True, delimiter=';'): |
|
23 """execute sql statements ignoring DROP/ CREATE GROUP or USER statements |
|
24 error. If a cnx is given, commit at each statement |
|
25 """ |
|
26 if hasattr(cursor_or_execute, 'execute'): |
|
27 execute = cursor_or_execute.execute |
|
28 else: |
|
29 execute = cursor_or_execute |
|
30 sqlstmts = sqlstmts.split(delimiter) |
|
31 if withpb: |
|
32 pb = ProgressBar(len(sqlstmts)) |
|
33 for sql in sqlstmts: |
|
34 sql = sql.strip() |
|
35 if withpb: |
|
36 pb.update() |
|
37 if not sql: |
|
38 continue |
|
39 # some dbapi modules doesn't accept unicode for sql string |
|
40 execute(str(sql)) |
|
41 if withpb: |
|
42 print |
|
43 |
|
44 |
|
45 def sqlgrants(schema, driver, user, |
|
46 text_index=True, set_owner=True, |
|
47 skip_relations=(), skip_entities=()): |
|
48 """return sql to give all access privileges to the given user on the system |
|
49 schema |
|
50 """ |
|
51 from yams.schema2sql import grant_schema |
|
52 from cubicweb.server.sources import native |
|
53 output = [] |
|
54 w = output.append |
|
55 w(native.grant_schema(user, set_owner)) |
|
56 w('') |
|
57 if text_index: |
|
58 indexer = get_indexer(driver) |
|
59 w(indexer.sql_grant_user(user)) |
|
60 w('') |
|
61 w(grant_schema(schema, user, set_owner, skip_entities=skip_entities)) |
|
62 return '\n'.join(output) |
|
63 |
|
64 |
|
65 def sqlschema(schema, driver, text_index=True, |
|
66 user=None, set_owner=False, |
|
67 skip_relations=('has_text', 'identity'), skip_entities=()): |
|
68 """return the system sql schema, according to the given parameters""" |
|
69 from yams.schema2sql import schema2sql |
|
70 from cubicweb.server.sources import native |
|
71 if set_owner: |
|
72 assert user, 'user is argument required when set_owner is true' |
|
73 output = [] |
|
74 w = output.append |
|
75 w(native.sql_schema(driver)) |
|
76 w('') |
|
77 if text_index: |
|
78 indexer = get_indexer(driver) |
|
79 w(indexer.sql_init_fti()) |
|
80 w('') |
|
81 dbhelper = get_adv_func_helper(driver) |
|
82 w(schema2sql(dbhelper, schema, |
|
83 skip_entities=skip_entities, skip_relations=skip_relations)) |
|
84 if dbhelper.users_support and user: |
|
85 w('') |
|
86 w(sqlgrants(schema, driver, user, text_index, set_owner, |
|
87 skip_relations, skip_entities)) |
|
88 return '\n'.join(output) |
|
89 |
|
90 |
|
91 def sqldropschema(schema, driver, text_index=True, |
|
92 skip_relations=('has_text', 'identity'), skip_entities=()): |
|
93 """return the sql to drop the schema, according to the given parameters""" |
|
94 from yams.schema2sql import dropschema2sql |
|
95 from cubicweb.server.sources import native |
|
96 output = [] |
|
97 w = output.append |
|
98 w(native.sql_drop_schema(driver)) |
|
99 w('') |
|
100 if text_index: |
|
101 indexer = get_indexer(driver) |
|
102 w(indexer.sql_drop_fti()) |
|
103 w('') |
|
104 w(dropschema2sql(schema, |
|
105 skip_entities=skip_entities, skip_relations=skip_relations)) |
|
106 return '\n'.join(output) |
|
107 |
|
108 |
|
109 |
|
110 class SQLAdapterMixIn(object): |
|
111 """Mixin for SQL data sources, getting a connection from a configuration |
|
112 dictionary and handling connection locking |
|
113 """ |
|
114 |
|
115 def __init__(self, source_config): |
|
116 try: |
|
117 self.dbdriver = source_config['db-driver'].lower() |
|
118 self.dbname = source_config['db-name'] |
|
119 except KeyError: |
|
120 raise ConfigurationError('missing some expected entries in sources file') |
|
121 self.dbhost = source_config.get('db-host') |
|
122 port = source_config.get('db-port') |
|
123 self.dbport = port and int(port) or None |
|
124 self.dbuser = source_config.get('db-user') |
|
125 self.dbpasswd = source_config.get('db-password') |
|
126 self.encoding = source_config.get('db-encoding', 'UTF-8') |
|
127 self.dbapi_module = get_dbapi_compliant_module(self.dbdriver) |
|
128 self.binary = self.dbapi_module.Binary |
|
129 self.dbhelper = self.dbapi_module.adv_func_helper |
|
130 self.sqlgen = SQLGenerator() |
|
131 |
|
132 def get_connection(self, user=None, password=None): |
|
133 """open and return a connection to the database""" |
|
134 if user or self.dbuser: |
|
135 self.info('connecting to %s@%s for user %s', self.dbname, |
|
136 self.dbhost or 'localhost', user or self.dbuser) |
|
137 else: |
|
138 self.info('connecting to %s@%s', self.dbname, |
|
139 self.dbhost or 'localhost') |
|
140 cnx = self.dbapi_module.connect(self.dbhost, self.dbname, |
|
141 user or self.dbuser, |
|
142 password or self.dbpasswd, |
|
143 port=self.dbport) |
|
144 init_cnx(self.dbdriver, cnx) |
|
145 #self.dbapi_module.type_code_test(cnx.cursor()) |
|
146 return cnx |
|
147 |
|
148 def merge_args(self, args, query_args): |
|
149 if args is not None: |
|
150 args = dict(args) |
|
151 for key, val in args.items(): |
|
152 # convert cubicweb binary into db binary |
|
153 if isinstance(val, Binary): |
|
154 val = self.binary(val.getvalue()) |
|
155 args[key] = val |
|
156 # should not collide |
|
157 args.update(query_args) |
|
158 return args |
|
159 return query_args |
|
160 |
|
161 def process_result(self, cursor): |
|
162 """return a list of CubicWeb compliant values from data in the given cursor |
|
163 """ |
|
164 descr = cursor.description |
|
165 encoding = self.encoding |
|
166 process_value = self.dbapi_module.process_value |
|
167 binary = Binary |
|
168 results = cursor.fetchall() |
|
169 for i, line in enumerate(results): |
|
170 result = [] |
|
171 for col, value in enumerate(line): |
|
172 if value is None: |
|
173 result.append(value) |
|
174 continue |
|
175 result.append(process_value(value, descr[col], encoding, binary)) |
|
176 results[i] = result |
|
177 return results |
|
178 |
|
179 |
|
180 def preprocess_entity(self, entity): |
|
181 """return a dictionary to use as extra argument to cursor.execute |
|
182 to insert/update an entity |
|
183 """ |
|
184 attrs = {} |
|
185 eschema = entity.e_schema |
|
186 for attr, value in entity.items(): |
|
187 rschema = eschema.subject_relation(attr) |
|
188 if rschema.is_final(): |
|
189 atype = str(entity.e_schema.destination(attr)) |
|
190 if atype == 'Boolean': |
|
191 value = self.dbhelper.boolean_value(value) |
|
192 elif atype == 'Password': |
|
193 # if value is a Binary instance, this mean we got it |
|
194 # from a query result and so it is already encrypted |
|
195 if isinstance(value, Binary): |
|
196 value = value.getvalue() |
|
197 else: |
|
198 value = crypt_password(value) |
|
199 elif isinstance(value, Binary): |
|
200 value = self.binary(value.getvalue()) |
|
201 attrs[str(attr)] = value |
|
202 return attrs |
|
203 |
|
204 |
|
205 from logging import getLogger |
|
206 from cubicweb import set_log_methods |
|
207 set_log_methods(SQLAdapterMixIn, getLogger('cubicweb.sqladapter')) |
|
208 |
|
209 def init_sqlite_connexion(cnx): |
|
210 # XXX should not be publicly exposed |
|
211 #def comma_join(strings): |
|
212 # return ', '.join(strings) |
|
213 #cnx.create_function("COMMA_JOIN", 1, comma_join) |
|
214 |
|
215 class concat_strings(object): |
|
216 def __init__(self): |
|
217 self.values = [] |
|
218 def step(self, value): |
|
219 if value is not None: |
|
220 self.values.append(value) |
|
221 def finalize(self): |
|
222 return ', '.join(self.values) |
|
223 # renamed to GROUP_CONCAT in cubicweb 2.45, keep old name for bw compat for |
|
224 # some time |
|
225 cnx.create_aggregate("CONCAT_STRINGS", 1, concat_strings) |
|
226 cnx.create_aggregate("GROUP_CONCAT", 1, concat_strings) |
|
227 |
|
228 def _limit_size(text, maxsize, format='text/plain'): |
|
229 if len(text) < maxsize: |
|
230 return text |
|
231 if format in ('text/html', 'text/xhtml', 'text/xml'): |
|
232 text = remove_html_tags(text) |
|
233 if len(text) > maxsize: |
|
234 text = text[:maxsize] + '...' |
|
235 return text |
|
236 |
|
237 def limit_size3(text, format, maxsize): |
|
238 return _limit_size(text, maxsize, format) |
|
239 cnx.create_function("LIMIT_SIZE", 3, limit_size3) |
|
240 |
|
241 def limit_size2(text, maxsize): |
|
242 return _limit_size(text, maxsize) |
|
243 cnx.create_function("TEXT_LIMIT_SIZE", 2, limit_size2) |
|
244 import yams.constraints |
|
245 if hasattr(yams.constraints, 'patch_sqlite_decimal'): |
|
246 yams.constraints.patch_sqlite_decimal() |
|
247 |
|
248 |
|
249 sqlite_hooks = SQL_CONNECT_HOOKS.setdefault('sqlite', []) |
|
250 sqlite_hooks.append(init_sqlite_connexion) |
|
251 |
|
252 def init_cnx(driver, cnx): |
|
253 for hook in SQL_CONNECT_HOOKS.get(driver, ()): |
|
254 hook(cnx) |