package io.trygvis.queue; import org.slf4j.*; import org.springframework.jdbc.core.*; import org.springframework.transaction.*; import org.springframework.transaction.annotation.*; import org.springframework.transaction.support.*; import javax.sql.*; import java.util.*; import java.util.concurrent.*; import static java.util.Arrays.*; import static java.util.concurrent.TimeUnit.*; import static org.springframework.transaction.annotation.Propagation.MANDATORY; import static org.springframework.transaction.support.TransactionSynchronizationManager.registerSynchronization; public class JdbcAsyncService implements AsyncService { private final Logger log = LoggerFactory.getLogger(getClass()); private final ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(10, Executors.defaultThreadFactory()); private final Map queues = new HashMap<>(); private final TransactionTemplate transactionTemplate; private final QueueDao queueDao; private final TaskDao taskDao; /** * Accessed from all the queue threads. */ private final Map taskRefs = Collections.synchronizedMap(new WeakHashMap()); public JdbcAsyncService(DataSource dataSource, PlatformTransactionManager transactionManager) { this.transactionTemplate = new TransactionTemplate(transactionManager); this.queueDao = new QueueDao(new JdbcTemplate(dataSource)); this.taskDao = new TaskDao(new JdbcTemplate(dataSource)); } @Transactional(propagation = MANDATORY) public Queue registerQueue(final String name, final int interval, AsyncCallable callable) { log.info("registerQueue: ENTER"); Queue q = queueDao.findByName(name); log.info("q = {}", q); final long interval_; if (q == null) { q = new Queue(name, interval * 1000); queueDao.insert(q); interval_ = interval; } else { // Found an existing queue. Use the Settings from the database. interval_ = q.interval; } final QueueThread queueThread = new QueueThread(q, callable); queues.put(name, queueThread); registerSynchronization(new TransactionSynchronizationAdapter() { public void afterCompletion(int status) { log.info("status = {}", status); if (status == TransactionSynchronization.STATUS_COMMITTED) { executor.scheduleAtFixedRate(new Runnable() { public void run() { queueThread.ping(); } }, 1000, 1000 * interval_, MILLISECONDS); // Thread thread = new Thread(queueThread, name); // thread.setDaemon(true); // thread.start(); queueThread.start(); } } }); log.info("registerQueue: LEAVE"); return q; } public void stopQueue(Queue queue) { QueueThread queueThread = queues.get(queue.name); if (queueThread == null) { throw new RuntimeException("No such queue: '" + queue.name + "'."); } queueThread.shutdown(); } public Queue getQueue(String name) { QueueThread queueThread = queues.get(name); if (queueThread == null) { throw new RuntimeException("No such queue: '" + name + "'."); } return queueThread.queue; } @Transactional(propagation = MANDATORY) public TaskRef schedule(Queue queue, String... args) { log.info("schedule: ENTER"); Date scheduled = new Date(); StringBuilder arguments = new StringBuilder(); for (String arg : args) { arguments.append(arg).append(' '); } long id = taskDao.insert(queue.name, scheduled, arguments.toString()); Task task = new Task(id, queue.name, scheduled, null, 0, null, asList(args)); log.info("task = {}", task); queues.get(queue.name).ping(); // try { // Thread.sleep(500); // } catch (InterruptedException e) { // e.printStackTrace(); // } log.info("schedule: LEAVE"); TaskRef taskRef = new TaskRef(task); taskRefs.put(task.id, taskRef); return taskRef; } @Transactional(readOnly = true) public Task update(Task ref) { return taskDao.findById(ref.id); } class QueueThread extends Thread { public boolean shouldRun = true; public final Queue queue; private final AsyncCallable callable; QueueThread(Queue queue, AsyncCallable callable) { super(queue.name); this.queue = queue; this.callable = callable; } public void ping() { log.info("Sending ping to " + queue); synchronized (this) { notify(); } } public void run() { while (shouldRun) { List tasks = taskDao.findByNameAndCompletedIsNull(queue.name); log.info("Found {} tasks on queue {}", tasks.size(), queue.name); try { for (final Task task : tasks) { try { executeTask(task); } catch (TransactionException | TaskFailureException e) { log.warn("Task execution failed", e); } } } catch (Exception e) { if (!isInterrupted() && !shouldRun) { log.warn("Error while executing tasks.", e); } else { log.warn("Error because queue was signalled to shut down.", e); } } synchronized (this) { try { wait(); } catch (InterruptedException e) { // ignore } } } log.info("Queue has stopped"); synchronized (this) { this.notify(); } } private void executeTask(final Task task) { final Date run = new Date(); log.info("Setting last run on task. date = {}, task = {}", run, task); transactionTemplate.execute(new TransactionCallbackWithoutResult() { protected void doInTransactionWithoutResult(TransactionStatus status) { taskDao.update(task.registerRun()); } }); transactionTemplate.execute(new TransactionCallbackWithoutResult() { protected void doInTransactionWithoutResult(TransactionStatus status) { Task t; try { callable.run(task.arguments); Date completed = new Date(); t = task.registerComplete(completed); log.info("Completed task: {}", t); taskDao.update(t); } catch (Exception e) { throw new TaskFailureException(e); } TaskRef taskRef = taskRefs.get(task.id); if (taskRef != null) { log.info("Notifying listeners on task: {}", t); //noinspection SynchronizationOnLocalVariableOrMethodParameter synchronized (taskRef) { taskRef.notifyAll(); } } } }); } public void shutdown() { log.info("Shutting down queue"); shouldRun = false; synchronized (this) { this.interrupt(); } while (isAlive()) { synchronized (this) { try { this.wait(100); } catch (InterruptedException e) { // ignore } } } } } private static class TaskFailureException extends RuntimeException { public TaskFailureException(Exception e) { super(e); } } }