package io.trygvis.container.compiler; import org.apache.commons.io.IOUtils; import org.testng.annotations.BeforeMethod; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import javax.tools.Diagnostic; import javax.tools.DiagnosticCollector; import javax.tools.JavaCompiler; import javax.tools.JavaFileObject; import javax.tools.StandardJavaFileManager; import javax.tools.ToolProvider; import java.io.IOException; import java.net.URL; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Map; import static java.util.Arrays.asList; import static javax.tools.JavaCompiler.CompilationTask; import static org.fest.assertions.Assertions.assertThat; public class ProcessorTest { Charset UTF_8 = Charset.forName("utf-8"); @BeforeMethod private void before() { } @DataProvider(name = "data", parallel = true) public static Object[][] data() { return new Object[][]{new Object[]{ new String[]{ "io.trygvis.persistence.test.basic.package-info", "io.trygvis.persistence.test.basic.Person", "io.trygvis.persistence.test.basic.ParentEntity", "io.trygvis.persistence.test.basic.ChildEntity", }, new String[]{ "io.trygvis.persistence.test.basic.Sequences", "io.trygvis.persistence.test.basic.BasicSqlSession", "io.trygvis.persistence.test.basic.BasicSqlSessionFactory", "io.trygvis.persistence.test.basic.PersonDao", "io.trygvis.persistence.test.basic.PersonRow", "io.trygvis.persistence.test.basic.ChildEntityDao", "io.trygvis.persistence.test.basic.ChildEntityRow", } }, new Object[]{ new String[]{ "io.trygvis.persistence.test.inheritance.package-info", "io.trygvis.persistence.test.inheritance.A",}, new String[]{ "io.trygvis.persistence.test.inheritance.Sequences", "io.trygvis.persistence.test.inheritance.InheritanceSqlSession", "io.trygvis.persistence.test.inheritance.InheritanceSqlSessionFactory", "io.trygvis.persistence.test.inheritance.DDao", "io.trygvis.persistence.test.inheritance.DRow", } }, }; } @Test(dataProvider = "data") public void testBasic(String[] files, String[] classes) throws Exception { JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); DiagnosticCollector collector = new DiagnosticCollector<>(); StandardJavaFileManager standardFileManager = compiler.getStandardFileManager(collector, Locale.ENGLISH, UTF_8); InMemoryJavaFileManager fileManager = new InMemoryJavaFileManager(standardFileManager); List sources = new ArrayList<>(); for (String file : files) { sources.add(loadJava(file)); } CompilationTask task = compiler.getTask(null, fileManager, collector, null, null, sources); task.setProcessors(asList(new MyProcessor())); boolean result = task.call(); if (!result) { for (Diagnostic diagnostic : collector.getDiagnostics()) { JavaFileObject source = diagnostic.getSource(); String error = ""; if (source != null) { error += source.toUri().getPath(); } error += ":" + diagnostic.getLineNumber() + ":" + diagnostic.getColumnNumber(); System.out.println(error + ": " + diagnostic.getMessage(Locale.ENGLISH)); } } for (Map.Entry entry : fileManager.codes.entrySet()) { System.out.println("=== " + entry.getKey()); System.out.println(entry.getValue()); } assertThat(fileManager.codes.keySet()).containsOnly((Object[]) classes); assertThat(collector.getDiagnostics()).isEmpty(); assertThat(result).isTrue(); fileManager.close(); } private JavaSourceFromString loadJava(String className) throws IOException { String path = "/" + className.replace('.', '/') + ".java"; URL resource = getClass().getResource(path); if (resource == null) { throw new RuntimeException("Could not load code for: " + path); } return new JavaSourceFromString(className, IOUtils.toString(resource, UTF_8)); } }