From 9e353ef7af8164d33544a890e60d2049de7d61cb Mon Sep 17 00:00:00 2001
From: Trygve Laugstøl <trygvis@inamo.no>
Date: Tue, 5 Jan 2021 19:48:04 +0100
Subject: Better sorting.

---
 src/main/java/io/trygvis/rules/acme/AcmeIo.java   | 159 ++++++++++++++--------
 src/main/resources/io/trygvis/rules/acme/acme.drl |   3 +
 src/main/resources/io/trygvis/rules/acme/vpn.drl  |   1 +
 3 files changed, 108 insertions(+), 55 deletions(-)

(limited to 'src/main')

diff --git a/src/main/java/io/trygvis/rules/acme/AcmeIo.java b/src/main/java/io/trygvis/rules/acme/AcmeIo.java
index 498a4a6..456195d 100644
--- a/src/main/java/io/trygvis/rules/acme/AcmeIo.java
+++ b/src/main/java/io/trygvis/rules/acme/AcmeIo.java
@@ -4,6 +4,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
 import com.fasterxml.jackson.databind.type.TypeFactory;
 import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
 import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator;
+import org.apache.commons.collections4.OrderedMap;
 import org.drools.core.common.DefaultFactHandle;
 import org.kie.api.KieBase;
 import org.kie.api.runtime.rule.FactHandle;
@@ -11,12 +12,7 @@ import org.kie.api.runtime.rule.FactHandle;
 import java.io.File;
 import java.io.FileWriter;
 import java.io.IOException;
-import java.lang.reflect.InvocationTargetException;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Comparator;
-import java.util.List;
-import java.util.TreeMap;
+import java.util.*;
 import java.util.function.Function;
 
 @SuppressWarnings("unchecked")
@@ -68,72 +64,123 @@ public class AcmeIo {
         }
 
         public void sort() {
-            var comparator = comparable(type, "key");
-
-            if (comparator == null) {
-                comparator = comparable(type, "name");
-            }
-
-            if (comparator == null) {
-                comparator = comparable(type, "fqdn");
-            }
-
-            if (comparator == null) {
-                comparator = Comparator.comparingInt(System::identityHashCode);
-            }
+            var comparator = comparable(type);
 
             this.values.sort(comparator);
         }
     }
 
