rtags.py
branchtls-sprint
changeset 1739 78b0819162a8
parent 1726 08918409815e
child 1740 2292ae32c98f
--- a/rtags.py	Tue May 12 11:49:30 2009 +0200
+++ b/rtags.py	Tue May 12 11:56:12 2009 +0200
@@ -16,9 +16,12 @@
 
     This class associates a single tag to each key.
     """
-
-    def __init__(self):
+    _allowed_values = None
+    def __init__(self, initfunc=None, allowed_values=None):
         self._tagdefs = {}
+        if allowed_values is not None:
+            self._allowed_values = allowed_values
+        self._initfunc = initfunc
 
     def __repr__(self):
         return repr(self._tagdefs)
@@ -28,24 +31,7 @@
         return self.get(*key)
     __contains__ = __getitem__
 
-    def _get_tagged(self, stype, otype, tagged=None):
-        stype, otype = str(stype), str(otype)
-        if tagged is None:
-            if stype[0] == '!':
-                tagged = 'subject'
-                stype = stype[1:]
-            elif otype[0] == '!':
-                tagged = 'object'
-                otype = otype[1:]
-            else:
-                raise AssertionError('either stype or rtype should have the '
-                                     'role mark ("!")')
-        else:
-            assert tagged in ('subject', 'object'), repr(tagged)
-        return stype, otype, tagged
-
-    def _get_keys(self, stype, rtype, otype, tagged=None):
-        stype, otype, tagged = self._get_tagged(stype, otype, tagged)
+    def _get_keys(self, stype, rtype, otype, tagged):
         keys = [(rtype, tagged, '*', '*'),
                 (rtype, tagged, '*', otype),
                 (rtype, tagged, stype, '*'),
@@ -58,18 +44,56 @@
                 keys.remove((rtype, tagged, stype, '*'))
         return keys
 
-    def tag_relation(self, stype, rtype, otype, tag, tagged=None):
-        stype, otype, tagged = self._get_tagged(stype, otype, tagged)
-        self._tagdefs[(str(rtype), tagged, stype, otype)] = tag
+    def init(self, schema):
+        # XXX check existing keys against schema
+        for rtype, tagged, stype, otype in self._tagdefs:
+            assert rtype in schema
+            if stype != '*':
+                assert stype in schema
+            if otype != '*':
+                assert otype in schema
+        if self._initfunc is not None:
+            for eschema in schema.entities():
+                for rschema, tschemas, role in eschema.relation_definitions(True):
+                    for tschema in tschemas:
+                        if role == 'subject':
+                            stype, otype = eschema, tschema
+                        else:
+                            stype, otype = tschema, eschema
+                        self._initfunc(stype, rtype, otype, role)
+
+    # rtag declaration api ####################################################
+
+    def tag_attribute(self, key, tag):
+        key = list(key)
+        key.append('*')
+        self.tag_subject_of(key, tag)
 
-    def tag_attribute(self, stype, attr, tag):
-        self.tag_relation(stype, attr, '*', tag, 'subject')
+    def tag_subject_of(self, key, tag):
+        key = list(key)
+        key.append('subject')
+        self.tag_relation(key, tag)
+
+    def tag_object_of(self, key, tag):
+        key = list(key)
+        key.append('object')
+        self.tag_relation(key, tag)
 
-    def del_rtag(self, stype, rtype, otype):
-        stype, otype, tagged = self._get_tagged(stype, otype)
-        del self._tagdefs[(str(rtype), tagged, stype, otype)]
+    def tag_relation(self, key, tag):
+        #if isinstance(key, basestring):
+        #    stype, rtype, otype = key.split()
+        #else:
+        stype, rtype, otype, tagged = [str(k) for k in key]
+        if self._allowed_values is not None:
+            assert tag in self._allowed_values
+        self._tagdefs[(rtype, tagged, stype, otype)] = tag
 
-    def get(self, stype, rtype, otype, tagged=None):
+    # rtag runtime api ########################################################
+
+    def del_rtag(self, stype, rtype, otype, tagged):
+        del self._tagdefs[(rtype, tagged, stype, otype)]
+
+    def get(self, stype, rtype, otype, tagged):
         for key in reversed(self._get_keys(stype, rtype, otype, tagged)):
             try:
                 return self._tagdefs[key]
@@ -87,12 +111,12 @@
 class RelationTagsSet(RelationTags):
     """This class associates a set of tags to each key."""
 
-    def tag_relation(self, stype, rtype, otype, tag, tagged=None):
-        stype, otype, tagged = self._get_tagged(stype, otype, tagged)
+    def tag_relation(self, key, tag):
+        stype, rtype, otype, tagged = [str(k) for k in key]
         rtags = self._tagdefs.setdefault((rtype, tagged, stype, otype), set())
         rtags.add(tag)
 
-    def get(self, stype, rtype, otype, tagged=None):
+    def get(self, stype, rtype, otype, tagged):
         rtags = set()
         for key in self._get_keys(stype, rtype, otype, tagged):
             try:
@@ -100,3 +124,16 @@
             except KeyError:
                 continue
         return rtags
+
+
+class RelationTagsBool(RelationTags):
+    _allowed_values = frozenset((True, False))
+
+    def tag_subject_of(self, key, tag=True):
+        super(RelationTagsBool, self).tag_subject_of(key, tag)
+
+    def tag_object_of(self, key, tag=True):
+        super(RelationTagsBool, self).tag_object_of(key, tag)
+
+    def tag_attribute(self, key, tag=True):
+        super(RelationTagsBool, self).tag_attribute(key, tag)