diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test_db.py | 82 |
1 files changed, 82 insertions, 0 deletions
diff --git a/test/test_db.py b/test/test_db.py new file mode 100644 index 0000000..7fc3cae --- /dev/null +++ b/test/test_db.py @@ -0,0 +1,82 @@ +from typing import Optional + +import pytest + +from ee.db import * + + +class MyPartNumber(object): + def __init__(self, number, supplier=None): + self.supplier = supplier + self.number = number + + def __repr__(self): + if self.supplier: + return "(supplier={}, number={})".format(self.supplier, self.number) + else: + return "(number={})".format(self.number) + + +class MyPart(object): + def __init__(self, mpn: MyPartNumber, spn: Optional[MyPartNumber] = None): + self.mpn = mpn + self.spn = spn + + def __repr__(self): + mpn = self.mpn if self.mpn else "" + spn = self.spn if self.spn else "" + return "(Part: mpn={}, spn={})".format(mpn, spn) + + +suppliers = [("supplier-a", "SA-"), ("supplier-b", "SB-"), ("supplier-c", None)] +mpns = ["RES" + value for value in ["100", "101", "103"]] + + +def test_basic(): + db = ObjDb() + build_db(db) + + assert len(db) == 11 + + +def build_db(db: ObjDb): + db.add(MyPart(MyPartNumber("RES100"))) + db.add(MyPart(MyPartNumber("RES101"))) + db.add(MyPart(MyPartNumber("RES102"))) + db.add(MyPart(MyPartNumber("RES103"))) + db.add(MyPart(MyPartNumber("RES104"))) + + db.add(MyPart(MyPartNumber("RES100"), MyPartNumber("SA-RES100", "supplier-a"))) + db.add(MyPart(MyPartNumber("RES101"), MyPartNumber("SA-RES101", "supplier-a"))) + db.add(MyPart(MyPartNumber("RES102"), MyPartNumber("SA-RES102", "supplier-a"))) + db.add(MyPart(MyPartNumber("RES103"), MyPartNumber("SA-RES103", "supplier-a"))) + + db.add(MyPart(MyPartNumber("RES101"), MyPartNumber("SB-RES101", "supplier-b"))) + db.add(MyPart(MyPartNumber("RES102"), MyPartNumber("SB-RES102", "supplier-b"))) + + +def test_index(): + db = ObjDb() + mpn_idx = db.add_index("mpn", lambda part: part.mpn.number if part.mpn else None) + spn_idx = db.add_unique_index("spn", lambda part: part.spn.number if part.spn else None) + supplier_pn = db.add_multi_index("supplier_pn", + lambda part: (part.spn.supplier, part.spn.number) if part.spn else None) + build_db(db) + + assert len(mpn_idx.values()) == 5 + tmp = mpn_idx.get("RES100") + assert len(mpn_idx.get("RES100")) == 2 + + # for supplier, spn in spn_idx.items(): + # print("SPN: {}={}".format(supplier, spn)) + assert len(spn_idx.items()) == (len(suppliers) - 1) * len(mpns) + + assert len((supplier_pn.items())) == 2 + assert "supplier-a" in supplier_pn + assert "SA-RES103" in supplier_pn.get("supplier-a") + assert "supplier-b" in supplier_pn + + +@pytest.mark.skipif('False') +def test_skip(): + pass |