diff options
Diffstat (limited to 'src/main/java')
-rwxr-xr-x | src/main/java/io/trygvis/queue/AsyncService.java | 27 | ||||
-rw-r--r-- | src/main/java/io/trygvis/queue/JdbcAsyncService.java | 253 | ||||
-rwxr-xr-x | src/main/java/io/trygvis/queue/Queue.java | 28 | ||||
-rw-r--r-- | src/main/java/io/trygvis/queue/QueueDao.java | 36 | ||||
-rwxr-xr-x | src/main/java/io/trygvis/queue/Task.java | 60 | ||||
-rw-r--r-- | src/main/java/io/trygvis/queue/TaskDao.java | 54 | ||||
-rw-r--r-- | src/main/java/io/trygvis/queue/TaskRef.java | 23 | ||||
-rw-r--r-- | src/main/java/io/trygvis/queue/spring/QueueSpringConfig.java | 17 |
8 files changed, 498 insertions, 0 deletions
diff --git a/src/main/java/io/trygvis/queue/AsyncService.java b/src/main/java/io/trygvis/queue/AsyncService.java new file mode 100755 index 0000000..c12f794 --- /dev/null +++ b/src/main/java/io/trygvis/queue/AsyncService.java @@ -0,0 +1,27 @@ +package io.trygvis.queue; + +import java.util.*; + +public interface AsyncService { + + /** + * @param name + * @param interval how often the queue should be polled for missed tasks in seconds. + */ + Queue registerQueue(String name, int interval, AsyncCallable callable); + + void stopQueue(Queue queue); + + Queue getQueue(String name); + + TaskRef schedule(Queue queue, String... args); + + /** + * Polls for a new state of the execution. + */ + Task update(Task ref); + + interface AsyncCallable { + void run(List<String> args) throws Exception; + } +} diff --git a/src/main/java/io/trygvis/queue/JdbcAsyncService.java b/src/main/java/io/trygvis/queue/JdbcAsyncService.java new file mode 100644 index 0000000..ae6cdba --- /dev/null +++ b/src/main/java/io/trygvis/queue/JdbcAsyncService.java @@ -0,0 +1,253 @@ +package io.trygvis.queue; + +import org.slf4j.*; +import org.springframework.jdbc.core.*; +import org.springframework.transaction.*; +import org.springframework.transaction.annotation.*; +import org.springframework.transaction.support.*; + +import javax.sql.*; +import java.util.*; +import java.util.concurrent.*; + +import static java.util.Arrays.*; +import static java.util.concurrent.TimeUnit.*; +import static org.springframework.transaction.annotation.Propagation.MANDATORY; +import static org.springframework.transaction.support.TransactionSynchronizationManager.registerSynchronization; + +public class JdbcAsyncService implements AsyncService { + private final Logger log = LoggerFactory.getLogger(getClass()); + + private final ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(10, Executors.defaultThreadFactory()); + + private final Map<String, QueueThread> queues = new HashMap<>(); + + private final TransactionTemplate transactionTemplate; + + private final QueueDao queueDao; + + private final TaskDao taskDao; + + /** + * Accessed from all the queue threads. + */ + private final Map<Long, TaskRef> taskRefs = Collections.synchronizedMap(new WeakHashMap<Long, TaskRef>()); + + public JdbcAsyncService(DataSource dataSource, PlatformTransactionManager transactionManager) { + this.transactionTemplate = new TransactionTemplate(transactionManager); + this.queueDao = new QueueDao(new JdbcTemplate(dataSource)); + this.taskDao = new TaskDao(new JdbcTemplate(dataSource)); + } + + @Transactional(propagation = MANDATORY) + public Queue registerQueue(final String name, final int interval, AsyncCallable callable) { + log.info("registerQueue: ENTER"); + + Queue q = queueDao.findByName(name); + + log.info("q = {}", q); + + final long interval_; + if (q == null) { + q = new Queue(name, interval * 1000); + queueDao.insert(q); + interval_ = interval; + } else { + // Found an existing queue. Use the Settings from the database. + interval_ = q.interval; + } + + final QueueThread queueThread = new QueueThread(q, callable); + queues.put(name, queueThread); + + registerSynchronization(new TransactionSynchronizationAdapter() { + public void afterCompletion(int status) { + log.info("status = {}", status); + if (status == TransactionSynchronization.STATUS_COMMITTED) { + executor.scheduleAtFixedRate(new Runnable() { + public void run() { + queueThread.ping(); + } + }, 1000, 1000 * interval_, MILLISECONDS); +// Thread thread = new Thread(queueThread, name); +// thread.setDaemon(true); +// thread.start(); + queueThread.start(); + } + } + }); + + log.info("registerQueue: LEAVE"); + return q; + } + + public void stopQueue(Queue queue) { + QueueThread queueThread = queues.get(queue.name); + + if (queueThread == null) { + throw new RuntimeException("No such queue: '" + queue.name + "'."); + } + + queueThread.shutdown(); + } + + public Queue getQueue(String name) { + QueueThread queueThread = queues.get(name); + + if (queueThread == null) { + throw new RuntimeException("No such queue: '" + name + "'."); + } + + return queueThread.queue; + } + + @Transactional(propagation = MANDATORY) + public TaskRef schedule(Queue queue, String... args) { + log.info("schedule: ENTER"); + + Date scheduled = new Date(); + + StringBuilder arguments = new StringBuilder(); + for (String arg : args) { + arguments.append(arg).append(' '); + } + + long id = taskDao.insert(queue.name, scheduled, arguments.toString()); + Task task = new Task(id, queue.name, scheduled, null, 0, null, asList(args)); + log.info("task = {}", task); + queues.get(queue.name).ping(); +// try { +// Thread.sleep(500); +// } catch (InterruptedException e) { +// e.printStackTrace(); +// } + + log.info("schedule: LEAVE"); + TaskRef taskRef = new TaskRef(task); + taskRefs.put(task.id, taskRef); + return taskRef; + } + + @Transactional(readOnly = true) + public Task update(Task ref) { + return taskDao.findById(ref.id); + } + + class QueueThread extends Thread { + public boolean shouldRun = true; + + public final Queue queue; + + private final AsyncCallable callable; + + QueueThread(Queue queue, AsyncCallable callable) { + super(queue.name); + this.queue = queue; + this.callable = callable; + } + + public void ping() { + log.info("Sending ping to " + queue); + synchronized (this) { + notify(); + } + } + + public void run() { + while (shouldRun) { + List<Task> tasks = taskDao.findByNameAndCompletedIsNull(queue.name); + + log.info("Found {} tasks on queue {}", tasks.size(), queue.name); + + try { + for (final Task task : tasks) { + try { + executeTask(task); + } catch (TransactionException | TaskFailureException e) { + log.warn("Task execution failed", e); + } + } + } catch (Exception e) { + if (!isInterrupted() && !shouldRun) { + log.warn("Error while executing tasks.", e); + } else { + log.warn("Error because queue was signalled to shut down.", e); + } + } + + synchronized (this) { + try { + wait(); + } catch (InterruptedException e) { + // ignore + } + } + } + + log.info("Queue has stopped"); + + synchronized (this) { + this.notify(); + } + } + + private void executeTask(final Task task) { + final Date run = new Date(); + log.info("Setting last run on task. date = {}, task = {}", run, task); + transactionTemplate.execute(new TransactionCallbackWithoutResult() { + protected void doInTransactionWithoutResult(TransactionStatus status) { + taskDao.update(task.registerRun()); + } + }); + + transactionTemplate.execute(new TransactionCallbackWithoutResult() { + protected void doInTransactionWithoutResult(TransactionStatus status) { + Task t; + + try { + callable.run(task.arguments); + Date completed = new Date(); + t = task.registerComplete(completed); + log.info("Completed task: {}", t); + taskDao.update(t); + } catch (Exception e) { + throw new TaskFailureException(e); + } + + TaskRef taskRef = taskRefs.get(task.id); + + if (taskRef != null) { + log.info("Notifying listeners on task: {}", t); + //noinspection SynchronizationOnLocalVariableOrMethodParameter + synchronized (taskRef) { + taskRef.notifyAll(); + } + } + } + }); + } + + public void shutdown() { + log.info("Shutting down queue"); + shouldRun = false; + synchronized (this) { + this.interrupt(); + } + while (isAlive()) { + synchronized (this) { + try { + this.wait(100); + } catch (InterruptedException e) { + // ignore + } + } + } + } + } + + private static class TaskFailureException extends RuntimeException { + public TaskFailureException(Exception e) { + super(e); + } + } +} diff --git a/src/main/java/io/trygvis/queue/Queue.java b/src/main/java/io/trygvis/queue/Queue.java new file mode 100755 index 0000000..7e1fbff --- /dev/null +++ b/src/main/java/io/trygvis/queue/Queue.java @@ -0,0 +1,28 @@ +package io.trygvis.queue; + +/** + * TODO: concurrency control. min/max number of consumers. + * TODO: check interval. How often should the fallback check the queue. Resolution in seconds. + */ +public class Queue<A, B> { + + public final String name; + + public final long interval; + + public Queue(String name, long interval) { + this.name = name; + this.interval = interval; + } + + public Queue withInterval(int interval) { + return new Queue(name, interval); + } + + public String toString() { + return "Queue{" + + "name='" + name + '\'' + + ", interval=" + interval + + '}'; + } +} diff --git a/src/main/java/io/trygvis/queue/QueueDao.java b/src/main/java/io/trygvis/queue/QueueDao.java new file mode 100644 index 0000000..fcf975f --- /dev/null +++ b/src/main/java/io/trygvis/queue/QueueDao.java @@ -0,0 +1,36 @@ +package io.trygvis.queue; + +import org.springframework.beans.factory.annotation.*; +import org.springframework.jdbc.core.*; +import org.springframework.stereotype.*; + +import java.sql.*; + +import static org.springframework.dao.support.DataAccessUtils.*; + +public class QueueDao { + + private final JdbcTemplate jdbcTemplate; + + public QueueDao(JdbcTemplate jdbcTemplate) { + this.jdbcTemplate = jdbcTemplate; + } + + public Queue findByName(String name) { + return singleResult(jdbcTemplate.query("SELECT name, interval FROM queue WHERE name=?", new QueueRowMapper(), name)); + } + + public void insert(Queue q) { + jdbcTemplate.update("INSERT INTO queue(name, interval) VALUES(?, ?)", q.name, q.interval); + } + + public void update(Queue q) { + jdbcTemplate.update("UPDATE queue SET interval=? WHERE name=?", q.interval, q.name); + } + + private class QueueRowMapper implements RowMapper<Queue> { + public Queue mapRow(ResultSet rs, int rowNum) throws SQLException { + return new Queue(rs.getString(1), rs.getLong(2)); + } + } +} diff --git a/src/main/java/io/trygvis/queue/Task.java b/src/main/java/io/trygvis/queue/Task.java new file mode 100755 index 0000000..2a061bd --- /dev/null +++ b/src/main/java/io/trygvis/queue/Task.java @@ -0,0 +1,60 @@ +package io.trygvis.queue; + +import java.util.*; + +import static java.util.Collections.unmodifiableList; + +/** + * TODO: next run on failures. A default strategy from the queue, let the task be able to re-schedule itself. + */ +public class Task { + + public final long id; + + public final String queue; + + public final Date scheduled; + + public final Date lastRun; + + public final int runCount; + + public final Date completed; + + public final List<String> arguments; + + Task(long id, String queue, Date scheduled, Date lastRun, int runCount, Date completed, List<String> arguments) { + this.id = id; + this.queue = queue; + this.scheduled = scheduled; + this.lastRun = lastRun; + this.runCount = runCount; + this.completed = completed; + + this.arguments = unmodifiableList(new ArrayList<>(arguments)); + } + + public Task registerRun() { + return new Task(id, queue, scheduled, new Date(), runCount + 1, completed, arguments); + } + + public Task registerComplete(Date completed) { + return new Task(id, queue, scheduled, lastRun, runCount, new Date(), arguments); + } + + public String toString() { + return "Task{" + + "id=" + id + + ", queue=" + queue + + ", scheduled=" + scheduled + + ", lastRun=" + lastRun + + ", runCount=" + runCount + + ", completed=" + completed + + ", arguments='" + arguments + '\'' + + '}'; + } + + public boolean isDone() { + return completed != null; + } +} diff --git a/src/main/java/io/trygvis/queue/TaskDao.java b/src/main/java/io/trygvis/queue/TaskDao.java new file mode 100644 index 0000000..84ce958 --- /dev/null +++ b/src/main/java/io/trygvis/queue/TaskDao.java @@ -0,0 +1,54 @@ +package io.trygvis.queue; + +import org.springframework.jdbc.core.*; + +import java.sql.*; +import java.util.Date; +import java.util.*; + +import static java.util.Arrays.*; + +public class TaskDao { + + private final JdbcTemplate jdbcTemplate; + + public TaskDao(JdbcTemplate jdbcTemplate) { + this.jdbcTemplate = jdbcTemplate; + } + + public long insert(String queue, Date scheduled, String arguments) { + jdbcTemplate.update("INSERT INTO task(id, run_count, queue, scheduled, arguments) " + + "VALUES(nextval('task_seq'), 0, ?, ?, ?)", queue, scheduled, arguments); + return jdbcTemplate.queryForObject("SELECT currval('task_seq')", Long.class); + } + + public Task findById(long id) { + return jdbcTemplate.queryForObject("SELECT " + TaskRowMapper.fields + " FROM task WHERE id=?", + new TaskRowMapper(), id); + } + + public List<Task> findByNameAndCompletedIsNull(String name) { + return jdbcTemplate.query("SELECT " + TaskRowMapper.fields + " FROM task WHERE queue=? AND completed IS NULL", + new TaskRowMapper(), name); + } + + public void update(Task task) { + jdbcTemplate.update("UPDATE task SET scheduled=?, last_run=?, run_count=?, completed=? WHERE id=?", + task.scheduled, task.lastRun, task.runCount, task.completed, task.id); + } + + private class TaskRowMapper implements RowMapper<Task> { + public static final String fields = "id, queue, scheduled, last_run, run_count, completed, arguments"; + + public Task mapRow(ResultSet rs, int rowNum) throws SQLException { + return new Task( + rs.getLong(1), + rs.getString(2), + rs.getTimestamp(3), + rs.getTimestamp(4), + rs.getInt(5), + rs.getTimestamp(6), + asList(rs.getString(7).split(" "))); + } + } +} diff --git a/src/main/java/io/trygvis/queue/TaskRef.java b/src/main/java/io/trygvis/queue/TaskRef.java new file mode 100644 index 0000000..e465ee7 --- /dev/null +++ b/src/main/java/io/trygvis/queue/TaskRef.java @@ -0,0 +1,23 @@ +package io.trygvis.queue; + +/** + * TODO: Implement waitForSuccess(). Rename waitFor to waitForAny(). Let waitForAny() throw TargetInvocationException + * which contains the exception that the task threw. + */ +public class TaskRef { + private Task task; + + public TaskRef(Task task) { + this.task = task; + } + + public Task getTask() { + return task; + } + + public void waitFor(int timeout) throws InterruptedException { + synchronized (this) { + wait(timeout); + } + } +} diff --git a/src/main/java/io/trygvis/queue/spring/QueueSpringConfig.java b/src/main/java/io/trygvis/queue/spring/QueueSpringConfig.java new file mode 100644 index 0000000..0eb7695 --- /dev/null +++ b/src/main/java/io/trygvis/queue/spring/QueueSpringConfig.java @@ -0,0 +1,17 @@ +package io.trygvis.queue.spring; + +import io.trygvis.queue.*; +import org.springframework.context.annotation.*; +import org.springframework.transaction.*; +import org.springframework.transaction.annotation.*; + +import javax.sql.*; + +@Configuration +@EnableTransactionManagement +public class QueueSpringConfig { + @Bean + public AsyncService asyncService(DataSource dataSource, PlatformTransactionManager platformTransactionManager) { + return new JdbcAsyncService(dataSource, platformTransactionManager); + } +} |