diff -r 14c9a0a1aca1 -r 694f6a50e138 rtags.py --- a/rtags.py Mon May 11 11:01:40 2009 +0200 +++ b/rtags.py Mon May 11 11:20:38 2009 +0200 @@ -28,8 +28,24 @@ return self.get(*key) __contains__ = __getitem__ - def _get_keys(self, rtype, tagged, stype, otype): - assert tagged in ('subject', 'object'), tagged + def _get_tagged(self, stype, otype, tagged=None): + stype, otype = str(stype), str(otype) + if tagged is None: + if stype[0] == '!': + tagged = 'subject' + stype = stype[1:] + elif otype[0] == '!': + tagged = 'object' + otype = otype[1:] + else: + raise AssertionError('either stype or rtype should have the ' + 'role mark ("!")') + else: + assert tagged in ('subject', 'object'), tagged + return stype, otype, tagged + + def _get_keys(self, stype, rtype, otype, tagged=None): + stype, otype, tagged = self._get_tagged(stype, otype, tagged) keys = [(rtype, tagged, '*', '*'), (rtype, tagged, '*', otype), (rtype, tagged, stype, '*'), @@ -42,46 +58,43 @@ keys.remove((rtype, tagged, stype, '*')) return keys - def tag_attribute(self, tag, stype, attr): - self._tagdefs[(str(attr), 'subject', str(stype), '*')] = tag + def tag_relation(self, stype, rtype, otype, tag, tagged=None): + stype, otype, tagged = self._get_tagged(stype, otype, tagged) + self._tagdefs[(str(rtype), tagged, stype, otype)] = tag - def tag_relation(self, tag, relation, tagged): - assert tagged in ('subject', 'object'), tagged - stype, rtype, otype = relation - self._tagdefs[(str(rtype), tagged, str(stype), str(otype))] = tag + def tag_attribute(self, stype, attr, tag): + self.tag_relation(stype, attr, '*', tag, tagged) - def del_rtag(self, relation, tagged): - assert tagged in ('subject', 'object'), tagged - stype, rtype, otype = relation - del self._tagdefs[(str(rtype), tagged, str(stype), str(otype))] + def del_rtag(self, stype, rtype, otype): + stype, otype, tagged = self._get_tagged(stype, otype) + del self._tagdefs[(str(rtype), tagged, stype, otype)] - def get(self, rtype, tagged, stype='*', otype='*'): - for key in reversed(self._get_keys(rtype, tagged, stype, otype)): + def get(self, stype, rtype, otype, tagged=None): + for key in reversed(self._get_keys(stype, rtype, otype, tagged)): try: return self._tagdefs[key] except KeyError: continue return None - def etype_get(self, etype, rtype, tagged, ttype='*'): - if tagged == 'subject': - return self.get(rtype, tagged, etype, ttype) - return self.get(rtype, tagged, ttype, etype) + def etype_get(self, etype, rtype, role, ttype='*'): + if role == 'subject': + return self.get(etype, rtype, ttype, role) + return self.get(ttype, rtype, etype, role) class RelationTagsSet(RelationTags): """This class associates a set of tags to each key.""" - def tag_relation(self, tag, relation, tagged): - assert tagged in ('subject', 'object'), tagged - stype, rtype, otype = relation + def tag_relation(self, stype, rtype, otype, tag, tagged=None): + stype, otype, tagged = self._get_tagged(stype, otype, tagged) rtags = self._tagdefs.setdefault((rtype, tagged, stype, otype), set()) rtags.add(tag) - def get(self, rtype, tagged, stype='*', otype='*'): + def get(self, stype, rtype, otype, tagged=None): rtags = set() - for key in self._get_keys(rtype, tagged, stype, otype): + for key in self._get_keys(stype, rtype, otype, tagged): try: rtags.update(self._tagdefs[key]) except KeyError: