aboutsummaryrefslogtreecommitdiff
path: root/src/ee/fact/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/ee/fact/__init__.py')
-rw-r--r--src/ee/fact/__init__.py114
1 files changed, 82 insertions, 32 deletions
diff --git a/src/ee/fact/__init__.py b/src/ee/fact/__init__.py
index 1bdef0d..02e6b2a 100644
--- a/src/ee/fact/__init__.py
+++ b/src/ee/fact/__init__.py
@@ -3,7 +3,7 @@ import logging
import os
from functools import total_ordering
from pathlib import Path
-from typing import MutableMapping, Optional, Mapping, List
+from typing import MutableMapping, Optional, List, Tuple
logger = logging.getLogger(__name__)
@@ -81,10 +81,9 @@ class Object(object):
class DataSet(object):
def __init__(self, name):
self._name = name
- self._object_types = {}
- self._objects_by_type = {} # type: MutableMapping[str, Mapping[str, Object]]
+ self._object_types = {} # type: MutableMapping[str, ObjectType]
+ self._objects_by_type = {} # type: MutableMapping[ObjectType, MutableMapping[str, Object]]
self._frozen = False
- self._changed = False
@property
def name(self):
@@ -93,38 +92,82 @@ class DataSet(object):
def freeze(self):
self._frozen = True
- def get_object_type(self, object_type: str) -> ObjectType:
+ def _assert_not_frozen(self):
+ if self._frozen:
+ raise Exception("This data set is frozen")
+
+ def _check_object_type(self, object_type: str, create: bool) -> \
+ Optional[Tuple[ObjectType, MutableMapping[str, Object]]]:
try:
- return self._object_types[object_type]
+ ot = self._object_types[object_type]
+ objects = self._objects_by_type[ot]
+ return ot, objects,
except KeyError:
+ if not create:
+ return None
+
+ self._assert_not_frozen()
+
ot = ObjectType(object_type)
self._object_types[object_type] = ot
- self._changed = True
- return ot
+ self._objects_by_type[ot] = objects = {}
+ return ot, objects,
- def get_object(self, object_type: str, key: str) -> Object:
- try:
- objects = self._objects_by_type[object_type]
- except KeyError:
- if self._frozen:
- raise Exception("This data set is frozen")
+ def _check_object(self, object_type: str, key: str, create: bool) -> Optional[Object]:
+ t = self._check_object_type(object_type, create)
- objects = {}
- self._objects_by_type[object_type] = objects
- self._changed = True
+ if not t:
+ return None
+ ot, objects = t
try:
return objects[key]
except KeyError:
- if self._frozen:
- raise Exception("This data set is frozen")
+ self._assert_not_frozen()
+
+ if not create:
+ raise Exception("No such type: {}:{}".format(object_type, key))
- ot = self.get_object_type(object_type)
o = Object(self, ot, key)
objects[key] = o
- self._changed = True
return o
+ def get_object_type(self, object_type: str) -> ObjectType:
+ ot, objects = self._check_object_type(object_type, False)
+
+ if not ot:
+ raise Exception("No such object type: {}".format(object_type))
+
+ return ot
+
+ def get_object(self, object_type: str, key: str) -> Object:
+ o = self._check_object(object_type, key, False)
+
+ if not o:
+ raise Exception("No such object: {}:{}".format(object_type, key))
+
+ return o
+
+ def has_object(self, object_type: str, key: str) -> bool:
+ t = self._check_object_type(object_type, False)
+
+ if t:
+ ot, objects = t
+ return key in objects
+
+ return False
+
+ def get_or_create_object(self, object_type: str, key: str) -> Object:
+ return self._check_object(object_type, key, True)
+
+ def create_object(self, object_type: str, key: str) -> Object:
+ self._assert_not_frozen()
+
+ if self.has_object(object_type, key):
+ raise Exception("Object already exist: {}:{}".format(object_type, key))
+
+ return self._check_object(object_type, key, True)
+
def items(self):
from itertools import chain
return list(chain.from_iterable([objects.values() for objects in self._objects_by_type.values()]))
@@ -133,11 +176,11 @@ class DataSet(object):
ds = DataSet(self._name)
for objects in self._objects_by_type.values():
for o in objects.values():
- ds.get_object(o.object_type.name, o.key)._set_from_object(o)
+ ds.create_object(o.object_type.name, o.key)._set_from_object(o)
for objects in other._objects_by_type.values():
for o in objects.values():
- ds.get_object(o.object_type.name, o.key)._set_from_object(o)
+ ds.get_or_create_object(o.object_type.name, o.key)._set_from_object(o)
return ds
@@ -149,8 +192,11 @@ class DataSetManager(object):
def metafile_for_ds(self, ds_name) -> Path:
return self._basedir / ds_name / "data-set.ini"
- def create_rw(self, name, inputs: List[str]) -> "LazyDataSet":
- return LazyDataSet(self, name, inputs)
+ def create_rw(self, name, inputs: List[str] = None) -> "LazyDataSet":
+ return LazyDataSet(self, False, name, inputs if inputs else [])
+
+ def create_ro(self, inputs: List[str]) -> "LazyDataSet":
+ return LazyDataSet(self, True, None, inputs)
def load(self, name, freeze=False) -> DataSet:
ds_dir = Path(name) if Path(name).is_absolute() else self._basedir / name
@@ -175,7 +221,7 @@ class DataSetManager(object):
key = o_path.name[:-4]
logger.info(" Loading key '{}'".format(key))
ini = self._load_ini(o_path)
- o = ds.get_object(ot, key)
+ o = ds.create_object(ot, key)
for k, v in ini.items("values"):
o.set(k, v)
@@ -229,21 +275,25 @@ class DataSetManager(object):
class LazyDataSet(object):
- def __init__(self, dsm: DataSetManager, name, inputs):
+ def __init__(self, dsm: DataSetManager, freeze: bool, name, inputs):
self._dsm = dsm
+ self._freeze = freeze
self._name = name
self._inputs = inputs
- def __enter__(self):
+ def __enter__(self) -> DataSet:
# logger.info("enter: name={}, inputs={}".format(self._name, self._inputs))
ds = DataSet(self._name)
for name in self._inputs:
ds = ds.merge(self._dsm.load(name, freeze=True))
+
+ if self._freeze:
+ ds.freeze()
+
self._ds = ds
- return self._ds
+ return ds
def __exit__(self, *args):
- # logger.info("exit: name={}, inputs={}".format(self._name, self._inputs))
- # logger.info("ds.size={}".format(len(self._ds.items())))
- self._dsm.store(self._ds)
+ if not self._freeze:
+ self._dsm.store(self._ds)
return False