rtags.py
branchtls-sprint
changeset 1721 694f6a50e138
parent 1548 bd225e776739
child 1723 30c3a713ab61
--- a/rtags.py	Mon May 11 11:01:40 2009 +0200
+++ b/rtags.py	Mon May 11 11:20:38 2009 +0200
@@ -28,8 +28,24 @@
         return self.get(*key)
     __contains__ = __getitem__
 
-    def _get_keys(self, rtype, tagged, stype, otype):
-        assert tagged in ('subject', 'object'), tagged
+    def _get_tagged(self, stype, otype, tagged=None):
+        stype, otype = str(stype), str(otype)
+        if tagged is None:
+            if stype[0] == '!':
+                tagged = 'subject'
+                stype = stype[1:]
+            elif otype[0] == '!':
+                tagged = 'object'
+                otype = otype[1:]
+            else:
+                raise AssertionError('either stype or rtype should have the '
+                                     'role mark ("!")')
+        else:
+            assert tagged in ('subject', 'object'), tagged
+        return stype, otype, tagged
+
+    def _get_keys(self, stype, rtype, otype, tagged=None):
+        stype, otype, tagged = self._get_tagged(stype, otype, tagged)
         keys = [(rtype, tagged, '*', '*'),
                 (rtype, tagged, '*', otype),
                 (rtype, tagged, stype, '*'),
@@ -42,46 +58,43 @@
                 keys.remove((rtype, tagged, stype, '*'))
         return keys
 
-    def tag_attribute(self, tag, stype, attr):
-        self._tagdefs[(str(attr), 'subject', str(stype), '*')] = tag
+    def tag_relation(self, stype, rtype, otype, tag, tagged=None):
+        stype, otype, tagged = self._get_tagged(stype, otype, tagged)
+        self._tagdefs[(str(rtype), tagged, stype, otype)] = tag
 
-    def tag_relation(self, tag, relation, tagged):
-        assert tagged in ('subject', 'object'), tagged
-        stype, rtype, otype = relation
-        self._tagdefs[(str(rtype), tagged, str(stype), str(otype))] = tag
+    def tag_attribute(self, stype, attr, tag):
+        self.tag_relation(stype, attr, '*', tag, tagged)
 
-    def del_rtag(self, relation, tagged):
-        assert tagged in ('subject', 'object'), tagged
-        stype, rtype, otype = relation
-        del self._tagdefs[(str(rtype), tagged, str(stype), str(otype))]
+    def del_rtag(self, stype, rtype, otype):
+        stype, otype, tagged = self._get_tagged(stype, otype)
+        del self._tagdefs[(str(rtype), tagged, stype, otype)]
 
-    def get(self, rtype, tagged, stype='*', otype='*'):
-        for key in reversed(self._get_keys(rtype, tagged, stype, otype)):
+    def get(self, stype, rtype, otype, tagged=None):
+        for key in reversed(self._get_keys(stype, rtype, otype, tagged)):
             try:
                 return self._tagdefs[key]
             except KeyError:
                 continue
         return None
 
-    def etype_get(self, etype, rtype, tagged, ttype='*'):
-        if tagged == 'subject':
-            return self.get(rtype, tagged, etype, ttype)
-        return self.get(rtype, tagged, ttype, etype)
+    def etype_get(self, etype, rtype, role, ttype='*'):
+        if role == 'subject':
+            return self.get(etype, rtype, ttype, role)
+        return self.get(ttype, rtype, etype, role)
 
 
 
 class RelationTagsSet(RelationTags):
     """This class associates a set of tags to each key."""
 
-    def tag_relation(self, tag, relation, tagged):
-        assert tagged in ('subject', 'object'), tagged
-        stype, rtype, otype = relation
+    def tag_relation(self, stype, rtype, otype, tag, tagged=None):
+        stype, otype, tagged = self._get_tagged(stype, otype, tagged)
         rtags = self._tagdefs.setdefault((rtype, tagged, stype, otype), set())
         rtags.add(tag)
 
-    def get(self, rtype, tagged, stype='*', otype='*'):
+    def get(self, stype, rtype, otype, tagged=None):
         rtags = set()
-        for key in self._get_keys(rtype, tagged, stype, otype):
+        for key in self._get_keys(stype, rtype, otype, tagged):
             try:
                 rtags.update(self._tagdefs[key])
             except KeyError: