from typing import TypeVar, Generic, Callable, MutableMapping, List, Iterable, Union, Any, Mapping, Tuple, AbstractSet

K = TypeVar('K')
V = TypeVar('V')

KeyCallable = Callable[[V], Union[K, Iterable[K]]]

__all__ = [
    "Index",
    "ListIndex",
    "UniqueIndex",
    "ObjDb",
]


class Index(Generic[K, V]):
    def add(self, value: V):
        pass

    def get(self, key: K) -> Iterable[V]:
        pass

    def clear(self):
        pass

    def __iter__(self):
        pass

    def items(self):
        pass

    def values(self):
        pass


class ListIndex(Index[K, V]):
    def __init__(self, name, key_callable: KeyCallable, multiple=False):
        self.name = name
        self.idx: MutableMapping[K, V] = {}
        self.key_callable = key_callable
        self.multiple = multiple

    def add(self, value: V):
        keys = self.key_callable(value)

        if keys is None:
            return

        if not self.multiple:
            keys = [keys]

        for key in keys:
            if key is None:
                continue

            values = self.idx.get(key, None)
            if values is not None:
                values.append(value)
            else:
                self.idx[key] = [value]

    def get(self, key: K) -> List[V]:
        items = self.idx.get(key, None)
        return self.idx[key] if items is not None else []

    def clear(self):
        return self.idx.clear()

    def __iter__(self):
        return self.idx.__iter__()

    def items(self) -> AbstractSet[Tuple[K, V]]:
        return self.idx.items()

    def values(self):
        return self.idx.values()


class UniqueIndex(Index[K, V]):
    def __init__(self, name, key_callable: KeyCallable, multiple=False):
        self.name = name
        self.idx: MutableMapping[K, V] = {}
        self.key_callable = key_callable
        self.multiple = multiple

    def get(self, key: K) -> Iterable[V]:
        items = self.idx.get(key, None)
        return [self.idx[key]] if items is not None else []

    def get_single(self, key: K) -> V:
        return self.idx[key]

    def add(self, value: V):
        keys = self.key_callable(value)

        if keys is None:
            return

        if not self.multiple:
            keys = [keys]

        for key in keys:
            present = self.idx.get(key, None)
            if present is not None:
                raise KeyError("Duplicate key in index '{}': key={}, value={}".format(self.name, key, repr(value)))

            self.idx[key] = value

    def clear(self):
        return self.idx.clear()

    def __iter__(self):
        return self.idx.__iter__()

    def items(self):
        return self.idx.items()

    def values(self):
        return self.idx.values()


class MultiIndex(Index[K, V]):
    def __init__(self, name, key_callable: KeyCallable, multiple=False):
        self.name = name
        self.idx: MutableMapping[K, V] = {}
        self.key_callable = key_callable
        self.multiple = multiple

    # TODO: this should return a new index
    def get(self, key: K) -> Mapping[K, V]:
        items = self.idx.get(key, None)
        return self.idx[key] if items is not None else {}

    def get_single(self, key: K) -> V:
        return self.idx[key]

    def add(self, value: V):
        keys = self.key_callable(value)

        if keys is None:
            return

        if not self.multiple:
            keys = [keys]

        for tpl in keys:
            if not isinstance(tpl, tuple):
                raise KeyError("The key must be a tuple, index='{}', key='{}'".format(self.name, repr(tpl)))

            parent_idx = self.idx
            for sub_key in tpl[0:-1]:
                if sub_key is None:
                    raise KeyError("Got None sub-key: index='{}', key='{}'".format(self.name, repr(tpl)))
                idx = parent_idx.get(sub_key, None)
                if idx is None:
                    idx = {}
                    parent_idx[sub_key] = idx
                parent_idx = idx

            values = parent_idx.get(tpl[-1], None)
            if values is None:
                values = []
                parent_idx[tpl[-1]] = values
            values.append(value)

    def clear(self):
        return self.idx.clear()

    def __iter__(self):
        return self.idx.__iter__()

    def items(self):
        return self.idx.items()

    def values(self):
        return self.idx.values()


class ObjDb(Generic[V]):
    def __init__(self):
        self.values: List[V] = []
        self._indexes: MutableMapping[str, Index[Any, V]] = {}

    def add(self, value: V):
        for idx in self._indexes.values():
            idx.add(value)

        self.values.append(value)

    def __iter__(self):
        return self.values.__iter__()

    def __len__(self) -> int:
        return len(self.values)

    def add_index(self, name, key_callable: KeyCallable, **kwargs) -> ListIndex[Any, V]:
        idx = ListIndex(name, key_callable, **kwargs)
        return self._add(name, idx)

    def add_unique_index(self, name, key_callable: KeyCallable, **kwargs) -> UniqueIndex[Any, V]:
        idx = UniqueIndex(name, key_callable, **kwargs)
        return self._add(name, idx)

    def add_multi_index(self, name, key_callable: KeyCallable, **kwargs) -> MultiIndex[Any, V]:
        idx = MultiIndex(name, key_callable, **kwargs)
        return self._add(name, idx)

    def index(self, name) -> Index:
        return self._indexes[name]

    def _add(self, name, idx):
        if name in self._indexes:
            raise KeyError("Index already exist: {}".format(name))

        for value in self.values:
            idx.add(value)

        self._indexes[name] = idx

        return idx