package edu.stanford.nlp.optimization;
/**
* Function for stochastic calculations that does update in place
* (instead of maintaining and returning the derivative).
*
* Weights are represented by an array of doubles and a scalar
* that indicates how much to scale all weights by.
* This allows all weights to be scaled by just modifying the scalar.
*
* @author Angel Chang
*/
public abstract class AbstractStochasticCachingDiffUpdateFunction
extends AbstractStochasticCachingDiffFunction {
protected boolean skipValCalc = false;
/**
* Gets a random sample (this is sampling with replacement).
*
* @param sampleSize number of samples to generate
* @return array of indices for random sample of sampleSize
*/
public int[] getSample(int sampleSize) {
int[] sample = new int[sampleSize];
for (int i = 0; i < sampleSize; i++) {
sample[i] = randGenerator.nextInt(this.dataDimension()); // Just generate a random index
}
return sample;
}
/**
* Computes value of function for specified value of x (scaled by xScale)
* only over samples indexed by batch.
*
* @param x unscaled weights
* @param xScale how much to scale x by when performing calculations
* @param batch indices of which samples to compute function over
* @return value of function at specified x (scaled by xScale) for samples
*/
public abstract double valueAt(double[] x, double xScale, int[] batch);
public double valueAt(double[] x, double xScale, int batchSize) {
getBatch(batchSize);
return valueAt(x, xScale, thisBatch);
}
/**
* Performs stochastic update of weights x (scaled by xScale) based
* on samples indexed by batch.
*
* @param x unscaled weights
* @param xScale how much to scale x by when performing calculations
* @param batch indices of which samples to compute function over
* @param gain how much to scale adjustments to x
* @return value of function at specified x (scaled by xScale) for samples
*/
public abstract double calculateStochasticUpdate(double[] x, double xScale, int[] batch, double gain);
/**
* Performs stochastic update of weights x (scaled by xScale) based
* on next batch of batchSize.
*
* @param x unscaled weights
* @param xScale how much to scale x by when performing calculations
* @param batchSize number of samples to pick next
* @param gain how much to scale adjustments to x
* @return value of function at specified x (scaled by xScale) for samples
*/
public double calculateStochasticUpdate(double[] x, double xScale, int batchSize, double gain) {
getBatch(batchSize);
return calculateStochasticUpdate(x, xScale, thisBatch, gain);
}
/**
* Performs stochastic gradient calculation based
* on samples indexed by batch and does not apply regularization.
* Does not update the parameter values.
* Typically stores derivative information for later access.
*
* @param x Unscaled weights
* @param batch Indices of which samples to compute function over
*/
public abstract void calculateStochasticGradient(double[] x, int[] batch);
/**
* Performs stochastic gradient updates based
* on samples indexed by batch and do not apply regularization.
*
* @param x unscaled weights
* @param batchSize number of samples to pick next
*/
public void calculateStochasticGradient(double[] x, int batchSize) {
getBatch(batchSize);
calculateStochasticGradient(x, thisBatch);
}
}