package edu.stanford.nlp.optimization; import edu.stanford.nlp.util.logging.Redwood; import edu.stanford.nlp.math.ArrayMath; import java.util.Random; import java.util.ArrayList; import java.util.List; import java.text.NumberFormat; import java.text.DecimalFormat; import java.io.PrintWriter; import java.io.FileOutputStream; import java.io.IOException; /** * @author Alex Kleeman */ public class StochasticDiffFunctionTester { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(StochasticDiffFunctionTester.class); static private double EPS = 1e-8; static private boolean quiet = false; protected int testBatchSize; protected int numBatches; protected AbstractStochasticCachingDiffFunction thisFunc; double[] approxGrad,fullGrad,diff,Hv,HvFD,v,curGrad,gradFD; double diffNorm,diffValue,fullValue,approxValue,diffGrad,maxGradDiff = 0.0,maxHvDiff = 0.0; Random generator; private static NumberFormat nf = new DecimalFormat("00.0"); public StochasticDiffFunctionTester(Function function){ // check for derivatives if (!(function instanceof AbstractStochasticCachingDiffFunction)) { log.info("Attempt to test non stochastic function using StochasticDiffFunctionTester"); throw new UnsupportedOperationException(); } thisFunc = (AbstractStochasticCachingDiffFunction) function; // Make sure the function is Stochastic generator = new Random(System.currentTimeMillis()); // used to generate random test vectors // Look for a good batchSize to test with by getting factors testBatchSize = (int) getTestBatchSize(thisFunc.dataDimension()); // Again make sure that our calculated batchSize is actually valid if(testBatchSize < 0 || testBatchSize > thisFunc.dataDimension() || (thisFunc.dataDimension()%testBatchSize != 0)){ log.info("Invalid testBatchSize found, testing aborted. Data size: " + thisFunc.dataDimension() + " batchSize: " + testBatchSize); System.exit(1); } numBatches = thisFunc.dataDimension()/testBatchSize; sayln("StochasticDiffFunctionTester created with:"); sayln(" data dimension = " + thisFunc.dataDimension()); sayln(" batch size = " + testBatchSize); sayln(" number of batches = " + numBatches); } private void sayln(String s) { if (!quiet) { log.info(s); } } // Get Prime Factors of an integer .... // Code was originally from http://www.idinews.com/sourcecode/IntegerFunction.html // Decompose integer into prime factors // ------------------------------------ // Upon return result[0] contains the number of factors (0 if N is 0), and // result[1] . . . result[result[0]] contain the factors in ascending order. private static long[] primeFactors(long N) {long [] fctr = new long[64]; // Result array long n = Math.abs(N); // Guard against negative short fctrIndex = 0; if (n > 0) { // Guard against zero // First do special cases 2 and 3 while (n % 2 == 0) {fctr[++fctrIndex] = 2; n /= 2;} while (n % 3 == 0) {fctr[++fctrIndex] = 3; n /= 3;} // Then every 6n-1 and 6n+1 until the divisor exceeds the square root // of the current quotient. NOTE: Some trial divisors will be // non-primes, e.g. 25, 35, 49, 55. They have no effect, however, // since their prime factors will already have been tried. for (int k = 5; k*k <= n; k += 6) for (int dvsr = k; dvsr <= k+2; dvsr+=2) { while (n % dvsr == 0) {fctr[++fctrIndex] = dvsr; n /= dvsr;} } if (n > 1) fctr[++fctrIndex] = n; // Store final factor, if any } fctr[0] = fctrIndex; // Store number of factors return fctr; } /** * getTestBatchSize - This function takes as input the size of the data and returns the largest factor of the data size * this is done so that when testing the function we are gaurenteed to have equally sized batches, and that the fewest * number of evaluations needs to be made in order to test the function. * * @param size - The size of the current data set * @return The largest factor of the data size */ private static long getTestBatchSize(long size){ long testBatchSize = 1; long[] factors = primeFactors( size ); long factorCount = factors[0]; // Calculate the batchsize for the factors if( factorCount == 0 ){ log.info("Attempt to test function on data of prime dimension. This would involve a batchSize of 1 and may take a very long time."); System.exit(1); }else if (factorCount == 2){ testBatchSize = (int) factors[1]; }else { // find the largest factor. for( int f = 1; f< factorCount;f++){ testBatchSize *= factors[f]; } } return testBatchSize; } /** * * This function tests to make sure that the sum of the stochastic calculated gradients is equal to the * full gradient. This requires using ordered sampling, so if the ObjectiveFunction itself randomizes * the inputs this function will likely fail. * * * @param x is the point to evaluate the function at * @param functionTolerance is the tolerance to place on the infinity norm of the gradient and value * @return boolean indicating success or failure. */ public boolean testSumOfBatches(double[] x, double functionTolerance){ boolean ret = false; log.info("Making sure that the sum of stochastic gradients equals the full gradient"); AbstractStochasticCachingDiffFunction.SamplingMethod tmpSampleMethod = thisFunc.sampleMethod; StochasticCalculateMethods tmpMethod = thisFunc.method; //Make sure that our function is using ordered sampling. Otherwise we have no gaurentees. thisFunc.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.Ordered; if(thisFunc.method==StochasticCalculateMethods.NoneSpecified){ log.info("No calculate method has been specified"); } approxValue = 0; approxGrad = new double[x.length]; curGrad = new double[x.length]; fullGrad = new double[x.length]; double percent = 0.0; //This loop runs through all the batches and sums of the calculations to compare against the full gradient for (int i = 0; i < numBatches ; i ++){ percent = 100*((double) i)/(numBatches); // update the value approxValue += thisFunc.valueAt(x,v,testBatchSize); // update the gradient thisFunc.returnPreviousValues = true; System.arraycopy(thisFunc.derivativeAt(x,v,testBatchSize ), 0,curGrad, 0, curGrad.length); //Update Approximate approxGrad = ArrayMath.pairwiseAdd(approxGrad,curGrad); double norm = ArrayMath.norm(approxGrad); System.err.printf("%5.1f percent complete %6.2f \n",percent,norm); } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // Get the full gradient and value, these should equal the approximates // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! log.info("About to calculate the full derivative and value"); System.arraycopy(thisFunc.derivativeAt(x),0,fullGrad,0,fullGrad.length); thisFunc.returnPreviousValues = true; fullValue = thisFunc.valueAt(x); diff = new double[x.length]; if( (ArrayMath.norm_inf(diff = ArrayMath.pairwiseSubtract(fullGrad,approxGrad))) < functionTolerance){ sayln(""); sayln("Success: sum of batch gradients equals full gradient"); ret = true; }else{ diffNorm = ArrayMath.norm(diff); sayln(""); sayln("Failure: sum of batch gradients minus full gradient has norm " + diffNorm); ret = false; } if(Math.abs(approxValue - fullValue) < functionTolerance){ sayln(""); sayln("Success: sum of batch values equals full value"); ret = true; }else{ sayln(""); sayln("Failure: sum of batch values minus full value has norm " + Math.abs(approxValue - fullValue)); ret = false; } thisFunc.sampleMethod = tmpSampleMethod; thisFunc.method = tmpMethod; return ret; } /** * * This function tests to make sure that the sum of the stochastic calculated gradients is equal to the * full gradient. This requires using ordered sampling, so if the ObjectiveFunction itself randomizes * the inputs this function will likely fail. * * * @param x is the point to evaluate the function at * @param functionTolerance is the tolerance to place on the infinity norm of the gradient and value * @return boolean indicating success or failure. */ public boolean testDerivatives(double[] x, double functionTolerance){ boolean ret = false; boolean compareHess = true; log.info("Making sure that the stochastic derivatives are ok."); AbstractStochasticCachingDiffFunction.SamplingMethod tmpSampleMethod = thisFunc.sampleMethod; StochasticCalculateMethods tmpMethod = thisFunc.method; //Make sure that our function is using ordered sampling. Otherwise we have no gaurentees. thisFunc.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.Ordered; if(thisFunc.method==StochasticCalculateMethods.NoneSpecified){ log.info("No calculate method has been specified"); } else if( !thisFunc.method.calculatesHessianVectorProduct() ){ compareHess = false; } approxValue = 0; approxGrad = new double[x.length]; curGrad = new double[x.length]; Hv = new double[x.length]; double percent = 0.0; //This loop runs through all the batches and sums of the calculations to compare against the full gradient for (int i = 0; i < numBatches ; i ++){ percent = 100*((double) i)/(numBatches); //Can't figure out how to get a carriage return??? ohh well System.err.printf("%5.1f percent complete\n",percent); // update the "hopefully" correct Hessian thisFunc.method = tmpMethod; System.arraycopy(thisFunc.HdotVAt(x,v,testBatchSize),0,Hv,0,Hv.length); // Now get the hessian through finite difference thisFunc.method = StochasticCalculateMethods.ExternalFiniteDifference; System.arraycopy(thisFunc.derivativeAt(x,v,testBatchSize ), 0,gradFD, 0, gradFD.length); thisFunc.recalculatePrevBatch = true; System.arraycopy(thisFunc.HdotVAt(x,v,gradFD,testBatchSize),0,HvFD,0,HvFD.length); //Compare the difference double DiffHv = ArrayMath.norm_inf(ArrayMath.pairwiseSubtract(Hv,HvFD)); //Keep track of the biggest H.v error if (DiffHv > maxHvDiff){maxHvDiff = DiffHv;} } if( maxHvDiff < functionTolerance){ sayln(""); sayln("Success: Hessian approximations lined up"); ret = true; }else{ sayln(""); sayln("Failure: Hessian approximation at somepoint was off by " + maxHvDiff); ret = false; } thisFunc.sampleMethod = tmpSampleMethod; thisFunc.method = tmpMethod; return ret; } /* This function is used to get a lower bound on the condition number. as it stands this is pretty straight forward: a random point (x) and vector (v) are generated, the Raleigh quotient ( v.H(x).v / v.v ) is then taken which provides both a lower bound on the largest eigenvalue, and an upper bound on the smallest eigenvalue. This can then be used to come up with a lower bound on the condition number of the hessian. */ public double testConditionNumber(int samples){ double maxSeen = 0.0; double minSeen = 0.0; double[] thisV = new double[ thisFunc.domainDimension() ]; double[] thisX = new double[thisV.length]; gradFD = new double[thisV.length]; HvFD = new double[thisV.length]; double thisVHV; boolean isNeg = false; boolean isPos = false; boolean isSemi = false; thisFunc.method = StochasticCalculateMethods.ExternalFiniteDifference; for(int j=0;j<samples;j++){ for (int i=0; i< thisV.length; i++){ thisV[i] = generator.nextDouble(); } for (int i=0; i< thisX.length; i++){ thisX[i] = generator.nextDouble(); } log.info("Evaluating Hessian Product"); System.arraycopy(thisFunc.derivativeAt(thisX,thisV,testBatchSize ), 0,gradFD, 0, gradFD.length); thisFunc.recalculatePrevBatch = true; System.arraycopy(thisFunc.HdotVAt(thisX,thisV,gradFD,testBatchSize),0,HvFD,0,HvFD.length); thisVHV = ArrayMath.innerProduct(thisV,HvFD); if( Math.abs(thisVHV) > maxSeen){ maxSeen = Math.abs(thisVHV); } if( Math.abs(thisVHV) < minSeen){ minSeen = Math.abs(thisVHV); } if( thisVHV < 0 ){ isNeg = true; } if( thisVHV > 0){ isPos = true; } if( thisVHV ==0 ){ isSemi = true; } log.info("It:" + j + " C:" + maxSeen/minSeen + "N:" + isNeg + "P:" + isPos + "S:" + isSemi); } System.out.println("Condition Number of: " + maxSeen/minSeen); System.out.println("Is negative: " + isNeg); System.out.println("Is positive: " + isPos); System.out.println("Is semi: " + isSemi); return maxSeen/minSeen; } public double[] getVariance(double[] x){ return getVariance(x,testBatchSize); } public double[] getVariance(double[] x, int batchSize){ double[] ret = new double[4]; double[] fullHx = new double[thisFunc.domainDimension()]; double[] thisHx = new double[x.length]; double[] thisGrad = new double[x.length]; List<double[]> HxList = new ArrayList<>(); /* PrintWriter file = null; NumberFormat nf = new DecimalFormat("0.000E0"); try{ file = new PrintWriter(new FileOutputStream("var.out"),true); } catch (IOException e){ log.info("Caught IOException outputing List to file: " + e.getMessage()); System.exit(1); } */ //get the full hessian thisFunc.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.Ordered; System.arraycopy(thisFunc.derivativeAt(x,x,thisFunc.dataDimension()),0,thisGrad,0,thisGrad.length); System.arraycopy(thisFunc.HdotVAt(x,x,thisGrad,thisFunc.dataDimension()),0,fullHx,0,fullHx.length); double fullNorm = ArrayMath.norm(fullHx); double hessScale = ((double) thisFunc.dataDimension()) / ((double) batchSize); thisFunc.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.RandomWithReplacement; int n = 100; double simDelta; double ratDelta; double simMean = 0; double ratMean = 0; double simS = 0; double ratS = 0; int k = 0; log.info(fullHx[4] +" " + x[4]); for(int i = 0; i<n; i++){ System.arraycopy(thisFunc.derivativeAt(x,x,batchSize),0,thisGrad,0,thisGrad.length); System.arraycopy(thisFunc.HdotVAt(x,x,thisGrad,batchSize),0,thisHx,0,thisHx.length); ArrayMath.multiplyInPlace(thisHx,hessScale); double thisNorm = ArrayMath.norm(thisHx); double sim = ArrayMath.innerProduct(thisHx,fullHx)/(thisNorm*fullNorm); double rat = thisNorm/fullNorm; k += 1; simDelta = sim - simMean; simMean += simDelta/k; simS += simDelta*(sim-simMean); ratDelta = rat-ratMean; ratMean += ratDelta/k; ratS += ratDelta*(rat-ratMean); //file.println( nf.format(sim) + " , " + nf.format(rat)); } double simVar = simS/(k-1); double ratVar = ratS/(k-1); //file.close(); ret[0]=simMean; ret[1]=simVar; ret[2]=ratMean; ret[3]=ratVar; return ret; } public void testVariance(double[] x){ int[] batchSizes = {10,20,35,50,75,150,300,500,750,1000,5000,10000}; double[] varResult; PrintWriter file = null; NumberFormat nf = new DecimalFormat("0.000E0"); try{ file = new PrintWriter(new FileOutputStream("var.out"),true); } catch (IOException e){ log.info("Caught IOException outputing List to file: " + e.getMessage()); System.exit(1); } for(int bSize:batchSizes){ varResult = getVariance(x,bSize); file.println(bSize + "," + nf.format(varResult[0]) + "," + nf.format(varResult[1]) + "," + nf.format(varResult[2]) + "," + nf.format(varResult[3])); log.info("Batch size of: " + bSize + " " + varResult[0] + "," + nf.format(varResult[1]) + "," + nf.format(varResult[2]) + "," + nf.format(varResult[3])); } file.close(); } /* public double getNormVariance(List<double[]> thisList){ double[] ratio = new double[thisList.size()]; double[] mean = new double[thisList.get(0).length]; double sizeInv = 1/( (double) thisList.size() ); for(double[] arr:thisList){ for(int i=0;i<arr.length;i++){ mean[i] += arr[i]*sizeInv; } } double meanNorm = ArrayMath.norm(mean); for(int i=0;i<thisList.size();i++){ ratio[i] = (ArrayMath.norm(thisList.get(i))/ meanNorm); } arrayToFile(ratio,"ratio.out"); return ArrayMath.variance(ratio); } public double getSimVariance(List<double[]> thisList){ double[] ang = new double[thisList.size()]; double[] mean = new double[thisList.get(0).length]; double sizeInv = 1/( (double) thisList.size() ); for(double[] arr:thisList){ for(int i=0;i<arr.length;i++){ mean[i] += arr[i]*sizeInv; } } double meanNorm = ArrayMath.norm(mean); for(int i=0;i<thisList.size();i++){ ang[i] = ArrayMath.innerProduct(thisList.get(i),mean); ang[i] = ang[i]/ ( meanNorm * ArrayMath.norm(thisList.get(i))); } arrayToFile(ang,"angle.out"); return ArrayMath.variance(ang); } */ public void listToFile(List<double[]> thisList,String fileName){ PrintWriter file = null; NumberFormat nf = new DecimalFormat("0.000E0"); try{ file = new PrintWriter(new FileOutputStream(fileName),true); } catch (IOException e){ log.info("Caught IOException outputing List to file: " + e.getMessage()); System.exit(1); } for(double[] element:thisList){ for(double val:element){ file.print(nf.format(val) + " "); } file.println(""); } file.close(); } public void arrayToFile(double[] thisArray,String fileName){ PrintWriter file = null; NumberFormat nf = new DecimalFormat("0.000E0"); try{ file = new PrintWriter(new FileOutputStream(fileName),true); } catch (IOException e){ log.info("Caught IOException outputing List to file: " + e.getMessage()); System.exit(1); } for(double element:thisArray){ file.print(nf.format(element) + " "); } file.close(); } /** * testObjectiveFunction * This function was written to provide a test for accuracy of stochastic objective functions. The test * checks for the following properties: * * 1) The sum of the value over each batch equals the full value * 2) The sum of the gradients over each batch equals the full gradient * 3) The gradient calculated using Incorporated Finite Difference is never more than functionTolerance from the * gradient using External Finite Difference * 4) The hessian vector also does not varry between Incorporated and External Finite Difference * * @param function The function to test * @param x The point to use for testing (v is generated randomly * @param functionTolerance The tolerance */ /* public boolean testObjectiveFunction(Function function, double[] x, double functionTolerance){ approxGrad = new double[x.length]; curGrad = new double[x.length]; approxValue = 0; //Generate the initial vectors for (int i = 0; i < x.length; i ++){ approxGrad[i] = 0; v[i] = generator.nextDouble() ; } //This loop runs through all the batches and sums of the calculations to compare against the full gradient for (int i = 0; i < numBatches ; i ++){ // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // Perform calculation using IncorporatedFiniteDifference // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! dfunction.method = StochasticCalculateMethods.IncorporatedFiniteDifference; // update the value approxValue += dfunction.valueAt(x,v,testBatchSize); // update the gradient dfunction.returnPreviousValues = true; System.arraycopy(dfunction.derivativeAt(x,v,testBatchSize ), 0,curGrad, 0, curGrad.length); // update the Hessian dfunction.returnPreviousValues = true; System.arraycopy(dfunction.HdotVAt(x,v,testBatchSize),0,HvAD,0,HvAD.length); //Update Approximate approxGrad = ArrayMath.pairwiseAdd(approxGrad,curGrad); // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // Perform calculations using external finite difference // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! dfunction.method = StochasticCalculateMethods.ExternalFiniteDifference; dfunction.recalculatePrevBatch = true; System.arraycopy(dfunction.derivativeAt(x,v,testBatchSize ), 0,gradFD, 0, gradFD.length); dfunction.recalculatePrevBatch = true; System.arraycopy(dfunction.HdotVAt(x,v,gradFD,testBatchSize),0,HvFD,0,HvFD.length); double DiffGrad = ArrayMath.norm_inf(ArrayMath.pairwiseSubtract(gradFD,curGrad)); // Keep track of the biggest error. if (DiffGrad > maxGradDiff){maxGradDiff = DiffGrad;} double DiffHv = ArrayMath.norm_inf(ArrayMath.pairwiseSubtract(HvAD,HvFD)); //Keep track of the biggest H.v error if (DiffHv > maxHvDiff){maxHvDiff = DiffHv;} } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // Get the full gradient and value, these should equal the approximates // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! System.arraycopy(dfunction.derivativeAt(x),0,fullGrad,0,fullGrad.length); fullValue = dfunction.valueAt(x); if(ArrayMath.norm_inf(ArrayMath.pairwiseSubtract(fullGrad,approxGrad)) < functionTolerance){ sayln(""); sayln(" Gradient is looking good"); }else{ diff = new double[x.length]; diff = ArrayMath.pairwiseSubtract(approxGrad,fullGrad); diffNorm = ArrayMath.norm(diff); sayln(""); sayln(" Seems there is a problem. Gradient is off by norm of " + diffNorm); }; if( maxGradDiff < functionTolerance ){ sayln(""); sayln(" Both gradients are the same"); }else{ diffValue = approxValue - fullValue; sayln(""); sayln(" Seems there is a problem. The two methods of calculating the gradient are different max |AD-FD|_inf Error of " + maxGradDiff); }; if( Math.abs(fullValue - approxValue) < functionTolerance){ sayln(""); sayln(" Value is looking good"); }else{ diffValue = approxValue - fullValue; sayln(""); sayln(" Seems there is a problem. Value is off by " + diffValue); }; if(maxHvDiff < functionTolerance){ sayln(""); sayln(" Hv Approimations line up well"); }else{ sayln(""); sayln(" Seems there is a problem. Hv approximations aren't quite close enough -- max |AD-FD|_inf Error of " + maxHvDiff); } return true; } */ }