package com.ewjordan.util.objectWrap;
import java.util.Random;
import org.apache.commons.math.analysis.DifferentiableMultivariateRealFunction;
import org.apache.commons.math.analysis.MultivariateRealFunction;
import org.apache.commons.math.analysis.MultivariateVectorialFunction;
import org.apache.commons.math.genetics.Chromosome;
import org.apache.commons.math.genetics.ChromosomePair;
import org.apache.commons.math.genetics.CrossoverPolicy;
import org.apache.commons.math.genetics.ElitisticListPopulation;
import org.apache.commons.math.genetics.GeneticAlgorithm;
import org.apache.commons.math.genetics.MutationPolicy;
import org.apache.commons.math.genetics.Population;
import org.apache.commons.math.genetics.SelectionPolicy;
import org.apache.commons.math.genetics.StoppingCondition;
import org.apache.commons.math.genetics.TournamentSelection;
import org.apache.commons.math.optimization.GoalType;
import org.apache.commons.math.optimization.OptimizationException;
import org.apache.commons.math.optimization.RealPointValuePair;
import org.apache.commons.math.optimization.direct.NelderMeadSimplex;
import org.apache.commons.math.optimization.direct.SimplexOptimizer;
import org.apache.commons.math.optimization.general.ConjugateGradientFormula;
import org.apache.commons.math.optimization.general.NonLinearConjugateGradientOptimizer;
/**
* This class optimizes OptimizableWrappedObjects against
* their fitness functions, using either Apache's gradient-based
* optimization routines or its genetic algorithm implementation.
*
* The routines here will generate the necessary crossover, mutation,
* and calculus quantities using a default implementation, however be
* aware that these defaults will not always be ideal, and in some cases
* you may have faster or more appropriate ways to do these things.
*
* The default crossover policy is a single point crossover at a random
* place along the chromosome - for a wrapped object, this means that
* we simply switch over from reading one object's fields to the other's.
* This means that the order of field declaration in each object matters.
* It is suggested that for use with pure genetic algorithms a decoding
* procedure is used to transform the chromosome (i.e. the object's fields)
* into the final structure that is then tested for fitness.
*
* @author Eric
*
*/
public class ObjectOptimizer {
static private double derivativeDelta = 0.001;
static private final Random rand = new Random();
static private int maxIterations = 100;
static public void setMaxIterations(int val) {
maxIterations = val;
}
static public int getMaxIterations() { return maxIterations; }
static private CrossoverPolicy defaultWrappedObjectCrossoverPolicy = new CrossoverPolicy() {
@Override
public ChromosomePair crossover(Chromosome first,
Chromosome second) {
if ((first instanceof WrappedObjectChromosome)
&& (second instanceof WrappedObjectChromosome)) {
WrappedObjectChromosome wrappedFirst = (WrappedObjectChromosome) first;
WrappedObjectChromosome wrappedSecond = (WrappedObjectChromosome) second;
int nMembers = wrappedFirst.size();
if (wrappedSecond.size() < nMembers) {
nMembers = wrappedSecond.size();
}
int crossPoint = rand.nextInt(nMembers);
WrappedObjectChromosome crossedFirst = WrappedObjectChromosome
.cross(wrappedFirst, wrappedSecond, crossPoint);
WrappedObjectChromosome crossedSecond = WrappedObjectChromosome
.cross(wrappedSecond, wrappedFirst, crossPoint);
ChromosomePair pair = new ChromosomePair(crossedFirst,
crossedSecond);
return pair;
} else {
throw new IllegalArgumentException(
"WrappedObjectChromosomes are the only types accepted here");
}
}
};
static private MutationPolicy defaultWrappedObjectMutationPolicy = new MutationPolicy() {
@Override
public Chromosome mutate(Chromosome original) {
if (original instanceof WrappedObjectChromosome) {
WrappedObjectChromosome wrapped = (WrappedObjectChromosome)original;
WrappedObjectChromosome result = WrappedObjectChromosome.mutate(wrapped);
return result;
} else {
throw new IllegalArgumentException("WrappedObjectChromosomes are the only types accepted here");
}
}
};
static private SelectionPolicy defaultWrappedObjectSelectionPolicy = new TournamentSelection(10);
static public WrappedObjectChromosome geneticallyOptimize(OptimizableWrappedObject obj,
final int generations, double crossoverRate,
double mutationRate, int populationLimit, double elitismRate) {
final GeneticAlgorithm alg = new GeneticAlgorithm(defaultWrappedObjectCrossoverPolicy,crossoverRate,
defaultWrappedObjectMutationPolicy, mutationRate,
defaultWrappedObjectSelectionPolicy);
Population initial = new ElitisticListPopulation(populationLimit, elitismRate);
initial.addChromosome(new WrappedObjectChromosome(obj));
Population finalPopulation = alg.evolve(initial, new StoppingCondition() {
@Override
public boolean isSatisfied(Population population) {
if (alg.getGenerationsEvolved() > generations) {
return true;
} else {
return false;
}
}
});
WrappedObjectChromosome chrom = (WrappedObjectChromosome)finalPopulation.getFittestChromosome();
return chrom;
}
static final MutationPolicy getWrappedObjectMutationPolicy() {
return defaultWrappedObjectMutationPolicy;
}
static final CrossoverPolicy getWrappedObjectCrossoverPolicy() {
return defaultWrappedObjectCrossoverPolicy;
}
static final WrappedObjectChromosome getChromosome(final OptimizableWrappedObject opt) {
WrappedObjectChromosome chromosome = new WrappedObjectChromosome(opt) {
@Override
public double fitness() {
return opt.getValue();
}
};
return chromosome;
}
/**
* Optimize an OptimizableWrappedObject using the nonlinear conjugate gradient method.
* Returns number of evaluations, or -1 if an exception was thrown.
* @throws IllegalArgumentException
* @throws FunctionEvaluationException
* @throws OptimizationException
*
*/
static public final int optimize(final OptimizableWrappedObject opt, GoalType goalType) throws OptimizationException, IllegalArgumentException {
NonLinearConjugateGradientOptimizer optimizer = new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.FLETCHER_REEVES);
opt.pullValuesFromObject();
RealPointValuePair pair = optimizer.optimize(maxIterations, getDifferentiableMultivariateRealFunction(opt),goalType, opt.getValues());
opt.setValues(pair.getPoint());
opt.pushValuesToObject();
return optimizer.getEvaluations();
}
static public final double[] optimize(final MultivariateRealFunction func, GoalType goalType, final double[] startingPoint) throws IllegalArgumentException {
Object funcPlusParam = new HasValue() {
MultivariateRealFunction f = func;
double[] param = startingPoint;
@Override
public double getValue() {
try {
return f.value(param);
} catch (IllegalArgumentException e) {
// e.printStackTrace();
}
return 0.0;
}
};
WrappedObject w = new WrappedObject(funcPlusParam);
final OptimizableWrappedObject opt = w.optimizable();
NonLinearConjugateGradientOptimizer optimizer = new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.FLETCHER_REEVES);
opt.pullValuesFromObject();
RealPointValuePair pair = null;
pair = optimizer.optimize(maxIterations, getDifferentiableMultivariateRealFunction(opt),goalType, opt.getValues());
return pair.getPoint();
}
/**
* Optimize an OptimizableWrappedObject using the Nelder Mead direct search method.
* Returns number of iterations, or -1 if an exception was thrown.
* @throws IllegalArgumentException
* @throws FunctionEvaluationException
* @throws OptimizationException
*
*/
static public final int optimizeNelderMead(final OptimizableWrappedObject opt, GoalType goalType) throws OptimizationException, IllegalArgumentException {
NelderMeadSimplex nm = new NelderMeadSimplex(opt.getNumberOfMembers());
opt.pullValuesFromObject();
SimplexOptimizer optimizer = new SimplexOptimizer();
optimizer.setSimplex(nm);
RealPointValuePair pair = optimizer.optimize(maxIterations, getDifferentiableMultivariateRealFunction(opt),goalType, opt.getValues());
opt.setValues(pair.getPoint());
opt.pushValuesToObject();
return optimizer.getEvaluations();
}
static final MultivariateRealFunction getMultivariateRealFunction(final OptimizableWrappedObject opt) {
MultivariateRealFunction func = new MultivariateRealFunction() {
private double[] prevVals = null;
@Override
public double value(double[] point) throws IllegalArgumentException {
double[] pvals = opt.getValues();
if (prevVals == null) {
prevVals = new double[pvals.length];
}
double value;
synchronized(opt.getObject()) { //TODO: test thread safety here...
System.arraycopy(pvals,0,prevVals,0,pvals.length);
opt.pushValuesToObject(point);
value = opt.getValue();
opt.pushValuesToObject(prevVals);
}
return value;
}
};
return func;
}
static final DifferentiableMultivariateRealFunction getDifferentiableMultivariateRealFunction(final OptimizableWrappedObject opt) {
DifferentiableMultivariateRealFunction func = new DifferentiableMultivariateRealFunction() {
private double[] prevVals = null;
@Override
public double value(double[] point)
throws IllegalArgumentException {
WrappedObject o = opt;
double[] pvals = o.getValues();
if (prevVals == null) {
prevVals = new double[pvals.length];
}
double value;
synchronized(o.getObject()) { //TODO: test thread safety here...
System.arraycopy(pvals,0,prevVals,0,pvals.length);
o.pushValuesToObject(point);
value = opt.getValue();
o.pushValuesToObject(prevVals);
}
return value;
}
private double getPartialDerivativeAtPoint(int k, double[] point) throws IllegalArgumentException {
double initialValue = point[k];
point[k] = initialValue + derivativeDelta;
double resultPlus = value(point);
point[k] = initialValue - derivativeDelta;
double resultMinus = value(point);
point[k] = initialValue;
return (resultPlus - resultMinus) / (2 * derivativeDelta);
}
private MultivariateRealFunction[] partials =
new MultivariateRealFunction[opt.getNumberOfMembers()];
@Override
public MultivariateRealFunction partialDerivative(final int k) {
if (partials[k] == null) {
partials[k] = new MultivariateRealFunction() {
@Override
public double value(double[] point)
throws IllegalArgumentException {
return getPartialDerivativeAtPoint(k, point);
}
};
}
return partials[k];
}
private final MultivariateVectorialFunction grad = new MultivariateVectorialFunction() {
@Override
public double[] value(double[] point)
throws IllegalArgumentException {
double[] returnValue = new double[point.length];
for (int i = 0; i < point.length; ++i) {
returnValue[i] = getPartialDerivativeAtPoint(i, point);
}
return returnValue;
}
};
@Override
public MultivariateVectorialFunction gradient() {
return grad;
}
};
return func;
}
}