/**
* Copyright 2014, Emory University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.emory.clir.clearnlp.classification.trainer;
import edu.emory.clir.clearnlp.classification.instance.IntInstance;
import edu.emory.clir.clearnlp.classification.model.SparseModel;
import edu.emory.clir.clearnlp.classification.model.StringModel;
import edu.emory.clir.clearnlp.classification.vector.SparseFeatureVector;
import edu.emory.clir.clearnlp.util.DSUtils;
import edu.emory.clir.clearnlp.util.MathUtils;
/**
* AdaGrad algorithm using hinge loss.
* @since 3.0.0
* @author Jinho D. Choi ({@code jinho.choi@emory.edu})
*/
public class AdaGradSVM extends AbstractAdaGrad
{
/**
* @param alpha the learning rate.
* @param rho the smoothing denominator.
*/
public AdaGradSVM(SparseModel model, boolean average, double alpha, double rho, double bias)
{
super(model, average, alpha, rho, bias);
}
/**
* @param alpha the learning rate.
* @param rho the smoothing denominator.
*/
public AdaGradSVM(StringModel model, int labelCutoff, int featureCutoff, boolean average, double alpha, double rho, double bias)
{
super(model, labelCutoff, featureCutoff, average, alpha, rho, bias);
}
@Override
protected boolean update(IntInstance instance, int averageCount)
{
int bestLabel = getBestLabel(instance);
if (!instance.isLabel(bestLabel))
{
updateGradients(instance, instance.getLabel(), bestLabel);
updateWeights (instance, instance.getLabel(), bestLabel, averageCount);
return true;
}
return false;
}
private int getBestLabel(IntInstance instance)
{
double[] scores = w_vector.getScores(instance.getFeatureVector());
scores[instance.getLabel()] -= 1d;
return DSUtils.maxIndex(scores);
}
private void updateGradients(IntInstance instance, int yp, int yn)
{
SparseFeatureVector x = instance.getFeatureVector();
int i, xi, len = x.size();
double vi;
// bias
updateGradients(yp, yn, 0, MathUtils.sq(d_bias));
for (i=0; i<len; i++)
{
xi = x.getIndex(i);
vi = MathUtils.sq(x.getWeight(i));
updateGradients(yp, yn, xi, vi);
}
}
private void updateGradients(int yp, int yn, int xi, double vi)
{
if (w_vector.isBinaryLabel())
{
d_gradients[xi] += vi;
}
else
{
d_gradients[w_vector.getWeightIndex(yp, xi)] += vi;
d_gradients[w_vector.getWeightIndex(yn, xi)] += vi;
}
}
private void updateWeights(IntInstance instance, int yp, int yn, int averageCount)
{
SparseFeatureVector x = instance.getFeatureVector();
int i, xi, len = x.size();
double vi;
// bias
updateWeights(yp, yn, averageCount, 0, d_bias);
for (i=0; i<len; i++)
{
xi = x.getIndex(i);
vi = x.getWeight(i);
updateWeights(yp, yn, averageCount, xi, vi);
}
}
private void updateWeights(int yp, int yn, int averageCount, int xi, double vi)
{
if (w_vector.isBinaryLabel())
{
if (yp == 1) vi *= -1;
updateWeight(xi, vi, averageCount);
}
else
{
updateWeight(w_vector.getWeightIndex(yp, xi), vi, averageCount);
updateWeight(w_vector.getWeightIndex(yn, xi), -vi, averageCount);
}
}
@Override
public String trainerInfo()
{
return getTrainerInfo("SVM");
}
}