-    private static <A, T extends Comparable<T>> Comparator comparable(Class<A> klass, String name) {
+    private static final Map<Class<?>, Comparator> comparators = new HashMap<>();
 
-        try {
-            var method = klass.getMethod("get" + name.substring(0, 1).toUpperCase() + name.substring(1));
-            if (!method.isAccessible()) {
-                if (!method.trySetAccessible())
-                    return null;
-            }
+    private static <A, T extends Comparable<T>> Comparator comparable(Class<A> klass) {
+        var comparator = comparators.get(klass);
+        if (comparator != null) {
+            return comparator;
+        }
 
-            return (a, b) -> {
-                try {
-                    var x = (T) method.invoke(a);
-                    var y = (T) method.invoke(b);
+        // TODO: check if klass is a Comparable directly.
 
-                    if (x == null && y == null) {
-                        return 0;
-                    }
+        var prioritizedKeys = List.of("key", "name", "fqdn");
 
-                    if (x == null) {
-                        return -1;
-                    } else if (y == null) {
-                        return 1;
-                    }
+        var discoveredFieldsP1 = new LinkedHashMap<String, Function<Object, Object>>();
+        var discoveredFieldsP2 = new LinkedHashMap<String, Function<Object, Object>>();
 
-                    return x.compareTo(y);
-                } catch (IllegalAccessException | InvocationTargetException e) {
-                    throw new RuntimeException(e);
-                }
-            };
-        } catch (NoSuchMethodException ignored) {
-        }
+        var prioritizedTypes = List.of(String.class, int.class, Number.class);
 
-        try {
-            var field = klass.getField(name);
+        for (var f : klass.getDeclaredFields()) {
+            if (f.getDeclaringClass() == Object.class) {
+                continue;
+            }
 
-            return (a, b) -> {
+            if (!f.trySetAccessible()) {
+                continue;
+            }
+
+            var collection = discoveredFieldsP2;
+
+            if (prioritizedTypes.contains(f.getType())) {
+                collection = discoveredFieldsP1;
+            }
+
+            collection.put(f.getName(), (Object o) -> {
                 try {
-                    var x = (T) field.get(a);
-                    var y = (T) field.get(b);
-                    return x.compareTo(y);
+                    return f.get(o);
                 } catch (IllegalAccessException e) {
                     throw new RuntimeException(e);
                 }
-            };
-        } catch (NoSuchFieldException ignored) {
+            });
+        }
+
+//        for (var m : klass.getFields()) {
+//            if (m.getParameterCount() != 0) {
+//                continue;
+//            }
+//
+//            var name = m.getName();
+//
+//            if (name.startsWith("get") && name.length() > 3 && Character.isUpperCase(name.charAt(4))) {
+//                name = name.substring(3, 3).toLowerCase() + name.substring(4);
+//            } else {
+//                continue;
+//            }
+//
+//            if (!m.isAccessible()) {
+//                if (!m.trySetAccessible())
+//                    return null;
+//            }
+//
+//            discoveredFields.put(name, m);
+//        }
+
+//        System.out.printf("Sorting %s by:%n", klass.getName());
+
+        var discoveredFields = new LinkedHashMap<>(discoveredFieldsP1);
+        discoveredFields.putAll(discoveredFieldsP2);
+
+        List<Function<Object, Object>> accessors = new ArrayList<>();
+        for (String prioritizedKey : prioritizedKeys) {
+            var m = discoveredFields.remove(prioritizedKey);
+            if (m == null) {
+                continue;
+            }
+
+            accessors.add(m);
+//            System.out.println("  + " + prioritizedKey);
         }
+        accessors.addAll(discoveredFields.values());
+//        discoveredFields.keySet().forEach((s)-> System.out.println("  - " + s));
+
+        comparator = (a, b) -> {
+//            if (klass.getName().contains("AcmeServer")) {
+//                System.out.println("AcmeIo.comparable");
+//            }
+
+            for (var method : accessors) {
+                var x = method.apply(a);
+                var y = method.apply(b);
+
+                if (x == null && y == null) {
+                    continue;
+                }
 
-        return null;
+                if (x == null) {
+                    return -1;
+                } else if (y == null) {
+                    return 1;
+                } else {
+                    var res = x.toString().compareTo(y.toString());
+                    if (res != 0) {
+                        return res;
+                    }
+                }
+            }
+
+            return 0;
+        };
+
+        comparators.put(klass, comparator);
+
+        return comparator;
     }
 
     public void dump(String s, Collection<FactHandle> factHandles, Function<Object, Boolean> filter) throws IOException {
@@ -183,7 +230,9 @@ public class AcmeIo {
     private static class AcmeClassLoader extends ClassLoader {
         private final KieBase kieBase;
 
-        public AcmeClassLoader(KieBase kieBase) {this.kieBase = kieBase;}
+        public AcmeClassLoader(KieBase kieBase) {
+            this.kieBase = kieBase;
+        }
 
         @Override
         public Class<?> loadClass(String name) throws ClassNotFoundException {
diff --git a/src/main/resources/io/trygvis/rules/acme/acme.drl b/src/main/resources/io/trygvis/rules/acme/acme.drl
index 72d296c..6369c24 100644
--- a/src/main/resources/io/trygvis/rules/acme/acme.drl
+++ b/src/main/resources/io/trygvis/rules/acme/acme.drl
@@ -5,6 +5,7 @@ import io.trygvis.rules.dba.Cluster;
 import io.trygvis.rules.dba.Container;
 
 declare AcmeServer
+    name    : String
     machine : Machine
 end
 
@@ -42,8 +43,10 @@ end
 rule "Create Acme servers"
 when
     $m : Machine(name.startsWith("acme-"))
+    not(AcmeServer(name == $m.name))
 then
     var s = new AcmeServer();
+    s.name = $m.name;
     s.machine = $m;
     insert(s)
 end
diff --git a/src/main/resources/io/trygvis/rules/acme/vpn.drl b/src/main/resources/io/trygvis/rules/acme/vpn.drl
index 082ecc0..90cdce2 100644
--- a/src/main/resources/io/trygvis/rules/acme/vpn.drl
+++ b/src/main/resources/io/trygvis/rules/acme/vpn.drl
@@ -48,6 +48,7 @@ rule "Make DNS entries for all VPN hosts"
 when
     $h : WgHost()
     $net : WgNet(name == $h.net)
+    not(DnsEntry(fqdn == "%s.%s".formatted($h.name, $net.domain), type == "A"))
 then
     var fqdn = "%s.%s".formatted($h.name, $net.domain);
     insert(DnsEntry.a(fqdn))
-- 
cgit v1.2.3