package edu.stanford.nlp.optimization;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.math.ArrayMath;
import java.util.function.Function;
import edu.stanford.nlp.util.Pair;
/**
* Stochastic Gradient Descent To Quasi Newton Minimizer.
*
* An experimental minimizer which takes a stochastic function (one implementing AbstractStochasticCachingDiffFunction)
* and executes SGD for the first couple passes, During the final iterations a series of approximate hessian vector
* products are built up... These are then passed to the QNMinimizer so that it can start right up without the typical
* delay.
*
* @author <a href="mailto:akleeman@stanford.edu">Alex Kleeman</a>
* @version 1.0
* @since 1.0
*/
public class ScaledSGDMinimizer<Q extends AbstractStochasticCachingDiffFunction> extends StochasticMinimizer<Q> {
/** A logger for this class */
private static final Redwood.RedwoodChannels log = Redwood.channels(ScaledSGDMinimizer.class);
private static int method = 1; // 0=MinErr 1=Bradley
public List<double[]> yList = null;
public List<double[]> sList = null;
public double[] diag;
private double fixedGain = 0.99;
private static int pairMem = 20;
private double aMax = 1e6;
public double tuneFixedGain(edu.stanford.nlp.optimization.Function function, double[] initial, long msPerTest,double fixedStart){
double[] xtest = new double[initial.length];
double fOpt = 0.0;
double factor = 1.2;
double min = Double.POSITIVE_INFINITY;
this.maxTime = msPerTest;
double prev = Double.POSITIVE_INFINITY;
// check for stochastic derivatives
if (!(function instanceof AbstractStochasticCachingDiffFunction)) {
throw new UnsupportedOperationException();
}
AbstractStochasticCachingDiffFunction dfunction = (AbstractStochasticCachingDiffFunction) function;
int it = 1;
boolean toContinue = true;
double f = fixedStart;
do{
System.arraycopy(initial, 0, xtest, 0, initial.length);
log.info("");
this.fixedGain = f;
log.info("Testing with batchsize: " + bSize + " gain: " + gain + " fixedGain: " + nf.format(fixedGain) );
this.numPasses = 10000;
this.minimize(function, 1e-100, xtest);
double result = dfunction.valueAt(xtest);
if(it == 1){
f = f/factor;
}
if( result < min ){
min = result;
fOpt = this.fixedGain;
f = f/factor;
prev = result;
}else if(result < prev){
f =f/factor;
prev = result;
}else if(result > prev){
toContinue = false;
}
it += 1;
log.info("");
log.info("Final value is: " + nf.format(result));
log.info("Optimal so far is: fixedgain: " + fOpt);
} while(toContinue);
return fOpt;
}
private class setFixedGain implements PropertySetter<Double>{
ScaledSGDMinimizer parent = null;
public setFixedGain(ScaledSGDMinimizer min){parent = min;}
public void set(Double in){
parent.fixedGain = in ;
}
}
@Override
public Pair<Integer,Double> tune( edu.stanford.nlp.optimization.Function function,double[] initial, long msPerTest){
this.quiet = true;
for(int i =0;i<2; i++){
this.fixedGain = tuneDouble(function,initial,msPerTest,new setFixedGain(this),0.1,1.0);
gain = tuneGain(function,initial,msPerTest,1e-7,1.0);
bSize = tuneBatch(function,initial,msPerTest,1);
log.info("Results: fixedGain: " + nf.format(this.fixedGain) + " gain: " + nf.format(gain) + " batch " + bSize );
}
return new Pair<>(bSize, gain);
}
@Override
public void shutUp() {
this.quiet = true;
}
public void setBatchSize(int batchSize) {
bSize = batchSize;
}
public ScaledSGDMinimizer(double SGDGain, int batchSize, int sgdPasses){
this(SGDGain,batchSize,sgdPasses, 1,false);
}
public ScaledSGDMinimizer(double SGDGain, int batchSize, int sgdPasses, int method){
this(SGDGain,batchSize,sgdPasses, method,false);
}
public ScaledSGDMinimizer(double SGDGain, int batchSize, int sgdPasses, int method, boolean outputToFile) {
bSize = batchSize;
gain = SGDGain;
this.numPasses = sgdPasses;
ScaledSGDMinimizer.method = method;
this.outputIterationsToFile = outputToFile;
}
public ScaledSGDMinimizer(double SGDGain, int batchSize){
this(SGDGain,batchSize,50);
}
public void setMaxTime(Long max){
maxTime = max;
}
@Override
public String getName(){
int g = (int) (gain*1000.0);
int f = (int) (fixedGain *1000.0);
return "ScaledSGD" + bSize + "_g" + g + "_f" + f ;
}
@Override
protected void takeStep(AbstractStochasticCachingDiffFunction dfunction){
for(int i = 0; i < x.length; i++){
double thisGain = fixedGain*gainSchedule(k,5*numBatches)/(diag[i]);
newX[i] = x[i] - thisGain*grad[i];
}
//Get a new pair...
say(" A ");
double[] s;
double[] y;
if (pairMem > 0 && sList.size() == pairMem || sList.size() == pairMem) {
s = sList.remove(0);
y = yList.remove(0);
} else {
s = new double[x.length];
y = new double[x.length];
}
s = ArrayMath.pairwiseSubtract(newX, x);
dfunction.recalculatePrevBatch = true;
System.arraycopy(dfunction.derivativeAt(newX,bSize),0, y,0,grad.length);
ArrayMath.pairwiseSubtractInPlace(y,newGrad); // newY = newY-newGrad
double[] comp = new double[x.length];
sList.add(s);
yList.add(y);
updateDiag(diag, s, y);
}
@Override
protected void init(AbstractStochasticCachingDiffFunction func){
diag = new double[x.length];
memory = 1;
for(int i=0;i<x.length;i++){diag[i]=fixedGain/gain;}
sList = new ArrayList<>();
yList = new ArrayList<>();
}
private void updateDiag(double[] diag,double[] s,double[] y){
if(method == 0){
updateDiagMinErr(diag,s,y);
}else if(method == 1){
updateDiagBFGS(diag,s,y);
}
}
public void updateDiagBFGS(double[] diag,double[] s,double[] y){
double sDs = 0.0;
double sy = 0.0;
for(int i=0;i<s.length;i++){
sDs += s[i]*diag[i]*s[i];
sy += s[i]*y[i];
}
say("B");
double[] newDiag = new double[s.length];
boolean updateDiag = true;
for(int i=0;i<s.length;i++){
newDiag[i] = (1-diag[i]*s[i]*s[i]/sDs)*diag[i] + y[i]*y[i]/sy;
if(newDiag[i] < 0){updateDiag = false;break;}
}
if(updateDiag){
System.arraycopy(newDiag, 0, diag, 0, s.length);
}else{
say("!");
}
}
private void updateDiagMinErr(double[] diag,double[] s,double[] y){
double low = 0.0;
double high = 0.0;
for(int i=0;i<s.length;i++){
double tmp = s[i] * (y[i] - diag[i]);
high += tmp*tmp;
}
say("M");
double alpha = Math.sqrt((ArrayMath.norm(y)/ArrayMath.norm(s))) *Math.sqrt(( 50.0/ (50.0 + k) ));
alpha = alpha*Math.sqrt(ArrayMath.average(diag));
say(" alpha " + nf.format(alpha));
high = Math.sqrt(high)/(2*alpha);
Function<Double,Double> func = new lagrange(s,y,diag,alpha);
double lamStar;
if( func.apply(low) > 0 ){
lamStar = getRoot(func,low,high);
} else{
lamStar = 0.0;
say(" * ");
}
for(int i=0;i<s.length;i++){
diag[i] = ( Math.abs(y[i]*s[i]) + 2*lamStar*diag[i])/(s[i]*s[i] + 1e-8 + 2*lamStar);
//diag[i] = (y[i]*s[i] + 2*lamStar*diag[i])/(s[i]*s[i] + 2*lamStar);
if (diag[i] <= 1.0/aMax) {
diag[i] = 1.0/gain;
}
}
}
private double getRoot(Function<Double,Double> func, double lower, double upper){
double mid = 0.5*(lower + upper);
double TOL = 1e-8;
double skew = 0.4;
int count = 0;
if(func.apply(upper) > 0 || func.apply(lower) < 0){
say("LOWER AND UPPER SUPPLIED TO GET ROOT DO NOT BOUND THE ROOT.");
}
double fval = func.apply(mid);
while( Math.abs(fval) > TOL ){
count += 1;
if( fval > 0 ){
lower = mid;
} else if( fval < 0){
upper = mid;
}
mid = skew*lower + (1-skew)*upper;
fval = func.apply(mid);
if (count > 100){
break;
}
}
say( " " + nf.format(mid) + " f" + nf.format(fval) );
return mid;
}
static class lagrange implements Function<Double,Double> {
private final double[] s;
private final double[] y;
private final double[] d;
private final double a;
public lagrange(double[] s, double[] y, double[] d, double a){
this.s = s;
this.y = y;
this.d = d;
this.a = a;
}
@Override
public Double apply(Double lam) {
double val = 0.0;
for(int i=0;i<s.length;i++){
double tmp = (y[i]*s[i] + 2*lam*d[i])/(s[i]*s[i] + 2*lam) - d[i];
val += tmp*tmp;
}
val -= a*a;
return val;
}
} // end static class lagrange
public static class Weights implements Serializable {
public double [] w;
public double [] d;
private static final long serialVersionUID = 814182172645533781L;
public Weights(double[] wt){
w = wt;
}
public Weights(double[] wt,double[] di){
w = wt;
d = di;
}
}
public static void serializeWeights(String serializePath,double[] weights) {
serializeWeights(serializePath,weights,null);
}
public static void serializeWeights(String serializePath,double[] weights,double[] diag) {
log.info("Serializing weights to " + serializePath + "...");
try {
Weights out = new Weights(weights,diag);
IOUtils.writeObjectToFile(out, serializePath);
} catch (Exception e) {
log.info("Error serializing to " + serializePath);
e.printStackTrace();
}
}
public static double[] getWeights(String loadPath) throws IOException, ClassCastException, ClassNotFoundException {
log.info("Loading weights from " + loadPath + "...");
double[] wt;
Weights w;
w = IOUtils.readObjectFromFile(loadPath);
wt = w.w;
return wt;
}
public static double[] getDiag(String loadPath) throws IOException, ClassCastException, ClassNotFoundException {
log.info("Loading weights from " + loadPath + "...");
double[] diag;
Weights w;
w = IOUtils.readObjectFromFile(loadPath);
diag = w.d;
return diag;
}
}