package func.svm;
import util.linalg.DenseVector;
import util.linalg.Vector;
import shared.DataSet;
import shared.Instance;
import shared.Trainer;
/**
* An implementation of the SMO algorithm.
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public class SequentialMinimalOptimization implements Trainer {
/**
* The tolerance value
*/
private static final double TOLERANCE = 1e-4;
/**
* The error value
*/
private static final double EPS = 1e-4;
/**
* An about zero value
*/
private static final double ZERO = 1e-8;
/**
* The number of iterations
*/
private int iterations;
/**
* The instances
*/
private DataSet examples;
/**
* The kernel function
*/
private Kernel kernel;
/**
* The slack value, all alpha weights
* must be between 0 and c inclusive
*/
private double c;
/**
* The weights on the support vectors
*/
private double[] a;
/**
* The threshold subtracted when
* evaluating the support vector machine
*/
private double b;
/**
* The error cache which is the real output
* - the expected output for non bound examples
* (examples whose a value is not 0 or c)
*/
private double[] e;
/**
* The weight vector (for linear kernels)
*/
private Vector w;
/**
* Make a new SMO trainer
* @param examples the instances to train on
* @param kernel the kernel to use
* @param c the slack value
*/
public SequentialMinimalOptimization(DataSet examples,
Kernel kernel, double c) {
this.c = c;
this.kernel = kernel;
this.examples = examples;
// alpha values are initiallly zero
a = new double[examples.size()];
// as is the threshold
b = 0;
// all the instances are initially bound
// so the error caches is all zero as well
e = new double[examples.size()];
// set up the kernel
kernel.clear();
kernel.setExamples(examples);
// set up the weight vector (if linear)
if (kernel instanceof LinearKernel) {
w = new DenseVector(
new double[examples.get(0).size()]);
}
}
/**
* @see shared.Trainer#train()
*/
public double train() {
// number of alpha values changed this iteration
int numChanged = 0;
// whether or not to loop through all examples
boolean examineAll = true;
// the main training loop
while (numChanged > 0 | examineAll) {
iterations++;
numChanged = 0;
if (examineAll) {
// loop through all the examples
for (int i = 0; i < a.length; i++) {
if (examine(i)) {
numChanged++;
}
}
} else {
// loop through all non bounded
for (int i = 0; i < a.length; i++) {
if (!isBound(i) && examine(i)) {
numChanged++;
}
}
}
// if we just examined all
// we're either done or can
// go back to only looking at non bounded examples
// else if we didn't change anything
// we should check everything
if (examineAll) {
examineAll = false;
} else if (numChanged == 0) {
examineAll = true;
}
}
return 0;
}
/**
* Get the created support vector machine
* @return the support vector machine
*/
public SupportVectorMachine getSupportVectorMachine() {
int supportVectorCount = 0;
for (int i = 0; i < a.length; i++) {
if (a[i] != 0) {
supportVectorCount++;
}
}
Instance[] support =
new Instance[supportVectorCount];
double[] supporta = new double[supportVectorCount];
int j = 0;
for (int i = 0; i < a.length; i++) {
if (a[i] != 0) {
support[j] = examples.get(i);
supporta[j] = a[i];
j++;
}
}
DataSet supportSet = new DataSet(support);
supportSet.setDescription(examples.getDescription());
return new SupportVectorMachine(supportSet, supporta, kernel, b);
}
/**
* Get the number of iterations performed
* @return the number of iterations
*/
public int getNumberOfIterations() {
return iterations;
}
/**
* Examine an example
* @param i the index of the example to examine
* @return true if the example was changed
*/
private final boolean examine(int j) {
// we first check the loose KTT conditions for the example
double ej = error(j);
double rj = ej * examples.get(j).getLabel().getPlusMinus();
// if it doesn't violate the loose KTT conditions
// just return
if (!((rj < -TOLERANCE && a[j] < c) || (rj > TOLERANCE && a[j] > 0))) {
return false;
}
// first we look for a second choice index, i, to take a step with
// if ej is positive we look for the smallest error ei
if (ej > 0) {
int i = -1; double ei = ej;
for (int k = 0; k < a.length; k++) {
if (!isBound(k) && e[k] < ei) {
ei = e[k]; i = k;
}
}
// and try and take a optimization step
if (i != -1 && takeStep(i, j, ej)) {
return true;
}
}
// if ej is negative we look for the largest ei
if (ej < 0) {
int i = -1; double ei = ej;
for (int k = 0; k < a.length; k++) {
if (!isBound(k) && e[k] > ei) {
ei = e[k]; i = k;
}
}
// and try and take a optimization step
if (i != -1 && takeStep(i, j, ej)) {
return true;
}
}
// if the second choice hueristic fails we look
// at all non bound indices, starting from a random point
int startI = (int) Math.random() * a.length;
int i = startI;
do {
if (!isBound(i) && takeStep(i, j, ej)) {
return true;
}
i = (i + 1) % a.length;
} while (i != startI);
// if that fails we look at all of the indices, starting from
// a random point
startI = (int) Math.random() * a.length;
i = startI;
do {
if (takeStep(i, j, ej)) {
return true;
}
i = (i + 1) % a.length;
} while (i != startI);
// we have failed to make progress
return false;
}
/**
* Perform the joint optimization on
* two indices
* @param i the first indice
* @param j the second
* @param ei the error for the second indice
* @return true if we make progress
*/
private final boolean takeStep(int i, int j, double ej) {
// the indices must be different
if (i == j) {
return false;
}
// the target values
double yi = examples.get(i).getLabel().getPlusMinus(),
yj = examples.get(j).getLabel().getPlusMinus();
// the new alpha values being computed
double ai, aj;
// the new threshold
double bnew;
// the upper and lower bounds of the line for aj
double l, h;
// the two target values multiplied together
double s = yi * yj;
// the error for the first index
double ei = error(i);
// depending on whether or not the target values are equal
// compute the l and h with the appropriate formulas
if (s < 0) {
l = Math.max(0, a[j] - a[i]);
h = Math.min(c, c + a[j] - a[i]);
} else {
l = Math.max(0, a[i] + a[j] - c);
h = Math.min(c, a[i] + a[j]);
}
// no progress can be made
if (l == h) {
return false;
}
// compute the kernel values
double kii = kernel.value(i, i);
double kij = kernel.value(i, j);
double kjj = kernel.value(j, j);
// the second derivative of the objective function
double eta = 2*kij - kii - kjj;
// calculate the new aj
// the normal case
if (eta < 0) {
// unconstrained max
aj = a[j] - yj * (ei - ej) / eta;
// clip it
if (aj < l) {
aj = l;
} else if (aj > h) {
aj = h;
}
} else {
// the abnormal, case
// actually calculate the objective function
// at aj = l, and aj = h
double fiold = ei + yi;
double fjold = ej + yj;
double vi = fiold + b - yi*a[i]*kii - yj*a[j]*kij;
double vj = fjold + b - yi*a[i]*kij - yj*a[j]*kjj;
double fl = a[i] + s*a[j] - s*l;
double fh = a[i] + s*a[j] - s*h;
double objl = fl + l - .5*kii*fl*fl - .5*kjj*l*l - s*kij*fl*l - yi*fl*vi - yj*l*vj;
double objh = fh + h - .5*kii*fh*fh - .5*kjj*h*h - s*kij*fh*h - yi*fh*vi - yj*h*vj;
if (objl > objh + EPS) {
aj = l;
} else if (objl < objh - EPS) {
aj = h;
} else {
aj = a[j];
}
}
// make aj zero or c if it is close to it
if (aj < ZERO) {
aj = 0;
} else if (aj > c - ZERO) {
aj = c;
}
// if there's no progress
if (Math.abs(aj - a[j]) < EPS*(aj + a[j] + EPS)) {
return false;
}
// set the ai value
ai = a[i] + s*(a[j] - aj);
// make ai zero or c if it is close to it
if (ai < ZERO) {
ai = 0;
} else if (ai > c - ZERO) {
ai = c;
}
// calculate the new threshold
if (ai > 0 && ai < c) {
// ai is not bounded
bnew = ei + yi*(ai - a[i])*kii + yj*(aj - a[j])*kij + b;
} else if (aj > 0 && aj < c) {
// aj is not bounded
bnew = ej + yi*(ai - a[i])*kij + yj*(aj - a[j])*kjj + b;
} else {
// all values in the range are valid, use the middle
double bi = ei + yi*(ai - a[i])*kii + yj*(aj - a[j])*kij + b;
double bj = ej + yi*(ai - a[i])*kij + yj*(aj - a[j])*kjj + b;
bnew = (bi + bj) / 2;
}
// i and j are either bound now or
// should have their error set to zero in the cache
if (ai > 0 && ai < c) {
e[i] = 0;
}
if (aj > 0 && aj < c) {
e[j] = 0;
}
// the deltas
double ti = yi*(ai - a[i]);
double tj = yj*(aj - a[j]);
double tb = b - bnew;
// update the linear vector if needed
if (w != null) {
w = examples.get(i).getData().times(ti).plus(w);
w = examples.get(j).getData().times(tj).plus(w);
}
// update the error cache
// for non bound examples not in the cache
for (int k = 0; k < e.length; k++) {
if (k != i && k != j && !isBound(k)) {
e[k] += ti*kernel.value(i,k) + tj*kernel.value(j,k) + tb;
}
}
// finally, set the a values and the threshold
b = bnew;
a[i] = ai;
a[j] = aj;
return true;
}
/**
* Check if an index is bound
* @param i the index to check
* @return true if it is
*/
private final boolean isBound(int i) {
return a[i] <= 0 || a[i] >= c;
}
/**
* Calculate (or look up in the cached)
* the error for an example
* @param i the example to look up the error for
* @return the error
*/
private final double error(int i) {
// if it's not bound we use the error cache
if (!isBound(i)) {
return e[i];
} else {
return evaluate(i) - examples.get(i).getLabel().getPlusMinus();
}
}
/**
* Evaluate the support vector machine for an example
* @param i the example to evaluate for
* @return the evaulated value
*/
private final double evaluate(int i) {
// quick linear case
if (w != null) {
return examples.get(i).getData().dotProduct(w) - b;
}
// non linear slow case
double result = 0;
for (int j = 0; j < a.length; j++) {
if (a[j] != 0) {
result += examples.get(j).getLabel().getPlusMinus()
* a[j] * kernel.value(i, j);
}
}
result -= b;
return result;
}
/**
* @see java.lang.Object#toString()
*/
public String toString() {
String ret = "b = " + b + "\n";
ret += "kernel = " + kernel + "\n";
ret += examples.toString();
return ret;
}
}