[schema] Define a decorator to handle methods override
authorDenis Laxalde <denis.laxalde@logilab.fr>
Fri, 12 Jan 2018 10:17:10 +0100
changeset 12258 46a8146f9703
parent 12257 39cd3c7eb2e8
child 12259 7c4746309ec5
child 12294 038ff1a7259f
[schema] Define a decorator to handle methods override There is a number of external classes (from yams/rql) methods overridden "by hand" in cubicweb/schema.py. Define a single decorator to factor out the pattern. It handles specifying the method name (to avoid conflict with the local namespace) and passing the original method to the new definition when needed.
cubicweb/schema.py
--- a/cubicweb/schema.py	Wed Jan 24 12:03:21 2018 +0100
+++ b/cubicweb/schema.py	Fri Jan 12 10:17:10 2018 +0100
@@ -19,6 +19,7 @@
 
 from __future__ import print_function
 
+from functools import wraps
 import re
 from os.path import join
 from hashlib import md5
@@ -594,8 +595,28 @@
         return text_type(req._(key))
 
 
+def _override_method(cls, method_name=None, pass_original=False):
+    """Override (or set) a method on `cls`."""
+    def decorator(func):
+        name = method_name or func.__name__
+        orig = None
+        if pass_original:
+            orig = getattr(cls, name)
+
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            if orig is not None:
+                kwargs['_orig'] = orig
+            return func(*args, **kwargs)
+
+        setattr(cls, name, wrapper)
+
+    return decorator
+
+
 # Schema objects definition ###################################################
 
+@_override_method(ERSchema, 'display_name')
 def ERSchema_display_name(self, req, form='', context=None):
     """return a internationalized string for the entity/relation type name in
     a given form
@@ -603,9 +624,7 @@
     return display_name(req, self.type, form, context)
 
 
-ERSchema.display_name = ERSchema_display_name
-
-
+@_override_method(PermissionMixIn)
 @cached
 def get_groups(self, action):
     """return the groups authorized to perform <action> on entities of
@@ -624,9 +643,7 @@
         return ()
 
 
-PermissionMixIn.get_groups = get_groups
-
-
+@_override_method(PermissionMixIn)
 @cached
 def get_rqlexprs(self, action):
     """return the rql expressions representing queries to check the user is allowed
@@ -645,10 +662,8 @@
         return ()
 
 
-PermissionMixIn.get_rqlexprs = get_rqlexprs
-
-
-def set_action_permissions(self, action, permissions):
+@_override_method(PermissionMixIn, pass_original=True)
+def set_action_permissions(self, action, permissions, _orig):
     """set the groups and rql expressions allowing to perform <action> on
     entities of this type
 
@@ -658,15 +673,12 @@
     :type permissions: tuple
     :param permissions: the groups and rql expressions allowing the given action
     """
-    orig_set_action_permissions(self, action, tuple(permissions))
+    _orig(self, action, tuple(permissions))
     clear_cache(self, 'get_rqlexprs')
     clear_cache(self, 'get_groups')
 
 
-orig_set_action_permissions = PermissionMixIn.set_action_permissions
-PermissionMixIn.set_action_permissions = set_action_permissions
-
-
+@_override_method(PermissionMixIn)
 def has_local_role(self, action):
     """return true if the action *may* be granted locally (i.e. either rql
     expressions or the owners group are used in security definition)
@@ -682,9 +694,7 @@
     return False
 
 
-PermissionMixIn.has_local_role = has_local_role
-
-
+@_override_method(PermissionMixIn)
 def may_have_permission(self, action, req):
     if action != 'read' and not (self.has_local_role('read') or
                                  self.has_perm(req, 'read')):
@@ -692,9 +702,7 @@
     return self.has_local_role(action) or self.has_perm(req, action)
 
 
-PermissionMixIn.may_have_permission = may_have_permission
-
-
+@_override_method(PermissionMixIn)
 def has_perm(self, _cw, action, **kwargs):
     """return true if the action is granted globally or locally"""
     try:
@@ -704,9 +712,7 @@
         return False
 
 
-PermissionMixIn.has_perm = has_perm
-
-
+@_override_method(PermissionMixIn)
 def check_perm(self, _cw, action, **kwargs):
     # NB: _cw may be a server transaction or a request object.
     #
@@ -749,9 +755,6 @@
     raise Unauthorized(action, str(self))
 
 
-PermissionMixIn.check_perm = check_perm
-
-
 CubicWebRelationDefinitionSchema._RPROPERTIES['eid'] = None
 # remember rproperties defined at this point. Others will have to be serialized in
 # CWAttribute.extra_props
@@ -1464,33 +1467,23 @@
 
 # XXX itou for some Statement methods
 
-def bw_get_etype(self, name):
-    return orig_get_etype(self, bw_normalize_etype(name))
-
-
-orig_get_etype = stmts.ScopeNode.get_etype
-stmts.ScopeNode.get_etype = bw_get_etype
-
-
-def bw_add_main_variable_delete(self, etype, vref):
-    return orig_add_main_variable_delete(self, bw_normalize_etype(etype), vref)
+@_override_method(stmts.ScopeNode, pass_original=True)
+def get_etype(self, name, _orig):
+    return _orig(self, bw_normalize_etype(name))
 
 
-orig_add_main_variable_delete = stmts.Delete.add_main_variable
-stmts.Delete.add_main_variable = bw_add_main_variable_delete
-
-
-def bw_add_main_variable_insert(self, etype, vref):
-    return orig_add_main_variable_insert(self, bw_normalize_etype(etype), vref)
+@_override_method(stmts.Delete, method_name='add_main_variable',
+                  pass_original=True)
+def _add_main_variable_delete(self, etype, vref, _orig):
+    return _orig(self, bw_normalize_etype(etype), vref)
 
 
-orig_add_main_variable_insert = stmts.Insert.add_main_variable
-stmts.Insert.add_main_variable = bw_add_main_variable_insert
+@_override_method(stmts.Insert, method_name='add_main_variable',
+                  pass_original=True)
+def _add_main_variable_insert(self, etype, vref, _orig):
+    return _orig(self, bw_normalize_etype(etype), vref)
 
 
-def bw_set_statement_type(self, etype):
-    return orig_set_statement_type(self, bw_normalize_etype(etype))
-
-
-orig_set_statement_type = stmts.Select.set_statement_type
-stmts.Select.set_statement_type = bw_set_statement_type
+@_override_method(stmts.Select, pass_original=True)
+def set_statement_type(self, etype, _orig):
+    return _orig(self, bw_normalize_etype(etype))