package io.trygvis.rules.engine; import ch.qos.logback.core.util.FileUtil; import com.fasterxml.jackson.annotation.ObjectIdGenerators; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.PropertyName; import com.fasterxml.jackson.databind.SerializationFeature; import com.fasterxml.jackson.databind.introspect.Annotated; import com.fasterxml.jackson.databind.introspect.JacksonAnnotationIntrospector; import com.fasterxml.jackson.databind.introspect.ObjectIdInfo; import com.fasterxml.jackson.databind.type.TypeFactory; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; import org.drools.core.common.InternalFactHandle; import org.drools.core.factmodel.GeneratedFact; import org.kie.api.KieBase; import org.kie.api.runtime.KieContainer; import org.kie.api.runtime.rule.FactHandle; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.io.FileWriter; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.util.*; import java.util.function.Function; @SuppressWarnings("unchecked") public class DbIo { private final Logger logger = LoggerFactory.getLogger(getClass()); private final ObjectMapper mapper; private static final List prioritizedKeys = List.of("key", "name", "fqdn"); public DbIo(KieContainer container, KieBase kieBase) { var factory = new YAMLFactory(); factory.enable(YAMLGenerator.Feature.USE_NATIVE_TYPE_ID); factory.enable(YAMLGenerator.Feature.USE_NATIVE_OBJECT_ID); mapper = new ObjectMapper(factory); mapper.disable(SerializationFeature.FAIL_ON_EMPTY_BEANS); var typeFactory = TypeFactory.defaultInstance() .withClassLoader(new DbClassLoader(container, kieBase)); mapper.setTypeFactory(typeFactory); mapper.findAndRegisterModules(); mapper.setAnnotationIntrospector(new JacksonAnnotationIntrospector() { @Override public ObjectIdInfo findObjectIdInfo(Annotated a) { final Class klass = a.getRawType(); if (GeneratedFact.class.isAssignableFrom(klass)) { System.out.println("klass = " + klass); for (String name : prioritizedKeys) { try { final String getter = "get" + name.substring(0, 1).toUpperCase() + name.substring(1); klass.getMethod(getter); return new ObjectIdInfo(PropertyName.construct(name), null, ObjectIdGenerators.PropertyGenerator.class, null); } catch (NoSuchMethodException ignore) { } } System.out.println("a.getRawType() = " + klass); return new ObjectIdInfo(null, null, ObjectIdGenerators.IntSequenceGenerator.class, null); } return super.findObjectIdInfo(a); } }); } public List load(File file) throws IOException { var parser = mapper.getFactory().createParser(file); var objects = parser.>readValueAs(new TypeReference>() {}); var items = new ArrayList<>(); for (DbObject object : objects) { try { var type = mapper.getTypeFactory().findClass(object.type); var x = mapper.treeToValue(object.data, type); if (x == null) { x = type.getDeclaredConstructor().newInstance(); } items.add(x); } catch (ClassNotFoundException | NoSuchMethodException | InstantiationException | IllegalAccessException | InvocationTargetException e) { System.out.println("e.getClass() = " + e.getClass().getName()); System.out.println("e.getMessage() = " + e.getMessage()); // ignore } } return items; } public void dump(File file, Collection factHandles) throws IOException { dump(file, factHandles, (o) -> true); } // This should just sort by all getters instead. static class FactCollection { public final Class type; public final List values; public FactCollection(Class type) { this.type = type; this.values = new ArrayList<>(); } public void sort() { var comparator = comparable(type); this.values.sort(comparator); } } private static final Map, Comparator> comparators = new HashMap<>(); private static > Comparator comparable(Class klass) { var comparator = comparators.get(klass); if (comparator != null) { return comparator; } // TODO: check if klass is a Comparable directly. var discoveredFieldsP1 = new LinkedHashMap>(); var discoveredFieldsP2 = new LinkedHashMap>(); var prioritizedTypes = List.of(String.class, int.class, Number.class); for (var f : klass.getDeclaredFields()) { if (f.getDeclaringClass() == Object.class) { continue; } if (!f.trySetAccessible()) { continue; } var collection = discoveredFieldsP2; if (prioritizedTypes.contains(f.getType())) { collection = discoveredFieldsP1; } collection.put(f.getName(), (Object o) -> { try { return f.get(o); } catch (IllegalAccessException e) { throw new RuntimeException(e); } }); } // for (var m : klass.getFields()) { // if (m.getParameterCount() != 0) { // continue; // } // // var name = m.getName(); // // if (name.startsWith("get") && name.length() > 3 && Character.isUpperCase(name.charAt(4))) { // name = name.substring(3, 3).toLowerCase() + name.substring(4); // } else { // continue; // } // // if (!m.isAccessible()) { // if (!m.trySetAccessible()) // return null; // } // // discoveredFields.put(name, m); // } // System.out.printf("Sorting %s by:%n", klass.getName()); var discoveredFields = new LinkedHashMap<>(discoveredFieldsP1); discoveredFields.putAll(discoveredFieldsP2); List> accessors = new ArrayList<>(); for (String prioritizedKey : prioritizedKeys) { var m = discoveredFields.remove(prioritizedKey); if (m == null) { continue; } accessors.add(m); // System.out.println(" + " + prioritizedKey); } accessors.addAll(discoveredFields.values()); // discoveredFields.keySet().forEach((s)-> System.out.println(" - " + s)); comparator = (a, b) -> { // if (klass.getName().contains("AcmeServer")) { // System.out.println("AcmeIo.comparable"); // } for (var method : accessors) { var x = method.apply(a); var y = method.apply(b); if (x == null && y == null) { continue; } if (x == null) { return -1; } else if (y == null) { return 1; } else { var res = x.toString().compareTo(y.toString()); if (res != 0) { return res; } } } return 0; }; comparators.put(klass, comparator); return comparator; } static record DbObject2(String type, Object data) { } public void dump(File file, Collection factHandles, Function filter) throws IOException { FileUtil.createMissingParentDirectories(file); var facts = new TreeMap, FactCollection>(Comparator.comparing(Class::getName)); logger.info("The fact database has {} entries", factHandles.size()); for (var handle : factHandles) { if (handle instanceof InternalFactHandle h) { var obj = h.getObject(); if (!filter.apply(obj)) { continue; } Class type = obj.getClass(); var collection = facts.get(type); if (collection == null) { collection = new FactCollection(type); facts.put(type, collection); } collection.values.add(obj); } else { logger.warn("Not a known FactHandle type when dumping fact: {}", handle.toExternalForm()); } } logger.info("Outputting {} facts", facts.size()); var objects = new ArrayList(facts.size()); for (var e : facts.entrySet()) { var name = e.getKey().getName(); var collection = e.getValue(); collection.sort(); for (var fact : collection.values) { objects.add(new DbObject2(name, fact)); } } objects.sort(new DbObjectComparator()); var factory = mapper.getFactory(); try (var writer = new FileWriter(file); var g = factory.createGenerator(writer)) { g.writeObject(objects); } } private class DbClassLoader extends ClassLoader { private final KieContainer container; private final KieBase kieBase; public DbClassLoader(KieContainer container, KieBase kieBase) { this.container = container; this.kieBase = kieBase; } @Override public Class loadClass(String name) throws ClassNotFoundException { logger.info("Loading class {}", name); try { var klass = super.loadClass(name); logger.info("Found class in super classloader"); return klass; } catch (ClassNotFoundException e) { var i = name.lastIndexOf('.'); String pkg, simpleName; if (i == -1) { pkg = null; simpleName = name; } else { pkg = name.substring(0, i); simpleName = name.substring(i + 1); } try { var klass = container.getClassLoader().loadClass(name); logger.info("Found class in container's classloader"); return klass; } catch (ClassNotFoundException ignore) { } try { logger.info("pkg = {}", pkg); logger.info("simpleName = {}", simpleName); var clazz = kieBase.getFactType(pkg, simpleName); if (clazz != null) { logger.info("Found class as a FactType"); return clazz.getFactClass(); } } catch (UnsupportedOperationException ignore) { } logger.warn("Class not found: {}", name); throw e; } } } private static class DbObjectComparator implements Comparator { private final List prioritizedPackages = List.of( "io.trygvis.rules.machine", "io.trygvis.rules.network", "io.trygvis.rules.dns", "io.trygvis.rules.dba", "io.trygvis.rules", "io.trygvis.rules.core"); @Override public int compare(DbObject2 a, DbObject2 b) { var indexA = a.type.lastIndexOf("."); String packageA = indexA == -1 ? null : a.type.substring(0, indexA); String classA = indexA == -1 ? a.type : a.type.substring(indexA + 1); var indexB = b.type.lastIndexOf("."); String packageB = indexB == -1 ? null : b.type.substring(0, indexB); String classB = indexB == -1 ? b.type : b.type.substring(indexB + 1); var priIdxA = prioritizedPackages.indexOf(packageA); var priIdxB = prioritizedPackages.indexOf(packageB); if (priIdxA == -1 && priIdxB == -1) { return classB.compareTo(classA); } else if (priIdxA == -1) { return 1; } else if (priIdxB == -1) { return -1; } return priIdxA - priIdxB; // var diff = priIdxB - priIdxA; // if (diff != 0) { // return diff; // } // // return classB.compareTo(classA); } } }