diff -r 2cfd50c8a415 -r 78b0819162a8 rtags.py --- a/rtags.py Tue May 12 11:49:30 2009 +0200 +++ b/rtags.py Tue May 12 11:56:12 2009 +0200 @@ -16,9 +16,12 @@ This class associates a single tag to each key. """ - - def __init__(self): + _allowed_values = None + def __init__(self, initfunc=None, allowed_values=None): self._tagdefs = {} + if allowed_values is not None: + self._allowed_values = allowed_values + self._initfunc = initfunc def __repr__(self): return repr(self._tagdefs) @@ -28,24 +31,7 @@ return self.get(*key) __contains__ = __getitem__ - def _get_tagged(self, stype, otype, tagged=None): - stype, otype = str(stype), str(otype) - if tagged is None: - if stype[0] == '!': - tagged = 'subject' - stype = stype[1:] - elif otype[0] == '!': - tagged = 'object' - otype = otype[1:] - else: - raise AssertionError('either stype or rtype should have the ' - 'role mark ("!")') - else: - assert tagged in ('subject', 'object'), repr(tagged) - return stype, otype, tagged - - def _get_keys(self, stype, rtype, otype, tagged=None): - stype, otype, tagged = self._get_tagged(stype, otype, tagged) + def _get_keys(self, stype, rtype, otype, tagged): keys = [(rtype, tagged, '*', '*'), (rtype, tagged, '*', otype), (rtype, tagged, stype, '*'), @@ -58,18 +44,56 @@ keys.remove((rtype, tagged, stype, '*')) return keys - def tag_relation(self, stype, rtype, otype, tag, tagged=None): - stype, otype, tagged = self._get_tagged(stype, otype, tagged) - self._tagdefs[(str(rtype), tagged, stype, otype)] = tag + def init(self, schema): + # XXX check existing keys against schema + for rtype, tagged, stype, otype in self._tagdefs: + assert rtype in schema + if stype != '*': + assert stype in schema + if otype != '*': + assert otype in schema + if self._initfunc is not None: + for eschema in schema.entities(): + for rschema, tschemas, role in eschema.relation_definitions(True): + for tschema in tschemas: + if role == 'subject': + stype, otype = eschema, tschema + else: + stype, otype = tschema, eschema + self._initfunc(stype, rtype, otype, role) + + # rtag declaration api #################################################### + + def tag_attribute(self, key, tag): + key = list(key) + key.append('*') + self.tag_subject_of(key, tag) - def tag_attribute(self, stype, attr, tag): - self.tag_relation(stype, attr, '*', tag, 'subject') + def tag_subject_of(self, key, tag): + key = list(key) + key.append('subject') + self.tag_relation(key, tag) + + def tag_object_of(self, key, tag): + key = list(key) + key.append('object') + self.tag_relation(key, tag) - def del_rtag(self, stype, rtype, otype): - stype, otype, tagged = self._get_tagged(stype, otype) - del self._tagdefs[(str(rtype), tagged, stype, otype)] + def tag_relation(self, key, tag): + #if isinstance(key, basestring): + # stype, rtype, otype = key.split() + #else: + stype, rtype, otype, tagged = [str(k) for k in key] + if self._allowed_values is not None: + assert tag in self._allowed_values + self._tagdefs[(rtype, tagged, stype, otype)] = tag - def get(self, stype, rtype, otype, tagged=None): + # rtag runtime api ######################################################## + + def del_rtag(self, stype, rtype, otype, tagged): + del self._tagdefs[(rtype, tagged, stype, otype)] + + def get(self, stype, rtype, otype, tagged): for key in reversed(self._get_keys(stype, rtype, otype, tagged)): try: return self._tagdefs[key] @@ -87,12 +111,12 @@ class RelationTagsSet(RelationTags): """This class associates a set of tags to each key.""" - def tag_relation(self, stype, rtype, otype, tag, tagged=None): - stype, otype, tagged = self._get_tagged(stype, otype, tagged) + def tag_relation(self, key, tag): + stype, rtype, otype, tagged = [str(k) for k in key] rtags = self._tagdefs.setdefault((rtype, tagged, stype, otype), set()) rtags.add(tag) - def get(self, stype, rtype, otype, tagged=None): + def get(self, stype, rtype, otype, tagged): rtags = set() for key in self._get_keys(stype, rtype, otype, tagged): try: @@ -100,3 +124,16 @@ except KeyError: continue return rtags + + +class RelationTagsBool(RelationTags): + _allowed_values = frozenset((True, False)) + + def tag_subject_of(self, key, tag=True): + super(RelationTagsBool, self).tag_subject_of(key, tag) + + def tag_object_of(self, key, tag=True): + super(RelationTagsBool, self).tag_object_of(key, tag) + + def tag_attribute(self, key, tag=True): + super(RelationTagsBool, self).tag_attribute(key, tag)