aboutsummaryrefslogtreecommitdiff
path: root/src/ee/ds/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/ee/ds/__init__.py')
-rw-r--r--src/ee/ds/__init__.py138
1 files changed, 80 insertions, 58 deletions
diff --git a/src/ee/ds/__init__.py b/src/ee/ds/__init__.py
index 5899d28..030113b 100644
--- a/src/ee/ds/__init__.py
+++ b/src/ee/ds/__init__.py
@@ -5,7 +5,7 @@ import os
import shutil
from functools import total_ordering
from pathlib import Path
-from typing import MutableMapping, Optional, List, Tuple, Union, Iterator
+from typing import MutableMapping, Optional, List, Tuple, Union, Iterator, Iterable
logger = logging.getLogger(__name__)
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class ObjectType(object):
def __init__(self, name: str):
self._name = name
- self._fields = []
+ self._fields = [] # type: List[str]
self._objects = {}
def __eq__(self, o) -> bool:
@@ -42,7 +42,7 @@ class ObjectType(object):
def index_of(self, field: str, create: bool = False) -> Optional[int]:
try:
return self._fields.index(field)
- except ValueError as e:
+ except ValueError:
if not create:
return None
@@ -81,6 +81,13 @@ class Object(object):
idx = self._ot.index_of(key)
return self._data[idx] if idx is not None and idx < len(self._data) else None
+ def get_req(self, key: str) -> str:
+ idx = self._ot.index_of(key)
+ if idx is not None and idx < len(self._data):
+ return self._data[idx]
+ else:
+ raise Exception("No such field: {}".format(key))
+
def get_all(self, *keys: str) -> Optional[List[str]]:
values = []
for key in keys:
@@ -92,18 +99,11 @@ class Object(object):
class DataSet(object):
- def __init__(self, name: Optional[str] = None):
- self._name = name
+ def __init__(self):
self._object_types = {} # type: MutableMapping[str, ObjectType]
self._objects_by_type = {} # type: MutableMapping[ObjectType, MutableMapping[str, Object]]
self._frozen = False
- @property
- def name(self):
- if not self._name:
- raise Exception("Unnamed data set")
- return self._name
-
def __len__(self):
return sum((len(objects) for objects in self._objects_by_type.values()))
@@ -151,11 +151,12 @@ class DataSet(object):
return o
def get_object_type(self, object_type: str) -> ObjectType:
- ot, objects = self._check_object_type(object_type, False)
+ t = self._check_object_type(object_type, False)
- if not ot:
+ if not t:
raise Exception("No such object type: {}".format(object_type))
+ ot, objects = t
return ot
def get_object(self, object_type: str, key: str) -> Object:
@@ -192,7 +193,7 @@ class DataSet(object):
yield o
def merge(self, other: "DataSet") -> "DataSet":
- ds = DataSet(self._name)
+ ds = DataSet()
for objects in self._objects_by_type.values():
for o in objects.values():
ds.create_object(o.object_type.name, o.key)._set_from_object(o)
@@ -203,6 +204,14 @@ class DataSet(object):
return ds
+ def import_object(self, other: Object) -> Object:
+ o = self._check_object(other.object_type.name, other.key, create=True)
+
+ for k in other.object_type.fields:
+ o.set(k, other.get(k))
+
+ return o
+
class DataSetManager(object):
def __init__(self, basedir: Union[Path, str]):
@@ -218,8 +227,15 @@ class DataSetManager(object):
def create_rw(self, name, clean: bool) -> "LazyRwDataSet":
return LazyRwDataSet(self, name, clean)
- def create_ro(self, inputs: List[str]) -> "LazyRoDataSet":
- return LazyRoDataSet(self, inputs)
+ def load_data_sets(self, inputs: List[str], freeze: bool = True) -> DataSet:
+ ds = DataSet()
+ for name in inputs:
+ ds = ds.merge(self.load(name, freeze=True))
+
+ if freeze:
+ ds.freeze()
+
+ return ds
def add_ds(self, ds_type: str, name: str, object_type: str, path: str = None):
if ds_type == "csv":
@@ -233,21 +249,21 @@ class DataSetManager(object):
def ds_type(self, name: str):
return "csv" if name in self._csv else "ini-dir"
- def load(self, name, freeze=False) -> DataSet:
+ def load(self, path, freeze=False) -> DataSet:
try:
- object_type, path = self._csv[name]
+ object_type, path = self._csv[path]
if not freeze:
raise Exception("CSV data sources must be frozen")
- return DataSetManager._load_csv(name, object_type, path, freeze)
+ return DataSetManager._load_csv(object_type, path, freeze)
except KeyError:
- return self._load_ini_dir(name, freeze)
+ return self._load_ini_dir(path, freeze)
@staticmethod
- def _load_csv(name: str, object_type: str, path: Path, freeze: bool) -> DataSet:
- logger.info("Loading CSV file {}".format(path))
- ds = DataSet(name)
+ def _load_csv(object_type: str, path: Path, freeze: bool) -> DataSet:
+ logger.debug("Loading CSV file {}".format(path))
+ ds = DataSet()
with open(str(path), newline='') as f:
r = csv.reader(f)
@@ -263,31 +279,33 @@ class DataSetManager(object):
for idx, value in zip(range(0, min(len(row), len(header))), row):
o.set(header[idx], value)
+ if freeze:
+ ds.freeze()
+
logger.debug("Loaded {} objects".format(len(ds)))
return ds
- def _load_ini_dir(self, name: str, freeze: bool) -> DataSet:
- ds_dir = Path(name) if Path(name).is_absolute() else self._basedir / name
+ def _load_ini_dir(self, _path: str, freeze: bool) -> DataSet:
+ ds_dir = Path(_path) if Path(_path).is_absolute() else self._basedir / _path
ds_dir = ds_dir if ds_dir.is_dir() else ds_dir.parent
- logger.info("Loading DS from '{}'".format(ds_dir))
+ logger.debug("Loading DS from '{}'".format(ds_dir))
- ini = self._load_ini(ds_dir / "data-set.ini")
- name = ini.get("data-set", "name")
+ self._load_ini(ds_dir / "data-set.ini")
- ds = DataSet(name)
+ ds = DataSet()
count = 0
for ot_path in ds_dir.glob("*"):
if not ot_path.is_dir():
continue
ot = ot_path.name
- logger.info(" Loading type '{}'".format(ot))
+ logger.debug(" Loading type '{}'".format(ot))
for o_path in ot_path.glob("*.ini"):
count += 1
key = o_path.name[:-4]
- logger.info(" Loading key '{}'".format(key))
+ logger.debug(" Loading key '{}'".format(key))
ini = self._load_ini(o_path)
o = ds.create_object(ot, key)
for k, v in ini.items("values"):
@@ -296,18 +314,18 @@ class DataSetManager(object):
if freeze:
ds.freeze()
- logger.info("Loaded {} items".format(count))
+ logger.debug("Loaded {} items".format(count))
return ds
- def store(self, ds: DataSet):
- ds_dir = self._basedir / ds.name
+ def store(self, ds: DataSet, ds_name: str):
+ ds_dir = self._basedir / ds_name
items = list(ds.items())
- logger.info("Storing DS '{}' with {} objects to {}".format(ds.name, len(items), ds_dir))
+ logger.info("Storing DS '{}' with {} objects to {}".format(ds_name, len(items), ds_dir))
os.makedirs(ds_dir, exist_ok=True)
ini = self._blank_ini()
ini.add_section("data-set")
- ini.set("data-set", "name", ds.name)
+ ini.set("data-set", "name", ds_name)
self._store_ini(ini, ds_dir / "data-set.ini")
for o in items:
@@ -327,6 +345,30 @@ class DataSetManager(object):
ini.set("values", k, str(v))
self._store_ini(ini, ot_dir / "{}.ini".format(key))
+ # noinspection PyMethodMayBeStatic
+ def store_csv(self, path: Union[str, Path], ds: DataSet, object_type: str,
+ order_by: Union[str, Iterable[str]] = None):
+ items = [o for o in ds.items() if o.object_type.name == object_type]
+
+ if order_by:
+ if isinstance(order_by, str):
+ items = sorted(items, key=lambda o: o.get_req(order_by))
+ elif isinstance(order_by, Iterable):
+ items = sorted(items, key=lambda o: [o.get_req(ob) for ob in order_by])
+ else:
+ raise Exception("Unsupported order_by")
+
+ with open(path, "w") as f:
+ w = csv.writer(f)
+
+ if len(items):
+ header = ds.get_object_type(object_type).fields
+ w.writerow(header)
+
+ for o in items:
+ row = [o.get(k) for k in header]
+ w.writerow(row)
+
@staticmethod
def _blank_ini():
return configparser.ConfigParser(interpolation=None)
@@ -350,26 +392,6 @@ class DataSetManager(object):
shutil.rmtree(self._basedir / name)
-class LazyRoDataSet(object):
- def __init__(self, dsm: DataSetManager, inputs):
- self._dsm = dsm
- self._inputs = inputs
-
- def __enter__(self) -> DataSet:
- # logger.info("enter: name={}, inputs={}".format(self._name, self._inputs))
- ds = DataSet()
- for name in self._inputs:
- ds = ds.merge(self._dsm.load(name, freeze=True))
-
- ds.freeze()
-
- self._ds = ds
- return ds
-
- def __exit__(self, *args):
- return False
-
-
class LazyRwDataSet(object):
def __init__(self, dsm: DataSetManager, name, clean):
self._dsm = dsm
@@ -384,10 +406,10 @@ class LazyRwDataSet(object):
raise IOError("DataSet already exists: {}, cookie={}".format(self._name, cookie))
self._dsm.remove(self._name)
- ds = DataSet(self._name)
+ ds = DataSet()
self._ds = ds
return ds
def __exit__(self, *args):
- self._dsm.store(self._ds)
+ self._dsm.store(self._ds, self._name)
return False