ability to map attributes to something else than usual cw mapping on sql generation stable
authorSylvain Thénault <sylvain.thenault@logilab.fr>
Thu, 09 Jul 2009 16:15:22 +0200
branchstable
changeset 2354 9b4bac626977
parent 2353 b11f1068a0d3
child 2359 d78cf4586707
ability to map attributes to something else than usual cw mapping on sql generation
server/sources/native.py
server/sources/rql2sql.py
server/test/unittest_rql2sql.py
--- a/server/sources/native.py	Thu Jul 09 16:14:22 2009 +0200
+++ b/server/sources/native.py	Thu Jul 09 16:15:22 2009 +0200
@@ -31,6 +31,7 @@
 from cubicweb.server.sources.rql2sql import SQLGenerator
 
 
+ATTR_MAP = {}
 NONSYSTEM_ETYPES = set()
 NONSYSTEM_RELATIONS = set()
 
@@ -90,6 +91,7 @@
 class NativeSQLSource(SQLAdapterMixIn, AbstractSource):
     """adapter for source using the native cubicweb schema (see below)
     """
+    sqlgen_class = SQLGenerator
     # need default value on class since migration doesn't call init method
     has_deleted_entitites_table = True
 
@@ -141,8 +143,8 @@
         AbstractSource.__init__(self, repo, appschema, source_config,
                                 *args, **kwargs)
         # sql generator
-        self._rql_sqlgen = SQLGenerator(appschema, self.dbhelper,
-                                        self.encoding)
+        self._rql_sqlgen = self.sqlgen_class(appschema, self.dbhelper,
+                                             self.encoding, ATTR_MAP.copy())
         # full text index helper
         self.indexer = get_indexer(self.dbdriver, self.encoding)
         # advanced functionality helper
@@ -209,6 +211,9 @@
         pool.pool_reset()
         self.repo._free_pool(pool)
 
+    def map_attribute(self, etype, attr, cb):
+        self._rql_sqlgen.attr_map['%s.%s' % (etype, attr)] = cb
+        
     # ISource interface #######################################################
 
     def compile_rql(self, rql):
--- a/server/sources/rql2sql.py	Thu Jul 09 16:14:22 2009 +0200
+++ b/server/sources/rql2sql.py	Thu Jul 09 16:15:22 2009 +0200
@@ -303,7 +303,7 @@
     protected by a lock
     """
 
-    def __init__(self, schema, dbms_helper, dbencoding='UTF-8'):
+    def __init__(self, schema, dbms_helper, dbencoding='UTF-8', attrmap=None):
         self.schema = schema
         self.dbms_helper = dbms_helper
         self.dbencoding = dbencoding
@@ -313,6 +313,9 @@
         if not self.dbms_helper.union_parentheses_support:
             self.union_sql = self.noparen_union_sql
         self._lock = threading.Lock()
+        if attrmap is None:
+            attrmap = {}
+        self.attr_map = attrmap
 
     def generate(self, union, args=None, varmap=None):
         """return SQL queries and a variable dictionnary from a RQL syntax tree
@@ -853,27 +856,30 @@
                                        relation.r_type)
         return '%s%s' % (lhssql, relation.children[1].accept(self, contextrels))
 
-    def _visit_attribute_relation(self, relation):
+    def _visit_attribute_relation(self, rel):
         """generate SQL for an attribute relation"""
-        lhs, rhs = relation.get_parts()
+        lhs, rhs = rel.get_parts()
         rhssql = rhs.accept(self)
         table = self._var_table(lhs.variable)
         if table is None:
-            assert relation.r_type == 'eid'
+            assert rel.r_type == 'eid'
             lhssql = lhs.accept(self)
         else:
             try:
-                lhssql = self._varmap['%s.%s' % (lhs.name, relation.r_type)]
+                lhssql = self._varmap['%s.%s' % (lhs.name, rel.r_type)]
             except KeyError:
-                if relation.r_type == 'eid':
+                mapkey = '%s.%s' % (self._state.solution[lhs.name], rel.r_type)
+                if mapkey in self.attr_map:
+                    lhssql = self.attr_map[mapkey](self, lhs.variable, rel)
+                elif rel.r_type == 'eid':
                     lhssql = lhs.variable._q_sql
                 else:
