from typing import TypeVar, Generic, Callable, MutableMapping, List, Iterable, Union, Any, Mapping, Tuple, AbstractSet K = TypeVar('K') V = TypeVar('V') KeyCallable = Callable[[V], Union[K, Iterable[K]]] __all__ = [ "Index", "ListIndex", "UniqueIndex", "ObjDb", ] class Index(Generic[K, V]): def add(self, value: V): pass def get(self, key: K) -> Iterable[V]: pass def clear(self): pass def __iter__(self): pass def items(self): pass def values(self): pass class ListIndex(Index[K, V]): def __init__(self, name, key_callable: KeyCallable, multiple=False): self.name = name self.idx: MutableMapping[K, V] = {} self.key_callable = key_callable self.multiple = multiple def add(self, value: V): keys = self.key_callable(value) if keys is None: return if not self.multiple: keys = [keys] for key in keys: if key is None: continue values = self.idx.get(key, None) if values is not None: values.append(value) else: self.idx[key] = [value] def get(self, key: K) -> List[V]: items = self.idx.get(key, None) return self.idx[key] if items is not None else [] def clear(self): return self.idx.clear() def __iter__(self): return self.idx.__iter__() def items(self) -> AbstractSet[Tuple[K, V]]: return self.idx.items() def values(self): return self.idx.values() class UniqueIndex(Index[K, V]): def __init__(self, name, key_callable: KeyCallable, multiple=False): self.name = name self.idx: MutableMapping[K, V] = {} self.key_callable = key_callable self.multiple = multiple def get(self, key: K) -> Iterable[V]: items = self.idx.get(key, None) return [self.idx[key]] if items is not None else [] def get_single(self, key: K) -> V: return self.idx[key] def add(self, value: V): keys = self.key_callable(value) if keys is None: return if not self.multiple: keys = [keys] for key in keys: present = self.idx.get(key, None) if present is not None: raise KeyError("Duplicate key in index '{}': key={}, value={}".format(self.name, key, repr(value))) self.idx[key] = value def clear(self): return self.idx.clear() def __iter__(self): return self.idx.__iter__() def items(self): return self.idx.items() def values(self): return self.idx.values() class MultiIndex(Index[K, V]): def __init__(self, name, key_callable: KeyCallable, multiple=False): self.name = name self.idx: MutableMapping[K, V] = {} self.key_callable = key_callable self.multiple = multiple # TODO: this should return a new index def get(self, key: K) -> Mapping[K, V]: items = self.idx.get(key, None) return self.idx[key] if items is not None else {} def get_single(self, key: K) -> V: return self.idx[key] def add(self, value: V): keys = self.key_callable(value) if keys is None: return if not self.multiple: keys = [keys] for tpl in keys: if not isinstance(tpl, tuple): raise KeyError("The key must be a tuple, index='{}', key='{}'".format(self.name, repr(tpl))) parent_idx = self.idx for sub_key in tpl[0:-1]: if sub_key is None: raise KeyError("Got None sub-key: index='{}', key='{}'".format(self.name, repr(tpl))) idx = parent_idx.get(sub_key, None) if idx is None: idx = {} parent_idx[sub_key] = idx parent_idx = idx values = parent_idx.get(tpl[-1], None) if values is None: values = [] parent_idx[tpl[-1]] = values values.append(value) def clear(self): return self.idx.clear() def __iter__(self): return self.idx.__iter__() def items(self): return self.idx.items() def values(self): return self.idx.values() class ObjDb(Generic[V]): def __init__(self): self.values: List[V] = [] self._indexes: MutableMapping[str, Index[Any, V]] = {} def add(self, value: V): for idx in self._indexes.values(): idx.add(value) self.values.append(value) def __iter__(self): return self.values.__iter__() def __len__(self) -> int: return len(self.values) def add_index(self, name, key_callable: KeyCallable, **kwargs) -> ListIndex[Any, V]: idx = ListIndex(name, key_callable, **kwargs) return self._add(name, idx) def add_unique_index(self, name, key_callable: KeyCallable, **kwargs) -> UniqueIndex[Any, V]: idx = UniqueIndex(name, key_callable, **kwargs) return self._add(name, idx) def add_multi_index(self, name, key_callable: KeyCallable, **kwargs) -> MultiIndex[Any, V]: idx = MultiIndex(name, key_callable, **kwargs) return self._add(name, idx) def index(self, name) -> Index: return self._indexes[name] def _add(self, name, idx): if name in self._indexes: raise KeyError("Index already exist: {}".format(name)) for value in self.values: idx.add(value) self._indexes[name] = idx return idx