--- a/rtags.py Mon May 11 11:01:40 2009 +0200
+++ b/rtags.py Mon May 11 11:20:38 2009 +0200
@@ -28,8 +28,24 @@
return self.get(*key)
__contains__ = __getitem__
- def _get_keys(self, rtype, tagged, stype, otype):
- assert tagged in ('subject', 'object'), tagged
+ def _get_tagged(self, stype, otype, tagged=None):
+ stype, otype = str(stype), str(otype)
+ if tagged is None:
+ if stype[0] == '!':
+ tagged = 'subject'
+ stype = stype[1:]
+ elif otype[0] == '!':
+ tagged = 'object'
+ otype = otype[1:]
+ else:
+ raise AssertionError('either stype or rtype should have the '
+ 'role mark ("!")')
+ else:
+ assert tagged in ('subject', 'object'), tagged
+ return stype, otype, tagged
+
+ def _get_keys(self, stype, rtype, otype, tagged=None):
+ stype, otype, tagged = self._get_tagged(stype, otype, tagged)
keys = [(rtype, tagged, '*', '*'),
(rtype, tagged, '*', otype),
(rtype, tagged, stype, '*'),
@@ -42,46 +58,43 @@
keys.remove((rtype, tagged, stype, '*'))
return keys
- def tag_attribute(self, tag, stype, attr):
- self._tagdefs[(str(attr), 'subject', str(stype), '*')] = tag
+ def tag_relation(self, stype, rtype, otype, tag, tagged=None):
+ stype, otype, tagged = self._get_tagged(stype, otype, tagged)
+ self._tagdefs[(str(rtype), tagged, stype, otype)] = tag
- def tag_relation(self, tag, relation, tagged):
- assert tagged in ('subject', 'object'), tagged
- stype, rtype, otype = relation
- self._tagdefs[(str(rtype), tagged, str(stype), str(otype))] = tag
+ def tag_attribute(self, stype, attr, tag):
+ self.tag_relation(stype, attr, '*', tag, tagged)
- def del_rtag(self, relation, tagged):
- assert tagged in ('subject', 'object'), tagged
- stype, rtype, otype = relation
- del self._tagdefs[(str(rtype), tagged, str(stype), str(otype))]
+ def del_rtag(self, stype, rtype, otype):
+ stype, otype, tagged = self._get_tagged(stype, otype)
+ del self._tagdefs[(str(rtype), tagged, stype, otype)]
- def get(self, rtype, tagged, stype='*', otype='*'):
- for key in reversed(self._get_keys(rtype, tagged, stype, otype)):
+ def get(self, stype, rtype, otype, tagged=None):
+ for key in reversed(self._get_keys(stype, rtype, otype, tagged)):
try:
return self._tagdefs[key]
except KeyError:
continue
return None
- def etype_get(self, etype, rtype, tagged, ttype='*'):
- if tagged == 'subject':
- return self.get(rtype, tagged, etype, ttype)
- return self.get(rtype, tagged, ttype, etype)
+ def etype_get(self, etype, rtype, role, ttype='*'):
+ if role == 'subject':
+ return self.get(etype, rtype, ttype, role)
+ return self.get(ttype, rtype, etype, role)
class RelationTagsSet(RelationTags):
"""This class associates a set of tags to each key."""
- def tag_relation(self, tag, relation, tagged):
- assert tagged in ('subject', 'object'), tagged
- stype, rtype, otype = relation
+ def tag_relation(self, stype, rtype, otype, tag, tagged=None):
+ stype, otype, tagged = self._get_tagged(stype, otype, tagged)
rtags = self._tagdefs.setdefault((rtype, tagged, stype, otype), set())
rtags.add(tag)
- def get(self, rtype, tagged, stype='*', otype='*'):
+ def get(self, stype, rtype, otype, tagged=None):
rtags = set()
- for key in self._get_keys(rtype, tagged, stype, otype):
+ for key in self._get_keys(stype, rtype, otype, tagged):
try:
rtags.update(self._tagdefs[key])
except KeyError: