summaryrefslogtreecommitdiff
path: root/src/ri-engine/src/main/java/io/trygvis/rules/engine/Engine.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/ri-engine/src/main/java/io/trygvis/rules/engine/Engine.java')
-rw-r--r--src/ri-engine/src/main/java/io/trygvis/rules/engine/Engine.java172
1 files changed, 172 insertions, 0 deletions
diff --git a/src/ri-engine/src/main/java/io/trygvis/rules/engine/Engine.java b/src/ri-engine/src/main/java/io/trygvis/rules/engine/Engine.java
new file mode 100644
index 0000000..f2247d3
--- /dev/null
+++ b/src/ri-engine/src/main/java/io/trygvis/rules/engine/Engine.java
@@ -0,0 +1,172 @@
+package io.trygvis.rules.engine;
+
+import org.drools.core.audit.WorkingMemoryConsoleLogger;
+import org.drools.core.base.MapGlobalResolver;
+import org.drools.reflective.classloader.ProjectClassLoader;
+import org.kie.api.KieServices;
+import org.kie.api.event.rule.AgendaEventListener;
+import org.kie.api.event.rule.RuleRuntimeEventListener;
+import org.kie.api.io.Resource;
+import org.kie.api.runtime.KieContainer;
+import org.kie.api.runtime.KieSession;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nullable;
+import java.io.Closeable;
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.net.URL;
+import java.net.URLClassLoader;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.List;
+
+public class Engine implements Closeable {
+ @SuppressWarnings("FieldCanBeLocal")
+ private final Logger logger = LoggerFactory.getLogger(getClass());
+
+ public final String name;
+ @Nullable
+ public final File output;
+ public final DbIo io;
+ public final KieSession session;
+
+ public Engine(String name, File[] databases, @Nullable File output, String[] agendaGroups, File[] modules)
+ throws IOException {
+ this.name = name;
+ this.output = output;
+
+ logger.info("Getting KieServices");
+
+ var services = KieServices.Factory.get();
+
+ var kieRepository = services.getRepository();
+
+ KieContainer container;
+ TemplateLoader templateLoader;
+ if (modules != null && modules.length > 0) {
+ List<Resource> resources = new ArrayList<>();
+ List<URL> files = new ArrayList<>();
+ for (File path : modules) {
+ if (!path.exists()) {
+ logger.warn("Module path does not exist: {}", path.getAbsolutePath());
+ continue;
+ }
+
+ logger.info("New KieBuilder: {}, file={}, directory={}", path, path.isFile(), path.isDirectory());
+
+ if (path.isFile()) {
+ files.add(path.toURI().toURL());
+ }
+
+ var resource = services.getResources().newFileSystemResource(path);
+ logger.info("resource.getResourceType() = {}", resource.getResourceType());
+ resources.add(resource);
+ }
+
+ var module = kieRepository.addKieModule(resources.get(0), resources.subList(1, resources.size()).toArray(new Resource[0]));
+ logger.info("module.getReleaseId() = {}", module.getReleaseId());
+ var rId = module.getReleaseId();
+
+ logger.info("Creating classpath container, releaseId=" + rId);
+ container = services.newKieContainer(rId);
+
+ templateLoader = new ClasspathTemplateLoader(new URLClassLoader(files.toArray(new URL[0])));
+ } else {
+ var classLoader = ProjectClassLoader.findParentClassLoader();
+ container = services.getKieClasspathContainer(classLoader);
+ templateLoader = new ClasspathTemplateLoader(classLoader);
+ }
+
+ logger.info("Creating KieBase \"{}\"", name);
+ logger.info("Available kie base names: {}", container.getKieBaseNames());
+ var kieBase = container.getKieBase(name);
+
+ session = container.newKieSession(name);
+
+ var l = new WorkingMemoryConsoleLogger(session);
+ session.addEventListener((AgendaEventListener) l);
+ session.addEventListener((RuleRuntimeEventListener) l);
+
+ session.getGlobals().setDelegate(new EngineGlobalResolver(templateLoader));
+
+ logger.info("Loading data");
+ io = new DbIo(container, kieBase);
+
+ List<Object> allObjects = new ArrayList<>();
+ for (File database : databases) {
+ var objects = io.load(database);
+
+ if (objects.isEmpty()) {
+ logger.warn("Did not load any objects, something is wrong");
+ return;
+ }
+
+ logger.info("Loaded {} objects from {}", objects.size(), database);
+ allObjects.addAll(objects);
+ }
+ logger.info("Loaded {} objects", allObjects.size());
+
+ for (var object : allObjects) {
+ logger.info("object = " + object);
+ session.insert(object);
+ }
+
+ for (var agendaGroup : agendaGroups) {
+ logger.info("Setting agenda: " + agendaGroup);
+ session.getAgenda().getAgendaGroup(agendaGroup).setFocus();
+ session.fireAllRules();
+ }
+ }
+
+ @Override
+ public void close() {
+ session.dispose();
+ }
+
+ private static class ClasspathTemplateLoader implements TemplateLoader {
+ private final ClassLoader classLoader;
+
+ private ClasspathTemplateLoader(ClassLoader classLoader) {
+ this.classLoader = classLoader;
+ }
+
+ @Override
+ public String load(String name) throws IOException {
+ var resource = "templates/" + name + ".j2";
+
+ try (var inputStream = classLoader.getResourceAsStream(resource)) {
+ if (inputStream == null) {
+ throw new FileNotFoundException("Classpath resource: " + resource);
+ }
+
+ return new String(inputStream.readAllBytes(), StandardCharsets.UTF_8);
+ }
+ }
+ }
+
+ private class EngineGlobalResolver extends MapGlobalResolver {
+ private final TemplateLoader templateLoader;
+
+ public EngineGlobalResolver() {
+ templateLoader = null;
+ }
+
+ public EngineGlobalResolver(TemplateLoader templateLoader) {
+ this.templateLoader = templateLoader;
+ }
+
+ @Override
+ public Object resolveGlobal(String identifier) {
+ if ("te".equals(identifier)) {
+ if (output == null) {
+ throw new IllegalArgumentException("An instance of the TemplateEngine is required, but this job is not configured with a output directory.");
+ }
+ return new JinjavaTemplateEngine(templateLoader, output);
+ }
+ return super.resolveGlobal(identifier);
+ }
+ }
+}