/**
* Copyright 2014, Emory University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.emory.clir.clearnlp.classification.trainer;
import edu.emory.clir.clearnlp.classification.instance.IntInstance;
import edu.emory.clir.clearnlp.classification.model.SparseModel;
import edu.emory.clir.clearnlp.classification.model.StringModel;
import edu.emory.clir.clearnlp.classification.vector.SparseFeatureVector;
import edu.emory.clir.clearnlp.util.MathUtils;
/**
* AdaGrad algorithm using hinge loss.
* @since 3.0.0
* @author Jinho D. Choi ({@code jinho.choi@emory.edu})
*/
public class AdaGradLR extends AbstractAdaGrad
{
/**
* @param alpha the learning rate.
* @param rho the smoothing denominator.
*/
public AdaGradLR(SparseModel model, boolean average, double alpha, double rho, double bias)
{
super(model, average, alpha, rho, bias);
}
/**
* @param alpha the learning rate.
* @param rho the smoothing denominator.
*/
public AdaGradLR(StringModel model, int labelCutoff, int featureCutoff, boolean average, double alpha, double rho, double bias)
{
super(model, labelCutoff, featureCutoff, average, alpha, rho, bias);
}
@Override
protected boolean update(IntInstance instance, int averageCount)
{
double[] gradients = getGradients(instance);
if (gradients[instance.getLabel()] > 0.01)
{
updateGradients(instance, gradients);
updateWeights (instance, gradients, averageCount);
return true;
}
return false;
}
private double[] getGradients(IntInstance instance)
{
double[] scores = w_vector.getScores(instance.getFeatureVector(), true);
int i, size = scores.length;
for (i=0; i<size; i++) scores[i] *= -1;
scores[instance.getLabel()] += 1;
return scores;
}
private void updateGradients(IntInstance instance, double[] gradidents)
{
SparseFeatureVector x = instance.getFeatureVector();
int i, j, xi, len = x.size(), lsize = w_vector.getLabelSize();
double[] g = new double[lsize];
double vi;
for (j=0; j<lsize; j++)
g[j] = MathUtils.sq(gradidents[j]);
updateGradients(g, 0, d_bias);
for (i=0; i<len; i++)
{
xi = x.getIndex(i);
vi = MathUtils.sq(x.getWeight(i));
updateGradients(g, xi, vi);
}
}
private void updateGradients(double[] g, int xi, double vi)
{
int j, lsize = w_vector.getLabelSize();
for (j=0; j<lsize; j++)
d_gradients[w_vector.getWeightIndex(j, xi)] += vi * g[j];
}
private void updateWeights(IntInstance instance, double[] gradients, int averageCount)
{
SparseFeatureVector x = instance.getFeatureVector();
int i, xi, len = x.size();
double vi;
updateWeights(gradients, 0, d_bias, averageCount);
for (i=0; i<len; i++)
{
xi = x.getIndex(i);
vi = x.getWeight(i);
updateWeights(gradients, xi, vi, averageCount);
}
}
private void updateWeights(double[] gradients, int xi, double vi, int averageCount)
{
int j, lsize = w_vector.getLabelSize();
for (j=0; j<lsize; j++)
updateWeight(w_vector.getWeightIndex(j, xi), vi*gradients[j], averageCount);
}
@Override
public String trainerInfo()
{
return getTrainerInfo("LR");
}
}