package net.sf.openrocket.optimization.general;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import net.sf.openrocket.util.BugException;
/**
* An implementation of a ParallelFunctionCache that evaluates function values
* in parallel and caches them. This allows pre-calculating possibly required
* function values beforehand. If values are not required after all, the
* computation can be aborted assuming the function evaluation supports it.
* <p>
* Note that while this class handles threads and abstracts background execution,
* the public methods themselves are NOT thread-safe and should be called from
* only one thread at a time.
*
* @author Sampo Niskanen <sampo.niskanen@iki.fi>
*/
public class ParallelExecutorCache implements ParallelFunctionCache {
private final Map<Point, Double> functionCache = new HashMap<Point, Double>();
private final Map<Point, Future<Double>> futureMap = new HashMap<Point, Future<Double>>();
private ExecutorService executor;
private Function function;
/**
* Construct a cache that uses the same number of computational threads as there are
* processors available.
*/
public ParallelExecutorCache() {
this(Runtime.getRuntime().availableProcessors());
}
/**
* Construct a cache that uses the specified number of computational threads for background
* computation. The threads that are created are marked as daemon threads.
*
* @param threadCount the number of threads to use in the executor.
*/
public ParallelExecutorCache(int threadCount) {
this(new ThreadPoolExecutor(threadCount, threadCount, 60, TimeUnit.SECONDS,
new LinkedBlockingQueue<Runnable>(),
new ThreadFactory() {
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(r);
t.setDaemon(true);
return t;
}
}));
}
/**
* Construct a cache that uses the specified ExecutorService for managing
* computational threads.
*
* @param executor the executor to use for function evaluations.
*/
public ParallelExecutorCache(ExecutorService executor) {
this.executor = executor;
}
@Override
public void compute(Collection<Point> points) {
for (Point p : points) {
compute(p);
}
}
@Override
public void compute(Point point) {
if (isOutsideRange(point)) {
// Point is outside of range
return;
}
if (functionCache.containsKey(point)) {
// Function has already been evaluated at the point
return;
}
if (futureMap.containsKey(point)) {
// Function is being evaluated at the point
return;
}
// Submit point for evaluation
FunctionCallable callable = new FunctionCallable(function, point);
Future<Double> future = executor.submit(callable);
futureMap.put(point, future);
}
@Override
public void waitFor(Collection<Point> points) throws InterruptedException, OptimizationException {
for (Point p : points) {
waitFor(p);
}
}
@Override
public void waitFor(Point point) throws InterruptedException, OptimizationException {
if (isOutsideRange(point)) {
return;
}
if (functionCache.containsKey(point)) {
return;
}
Future<Double> future = futureMap.get(point);
if (future == null) {
throw new IllegalStateException("waitFor called for " + point + " but it is not being computed");
}
try {
double value = future.get();
functionCache.put(point, value);
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof InterruptedException) {
throw (InterruptedException) cause;
}
if (cause instanceof OptimizationException) {
throw (OptimizationException) cause;
}
if (cause instanceof RuntimeException) {
throw (RuntimeException) cause;
}
throw new BugException("Function threw unknown exception while processing", e);
}
}
@Override
public List<Point> abort(Collection<Point> points) {
List<Point> computed = new ArrayList<Point>(Math.min(points.size(), 10));
for (Point p : points) {
if (abort(p)) {
computed.add(p);
}
}
return computed;
}
@Override
public boolean abort(Point point) {
if (isOutsideRange(point)) {
return false;
}
if (functionCache.containsKey(point)) {
return true;
}
Future<Double> future = futureMap.remove(point);
if (future == null) {
throw new IllegalStateException("abort called for " + point + " but it is not being computed");
}
if (future.isDone()) {
// Evaluation has been completed, store value in cache
try {
double value = future.get();
functionCache.put(point, value);
return true;
} catch (Exception e) {
return false;
}
} else {
// Cancel the evaluation
future.cancel(true);
return false;
}
}
@Override
public void abortAll() {
Iterator<Point> iterator = futureMap.keySet().iterator();
while (iterator.hasNext()) {
Point point = iterator.next();
Future<Double> future = futureMap.get(point);
iterator.remove();
if (future.isDone()) {
// Evaluation has been completed, store value in cache
try {
double value = future.get();
functionCache.put(point, value);
} catch (Exception e) {
// Ignore
}
} else {
// Cancel the evaluation
future.cancel(true);
}
}
}
@Override
public double getValue(Point point) {
if (isOutsideRange(point)) {
return Double.MAX_VALUE;
}
Double d = functionCache.get(point);
if (d == null) {
throw new IllegalStateException(point + " is not in function cache. " +
"functionCache=" + functionCache + " futureMap=" + futureMap);
}
return d;
}
@Override
public Function getFunction() {
return function;
}
@Override
public void setFunction(Function function) {
this.function = function;
clearCache();
}
@Override
public void clearCache() {
List<Point> list = new ArrayList<Point>(futureMap.keySet());
abort(list);
functionCache.clear();
}
public ExecutorService getExecutor() {
return executor;
}
/**
* Check whether a point is outside of the valid optimization range.
*/
private boolean isOutsideRange(Point p) {
int n = p.dim();
for (int i = 0; i < n; i++) {
double d = p.get(i);
// Include NaN in disallowed range
if (!(d >= 0.0 && d <= 1.0)) {
return true;
}
}
return false;
}
/**
* A Callable that evaluates a function at a specific point and returns the result.
*/
private class FunctionCallable implements Callable<Double> {
private final Function calledFunction;
private final Point point;
public FunctionCallable(Function function, Point point) {
this.calledFunction = function;
this.point = point;
}
@Override
public Double call() throws InterruptedException, OptimizationException {
return calledFunction.evaluate(point);
}
}
}