/*******************************************************************************
* Copyright (c) 2010 Haifeng Li
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
package smile.classification;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.DoubleArrayList;
import smile.math.Math;
import smile.math.SparseArray;
import smile.math.kernel.LinearKernel;
import smile.math.kernel.MercerKernel;
import smile.util.MulticoreExecutor;
/**
* Support vector machines for classification. The basic support vector machine
* is a binary linear classifier which chooses the hyperplane that represents
* the largest separation, or margin, between the two classes. If such a
* hyperplane exists, it is known as the maximum-margin hyperplane and the
* linear classifier it defines is known as a maximum margin classifier.
* <p>
* If there exists no hyperplane that can perfectly split the positive and
* negative instances, the soft margin method will choose a hyperplane
* that splits the instances as cleanly as possible, while still maximizing
* the distance to the nearest cleanly split instances.
* <p>
* The nonlinear SVMs are created by applying the kernel trick to
* maximum-margin hyperplanes. The resulting algorithm is formally similar,
* except that every dot product is replaced by a nonlinear kernel function.
* This allows the algorithm to fit the maximum-margin hyperplane in a
* transformed feature space. The transformation may be nonlinear and
* the transformed space be high dimensional. For example, the feature space
* corresponding Gaussian kernel is a Hilbert space of infinite dimension.
* Thus though the classifier is a hyperplane in the high-dimensional feature
* space, it may be nonlinear in the original input space. Maximum margin
* classifiers are well regularized, so the infinite dimension does not spoil
* the results.
* <p>
* The effectiveness of SVM depends on the selection of kernel, the kernel's
* parameters, and soft margin parameter C. Given a kernel, best combination
* of C and kernel's parameters is often selected by a grid-search with
* cross validation.
* <p>
* The dominant approach for creating multi-class SVMs is to reduce the
* single multi-class problem into multiple binary classification problems.
* Common methods for such reduction is to build binary classifiers which
* distinguish between (i) one of the labels to the rest (one-versus-all)
* or (ii) between every pair of classes (one-versus-one). Classification
* of new instances for one-versus-all case is done by a winner-takes-all
* strategy, in which the classifier with the highest output function assigns
* the class. For the one-versus-one approach, classification
* is done by a max-wins voting strategy, in which every classifier assigns
* the instance to one of the two classes, then the vote for the assigned
* class is increased by one vote, and finally the class with most votes
* determines the instance classification.
*
* <h2>References</h2>
* <ol>
* <li> Christopher J. C. Burges. A Tutorial on Support Vector Machines for Pattern Recognition. Data Mining and Knowledge Discovery 2:121-167, 1998.</li>
* <li> John Platt. Sequential Minimal Optimization: A Fast Algorithm for Training Support Vector Machines.</li>
* <li> Rong-En Fan, Pai-Hsuen, and Chih-Jen Lin. Working Set Selection Using Second Order Information for Training Support Vector Machines. JMLR, 6:1889-1918, 2005.</li>
* <li> Antoine Bordes, Seyda Ertekin, Jason Weston and Leon Bottou. Fast Kernel Classifiers with Online and Active Learning, Journal of Machine Learning Research, 6:1579-1619, 2005.</li>
* <li> Tobias Glasmachers and Christian Igel. Second Order SMO Improves SVM Online and Active Learning.</li>
* <li> Chih-Chung Chang and Chih-Jen Lin. LIBSVM: a Library for Support Vector Machines.</li>
* </ol>
*
* @param <T> the type of input object.
*
* @author Haifeng Li
*/
public class SVM <T> implements OnlineClassifier<T>, SoftClassifier<T>, Serializable {
private static final long serialVersionUID = 1L;
private static final Logger logger = LoggerFactory.getLogger(SVM.class);
/**
* The type of multi-class SVMs.
*/
public enum Multiclass {
/**
* One vs one classification.
*/
ONE_VS_ONE,
/**
* One vs all classification.
*/
ONE_VS_ALL,
};
/**
* The default value for K_tt + K_ss - 2 * K_ts if kernel is not positive.
*/
private static final double TAU = 1E-12;
/**
* Learned two-class support vector machine.
*/
private LASVM svm;
/**
* Learned multi-class support vector machines.
*/
private List<LASVM> svms;
/**
* The kernel function.
*/
private MercerKernel<T> kernel;
/**
* The dimensionality of instances. Useful for sparse arrays.
*/
private int p;
/**
* The number of classes;
*/
private int k;
/**
* The strategy for multi-class classification.
*/
private Multiclass strategy = Multiclass.ONE_VS_ONE;
/**
* The class weight.
*/
private double[] wi;
/**
* The tolerance of convergence test.
*/
private double tol = 1E-3;
/**
* Trainer for support vector machines.
*/
public static class Trainer<T> extends ClassifierTrainer<T> {
/**
* The kernel function.
*/
private MercerKernel<T> kernel;
/**
* The number of classes;
*/
private int k;
/**
* The class weight. Must be positive. Sets the parameter C of class i
* to weight[i] * C.
*/
private double[] weight;
/**
* The soft margin penalty parameter for positive samples.
*/
private double Cp = 1.0;
/**
* The soft margin penalty parameter for negative samples.
*/
private double Cn = 1.0;
/**
* The strategy for multi-class classification.
*/
private Multiclass strategy = Multiclass.ONE_VS_ONE;
/**
* The tolerance of convergence test.
*/
private double tol = 1E-3;
/**
* The number of epochs of stochastic learning.
*/
private int epochs = 2;
/**
* Constructor of trainer for binary SVMs.
* @param kernel the kernel function.
* @param C the soft margin penalty parameter.
*/
public Trainer(MercerKernel<T> kernel, double C) {
if (C < 0) {
throw new IllegalArgumentException("Invalid soft margin penalty: " + C);
}
this.kernel = kernel;
this.Cp = C;
this.Cn = C;
this.k = 2;
}
/**
* Constructor of trainer for binary SVMs.
* @param kernel the kernel function.
* @param Cp the soft margin penalty parameter for positive instances.
* @param Cn the soft margin penalty parameter for negative instances.
*/
public Trainer(MercerKernel<T> kernel, double Cp, double Cn) {
if (Cp < 0) {
throw new IllegalArgumentException("Invalid postive instance soft margin penalty: " + Cp);
}
if (Cn < 0) {
throw new IllegalArgumentException("Invalid negative instance soft margin penalty: " + Cn);
}
this.kernel = kernel;
this.Cp = Cp;
this.Cn = Cn;
this.k = 2;
}
/**
* Constructor of trainer for multi-class SVMs.
* @param kernel the kernel function.
* @param C the soft margin penalty parameter.
* @param k the number of classes.
*/
public Trainer(MercerKernel<T> kernel, double C, int k, Multiclass strategy) {
if (C < 0) {
throw new IllegalArgumentException("Invalid soft margin penalty: " + C);
}
if (k < 3) {
throw new IllegalArgumentException("Invalid number of classes: " + k);
}
this.kernel = kernel;
this.Cp = C;
this.Cn = C;
this.k = k;
this.strategy = strategy;
}
/**
* Constructor of trainer for multi-class SVMs.
* @param kernel the kernel function.
* @param C the soft margin penalty parameter.
* @param weight class weight. Must be positive. The soft margin penalty
* of class i will be weight[i] * C.
*/
public Trainer(MercerKernel<T> kernel, double C, double[] weight, Multiclass strategy) {
if (C < 0) {
throw new IllegalArgumentException("Invalid soft margin penalty: " + C);
}
if (weight.length < 3) {
throw new IllegalArgumentException("Invalid number of classes: " + weight.length);
}
this.kernel = kernel;
this.Cp = C;
this.Cn = C;
this.k = weight.length;
this.weight = weight;
this.strategy = strategy;
}
/**
* Sets the tolerance of convergence test.
*
* @param tol the tolerance of convergence test.
*/
public Trainer<T> setTolerance(double tol) {
if (tol <= 0.0) {
throw new IllegalArgumentException("Invalid tolerance of convergence test:" + tol);
}
this.tol = tol;
return this;
}
/**
* Sets the number of epochs of stochastic learning.
* @param epochs the number of epochs of stochastic learning.
*/
public Trainer<T> setNumEpochs(int epochs) {
if (epochs < 1) {
throw new IllegalArgumentException("Invalid numer of epochs of stochastic learning:" + epochs);
}
this.epochs = epochs;
return this;
}
@Override
public SVM<T> train(T[] x, int[] y) {
return train(x, y, null);
}
/**
* Learns a SVM classifier with given training data.
* @param x training instances.
* @param y training labels in [0, k), where k is the number of classes.
* @param weight instance weight. Must be positive. The soft margin penalty
* for instance i will be weight[i] * C.
* @return trained SVM classifier
*/
public SVM<T> train(T[] x, int[] y, double[] weight) {
SVM<T> svm = null;
if (k == 2) {
svm = new SVM<>(kernel, Cp, Cn);
} else {
if (this.weight == null) {
svm = new SVM<>(kernel, Cp, k, strategy);
} else {
svm = new SVM<>(kernel, Cp, this.weight, strategy);
}
}
svm.setTolerance(tol);
for (int i = 1; i <= epochs; i++) {
svm.learn(x, y, weight);
}
svm.finish();
return svm;
}
}
/**
* Online Two-class SVM.
*/
final class LASVM implements Serializable {
private static final long serialVersionUID = 1L;
/**
* Support vector.
*/
class SupportVector implements Serializable {
private static final long serialVersionUID = 1L;
/**
* Support vector.
*/
T x;
/**
* Support vector label.
*/
int y;
/**
* Lagrangian multiplier of support vector.
*/
double alpha;
/**
* Gradient y - Kα.
*/
double g;
/**
* Lower bound of alpha.
*/
double cmin;
/**
* Upper bound of alpha.
*/
double cmax;
/**
* Kernel value k(x, x)
*/
double k;
/**
* Kernel value cache.
*/
DoubleArrayList kcache;
}
/**
* The soft margin penalty parameter for positive samples.
*/
private double Cp = 1.0;
/**
* The soft margin penalty parameter for negative samples.
*/
private double Cn = 1.0;
/**
* Support vectors.
*/
List<SupportVector> sv = new ArrayList<>();
/**
* Weight vector for linear SVM.
*/
double[] w;
/**
* Threshold of decision function.
*/
double b = 0.0;
/**
* The number of support vectors.
*/
int nsv = 0;
/**
* The number of bounded support vectors.
*/
int nbsv = 0;
/**
* Platt Scaling for estimating posterior probabilities.
*/
PlattScaling platt;
/**
* If minimax is called after update.
*/
transient boolean minmaxflag = false;
/**
* Most violating pair.
* argmin gi of m_i < alpha_i
* argmax gi of alpha_i < M_i
* where m_i = min{0, y_i * C}
* and M_i = max{0, y_i * C}
*/
transient SupportVector svmin = null;
transient SupportVector svmax = null;
transient double gmin = Double.MAX_VALUE;
transient double gmax = -Double.MAX_VALUE;
/**
* Constructor.
* @param Cp the soft margin penalty parameter for positive instances.
* @param Cn the soft margin penalty parameter for negative instances.
*/
LASVM(double Cp, double Cn) {
this.Cp = Cp;
this.Cn = Cn;
}
/**
* Trains the SVM with the given dataset for one epoch. The caller may
* call this method multiple times to obtain better accuracy although
* one epoch is usually sufficient. After calling this method sufficient
* times (usually 1 or 2), the users should call {@link #finalize()}
* to further process support vectors.
*/
void learn(T[] x, int[] y) {
learn(x, y, null);
}
/**
* Trains the SVM with the given dataset for one epoch. The caller may
* call this method multiple times to obtain better accuracy although
* one epoch is usually sufficient. After calling this method sufficient
* times (usually 1 or 2), the users should call {@link #finalize()}
* to further process support vectors.
*/
void learn(T[] x, int[] y, double[] weight) {
if (p == 0 && kernel instanceof LinearKernel) {
if (x instanceof double[][]) {
double[] x0 = (double[]) x[0];
p = x0.length;
} else if (x instanceof float[][]) {
float[] x0 = (float[]) x[0];
p = x0.length;
} else {
throw new UnsupportedOperationException("Unsupported data type for linear kernel.");
}
}
int c1 = 0, c2 = 0;
for (SupportVector v : sv) {
if (v != null) {
if (v.y > 0) c1++;
else if (v.y < 0) c2++;
}
}
// If the SVM is empty or has very few support vectors, use some
// instances as initial support vectors.
final int n = x.length;
if (c1 < 5 || c2 < 5) {
for (int i = 0; i < n; i++) {
if (y[i] == 1 && c1 < 5) {
if (weight == null) {
process(x[i], y[i]);
} else {
process(x[i], y[i], weight[i]);
}
c1++;
}
if (y[i] == -1 && c2 < 5) {
if (weight == null) {
process(x[i], y[i]);
} else {
process(x[i], y[i], weight[i]);
}
c2++;
}
if (c1 >= 5 && c2 >= 5) {
break;
}
}
}
// train SVM in a stochastic order.
int[] index = Math.permutate(n);
for (int i = 0; i < n; i++) {
if (weight == null) {
process(x[index[i]], y[index[i]]);
} else {
process(x[index[i]], y[index[i]], weight[index[i]]);
}
do {
reprocess(tol); // at least one call to reprocess
minmax();
} while (gmax - gmin > 1000);
}
}
/**
* Returns the function value after training.
*/
double predict(T x) {
double f = b;
if (kernel instanceof LinearKernel && w != null) {
if (x instanceof double[]) {
f += Math.dot(w, (double[]) x);
} else if (x instanceof SparseArray) {
for (SparseArray.Entry e : (SparseArray) x) {
f += w[e.i] * e.x;
}
} else {
throw new UnsupportedOperationException("Unsupported data type for linear kernel");
}
} else {
for (SupportVector v : sv) {
if (v != null) {
f += v.alpha * kernel.k(v.x, x);
}
}
}
return f;
}
/**
* Find support vectors with smallest (of I_up) and largest (of I_down) gradients.
*/
void minmax() {
if (!minmaxflag) {
gmin = Double.MAX_VALUE;
gmax = -Double.MAX_VALUE;
for (SupportVector v : sv) {
if (v != null) {
double gi = v.g;
double ai = v.alpha;
if (gi < gmin && ai > v.cmin) {
svmin = v;
gmin = gi;
}
if (gi > gmax && ai < v.cmax) {
svmax = v;
gmax = gi;
}
}
}
minmaxflag = true;
}
}
/**
* Sequential minimal optimization.
* @param v1 the first vector of working set.
* @param v2 the second vector of working set.
* @param epsgr the tolerance of convergence test.
*/
boolean smo(SupportVector v1, SupportVector v2, double epsgr) {
// SO working set selection
// Determine coordinate to process
if (v1 == null || v2 == null) {
if (v1 == null && v2 == null) {
minmax();
if (gmax > -gmin) {
v2 = svmax;
} else {
v1 = svmin;
}
}
if (v2 == null) {
if (v1.kcache == null) {
v1.kcache = new DoubleArrayList(sv.size());
for (SupportVector v : sv) {
if (v != null) {
v1.kcache.add(kernel.k(v1.x, v.x));
} else {
v1.kcache.add(0.0);
}
}
}
// determine imax
double km = v1.k;
double gm = v1.g;
double best = 0.0;
for (int i = 0; i < sv.size(); i++) {
SupportVector v = sv.get(i);
if (v == null) {
continue;
}
double Z = v.g - gm;
double k = v1.kcache.get(i);
double curv = km + v.k - 2.0 * k;
// double curv = 2.0 - 2.0 * k; // for Gaussian kernel only
if (curv <= 0.0) curv = TAU;
double mu = Z / curv;
if ((mu > 0.0 && v.alpha < v.cmax) || (mu < 0.0 && v.alpha > v.cmin)) {
double gain = Z * mu;
if (gain > best) {
best = gain;
v2 = v;
}
}
}
} else {
if (v2.kcache == null) {
v2.kcache = new DoubleArrayList(sv.size());
for (SupportVector v : sv) {
if (v != null) {
v2.kcache.add(kernel.k(v2.x, v.x));
} else {
v2.kcache.add(0.0);
}
}
}
// determine imin
double km = v2.k;
double gm = v2.g;
double best = 0.0;
for (int i = 0; i < sv.size(); i++) {
SupportVector v = sv.get(i);
if (v == null) {
continue;
}
double Z = gm - v.g;
double k = v2.kcache.get(i);
double curv = km + v.k - 2.0 * k;
// double curv = 2.0 - 2.0 * k; // for Gaussian kernel only
if (curv <= 0.0) curv = TAU;
double mu = Z / curv;
if ((mu > 0.0 && v.alpha > v.cmin) || (mu < 0.0 && v.alpha < v.cmax)) {
double gain = Z * mu;
if (gain > best) {
best = gain;
v1 = v;
}
}
}
}
}
if (v1 == null || v2 == null) {
return false;
}
if (v1.kcache == null) {
v1.kcache = new DoubleArrayList(sv.size());
for (SupportVector v : sv) {
if (v != null) {
v1.kcache.add(kernel.k(v1.x, v.x));
} else {
v1.kcache.add(0.0);
}
}
}
if (v2.kcache == null) {
v2.kcache = new DoubleArrayList(sv.size());
for (SupportVector v : sv) {
if (v != null) {
v2.kcache.add(kernel.k(v2.x, v.x));
} else {
v2.kcache.add(0.0);
}
}
}
// Determine curvature
double curv = v1.k + v2.k - 2 * kernel.k(v1.x, v2.x);
if (curv <= 0.0) curv = TAU;
double step = (v2.g - v1.g) / curv;
// Determine maximal step
if (step >= 0.0) {
double ostep = v1.alpha - v1.cmin;
if (ostep < step) {
step = ostep;
}
ostep = v2.cmax - v2.alpha;
if (ostep < step) {
step = ostep;
}
} else {
double ostep = v2.cmin - v2.alpha;
if (ostep > step) {
step = ostep;
}
ostep = v1.alpha - v1.cmax;
if (ostep > step) {
step = ostep;
}
}
// Perform update
v1.alpha -= step;
v2.alpha += step;
for (int i = 0; i < sv.size(); i++) {
SupportVector v = sv.get(i);
if (v != null) {
v.g -= step * (v2.kcache.get(i) - v1.kcache.get(i));
}
}
minmaxflag = false;
// optimality test
minmax();
b = (gmax + gmin) / 2;
if (gmax - gmin < epsgr) {
return false;
}
return true;
}
/**
* Process a new sample.
*/
boolean process(T x, int y) {
return process(x, y, 1.0);
}
/**
* Process a new sample.
*/
boolean process(T x, int y, double weight) {
if (y != +1 && y != -1) {
throw new IllegalArgumentException("Invalid label: " + y);
}
if (weight <= 0.0) {
throw new IllegalArgumentException("Invalid instance weight: " + weight);
}
// Compute gradient
double g = y;
DoubleArrayList kcache = new DoubleArrayList(sv.size() + 1);
if (!sv.isEmpty()) {
for (SupportVector v : sv) {
if (v != null) {
// Bail out if already in expansion?
if (v.x == x) {
return true;
}
double k = kernel.k(v.x, x);
g -= v.alpha * k;
kcache.add(k);
} else {
kcache.add(0.0);
}
}
// Decide insertion
minmax();
if (gmin < gmax) {
if ((y > 0 && g < gmin) || (y < 0 && g > gmax)) {
return false;
}
}
}
// Insert
SupportVector v = new SupportVector();
v.x = x;
v.y = y;
v.alpha = 0.0;
v.g = g;
v.k = kernel.k(x, x);
v.kcache = kcache;
if (y > 0) {
v.cmin = 0;
v.cmax = weight * Cp;
} else {
v.cmin = -weight * Cn;
v.cmax = 0;
}
int i = sv.size();
for (; i < sv.size(); i++) {
if (sv.get(i) == null) {
sv.set(i, v);
kcache.set(i, v.k);
for (int j = 0; j < sv.size(); j++) {
SupportVector v1 = sv.get(j);
if (v1 != null && v1.kcache != null) {
v1.kcache.set(i, kcache.get(j));
}
}
break;
}
}
if (i >= sv.size()) {
for (int j = 0; j < sv.size(); j++) {
SupportVector v1 = sv.get(j);
if (v1 != null && v1.kcache != null) {
v1.kcache.add(kcache.get(j));
}
}
v.kcache.add(v.k);
sv.add(v);
}
// Process
if (y > 0) {
smo(null, v, 0.0);
} else {
smo(v, null, 0.0);
}
minmaxflag = false;
return true;
}
/**
* Reprocess support vectors.
* @param epsgr the tolerance of convergence test.
*/
boolean reprocess(double epsgr) {
boolean status = smo(null, null, epsgr);
evict();
return status;
}
/**
* Call reprocess until converge.
*/
void finish() {
finish(tol);
}
/**
* Call reprocess until converge.
* @param epsgr the tolerance of convergence test.
*/
void finish(double epsgr) {
logger.info("SVM finializes the training by reprocess.");
for (int count = 1; smo(null, null, epsgr); count++) {
if (count % 1000 == 0) {
logger.info("finishing {} reprocess iterations.");
}
}
logger.info("SVM finished the reprocess.");
Iterator<SupportVector> iter = sv.iterator();
while (iter.hasNext()) {
SupportVector v = iter.next();
if (v == null) {
iter.remove();
} else if (v.alpha == 0) {
if ((v.g >= gmax && 0 >= v.cmax) || (v.g <= gmin && 0 <= v.cmin)) {
iter.remove();
}
}
}
cleanup();
if (kernel instanceof LinearKernel) {
w = new double[p];
for (SupportVector v : sv) {
if (v.x instanceof double[]) {
double[] x = (double[]) v.x;
for (int i = 0; i < w.length; i++) {
w[i] += v.alpha * x[i];
}
} else if (v.x instanceof int[]) {
int[] x = (int[]) v.x;
for (int i = 0; i < x.length; i++) {
w[x[i]] += v.alpha;
}
} else if (v.x instanceof SparseArray) {
for (SparseArray.Entry e : (SparseArray) v.x) {
w[e.i] += v.alpha * e.x;
}
}
}
}
}
/**
* After calling finish, the user should call this method
* to train Platt Scaling to estimate posteriori probabilities.
*
* @param x training samples.
* @param y training labels.
*/
void trainPlattScaling(T[] x, int[] y) {
int l = y.length;
double[] scores = new double[l];
for (int i = 0; i < l; i++) {
scores[i] = predict(x[i]);
}
platt = new PlattScaling(scores, y);
}
void evict() {
minmax();
for (int i = 0; i < sv.size(); i++) {
SupportVector v = sv.get(i);
if (v != null && v.alpha == 0) {
if ((v.g >= gmax && 0 >= v.cmax) || (v.g <= gmin && 0 <= v.cmin)) {
sv.set(i, null);
}
}
}
}
/**
* Cleanup kernel cache to free memory.
*/
void cleanup() {
nsv = 0;
nbsv = 0;
for (SupportVector v : sv) {
if (v != null) {
nsv++;
v.kcache = null;
if (v.alpha == v.cmin || v.alpha == v.cmax) {
nbsv++;
}
}
}
logger.info("{} support vectors, {} bounded\n", nsv, nbsv);
}
}
/**
* Constructor of binary SVM.
* @param kernel the kernel function.
* @param C the soft margin penalty parameter.
*/
public SVM(MercerKernel<T> kernel, double C) {
this(kernel, C, C);
}
/**
* Constructor of binary SVM.
* @param kernel the kernel function.
* @param Cp the soft margin penalty parameter for positive instances.
* @param Cn the soft margin penalty parameter for negative instances.
*/
public SVM(MercerKernel<T> kernel, double Cp, double Cn) {
if (Cp < 0.0) {
throw new IllegalArgumentException("Invalid postive instance soft margin penalty: " + Cp);
}
if (Cn < 0.0) {
throw new IllegalArgumentException("Invalid negative instance soft margin penalty: " + Cn);
}
this.kernel = kernel;
this.k = 2;
svm = new LASVM(Cp, Cn);
}
/**
* Constructor of multi-class SVM.
* @param kernel the kernel function.
* @param C the soft margin penalty parameter.
* @param k the number of classes.
*/
public SVM(MercerKernel<T> kernel, double C, int k, Multiclass strategy) {
if (C < 0.0) {
throw new IllegalArgumentException("Invalid soft margin penalty: " + C);
}
if (k < 3) {
throw new IllegalArgumentException("Invalid number of classes: " + k);
}
this.kernel = kernel;
this.k = k;
this.strategy = strategy;
if (strategy == Multiclass.ONE_VS_ALL) {
svms = new ArrayList<>(k);
for (int i = 0; i < k; i++) {
svms.add(new LASVM(C, C));
}
} else {
svms = new ArrayList<>(k * (k - 1) / 2);
for (int i = 0; i < k; i++) {
for (int j = i + 1; j < k; j++) {
svms.add(new LASVM(C, C));
}
}
}
}
/**
* Constructor of multi-class SVM.
* @param kernel the kernel function.
* @param C the soft margin penalty parameter
* @param weight class weight. Must be positive. The soft margin penalty
* of class i will be weight[i] * C.
*/
public SVM(MercerKernel<T> kernel, double C, double[] weight, Multiclass strategy) {
if (C < 0.0) {
throw new IllegalArgumentException("Invalid soft margin penalty: " + C);
}
if (weight.length < 3) {
throw new IllegalArgumentException("Invalid number of classes: " + weight.length);
}
for (int i = 0; i < weight.length; i++) {
if (weight[i] <= 0.0) {
throw new IllegalArgumentException("Invalid class weight: " + weight[i]);
}
}
this.kernel = kernel;
this.k = weight.length;
this.strategy = strategy;
this.wi = weight;
if (strategy == Multiclass.ONE_VS_ALL) {
svms = new ArrayList<>(k);
for (int i = 0; i < k; i++) {
svms.add(new LASVM(C, C));
}
} else {
svms = new ArrayList<>(k * (k - 1) / 2);
for (int i = 0; i < k; i++) {
for (int j = i + 1; j < k; j++) {
svms.add(new LASVM(weight[i]*C, weight[j]*C));
}
}
}
}
/**
* Sets the tolerance of convergence test.
*
* @param tol the tolerance of convergence test.
*/
public SVM<T> setTolerance(double tol) {
if (tol <= 0.0) {
throw new IllegalArgumentException("Invalid tolerance of convergence test:" + tol);
}
this.tol = tol;
return this;
}
@Override
public void learn(T x, int y) {
learn(x, y, 1.0);
}
/**
* Online update the classifier with a new training instance.
* Note that this method is NOT multi-thread safe.
*
* @param x training instance.
* @param y training label.
* @param weight instance weight. Must be positive. The soft margin penalty
* parameter for instance will be weight * C.
*/
public void learn(T x, int y, double weight) {
if (y < 0 || y >= k) {
throw new IllegalArgumentException("Invalid label");
}
if (weight <= 0.0) {
throw new IllegalArgumentException("Invalid instance weight: " + weight);
}
if (k == 2) {
if (y == 1) {
svm.process(x, +1, weight);
} else {
svm.process(x, -1, weight);
}
} else if (strategy == Multiclass.ONE_VS_ALL) {
if (wi != null) {
weight *= wi[y];
}
for (int i = 0; i < k; i++) {
if (y == i) {
svms.get(i).process(x, +1, weight);
} else {
svms.get(i).process(x, -1, weight);
}
}
} else {
for (int i = 0, m = 0; i < k; i++) {
for (int j = i + 1; j < k; j++, m++) {
if (y == i) {
svms.get(m).process(x, +1, weight);
} else if (y == j) {
svms.get(m).process(x, -1, weight);
}
}
}
}
}
/**
* Trains the SVM with the given dataset for one epoch. The caller may
* call this method multiple times to obtain better accuracy although
* one epoch is usually sufficient. After calling this method sufficient
* times (usually 1 or 2), the users should call {@link #finalize()}
* to further process support vectors.
*
* @param x training instances.
* @param y training labels in [0, k), where k is the number of classes.
*/
public void learn(T[] x, int[] y) {
learn(x, y, null);
}
/**
* Trains the SVM with the given dataset for one epoch. The caller may
* call this method multiple times to obtain better accuracy although
* one epoch is usually sufficient. After calling this method sufficient
* times (usually 1 or 2), the users should call {@link #finalize()}
* to further process support vectors.
*
* @param x training instances.
* @param y training labels in [0, k), where k is the number of classes.
* @param weight instance weight. Must be positive. The soft margin penalty
* parameter for instance i will be weight[i] * C.
*/
@SuppressWarnings("unchecked")
public void learn(T[] x, int[] y, double[] weight) {
if (x.length != y.length) {
throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
}
if (weight != null && x.length != weight.length) {
throw new IllegalArgumentException(String.format("The sizes of X and instance weight don't match: %d != %d", x.length, weight.length));
}
int miny = Math.min(y);
if (miny < 0) {
throw new IllegalArgumentException("Negative class label:" + miny);
}
int maxy = Math.max(y);
if (maxy >= k) {
throw new IllegalArgumentException("Invalid class label:" + maxy);
}
if (k == 2) {
int[] yi = new int[y.length];
for (int i = 0; i < y.length; i++) {
if (y[i] == 1) {
yi[i] = +1;
} else {
yi[i] = -1;
}
}
if (weight == null) {
svm.learn(x, yi);
} else {
svm.learn(x, yi, weight);
}
} else if (strategy == Multiclass.ONE_VS_ALL) {
List<TrainingTask> tasks = new ArrayList<>(k);
for (int i = 0; i < k; i++) {
int[] yi = new int[y.length];
double[] w = wi == null ? weight : new double[y.length];
for (int l = 0; l < y.length; l++) {
if (y[l] == i) {
yi[l] = +1;
} else {
yi[l] = -1;
}
if (wi != null) {
w[l] = wi[y[l]];
if (weight != null) {
w[l] *= weight[l];
}
}
}
tasks.add(new TrainingTask(svms.get(i), x, yi, w));
}
try {
MulticoreExecutor.run(tasks);
} catch (Exception e) {
e.printStackTrace();
}
} else {
List<TrainingTask> tasks = new ArrayList<>(k * (k - 1) / 2);
for (int i = 0, m = 0; i < k; i++) {
for (int j = i + 1; j < k; j++, m++) {
int n = 0;
for (int l = 0; l < y.length; l++) {
if (y[l] == i || y[l] == j) {
n++;
}
}
T[] xij = (T[]) java.lang.reflect.Array.newInstance(x.getClass().getComponentType(), n);
int[] yij = new int[n];
double[] wij = weight == null ? null : new double[n];
for (int l = 0, q = 0; l < y.length; l++) {
if (y[l] == i) {
xij[q] = x[l];
yij[q] = +1;
if (weight != null) {
wij[q] = weight[l];
}
q++;
} else if (y[l] == j) {
xij[q] = x[l];
yij[q] = -1;
if (weight != null) {
wij[q] = weight[l];
}
q++;
}
}
tasks.add(new TrainingTask(svms.get(m), xij, yij, wij));
}
}
try {
MulticoreExecutor.run(tasks);
} catch (Exception e) {
logger.error("Failed to train SVM on multi-core", e);
}
}
}
/**
* Process support vectors until converge.
*/
public void finish() {
if (k == 2) {
svm.finish();
} else {
List<ProcessTask> tasks = new ArrayList<>(svms.size());
for (LASVM s : svms) {
tasks.add(new ProcessTask(s));
}
try {
MulticoreExecutor.run(tasks);
} catch (Exception e) {
logger.error("Failed to train SVM on multi-core", e);
}
}
}
/**
* Indicates if Platt scaling is available.
* @return true if Platt Scaling is available
*/
public boolean hasPlattScaling(){
return (svm.platt != null);
}
/**
* After calling finish, the user should call this method
* to train Platt Scaling to estimate posteriori probabilities.
*
* @param x training samples.
* @param y training labels.
*/
public void trainPlattScaling(T[] x, int[] y) {
if (k == 2) {
svm.trainPlattScaling(x, y);
} else if (strategy == Multiclass.ONE_VS_ALL) {
List<PlattScalingTask> tasks = new ArrayList<>(svms.size());
for (int m = 0; m < svms.size(); m++) {
LASVM s = svms.get(m);
int l = y.length;
int[] yi = new int[l];
for (int i = 0; i < l; i++) {
if (y[i] == m)
yi[i] = +1;
else
yi[i] = -1;
}
tasks.add(new PlattScalingTask(s, x, yi));
}
try {
MulticoreExecutor.run(tasks);
} catch (Exception e) {
logger.error("Failed to train Platt Scaling on multi-core", e);
}
} else {
List<PlattScalingTask> tasks = new ArrayList<>(svms.size());
for (int i = 0, m = 0; i < k; i++) {
for (int j = i + 1; j < k; j++, m++) {
LASVM s = svms.get(m);
int l = y.length;
int[] yi = new int[l];
for (int p = 0; p < l; p++) {
if (y[p] == i)
yi[p] = +1;
else
yi[p] = -1;
}
tasks.add(new PlattScalingTask(s, x, yi));
}
}
try {
MulticoreExecutor.run(tasks);
} catch (Exception e) {
logger.error("Failed to train Platt Scaling on multi-core", e);
}
}
}
/**
* Trains a LASVM.
*/
class TrainingTask implements Callable<LASVM> {
LASVM svm;
T[] x;
int[] y;
double[] weight; // instance weight
TrainingTask(LASVM svm, T[] x, int[] y, double[] weight) {
this.svm = svm;
this.x = x;
this.y = y;
this.weight = weight;
}
@Override
public LASVM call() {
svm.learn(x, y, weight);
return svm;
}
}
/**
* Reprocess a LASVM.
*/
class ProcessTask implements Callable<LASVM> {
LASVM svm;
ProcessTask(LASVM svm) {
this.svm = svm;
}
@Override
public LASVM call() {
svm.finish();
return svm;
}
}
/**
* Train Platt Scaling.
*/
class PlattScalingTask implements Callable<LASVM> {
LASVM svm;
T[] x;
int[] y;
PlattScalingTask(LASVM svm, T[] x, int[] y) {
this.svm = svm;
this.x = x;
this.y = y;
}
@Override
public LASVM call() {
svm.trainPlattScaling(x, y);
return svm;
}
}
@Override
public int predict(T x) {
if (k == 2) {
// two class
if (svm.predict(x) > 0) {
return 1;
} else {
return 0;
}
} else if (strategy == Multiclass.ONE_VS_ALL) {
// one-vs-all
int label = 0;
double maxf = Double.NEGATIVE_INFINITY;
for (int i = 0; i < svms.size(); i++) {
double f = svms.get(i).predict(x);
if (f > maxf) {
label = i;
maxf = f;
}
}
return label;
} else {
// one-vs-one
int[] count = new int[k];
for (int i = 0, m = 0; i < k; i++) {
for (int j = i + 1; j < k; j++, m++) {
double f = svms.get(m).predict(x);
if (f > 0) {
count[i]++;
} else {
count[j]++;
}
}
}
int max = 0;
int label = 0;
for (int i = 0; i < k; i++) {
if (count[i] > max) {
max = count[i];
label = i;
}
}
return label;
}
}
/** Calculate the posterior probability. */
private double posterior(LASVM svm, double y) {
final double minProb = 1e-7;
final double maxProb = 1 - minProb;
return Math.min(Math.max(svm.platt.predict(y), minProb), maxProb);
}
@Override
public int predict(T x, double[] prob) {
if (k == 2) {
if (svm.platt == null) {
throw new UnsupportedOperationException("PlattScaling was not trained yet. Please call SVM.trainPlattScaling() first.");
}
// two class
double y = svm.predict(x);
prob[1] = posterior(svm, y);
prob[0] = 1.0 - prob[1];
if (y > 0) {
return 1;
} else {
return 0;
}
} else if (strategy == Multiclass.ONE_VS_ALL) {
// one-vs-all
int label = 0;
double maxf = Double.NEGATIVE_INFINITY;
for (int i = 0; i < svms.size(); i++) {
LASVM svm = svms.get(i);
if (svm.platt == null) {
throw new UnsupportedOperationException("PlattScaling was not trained yet. Please call SVM.trainPlattScaling() first.");
}
double f = svm.predict(x);
prob[i] = posterior(svm, f);
if (f > maxf) {
label = i;
maxf = f;
}
}
smile.math.Math.unitize1(prob);
return label;
} else {
// one-vs-one
int[] count = new int[k];
double[][] r = new double[k][k];
for (int i = 0, m = 0; i < k; i++) {
for (int j = i + 1; j < k; j++, m++) {
LASVM svm = svms.get(m);
if (svm.platt == null) {
throw new UnsupportedOperationException("PlattScaling was not trained yet. Please call SVM.trainPlattScaling() first.");
}
double f = svm.predict(x);
r[i][j] = posterior(svm, f);
r[j][i] = 1.0 - r[i][j];
if (f > 0) {
count[i]++;
} else {
count[j]++;
}
}
}
PlattScaling.multiclass(k, r, prob);
int max = 0;
int label = 0;
for (int i = 0; i < k; i++) {
if (count[i] > max) {
max = count[i];
label = i;
}
}
return label;
}
}
}