-                    lhssql = '%s.%s%s' % (table, SQL_PREFIX, relation.r_type)
+                    lhssql = '%s.%s%s' % (table, SQL_PREFIX, rel.r_type)
         try:
-            if relation._q_needcast == 'TODAY':
+            if rel._q_needcast == 'TODAY':
                 sql = 'DATE(%s)%s' % (lhssql, rhssql)
             # XXX which cast function should be used
-            #elif relation._q_needcast == 'NOW':
+            #elif rel._q_needcast == 'NOW':
             #    sql = 'TIMESTAMP(%s)%s' % (lhssql, rhssql)
             else:
                 sql = '%s%s' % (lhssql, rhssql)
@@ -884,15 +890,15 @@
         else:
             return sql
 
-    def _visit_has_text_relation(self, relation):
+    def _visit_has_text_relation(self, rel):
         """generate SQL for a has_text relation"""
-        lhs, rhs = relation.get_parts()
+        lhs, rhs = rel.get_parts()
         const = rhs.children[0]
-        alias = self._fti_table(relation)
+        alias = self._fti_table(rel)
         jointo = lhs.accept(self)
         restriction = ''
         lhsvar = lhs.variable
-        me_is_principal = lhsvar.stinfo.get('principal') is relation
+        me_is_principal = lhsvar.stinfo.get('principal') is rel
         if me_is_principal:
             if not lhsvar.stinfo['typerels']:
                 # the variable is using the fti table, no join needed
@@ -908,8 +914,8 @@
                 else:
                     etypes = ','.join("'%s'" % etype for etype in lhsvar.stinfo['possibletypes'])
                     restriction = " AND %s.type IN (%s)" % (ealias, etypes)
-        if isinstance(relation.parent, Not):
-            self._state.done.add(relation.parent)
+        if isinstance(rel.parent, Not):
+            self._state.done.add(rel.parent)
             not_ = True
         else:
             not_ = False
@@ -1117,6 +1123,9 @@
         if isinstance(linkedvar, ColumnAlias):
             raise BadRQLQuery('variable %s should be selected by the subquery'
                               % variable.name)
+        mapkey = '%s.%s' % (self._state.solution[linkedvar.name], rel.r_type)
+        if mapkey in self.attr_map:
+            return self.attr_map[mapkey](self, linkedvar, rel)
         try:
             sql = self._varmap['%s.%s' % (linkedvar.name, rel.r_type)]
         except KeyError:
--- a/server/test/unittest_rql2sql.py	Thu Jul 09 16:14:22 2009 +0200
+++ b/server/test/unittest_rql2sql.py	Thu Jul 09 16:15:22 2009 +0200
@@ -1436,6 +1436,22 @@
                     '''SELECT COUNT(1)
 WHERE EXISTS(SELECT 1 FROM owned_by_relation AS rel_owned_by0, cw_Affaire AS P WHERE rel_owned_by0.eid_from=P.cw_eid AND rel_owned_by0.eid_to=1 UNION SELECT 1 FROM owned_by_relation AS rel_owned_by1, cw_Note AS P WHERE rel_owned_by1.eid_from=P.cw_eid AND rel_owned_by1.eid_to=1)''')
 
+    def test_attr_map(self):
+        def generate_ref(gen, linkedvar, rel):
+            linkedvar.accept(gen)
+            return 'VERSION_DATA(%s)' % linkedvar._q_sql
+        self.o.attr_map['Affaire.ref'] = generate_ref
+        try:
+            self._check('Any R WHERE X ref R',
+                        '''SELECT VERSION_DATA(X.cw_eid)
+FROM cw_Affaire AS X''')
+            self._check('Any X WHERE X ref 1',
+                        '''SELECT X.cw_eid
+FROM cw_Affaire AS X
+WHERE VERSION_DATA(X.cw_eid)=1''')
+        finally:
+            self.o.attr_map.clear()
+
 
 class SqliteSQLGeneratorTC(PostgresSQLGeneratorTC):