diff options
Diffstat (limited to 'src/ri-engine/src/main/java/io/trygvis/rules/engine/DbIo.java')
-rw-r--r-- | src/ri-engine/src/main/java/io/trygvis/rules/engine/DbIo.java | 365 |
1 files changed, 365 insertions, 0 deletions
diff --git a/src/ri-engine/src/main/java/io/trygvis/rules/engine/DbIo.java b/src/ri-engine/src/main/java/io/trygvis/rules/engine/DbIo.java new file mode 100644 index 0000000..7dc24ad --- /dev/null +++ b/src/ri-engine/src/main/java/io/trygvis/rules/engine/DbIo.java @@ -0,0 +1,365 @@ +package io.trygvis.rules.engine; + +import ch.qos.logback.core.util.FileUtil; +import com.fasterxml.jackson.annotation.ObjectIdGenerators; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyName; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.databind.introspect.Annotated; +import com.fasterxml.jackson.databind.introspect.JacksonAnnotationIntrospector; +import com.fasterxml.jackson.databind.introspect.ObjectIdInfo; +import com.fasterxml.jackson.databind.type.TypeFactory; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; +import org.drools.core.common.DefaultFactHandle; +import org.drools.core.factmodel.GeneratedFact; +import org.kie.api.KieBase; +import org.kie.api.runtime.KieContainer; +import org.kie.api.runtime.rule.FactHandle; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.util.*; +import java.util.function.Function; + +@SuppressWarnings("unchecked") +public class DbIo { + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private final ObjectMapper mapper; + + private static final List<String> prioritizedKeys = List.of("key", "name", "fqdn"); + + public DbIo(KieContainer container, KieBase kieBase) { + var factory = new YAMLFactory(); + factory.enable(YAMLGenerator.Feature.USE_NATIVE_TYPE_ID); + factory.enable(YAMLGenerator.Feature.USE_NATIVE_OBJECT_ID); + mapper = new ObjectMapper(factory); + mapper.disable(SerializationFeature.FAIL_ON_EMPTY_BEANS); + var typeFactory = TypeFactory.defaultInstance() + .withClassLoader(new DbClassLoader(container, kieBase)); + mapper.setTypeFactory(typeFactory); + mapper.findAndRegisterModules(); + + mapper.setAnnotationIntrospector(new JacksonAnnotationIntrospector() { + @Override + public ObjectIdInfo findObjectIdInfo(Annotated a) { + final Class<?> klass = a.getRawType(); + if (GeneratedFact.class.isAssignableFrom(klass)) { + System.out.println("klass = " + klass); + + for (String name : prioritizedKeys) { + try { + final String getter = "get" + name.substring(0, 1).toUpperCase() + name.substring(1); + klass.getMethod(getter); + return new ObjectIdInfo(PropertyName.construct(name), null, ObjectIdGenerators.PropertyGenerator.class, null); + } catch (NoSuchMethodException ignore) { + } + } + System.out.println("a.getRawType() = " + klass); + return new ObjectIdInfo(null, null, ObjectIdGenerators.IntSequenceGenerator.class, null); + } + + return super.findObjectIdInfo(a); + } + }); + } + + public List<Object> load(File file) throws IOException { + var parser = mapper.getFactory().createParser(file); + + var objects = parser.<List<DbObject>>readValueAs(new TypeReference<List<DbObject>>() {}); + + var items = new ArrayList<>(); + for (DbObject object : objects) { + try { + var type = mapper.getTypeFactory().findClass(object.type); + var x = mapper.treeToValue(object.data, type); + if (x == null) { + x = type.getDeclaredConstructor().newInstance(); + } + items.add(x); + } catch (ClassNotFoundException | NoSuchMethodException | InstantiationException | IllegalAccessException | InvocationTargetException e) { + System.out.println("e.getClass() = " + e.getClass().getName()); + System.out.println("e.getMessage() = " + e.getMessage()); + // ignore + } + } + + return items; + } + + public void dump(File file, Collection<FactHandle> factHandles) throws IOException { + dump(file, factHandles, (o) -> true); + } + + // This should just sort by all getters instead. + static class FactCollection<T> { + public final Class<T> type; + public final List<T> values; + + public FactCollection(Class<T> type) { + this.type = type; + this.values = new ArrayList<>(); + } + + public void sort() { + var comparator = comparable(type); + + this.values.sort(comparator); + } + } + + private static final Map<Class<?>, Comparator> comparators = new HashMap<>(); + + private static <A, T extends Comparable<T>> Comparator comparable(Class<A> klass) { + var comparator = comparators.get(klass); + if (comparator != null) { + return comparator; + } + + // TODO: check if klass is a Comparable directly. + + var discoveredFieldsP1 = new LinkedHashMap<String, Function<Object, Object>>(); + var discoveredFieldsP2 = new LinkedHashMap<String, Function<Object, Object>>(); + + var prioritizedTypes = List.of(String.class, int.class, Number.class); + + for (var f : klass.getDeclaredFields()) { + if (f.getDeclaringClass() == Object.class) { + continue; + } + + if (!f.trySetAccessible()) { + continue; + } + + var collection = discoveredFieldsP2; + + if (prioritizedTypes.contains(f.getType())) { + collection = discoveredFieldsP1; + } + + collection.put(f.getName(), (Object o) -> { + try { + return f.get(o); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + }); + } + +// for (var m : klass.getFields()) { +// if (m.getParameterCount() != 0) { +// continue; +// } +// +// var name = m.getName(); +// +// if (name.startsWith("get") && name.length() > 3 && Character.isUpperCase(name.charAt(4))) { +// name = name.substring(3, 3).toLowerCase() + name.substring(4); +// } else { +// continue; +// } +// +// if (!m.isAccessible()) { +// if (!m.trySetAccessible()) +// return null; +// } +// +// discoveredFields.put(name, m); +// } + +// System.out.printf("Sorting %s by:%n", klass.getName()); + + var discoveredFields = new LinkedHashMap<>(discoveredFieldsP1); + discoveredFields.putAll(discoveredFieldsP2); + + List<Function<Object, Object>> accessors = new ArrayList<>(); + for (String prioritizedKey : prioritizedKeys) { + var m = discoveredFields.remove(prioritizedKey); + if (m == null) { + continue; + } + + accessors.add(m); +// System.out.println(" + " + prioritizedKey); + } + accessors.addAll(discoveredFields.values()); +// discoveredFields.keySet().forEach((s)-> System.out.println(" - " + s)); + + comparator = (a, b) -> { +// if (klass.getName().contains("AcmeServer")) { +// System.out.println("AcmeIo.comparable"); +// } + + for (var method : accessors) { + var x = method.apply(a); + var y = method.apply(b); + + if (x == null && y == null) { + continue; + } + + if (x == null) { + return -1; + } else if (y == null) { + return 1; + } else { + var res = x.toString().compareTo(y.toString()); + if (res != 0) { + return res; + } + } + } + + return 0; + }; + + comparators.put(klass, comparator); + + return comparator; + } + + static record DbObject2(String type, Object data) { + } + + public void dump(File file, Collection<FactHandle> factHandles, Function<Object, Boolean> filter) throws IOException { + FileUtil.createMissingParentDirectories(file); + + var facts = new TreeMap<Class<?>, FactCollection<Object>>(Comparator.comparing(Class::getName)); + for (var handle : factHandles) { + if (handle instanceof DefaultFactHandle h) { + var obj = h.getObject(); + if (!filter.apply(obj)) { + continue; + } + + Class<?> type = obj.getClass(); + var collection = facts.get(type); + + if (collection == null) { + collection = new FactCollection(type); + facts.put(type, collection); + } + + collection.values.add(obj); + } + } + + var objects = new ArrayList<DbObject2>(facts.size()); + for (var e : facts.entrySet()) { + var name = e.getKey().getName(); + + var collection = e.getValue(); + collection.sort(); + for (var fact : collection.values) { + objects.add(new DbObject2(name, fact)); + } + } + + objects.sort(new DbObjectComparator()); + + var factory = mapper.getFactory(); + try (var writer = new FileWriter(file); + var g = factory.createGenerator(writer)) { + g.writeObject(objects); + } + } + + private class DbClassLoader extends ClassLoader { + private final KieContainer container; + private final KieBase kieBase; + + public DbClassLoader(KieContainer container, KieBase kieBase) { + this.container = container; + this.kieBase = kieBase; + } + + @Override + public Class<?> loadClass(String name) throws ClassNotFoundException { + logger.info("Loading class {}", name); + try { + var klass = super.loadClass(name); + logger.info("Found class in super classloader"); + return klass; + } catch (ClassNotFoundException e) { + var i = name.lastIndexOf('.'); + String pkg, simpleName; + if (i == -1) { + pkg = null; + simpleName = name; + } else { + pkg = name.substring(0, i); + simpleName = name.substring(i + 1); + } + + try { + var klass = container.getClassLoader().loadClass(name); + logger.info("Found class in container's classloader"); + return klass; + } catch (ClassNotFoundException ignore) { + } + + try { + logger.info("pkg = {}", pkg); + logger.info("simpleName = {}", simpleName); + var clazz = kieBase.getFactType(pkg, simpleName); + if (clazz != null) { + logger.info("Found class as a FactType"); + return clazz.getFactClass(); + } + } catch (UnsupportedOperationException ignore) { + } + + logger.warn("Class not found: {}", name); + + throw e; + } + } + } + + private static class DbObjectComparator implements Comparator<DbObject2> { + private final List<String> prioritizedPackages = List.of( + "io.trygvis.rules.machine", + "io.trygvis.rules.network", + "io.trygvis.rules.dns", + "io.trygvis.rules.dba", + "io.trygvis.rules", + "io.trygvis.rules.core"); + + @Override + public int compare(DbObject2 a, DbObject2 b) { + var indexA = a.type.lastIndexOf("."); + String packageA = indexA == -1 ? null : a.type.substring(0, indexA); + String classA = indexA == -1 ? a.type : a.type.substring(indexA + 1); + + var indexB = b.type.lastIndexOf("."); + String packageB = indexB == -1 ? null : b.type.substring(0, indexB); + String classB = indexB == -1 ? b.type : b.type.substring(indexB + 1); + + var priIdxA = prioritizedPackages.indexOf(packageA); + var priIdxB = prioritizedPackages.indexOf(packageB); + + if (priIdxA == -1 && priIdxB == -1) { + return classB.compareTo(classA); + } else if (priIdxA == -1) { + return 1; + } else if (priIdxB == -1) { + return -1; + } + return priIdxA - priIdxB; +// var diff = priIdxB - priIdxA; +// if (diff != 0) { +// return diff; +// } +// +// return classB.compareTo(classA); + } + } +} |