web/request.py
author Adrien Di Mascio <Adrien.DiMascio@logilab.fr>
Wed, 12 Nov 2008 16:07:40 +0100
changeset 28 9b7067bfaa15
parent 0 b97547f5f1fa
child 495 f8b1edfe9621
permissions -rw-r--r--
introduce html_headers.on_load() method as a shortcut for add_post_inline_script('''jQuery(document).ready(...''')

"""abstract class for http request

:organization: Logilab
:copyright: 2001-2008 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
:contact: http://www.logilab.fr/ -- mailto:contact@logilab.fr
"""
__docformat__ = "restructuredtext en"

import Cookie
import sha
import time
import random
import base64
from urlparse import urlsplit
from itertools import count

from rql.utils import rqlvar_maker

from logilab.common.decorators import cached

# XXX move _MARKER here once AppObject.external_resource has been removed
from cubicweb.dbapi import DBAPIRequest
from cubicweb.common.appobject import _MARKER 
from cubicweb.common.mail import header
from cubicweb.common.uilib import remove_html_tags
from cubicweb.common.utils import SizeConstrainedList, HTMLHead
from cubicweb.web import (INTERNAL_FIELD_VALUE, LOGGER, NothingToEdit, RequestError,
                       StatusResponse)


def list_form_param(form, param, pop=False):
    """get param from form parameters and return its value as a list,
    skipping internal markers if any

    * if the parameter isn't defined, return an empty list
    * if the parameter is a single (unicode) value, return a list
      containing that value
    * if the parameter is already a list or tuple, just skip internal
      markers

    if pop is True, the parameter is removed from the form dictionnary
    """
    if pop:
        try:
            value = form.pop(param)
        except KeyError:
            return []
    else:
        value = form.get(param, ())
    if value is None:
        value = ()
    elif not isinstance(value, (list, tuple)):
        value = [value]
    return [v for v in value if v != INTERNAL_FIELD_VALUE]



