from pathlib import Path from typing import List, MutableMapping, Optional, Iterator, Union from ee import EeException from ee.money import Money from ee.xml import types __all__ = [ "Part", "PartDb", "load_db", "save_db", ] class Part(object): def __init__(self, xml: types.Part): assert type(xml) == types.Part self.xml = xml xml.referencesProp = xml.referencesProp if xml.referencesProp is not None else types.ReferencesList() xml.price_breaksProp = xml.price_breaksProp if xml.price_breaksProp is not None else types.PriceBreakList() xml.linksProp = xml.linksProp if xml.linksProp is not None else types.LinkList() xml.factsProp = xml.factsProp if xml.factsProp is not None else types.FactList() def clean_xml(self): x = self.xml if len(x.referencesProp.part_referenceProp) == 0 and \ len(x.referencesProp.schematic_referenceProp) == 0 and \ len(x.referencesProp.part_numberProp) == 0 and \ len(x.referencesProp.supplier_part_numberProp) == 0: x.referencesProp = None if len(x.price_breaksProp.price_break) == 0: x.price_breaksProp = None if len(x.linksProp.link) == 0: x.linksProp = None if len(x.factsProp.fact) == 0: x.factsProp = None @property def underlying(self) -> types.Part: return self.xml @property def uri(self) -> str: return self.xml.uriProp @property def supplier(self) -> str: return self.xml.supplierProp # Part number ref def add_part_reference(self, uri): self.get_part_references().append(types.PartReference(part_uri=uri)) def get_part_references(self) -> List[types.PartReference]: return self.xml.referencesProp.part_referenceProp def get_exactly_one_part_reference(self) -> types.PartReference: refs = self.get_part_references() if len(refs) == 0: raise EeException("This part does not contain any part references{}". format(", uri=" + self.uri if self.uri else "")) if len(refs) != 1: raise EeException("This part does not contain exactly one part reference: {}". format(", ".join([ref.part_uriProp for ref in refs]))) return refs[0] # Schematic references def add_schematic_reference(self, ref): self.get_schematic_references().append(types.SchematicReference(reference=ref)) def get_schematic_references(self) -> List[types.SchematicReference]: return self.xml.referencesProp.schematic_referenceProp def get_only_schematic_reference(self) -> Optional[types.SchematicReference]: return next(iter(self.get_schematic_references()), None) def get_exactly_one_schematic_reference(self) -> types.SchematicReference: refs = self.get_schematic_references() if len(refs) == 0: raise EeException("This part does not contain any schematic references{}". format(", uri=" + self.uri if self.uri else "")) if len(refs) != 1: raise EeException("This part does not contain exactly one schematic reference: {}". format(", ".join([ref.referenceProp for ref in refs]))) return refs[0] # MPNs def add_mpn(self, mpn: str): self.get_mpns().append(types.PartNumber(value=mpn)) def get_mpns(self) -> List[types.PartNumber]: return self.xml.referencesProp.part_numberProp def get_only_mpn(self) -> Optional[types.PartNumber]: return next(iter(self.get_mpns()), None) def get_exactly_one_mpn(self) -> types.PartNumber: mpns = self.get_mpns() if len(mpns) == 0: raise EeException("This part does not contain any manufacturer part numbers{}". format(", uri=" + self.uri if self.uri else "")) if len(mpns) != 1: raise EeException("This part does not contain exactly one mpn: {}". format(", ".join([mpn.valueProp for mpn in mpns]))) return mpns[0] # SPNs def add_spn(self, mpn: str): self.get_spns().append(types.SupplierPartNumber(value=mpn)) def get_spns(self) -> List[types.SupplierPartNumber]: return self.xml.referencesProp.supplier_part_numberProp def get_only_spn(self) -> Optional[types.SupplierPartNumber]: return next(iter(self.get_spns()), None) def get_exactly_one_spn(self) -> types.SupplierPartNumber: spns = self.get_spns() if len(spns) == 0: raise EeException("This part does not contain any manufacturer part numbers{}". format(", uri=" + self.uri if self.uri else "")) if len(spns) != 1: raise EeException("This part does not contain exactly one spn: {}". format(", ".join([spn.valueProp for spn in spns]))) return spns[0] # Price breaks def add_price_break(self, quantity, price: Money): amount = types.Amount(value=price.amount, currency=price.currency) pb = types.PriceBreak(quantity=quantity, amount=amount) self.xml.price_breaksProp.price_break.append(pb) # Links def get_links(self) -> List[types.Link]: return self.xml.linksProp.link # Facts def get_facts(self) -> List[types.Fact]: return self.xml.factsProp.fact def find_fact(self, key: str) -> Optional[types.Fact]: return next((f for f in self.get_facts() if f.keyProp == key), None) class Entry(object): def __init__(self, new: bool, part: types.Part): self.new = new self.part = part self.pn = next((p.valueProp for p in Part(part).get_mpns()), None) class PartDb(object): def __init__(self): self.parts: List[Entry] = [] self.pn_index: MutableMapping[str, Entry] = {} self.new_entries = 0 def add_entry(self, part: Union[Part, types.Part], new: bool): if isinstance(part, Part): part = part.underlying e = Entry(new, part) self.parts.append(e) if e.pn: self.pn_index[e.pn] = e if e.new: self.new_entries = self.new_entries + 1 def iterparts(self, sort=False) -> Iterator[types.Part]: it = (e.part for e in self.parts) return sorted(it, key=lambda p: p.uriProp) if sort else it def size(self) -> int: return len(self.parts) def find_by_pn(self, pn: str) -> Optional[types.Part]: entry = self.pn_index.get(pn, None) return entry.part if entry else None def load_db(path: Path) -> PartDb: db = PartDb() with path.open("r") as f: part_db: types.PartDb = types.parse(f, silence=True) part_db.partsProp = part_db.partsProp or types.PartList() for p in part_db.partsProp.part: db.add_entry(p, False) return db def find_root_tag(root): return next((tag for tag, klass in types.GDSClassesMapping.items() if klass == type(root)), None) def save_db(path: Path, db: PartDb, sort=False): part_db = types.PartDb() parts = part_db.parts = types.PartList() for part in db.iterparts(sort=sort): p = Part(part) p.clean_xml() parts.partProp.append(p.underlying) with path.open("w") as f: part_db.export(outfile=f, level=0, name_=find_root_tag(part_db))