package ch.unibe.scg.cells.benchmarks;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.Scanner;
import java.util.Set;
/** Support Vector Machine. */
public class SVM {
/** Implements the simple vector logic where each component is a real number. */
private static class RealVector {
double[] w;
RealVector(double[] w) {
this.w = w;
}
RealVector(RealVector other) {
this.w = new double[other.getDimension()];
for (int i = 0; i < w.length; i++) {
w[i] = other.w[i];
}
}
RealVector(int dim) {
this.w = new double[dim];
}
RealVector(double fill, int count) {
this.w = new double[count];
for (int i = 0; i < count; i++) {
w[i] = fill;
}
}
/** Adds a vector to this vector. Assumes equal dimensions. */
void add(RealVector other) {
double[] u = other.getFeatures();
for (int i = 0; i < u.length; ++i) {
w[i] += u[i];
}
}
RealVector add(double value) {
add(new RealVector(value, getDimension()));
return this;
}
/** Subtracts a vector from this vector. Assumes equal dimensions. */
void subtract(RealVector other) {
double[] u = other.getFeatures();
for (int i = 0; i < u.length; ++i) {
w[i] -= u[i];
}
}
/** Dot-product between two vectors. Assumes equal dimensions. */
double dotProduct(RealVector other) {
double result = 0.0;
double[] u = other.getFeatures();
for (int i = 0; i < u.length; ++i) {
result += u[i] * this.w[i];
}
return result;
}
/** Scales the coefficients of this vector by some real factor. */
RealVector scaleThis(double factor) {
for (int i = 0; i < this.w.length; ++i) {
this.w[i] *= factor;
}
return this;
}
/** L2 norm of the vector. */
double getNorm() {
return Math.sqrt(dotProduct(this));
}
double[] getFeatures() {
return this.w;
}
int getDimension() {
return this.w.length;
}
@Override
public String toString() {
StringBuffer sb = new StringBuffer();
for (int i = 0; i < this.w.length; ++i) {
sb.append(w[i] + " ");
}
return sb.toString();
}
double average() {
double ret = 0;
for (int i = 0; i < w.length; i++) {
ret += w[i];
}
int total = Math.max(1, w.length);
return ret / total;
}
}
/** Represents a training instance. */
static class TrainingInstance {
private RealVector features;
private int label;
public TrainingInstance(RealVector features, int label) {
this.features = features;
this.label = label;
}
/**
* Instantiates the training instance from a string.
* Supposes that the instance is given as a series of doubles and
* that the last element is the label. To avoid precision problems,
* the label is considered 1 if the last coefficient is > 0.5, -1 otherwise.
*/
TrainingInstance(String s) {
List<Double> parsedInput = new LinkedList<>();
try (Scanner sc = new Scanner(s)) {
while (sc.hasNextDouble()) {
parsedInput.add(sc.nextDouble());
}
}
// Last element is always the label.
int n = parsedInput.size() - 1;
// Convert the tokens to feature vector and label.
double [] coef = new double[n];
int cnt = 0;
for (Double c : parsedInput) {
if (cnt < n) {
coef[cnt++] = c;
} else {
this.label = c > 0.5 ? 1 : -1;
}
}
this.features = new RealVector(coef);
}
RealVector getFeatures() {
return features;
}
int getLabel() {
return label;
}
int getFeatureCount() {
return features.getDimension();
}
}
static class BatchSVM {
//some reasonable start points, gathered from simpler svm classifier
final private double[] defaultWeights = new double[] {
-1.1466, -2.2231, -0.2428 , -0.8678, 2.0886, 0.0985,
-2.0055, -0.0099, -0.0609, -0.9270, -1.5213, 0.5135,
-0.2026, -0.4987, 0.5676, -0.8285, 0.6187, -0.0617,
-0.2748, -2.0780, -1.1740, 2.2460, 0.8605, -0.1293,
-0.0787, 1.1157, -0.4567, 1.1253, -2.4983, 0.6353,
1.4149, -0.7522, 0.8811, -2.4835, 4.6290, -0.8594
};
final private List<TrainingInstance> trainingSet;
double kSmall;
int subsampleSize;
double epsilon = 0.0001;
Random rand = new Random();
BatchSVM( List<TrainingInstance> trainingSet) {
this.trainingSet = trainingSet;
kSmall = 0.02;
subsampleSize = (int)Math.round(kSmall * trainingSet.size());
}
RealVector train(int maxIter) {
RealVector w = new RealVector(defaultWeights);
for (int i = 0; i < maxIter; i++) {
TrainingInstance[] batch = createBatch(selectSubset(subsampleSize), w);
RealVector grad = batchGradient(logisticLoss(batch, w), batch);
w.subtract(grad);
if(grad.getNorm() < epsilon) {
break;
}
}
return w;
}
private TrainingInstance[] createBatch(TrainingInstance[] subset, RealVector weights) {
TrainingInstance[] misclassified = misclassified(subset, weights);
TrainingInstance[] ret = new TrainingInstance[subset.length + 3* misclassified.length];
for (int i = 0; i < ret.length; i++) {
if (i < subset.length) {
ret[i] = subset[i];
} else {
ret[i] = misclassified[(i - subset.length) % misclassified.length];
}
}
return ret;
}
/** Calculates the logloss for every object in batch. */
private double[] logisticLoss(TrainingInstance[] batch, RealVector weights)
{
double[] result = new double[batch.length];
double[] raw = new double[batch.length];
for (int i = 0; i < batch.length; i++) {
RealVector cur = batch[i].getFeatures();
raw[i] = cur.dotProduct(weights);
}
for (int i = 0; i < result.length; i++) {
double exp = Math.exp(raw[i]);
result[i] = exp / (1 + exp);
}
return (result);
}
private RealVector batchGradient(double[] logloss, TrainingInstance[] batch)
{
int dimensions = batch[0].getFeatureCount();
int batchSize = batch.length;
RealVector toReplicate = new RealVector(logloss);
RealVector labels = new RealVector(batchSize);
for (int i = 0 ; i < batchSize; i++)
{
labels.getFeatures()[i] = batch[i].getLabel();
}
labels.add(1).scaleThis(0.5);
toReplicate.subtract(labels);
RealVector[] repmat = new RealVector[dimensions];
for (int i = 0; i < dimensions; i++) {
repmat[i] = new RealVector(toReplicate);
}
for (int i = 0; i < dimensions; i++) {
for(int j = 0; j < batchSize; j++) {
repmat[i].w[j] *= batch[j].getFeatures().w[i];
}
}
RealVector result = new RealVector(dimensions);
for (int i = 0; i < dimensions;i++) {
result.w[i] = repmat[i].average();
}
return result;
}
private TrainingInstance[] misclassified(TrainingInstance[] subset, RealVector weights) {
List<TrainingInstance> result = new ArrayList<>();
for (TrainingInstance ti : subset) {
int predictedClass = classify(ti, weights);
if ((predictedClass == +1) && (ti.getLabel() == -1)) {
result.add(ti);
}
}
return result.toArray(new TrainingInstance[result.size()]);
}
private int classify(TrainingInstance ti, RealVector w) {
RealVector features = ti.getFeatures();
double result = features.dotProduct(w);
if (result >= 0) {
return 1;
}
return -1;
}
private TrainingInstance[] selectSubset(int k) {
TrainingInstance[] result = new TrainingInstance[k];
Set<Integer> visited = new HashSet<>();
for (int i = 0; i < k; i++) {
while (true) {
int nextIdx = rand.nextInt(k);
if (visited.add(nextIdx)) {
result[i] = trainingSet.get(nextIdx);
break;
}
}
}
return result;
}
}
// Hyperplane weights.
RealVector weights;
private SVM(RealVector weights) {
this.weights = weights;
}
/** Trains SVM with a list of training instances, and with given maximum number of iterations. */
static SVM trainSVM(List<TrainingInstance> trainingSet, int iterations) {
return new SVM(new BatchSVM(trainingSet).train(iterations));
}
/** Instantiates SVM from weights given as a string. */
SVM(String input) {
List<Double> ll = new LinkedList<>();
try (Scanner sc = new Scanner(input)) {
while(sc.hasNext()) {
double coef = sc.nextDouble();
ll.add(coef);
}
}
double[] w = new double[ll.size()];
int cnt = 0;
for (Double coef : ll) {
w[cnt++] = coef;
}
this.weights = new RealVector(w);
}
/** Instantiates the SVM model as the average model of the input SVMs. */
SVM(List<SVM> svmList) {
int dim = svmList.get(0).getDimension();
RealVector w = new RealVector(dim);
for (SVM svm : svmList) {
w.add(svm.getWeights());
}
this.weights = w.scaleThis(1.0 / svmList.size());
}
int getDimension() {
return weights.getDimension();
}
/** Given a training instance it returns the result of sign(weights'instanceFeatures). */
int classify(TrainingInstance ti) {
RealVector features = ti.getFeatures();
double result = features.dotProduct(weights);
if (result >= 0) {
return 1;
}
return -1;
}
RealVector getWeights() {
return this.weights;
}
@Override
public String toString() {
return weights.toString();
}
}