package io.trygvis.rules.engine; import ch.qos.logback.core.util.FileUtil; import com.fasterxml.jackson.annotation.ObjectIdGenerators; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.PropertyName; import com.fasterxml.jackson.databind.SerializationFeature; import com.fasterxml.jackson.databind.introspect.Annotated; import com.fasterxml.jackson.databind.introspect.JacksonAnnotationIntrospector; import com.fasterxml.jackson.databind.introspect.ObjectIdInfo; import com.fasterxml.jackson.databind.type.TypeFactory; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; import org.drools.core.common.DefaultFactHandle; import org.drools.core.factmodel.GeneratedFact; import org.kie.api.KieBase; import org.kie.api.runtime.rule.FactHandle; import java.io.File; import java.io.FileWriter; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.util.*; import java.util.function.Function; @SuppressWarnings("unchecked") public class DbIo { private final ObjectMapper mapper; private static final List prioritizedKeys = List.of("key", "name", "fqdn"); public DbIo(KieBase kieBase) { var factory = new YAMLFactory(); factory.enable(YAMLGenerator.Feature.USE_NATIVE_TYPE_ID); factory.enable(YAMLGenerator.Feature.USE_NATIVE_OBJECT_ID); mapper = new ObjectMapper(factory); mapper.disable(SerializationFeature.FAIL_ON_EMPTY_BEANS); var typeFactory = TypeFactory.defaultInstance() .withClassLoader(new AcmeClassLoader(kieBase)); mapper.setTypeFactory(typeFactory); mapper.findAndRegisterModules(); mapper.setAnnotationIntrospector(new JacksonAnnotationIntrospector() { @Override public ObjectIdInfo findObjectIdInfo(Annotated a) { final Class klass = a.getRawType(); if (GeneratedFact.class.isAssignableFrom(klass)) { System.out.println("klass = " + klass); for (String name : prioritizedKeys) { try { final String getter = "get" + name.substring(0, 1).toUpperCase() + name.substring(1); var f = klass.getMethod(getter); return new ObjectIdInfo(PropertyName.construct(name), null, ObjectIdGenerators.PropertyGenerator.class, null); } catch (NoSuchMethodException ignore) { } } System.out.println("a.getRawType() = " + klass); return new ObjectIdInfo(null, null, ObjectIdGenerators.IntSequenceGenerator.class, null); } return super.findObjectIdInfo(a); } }); } public List load(String file) throws IOException { var parser = mapper.getFactory().createParser(new File(file)); var objects = mapper.readValues(parser, DbObject.class).readAll(new ArrayList<>()); List items = new ArrayList<>(objects.size()); for (DbObject object : objects) { try { var type = mapper.getTypeFactory().findClass(object.type); var x = mapper.treeToValue(object.data, type); if (x == null) { x = type.getDeclaredConstructor().newInstance(); } items.add(x); } catch (ClassNotFoundException | NoSuchMethodException | InstantiationException | IllegalAccessException | InvocationTargetException e) { // ignore } } return items; } public void dump(String s, Collection factHandles) throws IOException { dump(s, factHandles, (o) -> true); } // This should just sort by all getters instead. static class FactCollection { public final Class type; public final List values; public FactCollection(Class type) { this.type = type; this.values = new ArrayList<>(); } public void sort() { var comparator = comparable(type); this.values.sort(comparator); } } private static final Map, Comparator> comparators = new HashMap<>(); private static > Comparator comparable(Class klass) { var comparator = comparators.get(klass); if (comparator != null) { return comparator; } // TODO: check if klass is a Comparable directly. var discoveredFieldsP1 = new LinkedHashMap>(); var discoveredFieldsP2 = new LinkedHashMap>(); var prioritizedTypes = List.of(String.class, int.class, Number.class); for (var f : klass.getDeclaredFields()) { if (f.getDeclaringClass() == Object.class) { continue; } if (!f.trySetAccessible()) { continue; } var collection = discoveredFieldsP2; if (prioritizedTypes.contains(f.getType())) { collection = discoveredFieldsP1; } collection.put(f.getName(), (Object o) -> { try { return f.get(o); } catch (IllegalAccessException e) { throw new RuntimeException(e); } }); } // for (var m : klass.getFields()) { // if (m.getParameterCount() != 0) { // continue; // } // // var name = m.getName(); // // if (name.startsWith("get") && name.length() > 3 && Character.isUpperCase(name.charAt(4))) { // name = name.substring(3, 3).toLowerCase() + name.substring(4); // } else { // continue; // } // // if (!m.isAccessible()) { // if (!m.trySetAccessible()) // return null; // } // // discoveredFields.put(name, m); // } // System.out.printf("Sorting %s by:%n", klass.getName()); var discoveredFields = new LinkedHashMap<>(discoveredFieldsP1); discoveredFields.putAll(discoveredFieldsP2); List> accessors = new ArrayList<>(); for (String prioritizedKey : prioritizedKeys) { var m = discoveredFields.remove(prioritizedKey); if (m == null) { continue; } accessors.add(m); // System.out.println(" + " + prioritizedKey); } accessors.addAll(discoveredFields.values()); // discoveredFields.keySet().forEach((s)-> System.out.println(" - " + s)); comparator = (a, b) -> { // if (klass.getName().contains("AcmeServer")) { // System.out.println("AcmeIo.comparable"); // } for (var method : accessors) { var x = method.apply(a); var y = method.apply(b); if (x == null && y == null) { continue; } if (x == null) { return -1; } else if (y == null) { return 1; } else { var res = x.toString().compareTo(y.toString()); if (res != 0) { return res; } } } return 0; }; comparators.put(klass, comparator); return comparator; } static record DbObject2(String type, Object data) { } public void dump(String s, Collection factHandles, Function filter) throws IOException { var yamlFile = new File("out", s + ".yaml"); FileUtil.createMissingParentDirectories(yamlFile); var facts = new TreeMap, FactCollection>(Comparator.comparing(Class::getName)); for (var handle : factHandles) { if (handle instanceof DefaultFactHandle h) { var obj = h.getObject(); if (!filter.apply(obj)) { continue; } Class type = obj.getClass(); var collection = facts.get(type); if (collection == null) { collection = new FactCollection(type); facts.put(type, collection); } collection.values.add(obj); } } var objects = new ArrayList(facts.size()); for (var e : facts.entrySet()) { var name = e.getKey().getName(); var collection = e.getValue(); collection.sort(); for (var fact : collection.values) { objects.add(new DbObject2(name, fact)); } } objects.sort(new DbObjectComparator()); var factory = mapper.getFactory(); try (var writer = new FileWriter(yamlFile); var g = factory.createGenerator(writer)) { g.writeObject(objects); } } private static class AcmeClassLoader extends ClassLoader { private final KieBase kieBase; public AcmeClassLoader(KieBase kieBase) { this.kieBase = kieBase; } @Override public Class loadClass(String name) throws ClassNotFoundException { try { return super.loadClass(name); } catch (ClassNotFoundException e) { var i = name.lastIndexOf('.'); String pkg, klass; if (i == -1) { pkg = null; klass = name; } else { pkg = name.substring(0, i); klass = name.substring(i + 1); } var clazz = kieBase.getFactType(pkg, klass); if (clazz == null) { throw e; } return clazz.getFactClass(); } } } private static class DbObjectComparator implements Comparator { private final List prioritizedPackages = List.of( "io.trygvis.rules.machine", "io.trygvis.rules.network", "io.trygvis.rules.dns", "io.trygvis.rules.dba", "io.trygvis.rules", "io.trygvis.rules.core"); @Override public int compare(DbObject2 a, DbObject2 b) { var indexA = a.type.lastIndexOf("."); String packageA = indexA == -1 ? null : a.type.substring(0, indexA); String classA = indexA == -1 ? a.type : a.type.substring(indexA + 1); var indexB = b.type.lastIndexOf("."); String packageB = indexB == -1 ? null : b.type.substring(0, indexB); String classB = indexB == -1 ? b.type : b.type.substring(indexB + 1); var priIdxA = prioritizedPackages.indexOf(packageA); var priIdxB = prioritizedPackages.indexOf(packageB); if (priIdxA == -1 && priIdxB == -1) { return classB.compareTo(classA); } else if (priIdxA == -1) { return 1; } else if (priIdxB == -1) { return -1; } return priIdxA - priIdxB; // var diff = priIdxB - priIdxA; // if (diff != 0) { // return diff; // } // // return classB.compareTo(classA); } } }