aboutsummaryrefslogtreecommitdiff
path: root/src/ee/ds/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/ee/ds/__init__.py')
-rw-r--r--src/ee/ds/__init__.py363
1 files changed, 363 insertions, 0 deletions
diff --git a/src/ee/ds/__init__.py b/src/ee/ds/__init__.py
new file mode 100644
index 0000000..c8c21dd
--- /dev/null
+++ b/src/ee/ds/__init__.py
@@ -0,0 +1,363 @@
+import configparser
+import logging
+import os
+import csv
+from functools import total_ordering
+from pathlib import Path
+from typing import MutableMapping, Optional, List, Tuple, Union, Iterator
+
+logger = logging.getLogger(__name__)
+
+
+@total_ordering
+class ObjectType(object):
+ def __init__(self, name: str):
+ self._name = name
+ self._fields = []
+ self._objects = {}
+
+ def __eq__(self, o) -> bool:
+ other = o # type: ObjectType
+ return isinstance(o, ObjectType) and self._name == other._name
+
+ def __lt__(self, o: object) -> bool:
+ if not isinstance(o, ObjectType):
+ return True
+
+ other = o # type: ObjectType
+ return self._name < other._name
+
+ def __hash__(self) -> int:
+ return self._name.__hash__()
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def fields(self):
+ return self._fields
+
+ def index_of(self, field: str, create: bool = False) -> Optional[int]:
+ try:
+ return self._fields.index(field)
+ except ValueError as e:
+ if not create:
+ return None
+
+ self._fields.append(field)
+ return len(self._fields) - 1
+
+
+class Object(object):
+ def __init__(self, ds: "DataSet", ot: ObjectType, key: str):
+ self._ds = ds
+ self._ot = ot
+ self._key = key
+ self._data = []
+
+ @property
+ def object_type(self):
+ return self._ot
+
+ @property
+ def key(self):
+ return self._key
+
+ def set(self, key: str, value: str):
+ if self._ds._frozen:
+ raise Exception("This data set is frozen")
+ idx = self._ot.index_of(key, create=True)
+ self._data.insert(idx, value)
+
+ def _set_from_object(self, other: "Object"):
+ for k in other._ot.fields:
+ self.set(k, other.get(k))
+
+ def get(self, key: str) -> Optional[str]:
+ idx = self._ot.index_of(key)
+ return self._data[idx] if idx is not None and idx < len(self._data) else None
+
+ def get_all(self, *keys: str) -> Optional[List[str]]:
+ values = []
+ for key in keys:
+ idx = self._ot.index_of(key)
+ if not idx or idx >= len(self._data):
+ return None
+ values.append(self._data[idx])
+ return values
+
+
+class DataSet(object):
+ def __init__(self, name):
+ self._name = name
+ self._object_types = {} # type: MutableMapping[str, ObjectType]
+ self._objects_by_type = {} # type: MutableMapping[ObjectType, MutableMapping[str, Object]]
+ self._frozen = False
+
+ @property
+ def name(self):
+ return self._name
+
+ def __len__(self):
+ return sum((len(objects) for objects in self._objects_by_type.values()))
+
+ def freeze(self):
+ self._frozen = True
+
+ def _assert_not_frozen(self):
+ if self._frozen:
+ raise Exception("This data set is frozen")
+
+ def _check_object_type(self, object_type: str, create: bool) -> \
+ Optional[Tuple[ObjectType, MutableMapping[str, Object]]]:
+ try:
+ ot = self._object_types[object_type]
+ objects = self._objects_by_type[ot]
+ return ot, objects,
+ except KeyError:
+ if not create:
+ return None
+
+ self._assert_not_frozen()
+
+ ot = ObjectType(object_type)
+ self._object_types[object_type] = ot
+ self._objects_by_type[ot] = objects = {}
+ return ot, objects,
+
+ def _check_object(self, object_type: str, key: str, create: bool) -> Optional[Object]:
+ t = self._check_object_type(object_type, create)
+
+ if not t:
+ return None
+
+ ot, objects = t
+ try:
+ return objects[key]
+ except KeyError:
+ self._assert_not_frozen()
+
+ if not create:
+ raise Exception("No such type: {}:{}".format(object_type, key))
+
+ o = Object(self, ot, key)
+ objects[key] = o
+ return o
+
+ def get_object_type(self, object_type: str) -> ObjectType:
+ ot, objects = self._check_object_type(object_type, False)
+
+ if not ot:
+ raise Exception("No such object type: {}".format(object_type))
+
+ return ot
+
+ def get_object(self, object_type: str, key: str) -> Object:
+ o = self._check_object(object_type, key, False)
+
+ if not o:
+ raise Exception("No such object: {}:{}".format(object_type, key))
+
+ return o
+
+ def has_object(self, object_type: str, key: str) -> bool:
+ t = self._check_object_type(object_type, False)
+
+ if t:
+ ot, objects = t
+ return key in objects
+
+ return False
+
+ def get_or_create_object(self, object_type: str, key: str) -> Object:
+ return self._check_object(object_type, key, True)
+
+ def create_object(self, object_type: str, key: str) -> Object:
+ self._assert_not_frozen()
+
+ if self.has_object(object_type, key):
+ raise Exception("Object already exist: {}:{}".format(object_type, key))
+
+ return self._check_object(object_type, key, True)
+
+ def items(self) -> Iterator[Object]:
+ for objects in self._objects_by_type.values():
+ for o in objects.values():
+ yield o
+
+ def merge(self, other: "DataSet") -> "DataSet":
+ ds = DataSet(self._name)
+ for objects in self._objects_by_type.values():
+ for o in objects.values():
+ ds.create_object(o.object_type.name, o.key)._set_from_object(o)
+
+ for objects in other._objects_by_type.values():
+ for o in objects.values():
+ ds.get_or_create_object(o.object_type.name, o.key)._set_from_object(o)
+
+ return ds
+
+
+class DataSetManager(object):
+ def __init__(self, basedir: Union[Path, str]):
+ self._basedir = Path(basedir)
+ self._csv = {} # type: MutableMapping[str, Tuple[str, Path]]
+
+ def cookie_for_ds(self, ds_name) -> Path:
+ try:
+ return self._csv[ds_name][1]
+ except KeyError:
+ return self._basedir / ds_name / "data-set.ini"
+
+ def create_rw(self, name, inputs: List[str] = None) -> "LazyDataSet":
+ return LazyDataSet(self, False, name, inputs if inputs else [])
+
+ def create_ro(self, inputs: List[str]) -> "LazyDataSet":
+ return LazyDataSet(self, True, None, inputs)
+
+ def add_ds(self, ds_type: str, name: str, object_type: str, path: str = None):
+ if ds_type == "csv":
+ if name in self._csv:
+ raise Exception("Data source already exists: {}".format(name))
+
+ self._csv[name] = object_type, Path(path),
+ else:
+ raise Exception("Unknown data source type: {}".format(ds_type))
+
+ def ds_type(self, name: str):
+ return "csv" if name in self._csv else "ini-dir"
+
+ def load(self, name, freeze=False) -> DataSet:
+ try:
+ object_type, path = self._csv[name]
+
+ if not freeze:
+ raise Exception("CSV data sources must be frozen")
+
+ return DataSetManager._load_csv(name, object_type, path, freeze)
+ except KeyError:
+ return self._load_ini_dir(name, freeze)
+
+ @staticmethod
+ def _load_csv(name: str, object_type: str, path: Path, freeze: bool) -> DataSet:
+ logger.info("Loading CSV file {}".format(path))
+ ds = DataSet(name)
+
+ with open(str(path), newline='') as f:
+ r = csv.reader(f)
+
+ header = next(r, None)
+ for row in r:
+ if len(row) == 0:
+ continue
+
+ key = row[0]
+
+ o = ds.create_object(object_type, key)
+ for idx, value in zip(range(0, min(len(row), len(header))), row):
+ o.set(header[idx], value)
+
+ logger.debug("Loaded {} objects".format(len(ds)))
+ return ds
+
+ def _load_ini_dir(self, name: str, freeze: bool) -> DataSet:
+ ds_dir = Path(name) if Path(name).is_absolute() else self._basedir / name
+ ds_dir = ds_dir if ds_dir.is_dir() else ds_dir.parent
+
+ logger.info("Loading DS from '{}'".format(ds_dir))
+
+ ini = self._load_ini(ds_dir / "data-set.ini")
+ name = ini.get("data-set", "name")
+
+ ds = DataSet(name)
+ count = 0
+ for ot_path in ds_dir.glob("*"):
+ if not ot_path.is_dir():
+ continue
+
+ ot = ot_path.name
+ logger.info(" Loading type '{}'".format(ot))
+ for o_path in ot_path.glob("*.ini"):
+ count += 1
+
+ key = o_path.name[:-4]
+ logger.info(" Loading key '{}'".format(key))
+ ini = self._load_ini(o_path)
+ o = ds.create_object(ot, key)
+ for k, v in ini.items("values"):
+ o.set(k, v)
+
+ if freeze:
+ ds.freeze()
+
+ logger.info("Loaded {} items".format(count))
+ return ds
+
+ def store(self, ds: DataSet):
+ ds_dir = self._basedir / ds.name
+ items = list(ds.items())
+ logger.info("Storing DS '{}' with {} objects to {}".format(ds.name, len(items), ds_dir))
+
+ os.makedirs(ds_dir, exist_ok=True)
+ ini = self._blank_ini()
+ ini.add_section("data-set")
+ ini.set("data-set", "name", ds.name)
+ self._store_ini(ini, ds_dir / "data-set.ini")
+
+ for o in items:
+ ot = o.object_type
+ key = o.key
+
+ ot_dir = ds_dir / ot.name
+ os.makedirs(ot_dir, exist_ok=True)
+ ini = self._blank_ini()
+ ini.add_section("meta")
+ ini.set("meta", "type", ot.name)
+
+ ini.add_section("values")
+ for k in ot.fields:
+ v = o.get(k)
+ if v:
+ ini.set("values", k, str(v))
+ self._store_ini(ini, ot_dir / "{}.ini".format(key))
+
+ @staticmethod
+ def _blank_ini():
+ return configparser.ConfigParser(interpolation=None)
+
+ def _load_ini(self, path: Path):
+ ini = self._blank_ini()
+ if len(ini.read(str(path))) != 1:
+ raise IOError("Could not load ini file: {}".format(path))
+ return ini
+
+ @staticmethod
+ def _store_ini(ini, path):
+ with open(path, "w") as f:
+ ini.write(f)
+
+
+class LazyDataSet(object):
+ def __init__(self, dsm: DataSetManager, freeze: bool, name, inputs):
+ self._dsm = dsm
+ self._freeze = freeze
+ self._name = name
+ self._inputs = inputs
+
+ def __enter__(self) -> DataSet:
+ # logger.info("enter: name={}, inputs={}".format(self._name, self._inputs))
+ ds = DataSet(self._name)
+ for name in self._inputs:
+ ds = ds.merge(self._dsm.load(name, freeze=True))
+
+ if self._freeze:
+ ds.freeze()
+
+ self._ds = ds
+ return ds
+
+ def __exit__(self, *args):
+ if not self._freeze:
+ self._dsm.store(self._ds)
+ return False