aboutsummaryrefslogtreecommitdiff
path: root/src/ee/fact/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/ee/fact/__init__.py')
-rw-r--r--src/ee/fact/__init__.py88
1 files changed, 76 insertions, 12 deletions
diff --git a/src/ee/fact/__init__.py b/src/ee/fact/__init__.py
index 02e6b2a..c8c21dd 100644
--- a/src/ee/fact/__init__.py
+++ b/src/ee/fact/__init__.py
@@ -1,9 +1,10 @@
import configparser
import logging
import os
+import csv
from functools import total_ordering
from pathlib import Path
-from typing import MutableMapping, Optional, List, Tuple
+from typing import MutableMapping, Optional, List, Tuple, Union, Iterator
logger = logging.getLogger(__name__)
@@ -37,12 +38,12 @@ class ObjectType(object):
def fields(self):
return self._fields
- def index_of(self, field: str, create: bool = False) -> int:
+ def index_of(self, field: str, create: bool = False) -> Optional[int]:
try:
return self._fields.index(field)
except ValueError as e:
if not create:
- raise e
+ return None
self._fields.append(field)
return len(self._fields) - 1
@@ -75,7 +76,16 @@ class Object(object):
def get(self, key: str) -> Optional[str]:
idx = self._ot.index_of(key)
- return self._data[idx] if idx < len(self._data) else None
+ return self._data[idx] if idx is not None and idx < len(self._data) else None
+
+ def get_all(self, *keys: str) -> Optional[List[str]]:
+ values = []
+ for key in keys:
+ idx = self._ot.index_of(key)
+ if not idx or idx >= len(self._data):
+ return None
+ values.append(self._data[idx])
+ return values
class DataSet(object):
@@ -89,6 +99,9 @@ class DataSet(object):
def name(self):
return self._name
+ def __len__(self):
+ return sum((len(objects) for objects in self._objects_by_type.values()))
+
def freeze(self):
self._frozen = True
@@ -168,9 +181,10 @@ class DataSet(object):
return self._check_object(object_type, key, True)
- def items(self):
- from itertools import chain
- return list(chain.from_iterable([objects.values() for objects in self._objects_by_type.values()]))
+ def items(self) -> Iterator[Object]:
+ for objects in self._objects_by_type.values():
+ for o in objects.values():
+ yield o
def merge(self, other: "DataSet") -> "DataSet":
ds = DataSet(self._name)
@@ -186,11 +200,15 @@ class DataSet(object):
class DataSetManager(object):
- def __init__(self, basedir: Path):
+ def __init__(self, basedir: Union[Path, str]):
self._basedir = Path(basedir)
+ self._csv = {} # type: MutableMapping[str, Tuple[str, Path]]
- def metafile_for_ds(self, ds_name) -> Path:
- return self._basedir / ds_name / "data-set.ini"
+ def cookie_for_ds(self, ds_name) -> Path:
+ try:
+ return self._csv[ds_name][1]
+ except KeyError:
+ return self._basedir / ds_name / "data-set.ini"
def create_rw(self, name, inputs: List[str] = None) -> "LazyDataSet":
return LazyDataSet(self, False, name, inputs if inputs else [])
@@ -198,7 +216,52 @@ class DataSetManager(object):
def create_ro(self, inputs: List[str]) -> "LazyDataSet":
return LazyDataSet(self, True, None, inputs)
+ def add_ds(self, ds_type: str, name: str, object_type: str, path: str = None):
+ if ds_type == "csv":
+ if name in self._csv:
+ raise Exception("Data source already exists: {}".format(name))
+
+ self._csv[name] = object_type, Path(path),
+ else:
+ raise Exception("Unknown data source type: {}".format(ds_type))
+
+ def ds_type(self, name: str):
+ return "csv" if name in self._csv else "ini-dir"
+
def load(self, name, freeze=False) -> DataSet:
+ try:
+ object_type, path = self._csv[name]
+
+ if not freeze:
+ raise Exception("CSV data sources must be frozen")
+
+ return DataSetManager._load_csv(name, object_type, path, freeze)
+ except KeyError:
+ return self._load_ini_dir(name, freeze)
+
+ @staticmethod
+ def _load_csv(name: str, object_type: str, path: Path, freeze: bool) -> DataSet:
+ logger.info("Loading CSV file {}".format(path))
+ ds = DataSet(name)
+
+ with open(str(path), newline='') as f:
+ r = csv.reader(f)
+
+ header = next(r, None)
+ for row in r:
+ if len(row) == 0:
+ continue
+
+ key = row[0]
+
+ o = ds.create_object(object_type, key)
+ for idx, value in zip(range(0, min(len(row), len(header))), row):
+ o.set(header[idx], value)
+
+ logger.debug("Loaded {} objects".format(len(ds)))
+ return ds
+
+ def _load_ini_dir(self, name: str, freeze: bool) -> DataSet:
ds_dir = Path(name) if Path(name).is_absolute() else self._basedir / name
ds_dir = ds_dir if ds_dir.is_dir() else ds_dir.parent
@@ -233,7 +296,8 @@ class DataSetManager(object):
def store(self, ds: DataSet):
ds_dir = self._basedir / ds.name
- logger.info("Storing DS '{}' with {} objects to {}".format(ds.name, len(ds.items()), ds_dir))
+ items = list(ds.items())
+ logger.info("Storing DS '{}' with {} objects to {}".format(ds.name, len(items), ds_dir))
os.makedirs(ds_dir, exist_ok=True)
ini = self._blank_ini()
@@ -241,7 +305,7 @@ class DataSetManager(object):
ini.set("data-set", "name", ds.name)
self._store_ini(ini, ds_dir / "data-set.ini")
- for o in ds.items():
+ for o in items:
ot = o.object_type
key = o.key