/*******************************************************************************
* Copyright 2012 University of Southern California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* This code was developed by the Information Integration Group as part
* of the Karma project at the Information Sciences Institute of the
* University of Southern California. For more information, publications,
* and related projects, please see: http://www.isi.edu/integration
******************************************************************************/
package edu.isi.karma.modeling.semantictypes.mycrf.optimization ;
import edu.isi.karma.modeling.semantictypes.mycrf.common.Constants;
import edu.isi.karma.modeling.semantictypes.mycrf.crfmodel.CRFModelAbstract;
import edu.isi.karma.modeling.semantictypes.mycrf.globaldata.GlobalDataAbstract;
import edu.isi.karma.modeling.semantictypes.mycrf.graph.GraphInterface;
import edu.isi.karma.modeling.semantictypes.mycrf.math.Matrix;
import edu.isi.karma.modeling.semantictypes.myutils.Prnt;
/**
* This class performs the task of finding the step size for the gradient during the optimization process.
*
* @author amangoel
*
*/
public class BacktrackingLineSearch {
GlobalDataAbstract globalData ;
CRFModelAbstract crfModel ;
int dim ;
public BacktrackingLineSearch(CRFModelAbstract crfModel, GlobalDataAbstract globalData) {
this.crfModel = crfModel ;
this.globalData = globalData ;
dim = crfModel.weights.length ;
}
double findStep(double[] searchDir, double[] gradient, double currError) {
double[] currWeights = new double[dim] ;
double lam1=0.0, lam2=0.0, tmplam=0.0, lammin=0.0;
double f1=0.0, f2=0.0 ;
double slope = 0.0 ;
System.arraycopy(crfModel.weights, 0, currWeights, 0, dim) ;
double searchDirNorm = Matrix.norm(searchDir) ;
if(searchDirNorm > Constants.BACKTRACKINGLINESEARCH_MAX_STEP) {
double ratio = Constants.BACKTRACKINGLINESEARCH_MAX_STEP / searchDirNorm ;
for(int i=0; i<dim; i++)
searchDir[i]*=ratio ;
}
slope = Matrix.dotProduct(gradient, searchDir) ;
double test = 0.0 ;
double tmp = 0.0 ;
for(int i=0; i<dim; i++) {
tmp = Math.abs(searchDir[i]/Math.max(Math.abs(currWeights[i]), 1.0)) ;
if(tmp > test)
test = tmp ;
}
lammin = Constants.TOLX/test ;
int iteration = 1 ;
lam1 = 1.0 ;
while(true) {
// Prnt.prn("BacktrackingLineSearch iteration #" + iteration + " and lambda = " + lam1) ;
for(int i=0;i<dim;i++)
crfModel.weights[i] = currWeights[i] + lam1 * searchDir[i] ;
for(GraphInterface graph : globalData.trainingGraphs) {
graph.computeGraphPotentialAndZ() ;
}
f1 = globalData.errorValue() ;
if(f1 < currError + Constants.ALPHA * lam1 * slope) {
return lam1 ;
}
if(iteration == 1) {
tmplam = -slope / (2 * (f1 - currError - slope)) ;
}
else {
double rhs1 = f1 - currError - slope*lam1 ;
double rhs2 = f2 - currError - slope*lam2 ;
double a = (rhs1/(lam1 * lam1) - rhs2/(lam2 * lam2)) / (lam1 - lam2) ;
double b = ((-rhs1 * lam2 / (lam1 * lam1)) + rhs2 * lam1 / (lam2 * lam2)) / (lam1 - lam2) ;
if(a == 0.0)
tmplam = -slope / (2.0 * b) ;
else {
double disc = b*b - 3*a*slope ;
if(disc < 0) {
Prnt.prn("Returning from Backtracking line search because discriminant is negative") ;
Prnt.endIt("in backtrackinglinesearch, disc is less than 0. exiting.") ;
}
else
tmplam = (-b + Math.sqrt(disc))/(3*a) ;
}
if(tmplam > 0.5 * lam1)
tmplam = 0.5 * lam1 ;
}
if(tmplam < lammin) { // lambda too small. can't move forward
System.arraycopy(currWeights, 0, crfModel.weights, 0, dim) ;
Prnt.prn("Returning because tmplam = " + tmplam + " < lammin = "+ lammin) ;
return 0.0 ;
}
f2 = f1 ;
lam2 = lam1 ;
lam1 = Math.max(tmplam, 0.1*lam1) ;
iteration++ ;
}
}
}