server/sources/rql2sql.py
branchstable
changeset 2354 9b4bac626977
parent 2199 bd0a0f219751
child 2915 651bbe1526b6
--- a/server/sources/rql2sql.py	Thu Jul 09 16:14:22 2009 +0200
+++ b/server/sources/rql2sql.py	Thu Jul 09 16:15:22 2009 +0200
@@ -303,7 +303,7 @@
     protected by a lock
     """
 
-    def __init__(self, schema, dbms_helper, dbencoding='UTF-8'):
+    def __init__(self, schema, dbms_helper, dbencoding='UTF-8', attrmap=None):
         self.schema = schema
         self.dbms_helper = dbms_helper
         self.dbencoding = dbencoding
@@ -313,6 +313,9 @@
         if not self.dbms_helper.union_parentheses_support:
             self.union_sql = self.noparen_union_sql
         self._lock = threading.Lock()
+        if attrmap is None:
+            attrmap = {}
+        self.attr_map = attrmap
 
     def generate(self, union, args=None, varmap=None):
         """return SQL queries and a variable dictionnary from a RQL syntax tree
@@ -853,27 +856,30 @@
                                        relation.r_type)
         return '%s%s' % (lhssql, relation.children[1].accept(self, contextrels))
 
-    def _visit_attribute_relation(self, relation):
+    def _visit_attribute_relation(self, rel):
         """generate SQL for an attribute relation"""
-        lhs, rhs = relation.get_parts()
+        lhs, rhs = rel.get_parts()
         rhssql = rhs.accept(self)
         table = self._var_table(lhs.variable)
         if table is None:
-            assert relation.r_type == 'eid'
+            assert rel.r_type == 'eid'
             lhssql = lhs.accept(self)
         else:
             try:
-                lhssql = self._varmap['%s.%s' % (lhs.name, relation.r_type)]
+                lhssql = self._varmap['%s.%s' % (lhs.name, rel.r_type)]
             except KeyError:
-                if relation.r_type == 'eid':
+                mapkey = '%s.%s' % (self._state.solution[lhs.name], rel.r_type)
+                if mapkey in self.attr_map:
+                    lhssql = self.attr_map[mapkey](self, lhs.variable, rel)
+                elif rel.r_type == 'eid':
                     lhssql = lhs.variable._q_sql
                 else:
-                    lhssql = '%s.%s%s' % (table, SQL_PREFIX, relation.r_type)
+                    lhssql = '%s.%s%s' % (table, SQL_PREFIX, rel.r_type)
         try:
-            if relation._q_needcast == 'TODAY':
+            if rel._q_needcast == 'TODAY':
                 sql = 'DATE(%s)%s' % (lhssql, rhssql)
             # XXX which cast function should be used
-            #elif relation._q_needcast == 'NOW':
+            #elif rel._q_needcast == 'NOW':
             #    sql = 'TIMESTAMP(%s)%s' % (lhssql, rhssql)
             else:
                 sql = '%s%s' % (lhssql, rhssql)
@@ -884,15 +890,15 @@
         else:
             return sql
 
-    def _visit_has_text_relation(self, relation):
+    def _visit_has_text_relation(self, rel):
         """generate SQL for a has_text relation"""
-        lhs, rhs = relation.get_parts()
+        lhs, rhs = rel.get_parts()
         const = rhs.children[0]
-        alias = self._fti_table(relation)
+        alias = self._fti_table(rel)
         jointo = lhs.accept(self)
         restriction = ''
         lhsvar = lhs.variable
-        me_is_principal = lhsvar.stinfo.get('principal') is relation
+        me_is_principal = lhsvar.stinfo.get('principal') is rel
         if me_is_principal:
             if not lhsvar.stinfo['typerels']:
                 # the variable is using the fti table, no join needed
@@ -908,8 +914,8 @@
                 else:
                     etypes = ','.join("'%s'" % etype for etype in lhsvar.stinfo['possibletypes'])
                     restriction = " AND %s.type IN (%s)" % (ealias, etypes)
-        if isinstance(relation.parent, Not):
-            self._state.done.add(relation.parent)
+        if isinstance(rel.parent, Not):
+            self._state.done.add(rel.parent)
             not_ = True
         else:
             not_ = False
@@ -1117,6 +1123,9 @@
         if isinstance(linkedvar, ColumnAlias):
             raise BadRQLQuery('variable %s should be selected by the subquery'
                               % variable.name)
+        mapkey = '%s.%s' % (self._state.solution[linkedvar.name], rel.r_type)
+        if mapkey in self.attr_map:
+            return self.attr_map[mapkey](self, linkedvar, rel)
         try:
             sql = self._varmap['%s.%s' % (linkedvar.name, rel.r_type)]
         except KeyError: