/**
* Copyright (C) 2014 - present by OpenGamma Inc. and the OpenGamma group of companies
*
* Please see distribution for license.
*/
package com.opengamma.service;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.Is.is;
import static org.testng.AssertJUnit.assertFalse;
import static org.testng.AssertJUnit.assertTrue;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.opengamma.util.test.TestGroup;
@Test(groups = TestGroup.UNIT)
@SuppressWarnings("unchecked")
public class ServiceContextAwareExecutorServiceTest {
private final ExecutorService _underlying = Executors.newSingleThreadExecutor();
private final ExecutorService _executor = new ServiceContextAwareExecutorService(_underlying);
@BeforeMethod
public void setUp() throws Exception {
ThreadLocalServiceContext.init(ServiceContext.of(Collections.<Class<?>, Object>emptyMap()));
}
@AfterMethod
public void tearDown() throws Exception {
ThreadLocalServiceContext.init(null);
}
@Test
public void submitCallable() throws ExecutionException, InterruptedException {
assertFalse(_underlying.submit(callable()).get());
assertTrue(_executor.submit(callable()).get());
}
@Test
public void submitRunnable() throws InterruptedException {
final ArrayBlockingQueue<Boolean> queue = new ArrayBlockingQueue<>(1);
Runnable r = new Runnable() {
@Override
public void run() {
boolean b = ThreadLocalServiceContext.getInstance() != null;
try {
queue.put(b);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
};
_underlying.submit(r);
assertFalse(queue.take());
_executor.submit(r);
assertTrue(queue.take());
_underlying.submit(r, null);
assertFalse(queue.take());
_executor.submit(r, null);
assertTrue(queue.take());
}
@Test
public void invokeAll() throws InterruptedException {
List<Callable<Boolean>> tasks2 = Lists.newArrayList(callable(), callable());
List<Future<Boolean>> futures2 = _executor.invokeAll(tasks2);
assertTrue(Iterables.all(futures2, predicate()));
List<Callable<Boolean>> tasks1 = Lists.newArrayList(callable(), callable());
List<Future<Boolean>> futures1 = _underlying.invokeAll(tasks1);
assertTrue(Iterables.all(futures1, Predicates.not(predicate())));
}
@Test
public void amendContext() throws InterruptedException, ExecutionException {
ThreadLocalServiceContext.init(ServiceContext.of(String.class, "StringService"));
ExecutorService executor = new ServiceContextAwareExecutorService(_underlying);
assertThat(executor.submit(stringCallable()).get(), is("WithString"));
// Now change to a different context, without StringService
ThreadLocalServiceContext.init(ServiceContext.of(Integer.class, 42));
assertThat(executor.submit(stringCallable()).get(), is("WithoutString"));
}
private Callable<String> stringCallable() {
return new Callable<String>() {
@Override
public String call() throws Exception {
if (ThreadLocalServiceContext.getInstance() == null) {
return "None";
}
try {
ThreadLocalServiceContext.getInstance().get(String.class);
return "WithString";
} catch (IllegalArgumentException e) {
return "WithoutString";
}
}
};
}
private Predicate<Future<Boolean>> predicate() {
return new Predicate<Future<Boolean>>() {
@Override
public boolean apply(Future<Boolean> future) {
try {
return future.get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}
};
}
private Callable<Boolean> callable() {
return new Callable<Boolean>() {
@Override
public Boolean call() throws Exception {
return ThreadLocalServiceContext.getInstance() != null;
}
};
}
}