diff -r 000000000000 -r b97547f5f1fa goa/rqlinterpreter.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/goa/rqlinterpreter.py Wed Nov 05 15:52:50 2008 +0100 @@ -0,0 +1,682 @@ +"""provide a minimal RQL support for google appengine dbmodel + +:organization: Logilab +:copyright: 2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr +""" +__docformat__ = "restructuredtext en" + +from mx.DateTime import DateTimeType, DateTimeDeltaType +from datetime import datetime + +from rql import RQLHelper, nodes +from logilab.common.compat import any + +from cubicweb import Binary +from cubicweb.rset import ResultSet +from cubicweb.goa import mx2datetime, datetime2mx +from cubicweb.server import SQL_CONNECT_HOOKS + +from google.appengine.api.datastore import Key, Get, Query, Entity +from google.appengine.api.datastore_types import Text, Blob +from google.appengine.api.datastore_errors import EntityNotFoundError, BadKeyError + + +def etype_from_key(key): + return Key(key).kind() + +def poss_var_types(myvar, ovar, kind, solutions): + return frozenset(etypes[myvar] for etypes in solutions + if etypes[ovar] == kind) + +def expand_result(results, result, myvar, values, dsget=None): + values = map(dsget, values) + if values: + result[myvar] = values.pop(0) + for value in values: + newresult = result.copy() + newresult[myvar] = value + results.append(newresult) + else: + results.remove(result) + +def _resolve(restrictions, solutions, fixed): + varname = restrictions[0].searched_var + objs = [] + for etype in frozenset(etypes[varname] for etypes in solutions): + gqlargs = {} + query = Query(etype) + for restriction in restrictions: + restriction.fill_query(fixed, query) + pobjs = query.Run() + if varname in fixed: + value = fixed[varname] + objs += (x for x in pobjs if x == value) + else: + objs += pobjs + if varname in fixed and not objs: + raise EidMismatch(varname, value) + return objs + +def _resolve_not(restrictions, solutions, fixed): + restr = restrictions[0] + constrvarname = restr.constraint_var + if len(restrictions) > 1 or not constrvarname in fixed: + raise NotImplementedError() + varname = restr.searched_var + objs = [] + for etype in frozenset(etypes[varname] for etypes in solutions): + gqlargs = {} + for operator in ('<', '>'): + query = Query(etype) + restr.fill_query(fixed, query, operator) + objs += query.Run() + return objs + +def _print_results(rlist): + return '[%s]' % ', '.join(_print_result(r) for r in rlist) + +def _print_result(rdict): + string = [] + for k, v in rdict.iteritems(): + if isinstance(v, Entity): + string.append('%s: %s' % (k, v.key()))#_print_result(v))) + elif isinstance(v, list): + string.append('%s: [%s]' % (k, ', '.join(str(i) for i in v))) + else: + string.append('%s: %s' % (k, v)) + return '{%s}' % ', '.join(string) + + +class EidMismatch(Exception): + def __init__(self, varname, value): + self.varname = varname + self.value = value + + +class Restriction(object): + supported_operators = ('=',) + def __init__(self, rel): + operator = rel.children[1].operator + if not operator in self.supported_operators: + raise NotImplementedError('unsupported operator') + self.rel = rel + self.operator = operator + self.rtype = rel.r_type + self.var = rel.children[0] + + def __repr__(self): + return '<%s for %s>' % (self.__class__.__name__, self.rel) + + @property + def rhs(self): + return self.rel.children[1].children[0] + + +class MultipleRestriction(object): + def __init__(self, restrictions): + self.restrictions = restrictions + + def resolve(self, solutions, fixed): + return _resolve(self.restrictions, solutions, fixed) + + +class VariableSelection(Restriction): + def __init__(self, rel, dsget, prefix='s'): + Restriction.__init__(self, rel) + self._dsget = dsget + self._not = self.rel.neged(strict=True) + self._prefix = prefix + '_' + + def __repr__(self): + return '<%s%s for %s>' % (self._prefix[0], self.__class__.__name__, self.rel) + + @property + def searched_var(self): + if self._prefix == 's_': + return self.var.name + return self.rhs.name + + @property + def constraint_var(self): + if self._prefix == 's_': + return self.rhs.name + return self.var.name + + def _possible_values(self, myvar, ovar, entity, solutions, dsprefix): + if self.rtype == 'identity': + return (entity.key(),) + value = entity.get(dsprefix + self.rtype) + if value is None: + return [] + if not isinstance(value, list): + value = [value] + vartypes = poss_var_types(myvar, ovar, entity.kind(), solutions) + return (v for v in value if v.kind() in vartypes) + + def complete_and_filter(self, solutions, results): + myvar = self.rhs.name + ovar = self.var.name + rtype = self.rtype + if self.schema.rschema(rtype).is_final(): + # should be detected by rql.stcheck: "Any C WHERE NOT X attr C" doesn't make sense + #if self._not: + # raise NotImplementedError() + for result in results: + result[myvar] = result[ovar].get('s_'+rtype) + elif self.var.name in results[0]: + if self.rhs.name in results[0]: + self.filter(solutions, results) + else: + if self._not: + raise NotImplementedError() + for result in results[:]: + values = self._possible_values(myvar, ovar, result[ovar], + solutions, 's_') + expand_result(results, result, myvar, values, self._dsget) + else: + assert self.rhs.name in results[0] + self.object_complete_and_filter(solutions, results) + + def filter(self, solutions, results): + myvar = self.rhs.name + ovar = self.var.name + newsols = {} + for result in results[:]: + entity = result[ovar] + key = entity.key() + if not key in newsols: + values = self._possible_values(myvar, ovar, entity, solutions, 's_') + newsols[key] = frozenset(v for v in values) + if self._not: + if result[myvar].key() in newsols[key]: + results.remove(result) + elif not result[myvar].key() in newsols[key]: + results.remove(result) + + def object_complete_and_filter(self, solutions, results): + if self._not: + raise NotImplementedError() + myvar = self.var.name + ovar = self.rhs.name + for result in results[:]: + values = self._possible_values(myvar, ovar, result[ovar], + solutions, 'o_') + expand_result(results, result, myvar, values, self._dsget) + + +class EidRestriction(Restriction): + def __init__(self, rel, dsget): + Restriction.__init__(self, rel) + self._dsget = dsget + + def resolve(self, kwargs): + value = self.rel.children[1].children[0].eval(kwargs) + return self._dsget(value) + + +class RelationRestriction(VariableSelection): + + def _get_value(self, fixed): + return fixed[self.constraint_var].key() + + def fill_query(self, fixed, query, operator=None): + restr = '%s%s %s' % (self._prefix, self.rtype, operator or self.operator) + query[restr] = self._get_value(fixed) + + def resolve(self, solutions, fixed): + if self.rtype == 'identity': + if self._not: + raise NotImplementedError() + return [fixed[self.constraint_var]] + if self._not: + return _resolve_not([self], solutions, fixed) + return _resolve([self], solutions, fixed) + + +class NotRelationRestriction(RelationRestriction): + + def _get_value(self, fixed): + return None + + def resolve(self, solutions, fixed): + if self.rtype == 'identity': + raise NotImplementedError() + return _resolve([self], solutions, fixed) + + +class AttributeRestriction(RelationRestriction): + supported_operators = ('=', '>', '>=', '<', '<=', 'ILIKE') + def __init__(self, rel, kwargs): + RelationRestriction.__init__(self, rel, None) + value = self.rhs.eval(kwargs) + if isinstance(value, (DateTimeType, DateTimeDeltaType)): + #yamstype = self.schema.rschema(self.rtype).objects()[0] + value = mx2datetime(value, 'Datetime') + self.value = value + if self.operator == 'ILIKE': + if value.startswith('%'): + raise NotImplementedError('LIKE is only supported for prefix search') + if not value.endswith('%'): + raise NotImplementedError('LIKE is only supported for prefix search') + self.operator = '>' + self.value = value[:-1] + + def complete_and_filter(self, solutions, results): + # check lhs var first in case this is a restriction + assert self._not + myvar, rtype, value = self.var.name, self.rtype, self.value + for result in results[:]: + if result[myvar].get('s_'+rtype) == value: + results.remove(result) + + def _get_value(self, fixed): + return self.value + + +class DateAttributeRestriction(AttributeRestriction): + """just a thin layer on top af `AttributeRestriction` that + tries to convert date strings such as in : + Any X WHERE X creation_date >= '2008-03-04' + """ + def __init__(self, rel, kwargs): + super(DateAttributeRestriction, self).__init__(rel, kwargs) + if isinstance(self.value, basestring): +# try: + self.value = datetime.strptime(self.value, '%Y-%m-%d') +# except Exception, exc: +# from logging import error +# error('unable to parse date %s with format %%Y-%%m-%%d (exc=%s)', value, exc) + + +class AttributeInRestriction(AttributeRestriction): + def __init__(self, rel, kwargs): + RelationRestriction.__init__(self, rel, None) + values = [] + for c in self.rel.children[1].iget_nodes(nodes.Constant): + value = c.eval(kwargs) + if isinstance(value, (DateTimeType, DateTimeDeltaType)): + #yamstype = self.schema.rschema(self.rtype).objects()[0] + value = mx2datetime(value, 'Datetime') + values.append(value) + self.value = values + + @property + def operator(self): + return 'in' + + +class TypeRestriction(AttributeRestriction): + def __init__(self, var): + self.var = var + + def __repr__(self): + return '<%s for %s>' % (self.__class__.__name__, self.var) + + def resolve(self, solutions, fixed): + objs = [] + for etype in frozenset(etypes[self.var.name] for etypes in solutions): + objs += Query(etype).Run() + return objs + + +def append_result(res, descr, i, j, value, etype): + if value is not None: + if etype in ('Date', 'Datetime', 'Time'): + value = datetime2mx(value, etype) + elif isinstance(value, Text): + value = unicode(value) + elif isinstance(value, Blob): + value = Binary(str(value)) + if j == 0: + res.append([value]) + descr.append([etype]) + else: + res[i].append(value) + descr[i].append(etype) + + +class ValueResolver(object): + def __init__(self, functions, args, term): + self.functions = functions + self.args = args + self.term = term + self._solution = self.term.stmt.solutions[0] + + def compute(self, result): + """return (entity type, value) to which self.term is evaluated according + to the given result dictionnary and to query arguments (self.args) + """ + return self.term.accept(self, result) + + def visit_function(self, node, result): + args = tuple(n.accept(self, result)[1] for n in node.children) + value = self.functions[node.name](*args) + return node.get_type(self._solution, self.args), value + + def visit_variableref(self, node, result): + value = result[node.name] + try: + etype = value.kind() + value = str(value.key()) + except AttributeError: + etype = self._solution[node.name] + return etype, value + + def visit_constant(self, node, result): + return node.get_type(kwargs=self.args), node.eval(self.args) + + +class RQLInterpreter(object): + """algorithm: + 1. visit the restriction clauses and collect restriction for each subject + of a relation. Different restriction types are: + * EidRestriction + * AttributeRestriction + * RelationRestriction + * VariableSelection (not really a restriction) + -> dictionary {: [restriction...], ...} + 2. resolve eid restrictions + 3. for each select in union: + for each solution in select'solutions: + 1. resolve variables which have attribute restriction + 2. resolve relation restriction + 3. resolve selection and add to global results + """ + def __init__(self, schema): + self.schema = schema + Restriction.schema = schema # yalta! + self.rqlhelper = RQLHelper(schema, {'eid': etype_from_key}) + self._stored_proc = {'LOWER': lambda x: x.lower(), + 'UPPER': lambda x: x.upper()} + for cb in SQL_CONNECT_HOOKS.get('sqlite', []): + cb(self) + + # emulate sqlite connection interface so we can reuse stored procedures + def create_function(self, name, nbargs, func): + self._stored_proc[name] = func + + def create_aggregate(self, name, nbargs, func): + self._stored_proc[name] = func + + + def execute(self, operation, parameters=None, eid_key=None, build_descr=True): + rqlst = self.rqlhelper.parse(operation, annotate=True) + try: + self.rqlhelper.compute_solutions(rqlst, kwargs=parameters) + except BadKeyError: + results, description = [], [] + else: + results, description = self.interpret(rqlst, parameters) + return ResultSet(results, operation, parameters, description, rqlst=rqlst) + + def interpret(self, node, kwargs, dsget=None): + if dsget is None: + self._dsget = Get + else: + self._dsget = dsget + try: + return node.accept(self, kwargs) + except NotImplementedError: + self.critical('support for query not implemented: %s', node) + raise + + def visit_union(self, node, kwargs): + results, description = [], [] + extra = {'kwargs': kwargs} + for child in node.children: + pres, pdescr = self.visit_select(child, extra) + results += pres + description += pdescr + return results, description + + def visit_select(self, node, extra): + constraints = {} + if node.where is not None: + node.where.accept(self, constraints, extra) + fixed, toresolve, postresolve, postfilters = {}, {}, {}, [] + # extract NOT filters + for vname, restrictions in constraints.items(): + for restr in restrictions[:]: + if isinstance(restr, AttributeRestriction) and restr._not: + postfilters.append(restr) + restrictions.remove(restr) + if not restrictions: + del constraints[vname] + # add TypeRestriction for variable which have no restrictions at all + for varname, var in node.defined_vars.iteritems(): + if not varname in constraints: + constraints[varname] = [TypeRestriction(var)] + #print node, constraints + # compute eid restrictions + kwargs = extra['kwargs'] + for varname, restrictions in constraints.iteritems(): + for restr in restrictions[:]: + if isinstance(restr, EidRestriction): + assert not varname in fixed + try: + value = restr.resolve(kwargs) + fixed[varname] = value + except EntityNotFoundError: + return [], [] + restrictions.remove(restr) + #print 'fixed', fixed.keys() + # combine remaining restrictions + for varname, restrictions in constraints.iteritems(): + for restr in restrictions: + if isinstance(restr, AttributeRestriction): + toresolve.setdefault(varname, []).append(restr) + elif isinstance(restr, NotRelationRestriction) or ( + isinstance(restr, RelationRestriction) and + not restr.searched_var in fixed and restr.constraint_var in fixed): + toresolve.setdefault(varname, []).append(restr) + else: + postresolve.setdefault(varname, []).append(restr) + try: + if len(toresolve[varname]) > 1: + toresolve[varname] = MultipleRestriction(toresolve[varname]) + else: + toresolve[varname] = toresolve[varname][0] + except KeyError: + pass + #print 'toresolve %s' % toresolve + #print 'postresolve %s' % postresolve + # resolve additional restrictions + if fixed: + partres = [fixed.copy()] + else: + partres = [] + for varname, restr in toresolve.iteritems(): + varpartres = partres[:] + try: + values = tuple(restr.resolve(node.solutions, fixed)) + except EidMismatch, ex: + varname = ex.varname + value = ex.value + partres = [res for res in partres if res[varname] != value] + if partres: + continue + # some join failed, no possible results + return [], [] + if not values: + # some join failed, no possible results + return [], [] + if not varpartres: + # init results + for value in values: + partres.append({varname: value}) + elif not varname in partres[0]: + # cartesian product + for res in partres: + res[varname] = values[0] + for res in partres[:]: + for value in values[1:]: + res = res.copy() + res[varname] = value + partres.append(res) + else: + # union + for res in varpartres: + for value in values: + res = res.copy() + res[varname] = value + partres.append(res) + #print 'partres', len(partres) + #print partres + # Note: don't check for empty partres since constant selection may still + # produce result at this point + # sort to get RelationRestriction before AttributeSelection + restrictions = sorted((restr for restrictions in postresolve.itervalues() + for restr in restrictions), + key=lambda x: not isinstance(x, RelationRestriction)) + # compute stuff not doable in the previous step using datastore queries + for restr in restrictions + postfilters: + restr.complete_and_filter(node.solutions, partres) + if not partres: + # some join failed, no possible results + return [], [] + if extra.pop('has_exists', False): + # remove potential duplicates introduced by exists + toremovevars = [v.name for v in node.defined_vars.itervalues() + if not v.scope is node] + if toremovevars: + newpartres = [] + for result in partres: + for var in toremovevars: + del result[var] + if not result in newpartres: + newpartres.append(result) + if not newpartres: + # some join failed, no possible results + return [], [] + partres = newpartres + if node.orderby: + for sortterm in reversed(node.orderby): + resolver = ValueResolver(self._stored_proc, kwargs, sortterm.term) + partres.sort(reverse=not sortterm.asc, + key=lambda x: resolver.compute(x)[1]) + if partres: + if node.offset: + partres = partres[node.offset:] + if node.limit: + partres = partres[:node.limit] + if not partres: + return [], [] + #print 'completed partres', _print_results(partres) + # compute results + res, descr = [], [] + for j, term in enumerate(node.selection): + resolver = ValueResolver(self._stored_proc, kwargs, term) + if not partres: + etype, value = resolver.compute({}) + # only constant selected + if not res: + res.append([]) + descr.append([]) + res[0].append(value) + descr[0].append(etype) + else: + for i, sol in enumerate(partres): + etype, value = resolver.compute(sol) + append_result(res, descr, i, j, value, etype) + #print '--------->', res + return res, descr + + def visit_and(self, node, constraints, extra): + for child in node.children: + child.accept(self, constraints, extra) + def visit_exists(self, node, constraints, extra): + extra['has_exists'] = True + self.visit_and(node, constraints, extra) + + def visit_not(self, node, constraints, extra): + for child in node.children: + child.accept(self, constraints, extra) + try: + extra.pop(node) + except KeyError: + raise NotImplementedError() + + def visit_relation(self, node, constraints, extra): + if node.is_types_restriction(): + return + rschema = self.schema.rschema(node.r_type) + neged = node.neged(strict=True) + if neged: + # ok, we *may* process this Not node (not implemented error will be + # raised later if we can't) + extra[node.parent] = True + if rschema.is_final(): + self._visit_final_relation(rschema, node, constraints, extra) + elif neged: + self._visit_non_final_neged_relation(rschema, node, constraints) + else: + self._visit_non_final_relation(rschema, node, constraints) + + def _visit_non_final_relation(self, rschema, node, constraints, not_=False): + lhs, rhs = node.get_variable_parts() + for v1, v2, prefix in ((lhs, rhs, 's'), (rhs, lhs, 'o')): + #if not_: + nbrels = len(v2.variable.stinfo['relations']) + #else: + # nbrels = len(v2.variable.stinfo['relations']) - len(v2.variable.stinfo['uidrels']) + if nbrels > 1: + constraints.setdefault(v1.name, []).append( + RelationRestriction(node, self._dsget, prefix)) + # just init an empty list for v2 variable to avoid a + # TypeRestriction being added for it + constraints.setdefault(v2.name, []) + break + else: + constraints.setdefault(rhs.name, []).append( + VariableSelection(node, self._dsget, 's')) + + def _visit_non_final_neged_relation(self, rschema, node, constraints): + lhs, rhs = node.get_variable_parts() + for v1, v2, prefix in ((lhs, rhs, 's'), (rhs, lhs, 'o')): + stinfo = v2.variable.stinfo + if not stinfo['selected'] and len(stinfo['relations']) == 1: + constraints.setdefault(v1.name, []).append( + NotRelationRestriction(node, self._dsget, prefix)) + constraints.setdefault(v2.name, []) + break + else: + self._visit_non_final_relation(rschema, node, constraints, True) + + def _visit_final_relation(self, rschema, node, constraints, extra): + varname = node.children[0].name + if rschema.type == 'eid': + constraints.setdefault(varname, []).append( + EidRestriction(node, self._dsget)) + else: + rhs = node.children[1].children[0] + if isinstance(rhs, nodes.VariableRef): + constraints.setdefault(rhs.name, []).append( + VariableSelection(node, self._dsget)) + elif isinstance(rhs, nodes.Constant): + if rschema.objects()[0] in ('Datetime', 'Date'): # XXX + constraints.setdefault(varname, []).append( + DateAttributeRestriction(node, extra['kwargs'])) + else: + constraints.setdefault(varname, []).append( + AttributeRestriction(node, extra['kwargs'])) + elif isinstance(rhs, nodes.Function) and rhs.name == 'IN': + constraints.setdefault(varname, []).append( + AttributeInRestriction(node, extra['kwargs'])) + else: + raise NotImplementedError() + + def _not_implemented(self, *args, **kwargs): + raise NotImplementedError() + + visit_or = _not_implemented + # shouldn't occurs + visit_set = _not_implemented + visit_insert = _not_implemented + visit_delete = _not_implemented + + +from logging import getLogger +from cubicweb import set_log_methods +set_log_methods(RQLInterpreter, getLogger('cubicweb.goa.rqlinterpreter')) +set_log_methods(Restriction, getLogger('cubicweb.goa.rqlinterpreter'))