summaryrefslogtreecommitdiff
path: root/src/ri-engine/src/main/java/io/trygvis/rules/engine/DbIo.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/ri-engine/src/main/java/io/trygvis/rules/engine/DbIo.java')
-rw-r--r--src/ri-engine/src/main/java/io/trygvis/rules/engine/DbIo.java365
1 files changed, 365 insertions, 0 deletions
diff --git a/src/ri-engine/src/main/java/io/trygvis/rules/engine/DbIo.java b/src/ri-engine/src/main/java/io/trygvis/rules/engine/DbIo.java
new file mode 100644
index 0000000..7dc24ad
--- /dev/null
+++ b/src/ri-engine/src/main/java/io/trygvis/rules/engine/DbIo.java
@@ -0,0 +1,365 @@
+package io.trygvis.rules.engine;
+
+import ch.qos.logback.core.util.FileUtil;
+import com.fasterxml.jackson.annotation.ObjectIdGenerators;
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.PropertyName;
+import com.fasterxml.jackson.databind.SerializationFeature;
+import com.fasterxml.jackson.databind.introspect.Annotated;
+import com.fasterxml.jackson.databind.introspect.JacksonAnnotationIntrospector;
+import com.fasterxml.jackson.databind.introspect.ObjectIdInfo;
+import com.fasterxml.jackson.databind.type.TypeFactory;
+import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
+import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator;
+import org.drools.core.common.DefaultFactHandle;
+import org.drools.core.factmodel.GeneratedFact;
+import org.kie.api.KieBase;
+import org.kie.api.runtime.KieContainer;
+import org.kie.api.runtime.rule.FactHandle;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
+import java.util.*;
+import java.util.function.Function;
+
+@SuppressWarnings("unchecked")
+public class DbIo {
+ private final Logger logger = LoggerFactory.getLogger(getClass());
+
+ private final ObjectMapper mapper;
+
+ private static final List<String> prioritizedKeys = List.of("key", "name", "fqdn");
+
+ public DbIo(KieContainer container, KieBase kieBase) {
+ var factory = new YAMLFactory();
+ factory.enable(YAMLGenerator.Feature.USE_NATIVE_TYPE_ID);
+ factory.enable(YAMLGenerator.Feature.USE_NATIVE_OBJECT_ID);
+ mapper = new ObjectMapper(factory);
+ mapper.disable(SerializationFeature.FAIL_ON_EMPTY_BEANS);
+ var typeFactory = TypeFactory.defaultInstance()
+ .withClassLoader(new DbClassLoader(container, kieBase));
+ mapper.setTypeFactory(typeFactory);
+ mapper.findAndRegisterModules();
+
+ mapper.setAnnotationIntrospector(new JacksonAnnotationIntrospector() {
+ @Override
+ public ObjectIdInfo findObjectIdInfo(Annotated a) {
+ final Class<?> klass = a.getRawType();
+ if (GeneratedFact.class.isAssignableFrom(klass)) {
+ System.out.println("klass = " + klass);
+
+ for (String name : prioritizedKeys) {
+ try {
+ final String getter = "get" + name.substring(0, 1).toUpperCase() + name.substring(1);
+ klass.getMethod(getter);
+ return new ObjectIdInfo(PropertyName.construct(name), null, ObjectIdGenerators.PropertyGenerator.class, null);
+ } catch (NoSuchMethodException ignore) {
+ }
+ }
+ System.out.println("a.getRawType() = " + klass);
+ return new ObjectIdInfo(null, null, ObjectIdGenerators.IntSequenceGenerator.class, null);
+ }
+
+ return super.findObjectIdInfo(a);
+ }
+ });
+ }
+
+ public List<Object> load(File file) throws IOException {
+ var parser = mapper.getFactory().createParser(file);
+
+ var objects = parser.<List<DbObject>>readValueAs(new TypeReference<List<DbObject>>() {});
+
+ var items = new ArrayList<>();
+ for (DbObject object : objects) {
+ try {
+ var type = mapper.getTypeFactory().findClass(object.type);
+ var x = mapper.treeToValue(object.data, type);
+ if (x == null) {
+ x = type.getDeclaredConstructor().newInstance();
+ }
+ items.add(x);
+ } catch (ClassNotFoundException | NoSuchMethodException | InstantiationException | IllegalAccessException | InvocationTargetException e) {
+ System.out.println("e.getClass() = " + e.getClass().getName());
+ System.out.println("e.getMessage() = " + e.getMessage());
+ // ignore
+ }
+ }
+
+ return items;
+ }
+
+ public void dump(File file, Collection<FactHandle> factHandles) throws IOException {
+ dump(file, factHandles, (o) -> true);
+ }
+
+ // This should just sort by all getters instead.
+ static class FactCollection<T> {
+ public final Class<T> type;
+ public final List<T> values;
+
+ public FactCollection(Class<T> type) {
+ this.type = type;
+ this.values = new ArrayList<>();
+ }
+
+ public void sort() {
+ var comparator = comparable(type);
+
+ this.values.sort(comparator);
+ }
+ }
+
+ private static final Map<Class<?>, Comparator> comparators = new HashMap<>();
+
+ private static <A, T extends Comparable<T>> Comparator comparable(Class<A> klass) {
+ var comparator = comparators.get(klass);
+ if (comparator != null) {
+ return comparator;
+ }
+
+ // TODO: check if klass is a Comparable directly.
+
+ var discoveredFieldsP1 = new LinkedHashMap<String, Function<Object, Object>>();
+ var discoveredFieldsP2 = new LinkedHashMap<String, Function<Object, Object>>();
+
+ var prioritizedTypes = List.of(String.class, int.class, Number.class);
+
+ for (var f : klass.getDeclaredFields()) {
+ if (f.getDeclaringClass() == Object.class) {
+ continue;
+ }
+
+ if (!f.trySetAccessible()) {
+ continue;
+ }
+
+ var collection = discoveredFieldsP2;
+
+ if (prioritizedTypes.contains(f.getType())) {
+ collection = discoveredFieldsP1;
+ }
+
+ collection.put(f.getName(), (Object o) -> {
+ try {
+ return f.get(o);
+ } catch (IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+// for (var m : klass.getFields()) {
+// if (m.getParameterCount() != 0) {
+// continue;
+// }
+//
+// var name = m.getName();
+//
+// if (name.startsWith("get") && name.length() > 3 && Character.isUpperCase(name.charAt(4))) {
+// name = name.substring(3, 3).toLowerCase() + name.substring(4);
+// } else {
+// continue;
+// }
+//
+// if (!m.isAccessible()) {
+// if (!m.trySetAccessible())
+// return null;
+// }
+//
+// discoveredFields.put(name, m);
+// }
+
+// System.out.printf("Sorting %s by:%n", klass.getName());
+
+ var discoveredFields = new LinkedHashMap<>(discoveredFieldsP1);
+ discoveredFields.putAll(discoveredFieldsP2);
+
+ List<Function<Object, Object>> accessors = new ArrayList<>();
+ for (String prioritizedKey : prioritizedKeys) {
+ var m = discoveredFields.remove(prioritizedKey);
+ if (m == null) {
+ continue;
+ }
+
+ accessors.add(m);
+// System.out.println(" + " + prioritizedKey);
+ }
+ accessors.addAll(discoveredFields.values());
+// discoveredFields.keySet().forEach((s)-> System.out.println(" - " + s));
+
+ comparator = (a, b) -> {
+// if (klass.getName().contains("AcmeServer")) {
+// System.out.println("AcmeIo.comparable");
+// }
+
+ for (var method : accessors) {
+ var x = method.apply(a);
+ var y = method.apply(b);
+
+ if (x == null && y == null) {
+ continue;
+ }
+
+ if (x == null) {
+ return -1;
+ } else if (y == null) {
+ return 1;
+ } else {
+ var res = x.toString().compareTo(y.toString());
+ if (res != 0) {
+ return res;
+ }
+ }
+ }
+
+ return 0;
+ };
+
+ comparators.put(klass, comparator);
+
+ return comparator;
+ }
+
+ static record DbObject2(String type, Object data) {
+ }
+
+ public void dump(File file, Collection<FactHandle> factHandles, Function<Object, Boolean> filter) throws IOException {
+ FileUtil.createMissingParentDirectories(file);
+
+ var facts = new TreeMap<Class<?>, FactCollection<Object>>(Comparator.comparing(Class::getName));
+ for (var handle : factHandles) {
+ if (handle instanceof DefaultFactHandle h) {
+ var obj = h.getObject();
+ if (!filter.apply(obj)) {
+ continue;
+ }
+
+ Class<?> type = obj.getClass();
+ var collection = facts.get(type);
+
+ if (collection == null) {
+ collection = new FactCollection(type);
+ facts.put(type, collection);
+ }
+
+ collection.values.add(obj);
+ }
+ }
+
+ var objects = new ArrayList<DbObject2>(facts.size());
+ for (var e : facts.entrySet()) {
+ var name = e.getKey().getName();
+
+ var collection = e.getValue();
+ collection.sort();
+ for (var fact : collection.values) {
+ objects.add(new DbObject2(name, fact));
+ }
+ }
+
+ objects.sort(new DbObjectComparator());
+
+ var factory = mapper.getFactory();
+ try (var writer = new FileWriter(file);
+ var g = factory.createGenerator(writer)) {
+ g.writeObject(objects);
+ }
+ }
+
+ private class DbClassLoader extends ClassLoader {
+ private final KieContainer container;
+ private final KieBase kieBase;
+
+ public DbClassLoader(KieContainer container, KieBase kieBase) {
+ this.container = container;
+ this.kieBase = kieBase;
+ }
+
+ @Override
+ public Class<?> loadClass(String name) throws ClassNotFoundException {
+ logger.info("Loading class {}", name);
+ try {
+ var klass = super.loadClass(name);
+ logger.info("Found class in super classloader");
+ return klass;
+ } catch (ClassNotFoundException e) {
+ var i = name.lastIndexOf('.');
+ String pkg, simpleName;
+ if (i == -1) {
+ pkg = null;
+ simpleName = name;
+ } else {
+ pkg = name.substring(0, i);
+ simpleName = name.substring(i + 1);
+ }
+
+ try {
+ var klass = container.getClassLoader().loadClass(name);
+ logger.info("Found class in container's classloader");
+ return klass;
+ } catch (ClassNotFoundException ignore) {
+ }
+
+ try {
+ logger.info("pkg = {}", pkg);
+ logger.info("simpleName = {}", simpleName);
+ var clazz = kieBase.getFactType(pkg, simpleName);
+ if (clazz != null) {
+ logger.info("Found class as a FactType");
+ return clazz.getFactClass();
+ }
+ } catch (UnsupportedOperationException ignore) {
+ }
+
+ logger.warn("Class not found: {}", name);
+
+ throw e;
+ }
+ }
+ }
+
+ private static class DbObjectComparator implements Comparator<DbObject2> {
+ private final List<String> prioritizedPackages = List.of(
+ "io.trygvis.rules.machine",
+ "io.trygvis.rules.network",
+ "io.trygvis.rules.dns",
+ "io.trygvis.rules.dba",
+ "io.trygvis.rules",
+ "io.trygvis.rules.core");
+
+ @Override
+ public int compare(DbObject2 a, DbObject2 b) {
+ var indexA = a.type.lastIndexOf(".");
+ String packageA = indexA == -1 ? null : a.type.substring(0, indexA);
+ String classA = indexA == -1 ? a.type : a.type.substring(indexA + 1);
+
+ var indexB = b.type.lastIndexOf(".");
+ String packageB = indexB == -1 ? null : b.type.substring(0, indexB);
+ String classB = indexB == -1 ? b.type : b.type.substring(indexB + 1);
+
+ var priIdxA = prioritizedPackages.indexOf(packageA);
+ var priIdxB = prioritizedPackages.indexOf(packageB);
+
+ if (priIdxA == -1 && priIdxB == -1) {
+ return classB.compareTo(classA);
+ } else if (priIdxA == -1) {
+ return 1;
+ } else if (priIdxB == -1) {
+ return -1;
+ }
+ return priIdxA - priIdxB;
+// var diff = priIdxB - priIdxA;
+// if (diff != 0) {
+// return diff;
+// }
+//
+// return classB.compareTo(classA);
+ }
+ }
+}