package net.sf.openrocket.optimization.general.multidim;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import net.sf.openrocket.optimization.general.FunctionCache;
import net.sf.openrocket.optimization.general.FunctionOptimizer;
import net.sf.openrocket.optimization.general.OptimizationController;
import net.sf.openrocket.optimization.general.OptimizationException;
import net.sf.openrocket.optimization.general.ParallelFunctionCache;
import net.sf.openrocket.optimization.general.Point;
import net.sf.openrocket.util.Statistics;
/**
* A customized implementation of the parallel multidirectional search algorithm by Dennis and Torczon.
* <p>
* This is a parallel pattern search optimization algorithm. The function evaluations are performed
* using an ExecutorService. By default a ThreadPoolExecutor is used that has as many thread defined
* as the system has processors.
* <p>
* The optimization can be aborted by interrupting the current thread.
*/
public class MultidirectionalSearchOptimizer implements FunctionOptimizer, Statistics {
private static final Logger log = LoggerFactory.getLogger(MultidirectionalSearchOptimizer.class);
private List<Point> simplex = new ArrayList<Point>();
private ParallelFunctionCache functionExecutor;
private boolean useExpansion = false;
private boolean useCoordinateSearch = false;
private int stepCount = 0;
private int reflectionAcceptance = 0;
private int expansionAcceptance = 0;
private int coordinateAcceptance = 0;
private int reductionFallback = 0;
public MultidirectionalSearchOptimizer() {
// No-op
}
public MultidirectionalSearchOptimizer(ParallelFunctionCache functionCache) {
this.functionExecutor = functionCache;
}
@Override
public void optimize(Point initial, OptimizationController control) throws OptimizationException {
FunctionCacheComparator comparator = new FunctionCacheComparator(functionExecutor);
final List<Point> pattern = SearchPattern.square(initial.dim());
log.info("Starting optimization at " + initial + " with pattern " + pattern);
try {
boolean simplexComputed = false;
double step = 0.5;
// Set up the current simplex
simplex.clear();
simplex.add(initial);
for (Point p : pattern) {
simplex.add(initial.add(p.mul(step)));
}
// Normal iterations
List<Point> reflection = new ArrayList<Point>(simplex.size());
List<Point> expansion = new ArrayList<Point>(simplex.size());
List<Point> coordinateSearch = new ArrayList<Point>(simplex.size());
Point current;
double currentValue;
boolean continueOptimization = true;
while (continueOptimization) {
log.debug("Starting optimization step with simplex " + simplex +
(simplexComputed ? "" : " (not computed)"));
stepCount++;
if (!simplexComputed) {
// TODO: Could something be computed in parallel?
functionExecutor.compute(simplex);
functionExecutor.waitFor(simplex);
Collections.sort(simplex, comparator);
simplexComputed = true;
}
current = simplex.get(0);
currentValue = functionExecutor.getValue(current);
/*
* Compute and queue the next points in likely order of usefulness.
* Expansion is unlikely as we're mainly dealing with bounded optimization.
*/
createReflection(simplex, reflection);
if (useCoordinateSearch)
createCoordinateSearch(current, step, coordinateSearch);
if (useExpansion)
createExpansion(simplex, expansion);
functionExecutor.compute(reflection);
if (useCoordinateSearch)
functionExecutor.compute(coordinateSearch);
if (useExpansion)
functionExecutor.compute(expansion);
// Check reflection acceptance
log.debug("Computing reflection");
functionExecutor.waitFor(reflection);
if (accept(reflection, currentValue)) {
log.debug("Reflection was successful, aborting coordinate search, " +
(useExpansion ? "computing" : "skipping") + " expansion");
if (useCoordinateSearch)
functionExecutor.abort(coordinateSearch);
simplex.clear();
simplex.add(current);
simplex.addAll(reflection);
Collections.sort(simplex, comparator);
if (useExpansion) {
/*
* Assume expansion to be unsuccessful, queue next reflection while computing expansion.
*/
createReflection(simplex, reflection);
functionExecutor.compute(reflection);
functionExecutor.waitFor(expansion);
if (accept(expansion, currentValue)) {
log.debug("Expansion was successful, aborting reflection");
functionExecutor.abort(reflection);
simplex.clear();
simplex.add(current);
simplex.addAll(expansion);
step *= 2;
Collections.sort(simplex, comparator);
expansionAcceptance++;
} else {
log.debug("Expansion failed");
reflectionAcceptance++;
}
} else {
reflectionAcceptance++;
}
} else {
log.debug("Reflection was unsuccessful, aborting expansion, computing coordinate search");
functionExecutor.abort(expansion);
/*
* Assume coordinate search to be unsuccessful, queue contraction step while computing.
*/
halveStep(simplex);
functionExecutor.compute(simplex);
if (useCoordinateSearch) {
functionExecutor.waitFor(coordinateSearch);
if (accept(coordinateSearch, currentValue)) {
log.debug("Coordinate search successful, reseting simplex");
List<Point> toAbort = new LinkedList<Point>(simplex);
simplex.clear();
simplex.add(current);
for (Point p : pattern) {
simplex.add(current.add(p.mul(step)));
}
toAbort.removeAll(simplex);
functionExecutor.abort(toAbort);
simplexComputed = false;
coordinateAcceptance++;
} else {
log.debug("Coordinate search unsuccessful, halving step.");
step /= 2;
simplexComputed = false;
reductionFallback++;
}
} else {
log.debug("Coordinate search not used, halving step.");
step /= 2;
simplexComputed = false;
reductionFallback++;
}
}
log.debug("Ending optimization step with simplex " + simplex);
continueOptimization = control.stepTaken(current, currentValue, simplex.get(0),
functionExecutor.getValue(simplex.get(0)), step);
if (Thread.interrupted()) {
throw new InterruptedException();
}
}
} catch (InterruptedException e) {
log.info("Optimization was interrupted with InterruptedException");
}
log.info("Finishing optimization at point " + simplex.get(0) + " value = " +
functionExecutor.getValue(simplex.get(0)));
log.info("Optimization statistics: " + getStatistics());
}
private void createReflection(List<Point> base, List<Point> reflection) {
Point current = base.get(0);
reflection.clear();
/* new = - (old - current) + current = 2*current - old */
for (int i = 1; i < base.size(); i++) {
Point p = base.get(i);
p = current.mul(2).sub(p);
reflection.add(p);
}
}
private void createExpansion(List<Point> base, List<Point> expansion) {
Point current = base.get(0);
expansion.clear();
for (int i = 1; i < base.size(); i++) {
Point p = current.mul(3).sub(base.get(i).mul(2));
expansion.add(p);
}
}
private void halveStep(List<Point> base) {
Point current = base.get(0);
for (int i = 1; i < base.size(); i++) {
Point p = base.get(i);
/* new = (old - current)*0.5 + current = old*0.5 + current*0.5 = (old + current)*0.5 */
p = p.add(current).mul(0.5);
base.set(i, p);
}
}
private void createCoordinateSearch(Point current, double step, List<Point> coordinateDirections) {
coordinateDirections.clear();
for (int i = 0; i < current.dim(); i++) {
Point p = new Point(current.dim());
p = p.set(i, step);
coordinateDirections.add(current.add(p));
coordinateDirections.add(current.sub(p));
}
}
private boolean accept(List<Point> points, double currentValue) {
for (Point p : points) {
if (functionExecutor.getValue(p) < currentValue) {
return true;
}
}
return false;
}
@Override
public Point getOptimumPoint() {
if (simplex.size() == 0) {
throw new IllegalStateException("Optimization has not been called, simplex is empty");
}
return simplex.get(0);
}
@Override
public double getOptimumValue() {
return functionExecutor.getValue(getOptimumPoint());
}
@Override
public FunctionCache getFunctionCache() {
return functionExecutor;
}
@Override
public void setFunctionCache(FunctionCache functionCache) {
if (!(functionCache instanceof ParallelFunctionCache)) {
throw new IllegalArgumentException("Function cache needs to be a ParallelFunctionCache: " + functionCache);
}
this.functionExecutor = (ParallelFunctionCache) functionCache;
}
@Override
public String getStatistics() {
return "MultidirectionalSearchOptimizer[stepCount=" + stepCount +
", reflectionAcceptance=" + reflectionAcceptance +
", expansionAcceptance=" + expansionAcceptance +
", coordinateAcceptance=" + coordinateAcceptance +
", reductionFallback=" + reductionFallback;
}
@Override
public void resetStatistics() {
stepCount = 0;
reflectionAcceptance = 0;
expansionAcceptance = 0;
coordinateAcceptance = 0;
reductionFallback = 0;
}
}