package edu.stanford.nlp.loglinear.learning;
import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import com.pholser.junit.quickcheck.ForAll;
import com.pholser.junit.quickcheck.From;
import com.pholser.junit.quickcheck.generator.InRange;
import org.junit.contrib.theories.DataPoint;
import org.junit.contrib.theories.Theories;
import org.junit.contrib.theories.Theory;
import org.junit.runner.RunWith;
import java.util.Random;
import static org.junit.Assert.*;
/**
* Created on 8/26/15.
* @author keenon
* <p>
* This does its best to Quickcheck our optimizers. The strategy here is to generate convex functions that are solvable
* in closed form, and then test that our optimizer is able to achieve a nearly optimal solution at convergence.
*/
@RunWith(Theories.class)
public class OptimizerTests {
@DataPoint
public static AbstractBatchOptimizer backtrackingAdaGrad = new BacktrackingAdaGradOptimizer();
@Theory
public void testOptimizeLogLikelihood(AbstractBatchOptimizer optimizer,
@ForAll(sampleSize = 5) @From(LogLikelihoodFunctionTest.GraphicalModelDatasetGenerator.class) GraphicalModel[] dataset,
@ForAll(sampleSize = 2) @From(LogLikelihoodFunctionTest.WeightsGenerator.class) ConcatVector initialWeights,
@ForAll(sampleSize = 2) @InRange(minDouble = 0.0, maxDouble = 5.0) double l2regularization) throws Exception {
AbstractDifferentiableFunction<GraphicalModel> ll = new LogLikelihoodDifferentiableFunction();
ConcatVector finalWeights = optimizer.optimize(dataset, ll, initialWeights, l2regularization, 1.0e-9, true);
System.err.println("Finished optimizing");
double logLikelihood = getValueSum(dataset, finalWeights, ll, l2regularization);
// Check in a whole bunch of random directions really nearby that there is no nearby point with a higher log
// likelihood
Random r = new Random(42);
for (int i = 0; i < 1000; i++) {
int size = finalWeights.getNumberOfComponents();
ConcatVector randomDirection = new ConcatVector(size);
for (int j = 0; j < size; j++) {
double[] dense = new double[finalWeights.isComponentSparse(j) ? finalWeights.getSparseIndex(j) + 1 : finalWeights.getDenseComponent(j).length];
for (int k = 0; k < dense.length; k++) {
dense[k] = (r.nextDouble() - 0.5) * 1.0e-3;
}
randomDirection.setDenseComponent(j, dense);
}
ConcatVector randomPerturbation = finalWeights.deepClone();
randomPerturbation.addVectorInPlace(randomDirection, 1.0);
double randomPerturbedLogLikelihood = getValueSum(dataset, randomPerturbation, ll, l2regularization);
// Check that we're within a very small margin of error (around 3 decimal places) of the randomly
// discovered value
if (logLikelihood < randomPerturbedLogLikelihood - (1.0e-3 * Math.max(1.0, Math.abs(logLikelihood)))) {
System.err.println("Thought optimal point was: " + logLikelihood);
System.err.println("Discovered better point: " + randomPerturbedLogLikelihood);
}
assertTrue(logLikelihood >= randomPerturbedLogLikelihood - (1.0e-3 * Math.max(1.0, Math.abs(logLikelihood))));
}
}
/*
@Theory
public void testOptimizeLogLikelihoodWithConstraints(AbstractBatchOptimizer optimizer,
@ForAll(sampleSize = 5) @From(LogLikelihoodFunctionTest.GraphicalModelDatasetGenerator.class) GraphicalModel[] dataset,
@ForAll(sampleSize = 2) @From(LogLikelihoodFunctionTest.WeightsGenerator.class) ConcatVector initialWeights,
@ForAll(sampleSize = 2) @InRange(minDouble = 0.0, maxDouble = 5.0) double l2regularization) throws Exception {
Random r = new Random(42);
int constraintComponent = r.nextInt(initialWeights.getNumberOfComponents());
double constraintValue = r.nextDouble();
if (r.nextBoolean()) {
optimizer.addSparseConstraint(constraintComponent, 0, constraintValue);
} else {
optimizer.addDenseConstraint(constraintComponent, new double[]{constraintValue});
}
// Put in some constraints
AbstractDifferentiableFunction<GraphicalModel> ll = new LogLikelihoodDifferentiableFunction();
ConcatVector finalWeights = optimizer.optimize(dataset, ll, initialWeights, l2regularization, 1.0e-9, false);
System.err.println("Finished optimizing");
assertEquals(constraintValue, finalWeights.getValueAt(constraintComponent, 0), 1.0e-9);
double logLikelihood = getValueSum(dataset, finalWeights, ll, l2regularization);
// Check in a whole bunch of random directions really nearby that there is no nearby point with a higher log
// likelihood
for (int i = 0; i < 1000; i++) {
int size = finalWeights.getNumberOfComponents();
ConcatVector randomDirection = new ConcatVector(size);
for (int j = 0; j < size; j++) {
if (j == constraintComponent) continue;
double[] dense = new double[finalWeights.isComponentSparse(j) ? finalWeights.getSparseIndex(j) + 1 : finalWeights.getDenseComponent(j).length];
for (int k = 0; k < dense.length; k++) {
dense[k] = (r.nextDouble() - 0.5) * 1.0e-3;
}
randomDirection.setDenseComponent(j, dense);
}
ConcatVector randomPerturbation = finalWeights.deepClone();
randomPerturbation.addVectorInPlace(randomDirection, 1.0);
double randomPerturbedLogLikelihood = getValueSum(dataset, randomPerturbation, ll, l2regularization);
// Check that we're within a very small margin of error (around 3 decimal places) of the randomly
// discovered value
if (logLikelihood < randomPerturbedLogLikelihood - (1.0e-3 * Math.max(1.0, Math.abs(logLikelihood)))) {
System.err.println("Thought optimal point was: " + logLikelihood);
System.err.println("Discovered better point: " + randomPerturbedLogLikelihood);
}
assertTrue(logLikelihood >= randomPerturbedLogLikelihood - (1.0e-3 * Math.max(1.0, Math.abs(logLikelihood))));
}
}
*/
private <T> double getValueSum(T[] dataset, ConcatVector weights, AbstractDifferentiableFunction<T> fn, double l2regularization) {
double value = 0.0;
for (T t : dataset) {
value += fn.getSummaryForInstance(t, weights, new ConcatVector(0));
}
return (value / dataset.length) - (weights.dotProduct(weights) * l2regularization);
}
}