import base64
import urllib.parse
import zlib
from pathlib import Path
from typing import Mapping, MutableMapping, List, Optional
from xml.dom import minidom

from lxml import etree

from ee.db import ObjDb
from ee.part import PartDb, save_db, load_db, Part, AssemblyPart


def decompress(input_stream, output_stream):
    doc = minidom.parse(input_stream)

    diagram_elem = doc.getElementsByTagName("diagram")[0]

    decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
    bs = base64.b64decode(diagram_elem.firstChild.wholeText)
    inflated = decompressor.decompress(bs)
    inflated += decompressor.flush()
    xml = urllib.parse.unquote(bytes.decode(inflated, "ascii"))
    doc = minidom.parseString(xml)
    pretty_xml = doc.toprettyxml(indent="  ")
    print(pretty_xml, file=output_stream)


class GraphObject(object):
    def __init__(self, id_: str, attrs, value):
        self.id = id_
        self.attrs = attrs
        self.value = value
        self.incoming: MutableMapping[str, "GraphEdge"] = {}
        self.outgoing: MutableMapping[str, "GraphEdge"] = {}

    def add_attr(self, key, value):
        self.attrs[key] = value

    def add_incoming(self, edge: "GraphEdge"):
        self.incoming[edge.id] = edge

    def add_outgoing(self, edge: "GraphEdge"):
        self.outgoing[edge.id] = edge


class GraphEdge(object):
    def __init__(self, id_, attrs, source_id, target_id, value):
        self.id = id_
        self.attrs = attrs
        self.source_id = source_id
        self.target_id = target_id
        self.value = value

        self.source: Optional[GraphObject] = None
        self.target: Optional[GraphObject] = None

    def add_attr(self, key, value):
        self.attrs[key] = value


class GraphModel(object):
    def __init__(self, objects: Mapping[str, GraphObject], edges: Mapping[str, GraphEdge]):
        self.objects = objects
        self.edges = edges

    @property
    def roots(self):
        return [o for o in self.objects.values() if len(o.incoming) == 0]

    @staticmethod
    def create(objects: Mapping[str, GraphObject], edges: Mapping[str, GraphEdge]) -> "GraphModel":
        for id_, edge in edges.items():
            source = objects[edge.source_id]
            target = objects[edge.target_id]

            source.add_outgoing(edge)
            target.add_incoming(edge)

            edge.source = source
            edge.target = target

        return GraphModel(objects, edges)


def load_graph(doc) -> GraphModel:
    def parse_text(s: str):
        if s is None:
            return None
        return s.replace("<br>", "\n").strip()

    root = doc.getroot()

    objects = {}
    edges = {}
    root = root.find("root")
    for child in root:
        try:
            a = child.attrib
            id_ = child.attrib["id"]

            value = None
            attrs = {}
            if child.tag == "mxCell":
                value = a.get("value")
                bad_keys = ("id", "edge", "parent", "vertex", "style", "source", "target", "value")
                cell = child
            elif child.tag == "object":
                value = a.get("label")
                cell = child.find("mxCell")
                a = cell.attrib
                bad_keys = ("id", "label", "placeholders")
            else:
                raise KeyError("Unknown tag: {}".format(child.tag))

            attrs = {key: value for key, value in child.attrib.items() if key not in bad_keys}

            value = parse_text(value)

            vertex = "vertex" in a and a["vertex"] == "1"
            edge = "edge" in a and a["edge"] == "1"

            if (edge and vertex) or (not edge and not vertex):
                continue  # don't know what these really are
                # raise ValueError("node is neither edge nor vertex: {}".format(a["id"]))

            if vertex:
                objects[id_] = GraphObject(id_, attrs, value)
            else:
                source_id = a["source"]
                target_id = a["target"]
                edges[id_] = GraphEdge(id_, attrs, source_id, target_id, value)

        except KeyError as e:
            id_ = child.attrib["id"] if "id" in child.attrib else "unknown"
            raise KeyError("Error while processing {}: {}, id={}".format(child, str(e), id_))

    return GraphModel.create(objects, edges)


def to_parts(in_path: Path, out_path: Path, part_dbs: List[Path]):
    parts: ObjDb[Part] = ObjDb[Part]()
    description_idx = parts.add_index("description", lambda p: p.underlying.descriptionProp)

    def find_part(o: GraphObject):
        if "part" in o.attrs:
            d = o.attrs["part"]
        else:
            d = o.value

        hits = description_idx.get(d)
        if len(hits) == 0:
            # print("No part with description found in database: '{}'".format(d))
            return
        elif len(hits) == 1:
            found = hits[0]
            return found
        else:
            # print("Found multiple parts with description '{}'".format(d))
            return

    def add_part(o: GraphObject):
        p = find_part(o)

        uri = p.uri if p else None
        ap = AssemblyPart(uri)

        if uri:
            ap.references.add_part_reference(uri)

        ap.references.add_description_reference(o.value)

        for out in o.outgoing.values():
            sub_part = add_part(out.target)
            ap.add_sub_part(sub_part)

        return ap

    for part_db in part_dbs:
        for xml in load_db(part_db).iterparts():
            parts.add(Part(xml))

    doc = etree.parse(str(in_path))
    graph = load_graph(doc)

    db = PartDb()
    a = db.assembly

    for root in graph.roots:
        a.parts.append(add_part(root))

    save_db(out_path, db)


def to_dot(in_path: Path, out_path: Path):
    def to_id(s: str):
        return s.replace("-", "_")

    def quote(s: str):
        return s.replace("\"", "\\\"")

    doc = etree.parse(str(in_path))
    graph = load_graph(doc)

    with open(str(out_path), "w") as f:
        print("digraph parts {", file=f)
        for id_, obj in graph.objects.items():
            if len(obj.attrs):
                attr_str = "\\n".join(["{}={}".format(k, quote(v)) for k, v in obj.attrs.items()])
                print("  {}_attrs [shape=plaintext, label=\"{}\"]".format(to_id(obj.id), quote(attr_str)), file=f)
                print("  {}_attrs -> {} [arrowhead=none,style=dotted]".format(to_id(obj.id), to_id(obj.id)), file=f)

            attrs = {}
            if obj.value:
                attrs["label"] = obj.value

            attr_str = ",".join(["{}=\"{}\"".format(k, quote(v)) for k, v in attrs.items()])
            print("  {} [{}];".format(to_id(obj.id), attr_str), file=f)

        for id_, edge in graph.edges.items():
            source_id = edge.source.id
            target_id = edge.target.id

            if len(edge.attrs):
                print(" // source={}, target={}".format(source_id, target_id), file=f)
                attr_str = "\\n".join(["{}={}".format(k, quote(v)) for k, v in edge.attrs.items()])

                print("  {}_fake [shape=plaintext, label=\"{}\"]".format(to_id(edge.target.id), attr_str), file=f)
                print("  {}_fake -> {}".format(to_id(edge.target.id), to_id(edge.target.id)), file=f)
                # source_id = "{}_fake".format(to_id(edge.id))
                target_id = "{}_fake".format(to_id(edge.target.id))
                arrowhead = "none"
            else:
                arrowhead = "normal"

            attrs = {}
            # if edge.value:
            #     attrs["label"] = edge.value

            # attr_str = ",".join(["{}=\"{}\"".format(k, quote(v)) for k, v in attrs.items()])

            print("  {} -> {}  [arrowhead={}];".format(to_id(source_id), to_id(target_id), arrowhead), file=f)
            for k, v in edge.attrs.items():
                print("  // {}={}".format(k, v))
        print("}")