/*
* File: SimplifiedSequentialMinimalOptimization.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry Learning Core
*
* Copyright July 19, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
*/
package gov.sandia.cognition.learning.algorithm.svm;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.KernelBinaryCategorizer;
import gov.sandia.cognition.learning.function.kernel.Kernel;
import gov.sandia.cognition.learning.function.kernel.KernelContainer;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.Randomized;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.Random;
/**
* This is a simplified version of the Sequential Minimization Algorithm (SMO)
* that was used as a stepping-stone in the full SMO implementation.
*
* @author Justin Basilico
* @since 3.1
* @see SequentialMinimalOptimization
*/
@PublicationReference(
title="The Simplified SMO Algorithm",
author="Andrew Ng",
year=2009,
type=PublicationType.WebPage,
url="http://www.stanford.edu/class/cs229/materials/smo.pdf")
public class SimplifiedSequentialMinimalOptimization<InputType>
extends AbstractAnytimeSupervisedBatchLearner<InputType, Boolean, KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>>>
implements KernelContainer<InputType>, Randomized
{
public static final int DEFAULT_MAX_ITERATIONS = 1000;
public static final int DEFAULT_MAX_STEPS_WITHOUT_CHANGE = 10;
/** The default maximum penalty is infinite, which means that it is
* hard-assignment. */
public static final double DEFAULT_MAX_PENALTY = Double.POSITIVE_INFINITY;
/** The default error tolerance is 0.001, which is what was recommended in
* the original Sequential Minimal Optimization paper. */
public static final double DEFAULT_ERROR_TOLERANCE = 0.001;
/** The default effective value for zero is {@value}. */
public static final double DEFAULT_EFFECTIVE_ZERO = 1.0e-10;
/** The kernel to use. */
private Kernel<? super InputType> kernel;
private double maxPenalty;
private double errorTolerance;
private int maxStepsWithoutChange;
private double effectiveZero;
private Random random;
/** The result categorizer. */
private transient KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>> result;
private transient ArrayList<InputOutputPair<? extends InputType, Boolean>> dataList;
private transient int dataSize;
/** The number of items changed on the most recent iteration. */
private transient int changeCount;
private transient int stepsWithoutChange;
/** The mapping of weight objects to non-zero weighted examples
* (support vectors). */
private transient LinkedHashMap<Integer, DefaultWeightedValue<InputType>> supportsMap;
public SimplifiedSequentialMinimalOptimization()
{
this(null, DEFAULT_MAX_PENALTY, DEFAULT_ERROR_TOLERANCE,
DEFAULT_MAX_STEPS_WITHOUT_CHANGE, DEFAULT_EFFECTIVE_ZERO,
DEFAULT_MAX_ITERATIONS, new Random());
}
public SimplifiedSequentialMinimalOptimization(
Kernel<? super InputType> kernel,
final double maxPenalty,
double errorTolerance,
int maxStepsWithoutChange,
double effectiveZero,
final int maxIterations,
Random random)
{
super(maxIterations);
this.setKernel(kernel);
this.setMaxPenalty(maxPenalty);
this.setErrorTolerance(errorTolerance);
this.setMaxStepsWithoutChange(maxStepsWithoutChange);
this.setEffectiveZero(effectiveZero);
this.setRandom(random);
}
@Override
protected boolean initializeAlgorithm()
{
this.result = null;
if (this.getData() == null)
{
// Error: No data to learn on.
return false;
}
this.dataList = new ArrayList<InputOutputPair<? extends InputType, Boolean>>(
this.getData().size());
int positives = 0;
for (InputOutputPair<? extends InputType, Boolean> example : this.getData())
{
if (example != null && example.getInput() != null && example.getOutput() != null)
{
this.dataList.add(example);
if (example.getOutput())
{
positives++;
}
}
}
this.dataSize = this.dataList.size();
if (this.dataSize <= 0)
{
// Error: No valid data to learn from.
this.dataList = null;
return false;
}
else if (positives <= 0 || positives >= this.dataSize)
{
throw new IllegalArgumentException("Data is all one category");
}
this.changeCount = this.getData().size();
this.stepsWithoutChange = 0;
this.supportsMap = new LinkedHashMap<Integer, DefaultWeightedValue<InputType>>();
// initialize alpha array to all zero
// initialize threshold to zero
this.result = new KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>>(
this.kernel, this.supportsMap.values(), 0.0);
return true;
}
@Override
protected boolean step()
{
System.out.println("Iteration: " + this.getIteration());
final double tol = this.errorTolerance;
final double C = this.maxPenalty;
this.changeCount = 0;
for (int i = 0; i < this.dataSize; i++)
{
//System.out.println();
//System.out.println(" i: " + i);
final double yI = this.getTarget(i);
final double eI = this.getSVMOutput(i) - yI;
double alphaI = this.getAlpha(i);
//System.out.println(" yi: " + yi);
//System.out.println(" Ei: " + Ei);
//System.out.println(" alphai: " + alphai);
final double yITimesEI = yI * eI;
if ( ((yITimesEI < -tol) && (alphaI < C))
|| ((yITimesEI > +tol) && (alphaI > 0)))
{
// Select a random j != i
int j = this.random.nextInt(this.dataSize - 1);
if (j >= i)
{
j += 1;
}
//for (int j = 0; j < dataSize; j++)
//{
// if (i == j) continue;
if (this.takeStep(i, j))
{
changeCount++;
}
}
//}
}
/*
System.out.println("Change count: " + changeCount);
System.out.println("Result " + result);
for (WeightedValue<?> support : result.getExamples())
{
System.out.println(" " + support.getWeight() + " " + support.getValue());
}
System.out.println("Bias: " + result.getBias());
*/
if (this.changeCount <= 0)
{
this.stepsWithoutChange++;
}
else
{
this.stepsWithoutChange = 0;
}
return this.stepsWithoutChange < this.maxStepsWithoutChange;
}
private boolean takeStep(
final int i,
final int j)
{
if (i == j)
{
// This is a sanity check. It cannot take a step if the two
// examples are exactly the same.
return false;
}
final double C = this.maxPenalty;
final double epsilon = this.effectiveZero;
final double CMinusEpsilon = C - epsilon;
final double yI = this.getTarget(i);
final double eI = this.getSVMOutput(i) - yI;
final double oldAlphaI = this.getAlpha(i);
// double alphaI = this.getAlpha(i);
final double yJ = this.getTarget(j);
final double eJ = this.getSVMOutput(j) - yJ;
final double oldAlphaJ = this.getAlpha(j);
// double alphaJ = this.getAlpha(j);
//System.out.println(" i: " + i);
//System.out.println(" yi: " + yi);
//System.out.println(" Ei: " + Ei);
//System.out.println(" alphai: " + alphai);
//System.out.println(" j: " + j);
//System.out.println(" Ej: " + Ej);
//System.out.println(" yj: " + yj);
//System.out.println(" alphaj: " + alphaj);
// Compute the lower and upper bounds to solve for new values of
// alphaI and alphaJ.
final double lowerBound;
final double upperBound;
if (yI != yJ)
{
final double alphaJMinusAlphaI = oldAlphaJ - oldAlphaI;
lowerBound = Math.max(0, alphaJMinusAlphaI);
upperBound = Math.min(C, alphaJMinusAlphaI + C);
}
else
{
final double alphaIPlusAlphaJ = oldAlphaI + oldAlphaJ;
lowerBound = Math.max(0, alphaIPlusAlphaJ - C);
upperBound = Math.min(C, alphaIPlusAlphaJ);
}
//System.out.println(" L: " + L);
//System.out.println(" H: " + H);
if (lowerBound >= upperBound)
{
return false;
}
// Evaluate the kernels between the values, using the property that by
// kernel symmetry: k(i,j) == k(j,i)
final double kII = this.evaluateKernel(i, i);
final double kIJ = this.evaluateKernel(i, j);
final double kJI = kIJ;
final double kJJ = this.evaluateKernel(j, j);
final double eta = kIJ + kJI - kII - kJJ;
//System.out.println(" eta: " + eta);
if (eta >= 0.0)
{
return false;
}
double newAlphaJ = oldAlphaJ - (yJ * (eI - eJ)) / eta;
if (newAlphaJ <= lowerBound)
{
newAlphaJ = lowerBound;
}
else if (newAlphaJ >= upperBound)
{
newAlphaJ = upperBound;
}
// If the new alpha is close enough to 0.0 or the maximum alpha, just
// set it to that value.
if (newAlphaJ < epsilon)
{
newAlphaJ = 0.0;
}
else if (newAlphaJ > CMinusEpsilon)
{
newAlphaJ = C;
}
//System.out.println(" alphajnew: " + alphaj);
if (Math.abs(newAlphaJ - oldAlphaJ) < epsilon)
{
return false;
}
double newAlphaI = oldAlphaI + yI * yJ * (oldAlphaJ - newAlphaJ);
// If the new alpha is close enough to 0.0 or the maximum alpha, just
// set it to that value.
if (newAlphaI < epsilon)
{
newAlphaI = 0.0;
}
else if (newAlphaI > CMinusEpsilon)
{
newAlphaI = C;
}
final double oldBias = this.getBias();
final double b1 = oldBias - eI
- yI * (newAlphaI - oldAlphaI) * kII
- yJ * (newAlphaJ - oldAlphaJ) * kIJ;
final double b2 = oldBias - eJ
- yI * (newAlphaI - oldAlphaI) * kJI
- yJ * (newAlphaJ - oldAlphaJ) * kJJ;
final double newBias;
if (newAlphaI > epsilon && newAlphaI < CMinusEpsilon)
{
newBias = b1;
}
else if (newAlphaJ > epsilon && newAlphaJ < CMinusEpsilon)
{
newBias = b2;
}
else
{
newBias = (b1 + b2) / 2.0;
}
//System.out.println(" alphai: " + alphai);
//System.out.println(" alphaj: " + alphaj);
//System.out.println(" b: " + b);
this.setAlpha(i, newAlphaI);
this.setAlpha(j, newAlphaJ);
this.setBias(newBias);
return true;
}
@Override
protected void cleanupAlgorithm()
{
this.dataList = null;
this.supportsMap = null;
}
private double evaluateKernel(
final int i,
final int j)
{
return this.kernel.evaluate(this.getPoint(i), this.getPoint(j));
}
private double getSVMOutput(
final InputType input)
{
return this.result.evaluateAsDouble(input);
}
private double getSVMOutput(
final int i)
{
return this.getSVMOutput(this.getPoint(i));
}
private InputType getPoint(
final int i)
{
return this.dataList.get(i).getInput();
}
private double getTarget(
final int i)
{
return this.dataList.get(i).getOutput() ? +1.0 : -1.0;
}
private double getAlpha(
final int i)
{
final DefaultWeightedValue<InputType> support = this.supportsMap.get(i);
if (support == null)
{
return 0.0;
}
else
{
// The weight is the label (+1 or -1) times alpha. Alpha is always
// greater than zero, so we just take the absolute value of the
// weight to get it.
return Math.abs(support.getWeight());
}
}
private void setAlpha(
int i,
double alpha)
{
if (alpha == 0.0)
{
this.supportsMap.remove(i);
}
else
{
// The weight is the label times alpha.
final double weight = this.getTarget(i) * alpha;
DefaultWeightedValue<InputType> support = this.supportsMap.get(i);
if (support == null)
{
support = new DefaultWeightedValue<InputType>(
this.getPoint(i), weight);
supportsMap.put(i, support);
}
else
{
support.setWeight(weight);
}
}
}
private double getBias()
{
return this.result.getBias();
}
private void setBias(
final double b)
{
this.result.setBias(b);
}
public KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>> getResult()
{
return this.result;
}
public Kernel<? super InputType> getKernel()
{
return kernel;
}
public void setKernel(
final Kernel<? super InputType> kernel)
{
this.kernel = kernel;
}
public double getMaxPenalty()
{
return maxPenalty;
}
public void setMaxPenalty(
final double maxPenalty)
{
if (maxPenalty <= 0.0)
{
throw new IllegalArgumentException("maxPenalty must be positive.");
}
this.maxPenalty = maxPenalty;
}
public double getErrorTolerance()
{
return errorTolerance;
}
public void setErrorTolerance(
final double errorTolerance)
{
if (errorTolerance < 0.0)
{
throw new IllegalArgumentException(
"errorTolerance cannot be negative.");
}
this.errorTolerance = errorTolerance;
}
public int getMaxStepsWithoutChange()
{
return maxStepsWithoutChange;
}
public void setMaxStepsWithoutChange(
final int maxStepsWithoutChange)
{
if (maxStepsWithoutChange <= 0)
{
throw new IllegalArgumentException(
"maxStepsWithoutChange must be positive");
}
this.maxStepsWithoutChange = maxStepsWithoutChange;
}
public double getEffectiveZero()
{
return this.effectiveZero;
}
public void setEffectiveZero(
final double effectiveZero)
{
if (effectiveZero < 0.0)
{
throw new IllegalArgumentException(
"effectiveZero cannot be negative.");
}
this.effectiveZero = effectiveZero;
}
public Random getRandom()
{
return this.random;
}
public void setRandom(
final Random random)
{
this.random = random;
}
}