class CubicWebRequestBase(DBAPIRequest):
    """abstract HTTP request, should be extended according to the HTTP backend"""    
    
    def __init__(self, vreg, https, form=None):
        super(CubicWebRequestBase, self).__init__(vreg)
        self.message = None
        self.authmode = vreg.config['auth-mode']
        self.https = https
        # raw html headers that can be added from any view
        self.html_headers = HTMLHead()
        # form parameters
        self.setup_params(form)
        # dictionnary that may be used to store request data that has to be
        # shared among various components used to publish the request (views,
        # controller, application...)
        self.data = {}
        # search state: 'normal' or 'linksearch' (eg searching for an object
        # to create a relation with another)
        self.search_state = ('normal',) 
        # tabindex generator
        self.tabindexgen = count()
        self.next_tabindex = self.tabindexgen.next
        # page id, set by htmlheader template
        self.pageid = None
        self.varmaker = rqlvar_maker()
        self.datadir_url = self._datadir_url()

    def set_connection(self, cnx, user=None):
        """method called by the session handler when the user is authenticated
        or an anonymous connection is open
        """
        super(CubicWebRequestBase, self).set_connection(cnx, user)
        # get request language:
        vreg = self.vreg
        if self.user:
            try:
                # 1. user specified language
                lang = vreg.typed_value('ui.language',
                                        self.user.properties['ui.language'])
                self.set_language(lang)
                return
            except KeyError, ex:
                pass
        if vreg.config['language-negociation']:
            # 2. http negociated language
            for lang in self.header_accept_language():
                if lang in self.translations:
                    self.set_language(lang)
                    return
        # 3. default language
        self.set_default_language(vreg)
            
    def set_language(self, lang):
        self._ = self.__ = self.translations[lang]
        self.lang = lang
        self.debug('request language: %s', lang)
        
    # input form parameters management ########################################
    
    # common form parameters which should be protected against html values
    # XXX can't add 'eid' for instance since it may be multivalued
    # dont put rql as well, if query contains < and > it will be corrupted!
    no_script_form_params = set(('vid', 
                                 'etype', 
                                 'vtitle', 'title',
                                 '__message',
                                 '__redirectvid', '__redirectrql'))
        
    def setup_params(self, params):
        """WARNING: we're intentionaly leaving INTERNAL_FIELD_VALUE here

        subclasses should overrides to 
        """
        if params is None:
            params = {}
        self.form = params
        encoding = self.encoding
        for k, v in params.items():
            if isinstance(v, (tuple, list)):
                v = [unicode(x, encoding) for x in v]
                if len(v) == 1:
                    v = v[0]
            if k in self.no_script_form_params:
                v = self.no_script_form_param(k, value=v)
            if isinstance(v, str):
                v = unicode(v, encoding)
            if k == '__message':
                self.set_message(v)
                del self.form[k]
            else:
                self.form[k] = v
    
    def no_script_form_param(self, param, default=None, value=None):
        """ensure there is no script in a user form param

        by default return a cleaned string instead of raising a security
        exception

        this method should be called on every user input (form at least) fields
        that are at some point inserted in a generated html page to protect
        against script kiddies
        """
        if value is None:
            value = self.form.get(param, default)
        if not value is default and value:
            # safety belt for strange urls like http://...?vtitle=yo&vtitle=yo
            if isinstance(value, (list, tuple)):
                self.error('no_script_form_param got a list (%s). Who generated the URL ?',
                           repr(value))
                value = value[0]
            return remove_html_tags(value)
        return value
        
    def list_form_param(self, param, form=None, pop=False):
        """get param from form parameters and return its value as a list,
        skipping internal markers if any
        
        * if the parameter isn't defined, return an empty list
        * if the parameter is a single (unicode) value, return a list
          containing that value
        * if the parameter is already a list or tuple, just skip internal
          markers

        if pop is True, the parameter is removed from the form dictionnary
        """
        if form is None:
            form = self.form
        return list_form_param(form, param, pop)            
    

    def reset_headers(self):
        """used by AutomaticWebTest to clear html headers between tests on
        the same resultset
        """
        self.html_headers = HTMLHead()
        return self

    # web state helpers #######################################################
    
    def set_message(self, msg):
        assert isinstance(msg, unicode)
        self.message = msg
    
    def update_search_state(self):
        """update the current search state"""
        searchstate = self.form.get('__mode')
        if not searchstate:
            searchstate = self.get_session_data('search_state', 'normal')
        self.set_search_state(searchstate)

    def set_search_state(self, searchstate):
        """set a new search state"""
        if searchstate is None or searchstate == 'normal':
            self.search_state = (searchstate or 'normal',)
        else:
            self.search_state = ('linksearch', searchstate.split(':'))
            assert len(self.search_state[-1]) == 4
        self.set_session_data('search_state', searchstate)

    def update_breadcrumbs(self):
        """stores the last visisted page in session data"""
        searchstate = self.get_session_data('search_state')
        if searchstate == 'normal':
            breadcrumbs = self.get_session_data('breadcrumbs', None)
            if breadcrumbs is None:
                breadcrumbs = SizeConstrainedList(10)
                self.set_session_data('breadcrumbs', breadcrumbs)
            breadcrumbs.append(self.url())

    def last_visited_page(self):
        breadcrumbs = self.get_session_data('breadcrumbs', None)
        if breadcrumbs:
            return breadcrumbs.pop()
        return self.base_url()

    def register_onetime_callback(self, func, *args):
        cbname = 'cb_%s' % (
            sha.sha('%s%s%s%s' % (time.time(), func.__name__,
                                  random.random(), 
                                  self.user.login)).hexdigest())
        def _cb(req):
            try:
                ret = func(req, *args)
            except TypeError:
                from warnings import warn
                warn('user callback should now take request as argument')
                ret = func(*args)            
            self.unregister_callback(self.pageid, cbname)
            return ret
        self.set_page_data(cbname, _cb)
        return cbname
    
    def unregister_callback(self, pageid, cbname):
        assert pageid is not None
        assert cbname.startswith('cb_')
        self.info('unregistering callback %s for pageid %s', cbname, pageid)
        self.del_page_data(cbname)

    def clear_user_callbacks(self):
        if self.cnx is not None:
            sessdata = self.session_data()
            callbacks = [key for key in sessdata if key.startswith('cb_')]
            for callback in callbacks:
                self.del_session_data(callback)
    
    # web edition helpers #####################################################
    
    @cached # so it's writed only once
    def fckeditor_config(self):
        self.html_headers.define_var('fcklang', self.lang)
        self.html_headers.define_var('fckconfigpath',
                                     self.build_url('data/fckcwconfig.js'))

    def edited_eids(self, withtype=False):
        """return a list of edited eids"""
        yielded = False
        # warning: use .keys since the caller may change `form`
        form = self.form
        try:
            eids = form['eid']
        except KeyError:
            raise NothingToEdit(None, {None: self._('no selected entities')})
        if isinstance(eids, basestring):
            eids = (eids,)
        for peid in eids:
            if withtype:
                typekey = '__type:%s' % peid
                assert typekey in form, 'no entity type specified'
                yield peid, form[typekey]
            else:
                yield peid
            yielded = True
        if not yielded:
            raise NothingToEdit(None, {None: self._('no selected entities')})

    # minparams=3 by default: at least eid, __type, and some params to change
    def extract_entity_params(self, eid, minparams=3):
        """extract form parameters relative to the given eid"""
        params = {}
        eid = str(eid)
        form = self.form
        for param in form:
            try:
                name, peid = param.split(':', 1)
            except ValueError:
                if not param.startswith('__') and param != "eid":
                    self.warning('param %s mis-formatted', param)
                continue
            if peid == eid:
                value = form[param]
                if value == INTERNAL_FIELD_VALUE:
                    value = None
                params[name] = value
        params['eid'] = eid
        if len(params) < minparams:
            print eid, params
            raise RequestError(self._('missing parameters for entity %s') % eid)
        return params
    
    def get_pending_operations(self, entity, relname, role):
        operations = {'insert' : [], 'delete' : []}
        for optype in ('insert', 'delete'):
            data = self.get_session_data('pending_%s' % optype) or ()
            for eidfrom, rel, eidto in data:
                if relname == rel:
                    if role == 'subject' and entity.eid == eidfrom:
                        operations[optype].append(eidto)
                    if role == 'object' and entity.eid == eidto:
                        operations[optype].append(eidfrom)
        return operations
    
    def get_pending_inserts(self, eid=None):
        """shortcut to access req's pending_insert entry

        This is where are stored relations being added while editing
        an entity. This used to be stored in a temporary cookie.
        """
        pending = self.get_session_data('pending_insert') or ()
        return ['%s:%s:%s' % (subj, rel, obj) for subj, rel, obj in pending
                if eid is None or eid in (subj, obj)]

    def get_pending_deletes(self, eid=None):
        """shortcut to access req's pending_delete entry

        This is where are stored relations being removed while editing
        an entity. This used to be stored in a temporary cookie.
        """
        pending = self.get_session_data('pending_delete') or ()
        return ['%s:%s:%s' % (subj, rel, obj) for subj, rel, obj in pending
                if eid is None or eid in (subj, obj)]

    def remove_pending_operations(self):
        """shortcut to clear req's pending_{delete,insert} entries

        This is needed when the edition is completed (whether it's validated
        or cancelled)
        """
        self.del_session_data('pending_insert')
        self.del_session_data('pending_delete')

    def cancel_edition(self, errorurl):
        """remove pending operations and `errorurl`'s specific stored data
        """
        self.del_session_data(errorurl)
        self.remove_pending_operations()
    
    # high level methods for HTTP headers management ##########################

    # must be cached since login/password are popped from the form dictionary
    # and this method may be called multiple times during authentication
    @cached
    def get_authorization(self):
        """Parse and return the Authorization header"""
        if self.authmode == "cookie":
            try:
                user = self.form.pop("__login")
                passwd = self.form.pop("__password", '')
                return user, passwd.encode('UTF8')
            except KeyError:
                self.debug('no login/password in form params')
                return None, None
        else:
            return self.header_authorization()
    
    def get_cookie(self):
        """retrieve request cookies, returns an empty cookie if not found"""
        try:
            return Cookie.SimpleCookie(self.get_header('Cookie'))
        except KeyError:
            return Cookie.SimpleCookie()

    def set_cookie(self, cookie, key, maxage=300):
        """set / update a cookie key

        by default, cookie will be available for the next 5 minutes.
        Give maxage = None to have a "session" cookie expiring when the
        client close its browser
        """
        morsel = cookie[key]
        if maxage is not None:
            morsel['Max-Age'] = maxage
        # make sure cookie is set on the correct path
        morsel['path'] = self.base_url_path()
        self.add_header('Set-Cookie', morsel.OutputString())

    def remove_cookie(self, cookie, key):
        """remove a cookie by expiring it"""
        morsel = cookie[key]
        morsel['Max-Age'] = 0
        # The only way to set up cookie age for IE is to use an old "expired"
        # syntax. IE doesn't support Max-Age there is no library support for
        # managing 
        # ===> Do _NOT_ comment this line :
        morsel['expires'] = 'Thu, 01-Jan-1970 00:00:00 GMT'
        self.add_header('Set-Cookie', morsel.OutputString())

    def set_content_type(self, content_type, filename=None, encoding=None):
        """set output content type for this request. An optional filename
        may be given
        """
        if content_type.startswith('text/'):
            content_type += ';charset=' + (encoding or self.encoding)
        self.set_header('content-type', content_type)
        if filename:
            if isinstance(filename, unicode):
                filename = header(filename).encode()
            self.set_header('content-disposition', 'inline; filename=%s'
                            % filename)

    # high level methods for HTML headers management ##########################

    def add_js(self, jsfiles, localfile=True):
        """specify a list of JS files to include in the HTML headers
        :param jsfiles: a JS filename or a list of JS filenames
        :param localfile: if True, the default data dir prefix is added to the
                          JS filename
        """
        if isinstance(jsfiles, basestring):
            jsfiles = (jsfiles,)
        for jsfile in jsfiles:
            if localfile:
                jsfile = self.datadir_url + jsfile
            self.html_headers.add_js(jsfile)

    def add_css(self, cssfiles, media=u'all', localfile=True, ieonly=False):
        """specify a CSS file to include in the HTML headers
        :param cssfiles: a CSS filename or a list of CSS filenames
        :param media: the CSS's media if necessary
        :param localfile: if True, the default data dir prefix is added to the
                          CSS filename
        """
        if isinstance(cssfiles, basestring):
            cssfiles = (cssfiles,)
        if ieonly:
            if self.ie_browser():
                add_css = self.html_headers.add_ie_css
            else:
                return # no need to do anything on non IE browsers
        else:
            add_css = self.html_headers.add_css
        for cssfile in cssfiles:
            if localfile:
                cssfile = self.datadir_url + cssfile
            add_css(cssfile, media)
    
    # urls/path management ####################################################
    
    def url(self, includeparams=True):
        """return currently accessed url"""
        return self.base_url() + self.relative_path(includeparams)

    def _datadir_url(self):
        """return url of the application's data directory"""
        return self.base_url() + 'data%s/' % self.vreg.config.instance_md5_version()
    
    def selected(self, url):
        """return True if the url is equivalent to currently accessed url"""
        reqpath = self.relative_path().lower()
        baselen = len(self.base_url())
        return (reqpath == url[baselen:].lower())

    def base_url_prepend_host(self, hostname):
        protocol, roothost = urlsplit(self.base_url())[:2]
        if roothost.startswith('www.'):
            roothost = roothost[4:]
        return '%s://%s.%s' % (protocol, hostname, roothost)

    def base_url_path(self):
        """returns the absolute path of the base url"""
        return urlsplit(self.base_url())[2]
        
    @cached
    def from_controller(self):
        """return the id (string) of the controller issuing the request"""
        controller = self.relative_path(False).split('/', 1)[0]
        registered_controllers = (ctrl.id for ctrl in
                                  self.vreg.registry_objects('controllers'))
        if controller in registered_controllers:
            return controller
        return 'view'
    
    def external_resource(self, rid, default=_MARKER):
        """return a path to an external resource, using its identifier

        raise KeyError  if the resource is not defined
        """
        try:
            value = self.vreg.config.ext_resources[rid]
        except KeyError:
            if default is _MARKER:
                raise
            return default
        if value is None:
            return None
        baseurl = self.datadir_url[:-1] # remove trailing /
        if isinstance(value, list):
            return [v.replace('DATADIR', baseurl) for v in value]
        return value.replace('DATADIR', baseurl)
    external_resource = cached(external_resource, keyarg=1)

    def validate_cache(self):
        """raise a `DirectResponse` exception if a cached page along the way
        exists and is still usable.

        calls the client-dependant implementation of `_validate_cache`
        """
        self._validate_cache()
        if self.http_method() == 'HEAD':
            raise StatusResponse(200, '')
        
    # abstract methods to override according to the web front-end #############
        
    def http_method(self):
        """returns 'POST', 'GET', 'HEAD', etc."""
        raise NotImplementedError()

    def _validate_cache(self):
        """raise a `DirectResponse` exception if a cached page along the way
        exists and is still usable
        """
        raise NotImplementedError()
        
    def relative_path(self, includeparams=True):
        """return the normalized path of the request (ie at least relative
        to the application's root, but some other normalization may be needed
        so that the returned path may be used to compare to generated urls

        :param includeparams:
           boolean indicating if GET form parameters should be kept in the path
        """
        raise NotImplementedError()

    def get_header(self, header, default=None):
        """return the value associated with the given input HTTP header,
        raise KeyError if the header is not set
        """
        raise NotImplementedError()

    def set_header(self, header, value):
        """set an output HTTP header"""
        raise NotImplementedError()

    def add_header(self, header, value):
        """add an output HTTP header"""
        raise NotImplementedError()
    
    def remove_header(self, header):
        """remove an output HTTP header"""
        raise NotImplementedError()
        
    def header_authorization(self):
        """returns a couple (auth-type, auth-value)"""
        auth = self.get_header("Authorization", None)
        if auth:
            scheme, rest = auth.split(' ', 1)
            scheme = scheme.lower()
            try:
                assert scheme == "basic"
                user, passwd = base64.decodestring(rest).split(":", 1)
                # XXX HTTP header encoding: use email.Header?
                return user.decode('UTF8'), passwd
            except Exception, ex:
                self.debug('bad authorization %s (%s: %s)',
                           auth, ex.__class__.__name__, ex)
        return None, None

    def header_accept_language(self):
        """returns an ordered list of preferred languages"""
        acceptedlangs = self.get_header('Accept-Language', '')
        langs = []
        for langinfo in acceptedlangs.split(','):
            try:
                lang, score = langinfo.split(';')
                score = float(score[2:]) # remove 'q='
            except ValueError:
                lang = langinfo
                score = 1.0
            lang = lang.split('-')[0]
            langs.append( (score, lang) )
        langs.sort(reverse=True)
        return (lang for (score, lang) in langs)

    def header_if_modified_since(self):
        """If the HTTP header If-modified-since is set, return the equivalent
        mx date time value (GMT), else return None
        """
        raise NotImplementedError()
    
    # page data management ####################################################

    def get_page_data(self, key, default=None):
        """return value associated to `key` in curernt page data"""
        page_data = self.cnx.get_session_data(self.pageid, {})
        return page_data.get(key, default)
        
    def set_page_data(self, key, value):
        """set value associated to `key` in current page data"""
        self.html_headers.add_unload_pagedata()
        page_data = self.cnx.get_session_data(self.pageid, {})
        page_data[key] = value
        return self.cnx.set_session_data(self.pageid, page_data)
        
    def del_page_data(self, key=None):
        """remove value associated to `key` in current page data
        if `key` is None, all page data will be cleared
        """
        if key is None:
            self.cnx.del_session_data(self.pageid)
        else:
            page_data = self.cnx.get_session_data(self.pageid, {})
            page_data.pop(key, None)
            self.cnx.set_session_data(self.pageid, page_data)

    # user-agent detection ####################################################

    @cached
    def useragent(self):
        return self.get_header('User-Agent', None)

    def ie_browser(self):
        useragent = self.useragent()
        return useragent and 'MSIE' in useragent
    
    def xhtml_browser(self):
        useragent = self.useragent()
        if useragent and ('MSIE' in useragent or 'KHTML' in useragent):
            return False
        return True

from cubicweb import set_log_methods
set_log_methods(CubicWebRequestBase, LOGGER)