package org.nuxeo.ecm.core.redis;
import static org.junit.Assert.assertEquals;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.UnsupportedEncodingException;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.inject.Inject;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.nuxeo.ecm.core.test.CoreFeature;
import org.nuxeo.ecm.core.work.AbstractWork;
import org.nuxeo.ecm.core.work.api.Work;
import org.nuxeo.ecm.core.work.api.WorkManager;
import org.nuxeo.ecm.core.work.api.WorkQueueMetrics;
import org.nuxeo.runtime.api.Framework;
import org.nuxeo.runtime.test.runner.Features;
import org.nuxeo.runtime.test.runner.FeaturesRunner;
import redis.clients.jedis.Jedis;
@Features({ RedisFeature.class, CoreFeature.class })
@RunWith(FeaturesRunner.class)
public class TestRedisWorkShutdown {
static Log log = LogFactory.getLog(TestRedisWorkShutdown.class);
static CountDownLatch canShutdown = new CountDownLatch(2);
static CountDownLatch canProceed = new CountDownLatch(1);
public static class MyWork extends AbstractWork {
private static final long serialVersionUID = 1L;
MyWork(String id) {
super(id);
setProgress(new Progress(0,2));
}
@Override
public String getTitle() {
return "waiting work";
}
Progress nextProgress() {
Progress progress = getProgress();
progress = new Progress(progress.getCurrent()+1, progress.getTotal());
setProgress(progress);
return progress;
}
@Override
public void work() {
Progress progress = nextProgress();
if (progress.getCurrent() < progress.getTotal()) {
try {
log.debug(id + " waiting for shutdown");
canShutdown.countDown();
canProceed.await(1, TimeUnit.MINUTES);
Assert.assertTrue(isSuspending());
suspended();
} catch (InterruptedException cause) {
Thread.currentThread()
.interrupt();
throw new RuntimeException(cause);
}
} else {
;
}
}
@Override
public String toString() {
return id;
}
}
@Inject
WorkManager works;
void assertMetrics(long scheduled, long running, long completed, long cancelled) {
assertEquals(new WorkQueueMetrics("default", scheduled, running, completed, cancelled), works.getMetrics("default"));
}
@Test
public void worksArePersisted() throws InterruptedException {
assertMetrics(0, 0, 0, 0);
try {
// given two running works
works.schedule(new MyWork("first"));
works.schedule(new MyWork("second"));
canShutdown.await(10, TimeUnit.SECONDS);
assertMetrics(0, 2, 0, 0);
// when I shutdown
Framework.getRuntime().standby(Instant.now().plus(Duration.ofSeconds(10)));
} finally {
// then works are suspending
canProceed.countDown();
}
// then works are re-scheduled
try {
List<Work> scheduled = new ScheduledRetriever().listScheduled();
Assert.assertThat(scheduled.size(), Matchers.is(2));
canProceed = new CountDownLatch(1);
} finally {
// when I reboot
Framework.getRuntime().resume();
}
Assert.assertTrue(works.awaitCompletion(10, TimeUnit.SECONDS));
// works are completed
assertMetrics(0, 0, 2, 2);
}
class ScheduledRetriever {
String namespace = Framework.getService(RedisAdmin.class)
.namespace("work");
byte[] keyBytes(String value) {
try {
return namespace.concat(value)
.getBytes("UTF-8");
} catch (UnsupportedEncodingException cause) {
throw new UnsupportedOperationException("Cannot encode " + value, cause);
}
}
byte[] queueBytes() {
return keyBytes("sched:default");
}
byte[] dataKey() {
return keyBytes("data");
}
List<Work> listScheduled() {
RedisPoolDescriptor config = Framework.getService(RedisAdmin.class).getConfig();
return config.newExecutor()
.execute(new RedisCallable<List<Work>>() {
@Override
public List<Work> call(Jedis jedis) {
Set<byte[]> keys = jedis.smembers(queueBytes());
List<Work> list = new ArrayList<Work>(keys.size());
for (byte[] workIdBytes : keys) {
// get data
byte[] workBytes = jedis.hget(dataKey(), workIdBytes);
Work work = deserializeWork(workBytes);
list.add(work);
}
return list;
}
});
}
Work deserializeWork(byte[] bytes) {
if (bytes == null) {
return null;
}
InputStream bain = new ByteArrayInputStream(bytes);
try (ObjectInputStream in = new ObjectInputStream(bain)) {
return (Work) in.readObject();
} catch (RuntimeException cause) {
throw cause;
} catch (IOException | ClassNotFoundException cause) {
throw new RuntimeException("Cannot deserialize work", cause);
}
}
}
}