package edu.stanford.nlp.optimization; import java.util.Random; import junit.framework.TestCase; import edu.stanford.nlp.math.ArrayMath; /** * This class both tests a particular DiffFunction and provides a basis * for testing whether any DiffFunction's derivative is correct. * * @author Galen Andrew * @author Christopher Manning */ public class DiffFunctionTest extends TestCase { // private static final double EPS = 1e-6; private static final Random r = new Random(); private static double[] estimateGradient(Function f, double[] x, int[] testIndices, double eps) { double[] lowAnswer = new double[testIndices.length]; double[] answer = new double[testIndices.length]; for (int i = 0; i < testIndices.length; i++) { double orig = x[testIndices[i]]; x[testIndices[i]] -= eps; lowAnswer[i] = f.valueAt(x); x[testIndices[i]] = orig + eps; answer[i] = f.valueAt(x); x[testIndices[i]] = orig; // restore value //System.err.println("new x is "+x[testIndices[i]]); answer[i] = (answer[i] - lowAnswer[i]) / (2.0 * eps); // System.err.print("."); //System.err.print(" "+answer[i]); } // System.err.println("Gradient estimate is: " + Arrays.toString(answer)); return answer; } public static void gradientCheck(DiffFunction f) { for (int deg = -2; deg > -7; deg--) { double eps = Math.pow(10, deg); System.err.println("testing for eps " + eps); gradientCheck(f, eps); } } public static void gradientCheck(DiffFunction f, double eps) { double[] x = new double[f.domainDimension()]; for (int i = 0; i < x.length; i++) { x[i] = Math.random() - 0.5; // 0.03; (i - 0.5) * 4; } gradientCheck(f, x, eps); } public static void gradientCheck(DiffFunction f, double[] x, double eps) { // just check a few dimensions int numDim = Math.min(10, x.length); int[] ind = new int[numDim]; if (numDim == x.length) { for (int i = 0; i < ind.length; i++) { ind[i] = i; } } else { ind[0] = 0; ind[1] = x.length - 1; for (int i = 2; i < ind.length; i++) { ind[i] = r.nextInt(x.length - 2) + 1; // ind[i] = i; } } gradientCheck(f, x, ind, eps); } public static void gradientCheck(DiffFunction f, double[] x, int[] ind, double eps) { // System.err.print("Testing grad <"); double[] testGrad = estimateGradient(f, x, ind, eps); // System.err.println(">"); double[] fullGrad = f.derivativeAt(x); double[] fGrad = new double[ind.length]; for (int i = 0; i < ind.length; i++) { fGrad[i] = fullGrad[ind[i]]; } double[] diff = ArrayMath.pairwiseSubtract(testGrad, fGrad); System.err.println("1-norm:" + ArrayMath.norm_1(diff)); assertEquals(0.0, ArrayMath.norm_1(diff), 2 * eps); System.err.println("2-norm:" + ArrayMath.norm(diff)); assertEquals(0.0, ArrayMath.norm(diff), 2 * eps); System.err.println("inf-norm:" + ArrayMath.norm_inf(diff)); assertEquals(0.0, ArrayMath.norm_inf(diff), 2 * eps); System.err.println("pearson:" + ArrayMath.pearsonCorrelation(testGrad,fGrad)); assertEquals(1.0, ArrayMath.pearsonCorrelation(testGrad,fGrad), 2 * eps); // This could exception if all numbers were the same and so there is no standard deviation. // ArrayMath.standardize(fGrad); // ArrayMath.standardize(testGrad); // System.err.printf("test: %s%n", Arrays.toString(testGrad)); // System.err.printf("full: %s%n",Arrays.toString(fGrad)); } public void testXSquaredPlusOne() { gradientCheck(new DiffFunction() { // this function does on a large vector x^2+1 @Override public double[] derivativeAt(double[] x) { return ArrayMath.add(ArrayMath.multiply(x, 2), 1); } @Override public double valueAt(double[] x) { return ArrayMath.innerProduct(x, ArrayMath.add(x, 1)); } @Override public int domainDimension() { return 10000; } }); } }